3.4_安全最佳实践

3.4 安全最佳实践

MCP服务通常会处理敏感数据和执行重要操作,因此安全性至关重要。本章将介绍MCP服务的安全最佳实践,包括访问控制、输入验证、加密和安全配置等方面。

认证与授权

实现认证机制

MCP服务应该实现强大的认证机制,以确保只有授权用户才能访问资源和执行工具。

JWT认证示例

import jwt
from datetime import datetime, timedelta
from mcp.server.fastmcp import FastMCP
from mcp.server.middleware import BaseMiddleware

# 创建MCP服务器
mcp = FastMCP("安全MCP服务")

# JWT配置
JWT_SECRET = "your-secret-key"  # 生产环境中应使用强随机密钥
JWT_ALGORITHM = "HS256"
TOKEN_EXPIRE_MINUTES = 30

# 生成JWT令牌
def create_access_token(data: dict, expires_delta: timedelta = None):
    """创建JWT访问令牌"""
    to_encode = data.copy()
    
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=TOKEN_EXPIRE_MINUTES)
    
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=JWT_ALGORITHM)
    
    return encoded_jwt

# 认证中间件
class AuthMiddleware(BaseMiddleware):
    """JWT认证中间件"""
    
    async def __call__(self, context, call_next):
        # 获取并验证令牌
        auth_header = context.scope.get("headers", {}).get("authorization")
        
        if not auth_header or not auth_header.startswith("Bearer "):
            return {"error": "未提供有效的认证令牌"}, 401
        
        token = auth_header.replace("Bearer ", "")
        
        try:
            # 验证令牌
            payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
            user_id = payload.get("sub")
            
            if user_id is None:
                return {"error": "无效的认证令牌"}, 401
            
            # 设置用户ID到上下文
            context.state.user_id = user_id
            
        except jwt.PyJWTError:
            return {"error": "无效的认证令牌"}, 401
        
        # 继续处理请求
        return await call_next(context)

# 注册认证中间件
mcp.add_middleware(AuthMiddleware())

# 用户认证工具
@mcp.tool()
async def login(username: str, password: str):
    """用户登录
    
    参数:
        username: 用户名
        password: 密码
        
    返回:
        认证令牌
    """
    # 在实际应用中,这里应该验证用户凭据
    # 此处仅作为示例,不要在生产环境中使用
    if username == "admin" and password == "password":
        # 创建访问令牌
        access_token = create_access_token(
            data={"sub": username}
        )
        
        return {
            "access_token": access_token,
            "token_type": "bearer"
        }
    else:
        return {"error": "无效的用户名或密码"}, 401

# 需要认证的资源
@mcp.resource("secure://data")
async def get_secure_data(context):
    """获取安全数据"""
    # 检查用户ID(认证中间件已设置)
    user_id = getattr(context.state, "user_id", None)
    
    if not user_id:
        return {"error": "需要认证"}, 401
    
    # 返回安全数据
    return {
        "message": f"这是用户 {user_id} 的安全数据",
        "timestamp": datetime.utcnow().isoformat()
    }

基于角色的访问控制(RBAC)

除了认证之外,还应该实现授权机制,以控制用户对不同资源和工具的访问权限:

# 用户角色和权限示例
PERMISSIONS = {
    "user": ["read:basic", "execute:basic"],
    "editor": ["read:basic", "read:advanced", "execute:basic", "execute:advanced"],
    "admin": ["read:all", "execute:all"]
}

# 授权中间件
class AuthorizationMiddleware(BaseMiddleware):
    """授权中间件"""
    
    async def __call__(self, context, call_next):
        # 检查用户是否已认证
        user_id = getattr(context.state, "user_id", None)
        
        if not user_id:
            # 如果是公共资源,允许继续
            if self._is_public_resource(context):
                return await call_next(context)
            
            return {"error": "需要认证"}, 401
        
        # 获取用户角色(实际应用中应该从数据库获取)
        user_role = self._get_user_role(user_id)
        user_permissions = PERMISSIONS.get(user_role, [])
        
        # 检查是否有权限访问资源或执行工具
        required_permission = self._get_required_permission(context)
        
        if not self._has_permission(user_permissions, required_permission):
            return {"error": "权限不足"}, 403
        
        # 设置角色和权限到上下文
        context.state.user_role = user_role
        context.state.user_permissions = user_permissions
        
        # 继续处理请求
        return await call_next(context)
    
    def _is_public_resource(self, context):
        """检查是否是公共资源"""
        # 实现公共资源检查逻辑
        resource_uri = context.scope.get("resource_uri", "")
        return resource_uri.startswith("public://")
    
    def _get_user_role(self, user_id):
        """获取用户角色"""
        # 实际应用中应该从数据库获取
        # 此处仅作为示例
        if user_id == "admin":
            return "admin"
        elif user_id in ["editor1", "editor2"]:
            return "editor"
        else:
            return "user"
    
    def _get_required_permission(self, context):
        """获取请求所需的权限"""
        scope_type = context.scope.get("type")
        
        if scope_type == "resource":
            resource_uri = context.scope.get("resource_uri", "")
            # 基于资源URI确定所需权限
            if resource_uri.startswith("admin://"):
                return "read:all"
            elif resource_uri.startswith("advanced://"):
                return "read:advanced"
            else:
                return "read:basic"
        
        elif scope_type == "tool":
            tool_name = context.scope.get("tool_name", "")
            # 基于工具名称确定所需权限
            if tool_name.startswith("admin_"):
                return "execute:all"
            elif tool_name.startswith("advanced_"):
                return "execute:advanced"
            else:
                return "execute:basic"
        
        # 默认需要基本权限
        return "read:basic"
    
    def _has_permission(self, user_permissions, required_permission):
        """检查用户是否有所需权限"""
        # 特殊情况: 所有权限
        if "read:all" in user_permissions and required_permission.startswith("read:"):
            return True
        if "execute:all" in user_permissions and required_permission.startswith("execute:"):
            return True
        
        # 直接检查权限
        return required_permission in user_permissions

# 注册授权中间件
mcp.add_middleware(AuthorizationMiddleware())

输入验证与安全处理

参数验证与清理

MCP服务应该对所有输入进行严格验证,以防止注入攻击和其他安全问题:

from pydantic import BaseModel, Field, validator
from typing import List, Optional
import re

# 输入验证模型
class DocumentInput(BaseModel):
    """文档输入验证模型"""
    title: str = Field(..., min_length=1, max_length=100)
    content: str = Field(..., min_length=1, max_length=10000)
    tags: Optional[List[str]] = Field(default=[], max_items=10)
    
    # 自定义验证器
    @validator('title')
    def title_must_be_valid(cls, v):
        # 防止标题中包含危险字符
        if re.search(r'[<>]', v):
            raise ValueError('标题不能包含HTML标签')
        return v
    
    @validator('tags', each_item=True)
    def tags_must_be_valid(cls, v):
        # 验证标签格式
        if not re.match(r'^[a-zA-Z0-9_-]{1,20}$', v):
            raise ValueError('标签只能包含字母、数字、下划线和连字符')
        return v

# 使用验证模型的工具
@mcp.tool()
async def create_document_secure(document: DocumentInput):
    """安全地创建文档
    
    参数:
        document: 文档数据
    
    返回:
        创建的文档
    """
    # 由于已通过Pydantic验证,输入数据是安全的
    # 创建文档逻辑...
    
    return {
        "id": "doc123",
        "title": document.title,
        "content_length": len(document.content),
        "tags": document.tags
    }

防止SQL注入

在与数据库交互时,应该使用参数化查询防止SQL注入:

from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession

async def unsafe_query(session: AsyncSession, user_input: str):
    """不安全的查询(不要这样做!)"""
    # 危险:直接在SQL中使用用户输入
    query = f"SELECT * FROM users WHERE username = '{user_input}'"
    result = await session.execute(text(query))
    return result.fetchall()

async def safe_query(session: AsyncSession, user_input: str):
    """安全的参数化查询(推荐做法)"""
    # 安全:使用参数化查询
    query = text("SELECT * FROM users WHERE username = :username")
    result = await session.execute(query, {"username": user_input})
    return result.fetchall()

防止命令注入

如果MCP工具执行系统命令,必须采取预防措施防止命令注入:

import subprocess
import shlex

async def unsafe_command(user_input: str):
    """不安全的命令执行(不要这样做!)"""
    # 危险:直接在命令中使用用户输入
    command = f"ls {user_input}"
    return subprocess.check_output(command, shell=True)

async def safe_command(user_input: str):
    """安全的命令执行(推荐做法)"""
    # 安全:使用参数列表并避免shell=True
    command = ["ls", user_input]
    return subprocess.check_output(command)

# 更安全:使用shlex模块处理参数
async def safer_command(user_input: str):
    """更安全的命令执行"""
    command = f"ls {shlex.quote(user_input)}"
    return subprocess.check_output(command, shell=True)

数据加密与隐私保护

传输层安全(TLS/SSL)

MCP服务应该使用TLS/SSL加密所有通信:

from mcp.server.fastmcp import FastMCP

# 创建MCP服务器
mcp = FastMCP("加密MCP服务")

# 注册资源和工具
# ...

# 使用HTTPS运行服务器
if __name__ == "__main__":
    # 指定SSL证书和密钥
    ssl_certfile = "/path/to/cert.pem"
    ssl_keyfile = "/path/to/key.pem"
    
    mcp.run(
        host="0.0.0.0",
        port=8443,
        ssl_certfile=ssl_certfile,
        ssl_keyfile=ssl_keyfile
    )

敏感数据加密

敏感数据在存储前应该加密:

from cryptography.fernet import Fernet
import base64
import os

# 初始化加密密钥
def initialize_encryption():
    """初始化加密模块"""
    # 生成或加载密钥
    key_file = "encryption_key.key"
    
    if os.path.exists(key_file):
        # 从文件加载密钥
        with open(key_file, "rb") as f:
            key = f.read()
    else:
        # 生成新密钥
        key = Fernet.generate_key()
        # 保存密钥到文件(在生产环境中应该使用安全的密钥管理)
        with open(key_file, "wb") as f:
            f.write(key)
    
    # 创建Fernet实例
    return Fernet(key)

# 加密函数
def encrypt_data(fernet: Fernet, data: str) -> str:
    """加密数据"""
    return fernet.encrypt(data.encode()).decode()

# 解密函数
def decrypt_data(fernet: Fernet, encrypted_data: str) -> str:
    """解密数据"""
    return fernet.decrypt(encrypted_data.encode()).decode()

# 在MCP服务器中使用加密
@mcp.on_startup
async def setup_encryption():
    """设置加密"""
    mcp.state.fernet = initialize_encryption()

@mcp.tool()
async def store_sensitive_data(user_id: str, data: str):
    """安全存储敏感数据
    
    参数:
        user_id: 用户ID
        data: 敏感数据
        
    返回:
        存储状态
    """
    # 加密敏感数据
    encrypted_data = encrypt_data(mcp.state.fernet, data)
    
    # 存储加密数据(例如保存到数据库)
    # ...
    
    return {"status": "success", "message": "数据已安全存储"}

@mcp.resource("sensitive://{user_id}")
async def get_sensitive_data(context, user_id: str):
    """获取敏感数据
    
    参数:
        user_id: 用户ID
        
    返回:
        敏感数据
    """
    # 认证和授权检查
    if getattr(context.state, "user_id", None) != user_id:
        return {"error": "无权访问"}, 403
    
    # 获取加密数据(例如从数据库获取)
    encrypted_data = "..."  # 从数据库获取
    
    # 解密数据
    decrypted_data = decrypt_data(mcp.state.fernet, encrypted_data)
    
    return {"data": decrypted_data}

防止常见攻击

防止DoS攻击

防止拒绝服务攻击的一种方法是实现速率限制:

import time
from collections import defaultdict
from mcp.server.middleware import BaseMiddleware

class RateLimitMiddleware(BaseMiddleware):
    """速率限制中间件"""
    
    def __init__(self, requests_per_minute=60):
        self.requests_per_minute = requests_per_minute
        self.request_counts = defaultdict(list)
    
    async def __call__(self, context, call_next):
        # 获取客户端IP
        client_ip = context.scope.get("client", {}).get("host", "unknown")
        
        # 获取当前时间
        current_time = time.time()
        
        # 清理过期记录(保留最近1分钟的记录)
        self.request_counts[client_ip] = [
            t for t in self.request_counts[client_ip]
            if current_time - t < 60
        ]
        
        # 检查请求次数
        if len(self.request_counts[client_ip]) >= self.requests_per_minute:
            return {"error": "请求过于频繁,请稍后再试"}, 429
        
        # 记录请求时间
        self.request_counts[client_ip].append(current_time)
        
        # 继续处理请求
        return await call_next(context)

# 注册速率限制中间件
mcp.add_middleware(RateLimitMiddleware(requests_per_minute=100))

防止跨站请求伪造(CSRF)

当MCP服务与Web应用集成时,应该防止CSRF攻击:

import secrets
from mcp.server.middleware import BaseMiddleware

class CSRFMiddleware(BaseMiddleware):
    """CSRF保护中间件"""
    
    async def __call__(self, context, call_next):
        # 检查请求方法
        method = context.scope.get("method", "").upper()
        
        # 对于非GET/HEAD请求,检查CSRF令牌
        if method not in ["GET", "HEAD"]:
            # 获取CSRF令牌
            csrf_token = context.scope.get("headers", {}).get("x-csrf-token")
            session_token = context.state.session.get("csrf_token") if hasattr(context.state, "session") else None
            
            # 验证令牌
            if not csrf_token or not session_token or csrf_token != session_token:
                return {"error": "CSRF验证失败"}, 403
        
        # 继续处理请求
        return await call_next(context)

# 生成CSRF令牌
@mcp.tool()
async def get_csrf_token(context):
    """获取CSRF令牌"""
    # 确保会话存在
    if not hasattr(context.state, "session"):
        context.state.session = {}
    
    # 生成新令牌
    csrf_token = secrets.token_hex(16)
    context.state.session["csrf_token"] = csrf_token
    
    return {"csrf_token": csrf_token}

安全配置最佳实践

安全的环境变量管理

敏感配置(如密钥和密码)应该通过环境变量或配置文件管理,而不是硬编码在代码中:

import os
from dotenv import load_dotenv

# 加载环境变量
load_dotenv()

# 创建MCP服务器
mcp = FastMCP("安全配置MCP服务")

# 从环境变量获取配置
DATABASE_URL = os.getenv("DATABASE_URL")
JWT_SECRET = os.getenv("JWT_SECRET")
API_KEYS = os.getenv("API_KEYS", "").split(",")

# 验证配置
if not JWT_SECRET:
    raise ValueError("未设置JWT_SECRET环境变量")

if not DATABASE_URL:
    raise ValueError("未设置DATABASE_URL环境变量")

# 使用配置
# ...

安全日志记录

实现安全日志记录,但避免记录敏感信息:

import logging
from mcp.server.middleware import BaseMiddleware

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    filename='mcp_security.log'
)

security_logger = logging.getLogger("mcp.security")

class SecurityLoggingMiddleware(BaseMiddleware):
    """安全日志记录中间件"""
    
    async def __call__(self, context, call_next):
        # 记录请求信息(不记录敏感数据)
        request_type = context.scope.get("type")
        client_ip = context.scope.get("client", {}).get("host", "unknown")
        user_id = getattr(context.state, "user_id", "anonymous")
        
        security_logger.info(
            f"Request: type={request_type}, ip={client_ip}, user={user_id}"
        )
        
        # 捕获异常并记录
        try:
            response = await call_next(context)
            
            # 记录响应状态
            status = response[1] if isinstance(response, tuple) and len(response) > 1 else 200
            security_logger.info(f"Response: status={status}")
            
            return response
        
        except Exception as e:
            # 记录异常
            security_logger.error(f"Exception: {str(e)}", exc_info=True)
            # 重新抛出异常
            raise

# 注册安全日志中间件
mcp.add_middleware(SecurityLoggingMiddleware())

安全审计与定期检查

实施安全审计

定期审计MCP服务的安全状况是确保安全的重要步骤:

@mcp.tool(require_admin=True)
async def security_audit():
    """执行安全审计
    
    返回:
        审计结果
    """
    # 收集安全信息
    audit_results = {
        "timestamp": datetime.utcnow().isoformat(),
        "checks": []
    }
    
    # 检查1: JWT密钥强度
    jwt_secret = os.getenv("JWT_SECRET", "")
    audit_results["checks"].append({
        "name": "JWT密钥强度",
        "status": "pass" if len(jwt_secret) >= 32 else "fail",
        "details": f"JWT密钥长度为 {len(jwt_secret)} 字符"
    })
    
    # 检查2: TLS配置
    ssl_cert = os.getenv("SSL_CERT", "")
    ssl_key = os.getenv("SSL_KEY", "")
    audit_results["checks"].append({
        "name": "TLS配置",
        "status": "pass" if ssl_cert and ssl_key else "fail",
        "details": "已配置TLS" if ssl_cert and ssl_key else "未配置TLS"
    })
    
    # 检查3: 检查中间件
    middleware_names = [m.__class__.__name__ for m in mcp.middleware]
    
    # 检查认证中间件
    auth_middleware = "AuthMiddleware" in middleware_names
    audit_results["checks"].append({
        "name": "认证中间件",
        "status": "pass" if auth_middleware else "fail",
        "details": "已启用认证中间件" if auth_middleware else "未启用认证中间件"
    })
    
    # 检查速率限制中间件
    rate_limit = "RateLimitMiddleware" in middleware_names
    audit_results["checks"].append({
        "name": "速率限制",
        "status": "pass" if rate_limit else "warn",
        "details": "已启用速率限制" if rate_limit else "未启用速率限制"
    })
    
    # 返回审计结果
    return audit_results

小结

在本章中,我们探讨了MCP服务的安全最佳实践:

  • 实现强大的认证和授权机制
  • 严格验证和清理所有输入
  • 加密敏感数据并使用TLS加密通信
  • 防御常见安全攻击
  • 安全配置和日志记录
  • 定期进行安全审计

通过采用这些最佳实践,您可以构建安全可靠的MCP服务,保护应用程序和用户数据免受潜在威胁。在下一章中,我们将探讨MCP服务的测试和部署策略。

使用 Hugo 构建
主题 StackJimmy 设计