安全架构概述

1. MCP安全模型

MCP协议的安全模型基于多层防护策略,确保通信安全、数据保护和访问控制:

  • 传输层安全:TLS/SSL加密通信
  • 身份认证:多种认证机制支持
  • 授权控制:基于角色的访问控制(RBAC)
  • 数据保护:敏感数据加密和脱敏
  • 审计日志:完整的操作审计追踪
from typing import Dict, List, Optional, Any, Set, Callable
from enum import Enum
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import hashlib
import secrets
import jwt
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
import base64
import os

class SecurityLevel(Enum):
    """安全级别"""
    PUBLIC = "public"
    INTERNAL = "internal"
    CONFIDENTIAL = "confidential"
    SECRET = "secret"
    TOP_SECRET = "top_secret"

class AuthenticationMethod(Enum):
    """认证方法"""
    API_KEY = "api_key"
    JWT_TOKEN = "jwt_token"
    OAUTH2 = "oauth2"
    CERTIFICATE = "certificate"
    MULTI_FACTOR = "multi_factor"

class Permission(Enum):
    """权限类型"""
    READ = "read"
    WRITE = "write"
    EXECUTE = "execute"
    DELETE = "delete"
    ADMIN = "admin"
    AUDIT = "audit"

@dataclass
class SecurityContext:
    """安全上下文"""
    user_id: str
    session_id: str
    roles: Set[str] = field(default_factory=set)
    permissions: Set[Permission] = field(default_factory=set)
    security_level: SecurityLevel = SecurityLevel.PUBLIC
    authentication_method: AuthenticationMethod = AuthenticationMethod.API_KEY
    authenticated_at: datetime = field(default_factory=datetime.now)
    expires_at: Optional[datetime] = None
    ip_address: Optional[str] = None
    user_agent: Optional[str] = None
    metadata: Dict[str, Any] = field(default_factory=dict)

@dataclass
class SecurityPolicy:
    """安全策略"""
    name: str
    description: str
    required_security_level: SecurityLevel
    allowed_authentication_methods: Set[AuthenticationMethod]
    required_permissions: Set[Permission]
    ip_whitelist: Optional[Set[str]] = None
    ip_blacklist: Optional[Set[str]] = None
    rate_limit: Optional[int] = None  # requests per minute
    session_timeout: Optional[timedelta] = None
    require_mfa: bool = False
    audit_required: bool = True
    custom_rules: List[Callable[[SecurityContext], bool]] = field(default_factory=list)

class SecurityException(Exception):
    """安全异常"""
    
    def __init__(self, message: str, error_code: str = "SECURITY_ERROR", 
                 context: Optional[SecurityContext] = None):
        super().__init__(message)
        self.error_code = error_code
        self.context = context
        self.timestamp = datetime.now()

class AuthenticationException(SecurityException):
    """认证异常"""
    
    def __init__(self, message: str, context: Optional[SecurityContext] = None):
        super().__init__(message, "AUTHENTICATION_ERROR", context)

class AuthorizationException(SecurityException):
    """授权异常"""
    
    def __init__(self, message: str, context: Optional[SecurityContext] = None):
        super().__init__(message, "AUTHORIZATION_ERROR", context)

2. 加密工具类

import json
from typing import Union, Tuple

class CryptographyManager:
    """加密管理器"""
    
    def __init__(self, master_key: Optional[str] = None):
        if master_key:
            self.master_key = master_key.encode()
        else:
            self.master_key = os.urandom(32)
        
        # 生成Fernet密钥
        kdf = PBKDF2HMAC(
            algorithm=hashes.SHA256(),
            length=32,
            salt=b'mcp_salt_2024',  # 在生产环境中应使用随机盐
            iterations=100000,
        )
        key = base64.urlsafe_b64encode(kdf.derive(self.master_key))
        self.fernet = Fernet(key)
    
    def encrypt_data(self, data: Union[str, dict, list]) -> str:
        """加密数据"""
        if isinstance(data, (dict, list)):
            data = json.dumps(data, ensure_ascii=False)
        elif not isinstance(data, str):
            data = str(data)
        
        encrypted = self.fernet.encrypt(data.encode('utf-8'))
        return base64.urlsafe_b64encode(encrypted).decode('utf-8')
    
    def decrypt_data(self, encrypted_data: str) -> str:
        """解密数据"""
        try:
            encrypted_bytes = base64.urlsafe_b64decode(encrypted_data.encode('utf-8'))
            decrypted = self.fernet.decrypt(encrypted_bytes)
            return decrypted.decode('utf-8')
        except Exception as e:
            raise SecurityException(f"数据解密失败: {e}")
    
    def decrypt_json(self, encrypted_data: str) -> Union[dict, list]:
        """解密JSON数据"""
        decrypted_str = self.decrypt_data(encrypted_data)
        try:
            return json.loads(decrypted_str)
        except json.JSONDecodeError as e:
            raise SecurityException(f"JSON解密失败: {e}")
    
    def hash_password(self, password: str, salt: Optional[str] = None) -> Tuple[str, str]:
        """哈希密码"""
        if salt is None:
            salt = secrets.token_hex(16)
        
        # 使用PBKDF2进行密码哈希
        kdf = PBKDF2HMAC(
            algorithm=hashes.SHA256(),
            length=32,
            salt=salt.encode(),
            iterations=100000,
        )
        
        password_hash = base64.urlsafe_b64encode(kdf.derive(password.encode())).decode()
        return password_hash, salt
    
    def verify_password(self, password: str, password_hash: str, salt: str) -> bool:
        """验证密码"""
        try:
            computed_hash, _ = self.hash_password(password, salt)
            return secrets.compare_digest(password_hash, computed_hash)
        except Exception:
            return False
    
    def generate_token(self, length: int = 32) -> str:
        """生成安全令牌"""
        return secrets.token_urlsafe(length)
    
    def generate_api_key(self) -> str:
        """生成API密钥"""
        return f"mcp_{secrets.token_urlsafe(32)}"

class DataMasking:
    """数据脱敏"""
    
    @staticmethod
    def mask_email(email: str) -> str:
        """脱敏邮箱地址"""
        if '@' not in email:
            return email
        
        local, domain = email.split('@', 1)
        if len(local) <= 2:
            masked_local = '*' * len(local)
        else:
            masked_local = local[0] + '*' * (len(local) - 2) + local[-1]
        
        return f"{masked_local}@{domain}"
    
    @staticmethod
    def mask_phone(phone: str) -> str:
        """脱敏电话号码"""
        if len(phone) <= 4:
            return '*' * len(phone)
        
        return phone[:3] + '*' * (len(phone) - 6) + phone[-3:]
    
    @staticmethod
    def mask_id_card(id_card: str) -> str:
        """脱敏身份证号"""
        if len(id_card) <= 8:
            return '*' * len(id_card)
        
        return id_card[:4] + '*' * (len(id_card) - 8) + id_card[-4:]
    
    @staticmethod
    def mask_credit_card(card_number: str) -> str:
        """脱敏信用卡号"""
        if len(card_number) <= 8:
            return '*' * len(card_number)
        
        return card_number[:4] + '*' * (len(card_number) - 8) + card_number[-4:]
    
    @staticmethod
    def mask_sensitive_data(data: Any, sensitive_fields: Set[str]) -> Any:
        """脱敏敏感数据"""
        if isinstance(data, dict):
            masked_data = {}
            for key, value in data.items():
                if key.lower() in {field.lower() for field in sensitive_fields}:
                    if 'email' in key.lower():
                        masked_data[key] = DataMasking.mask_email(str(value))
                    elif 'phone' in key.lower():
                        masked_data[key] = DataMasking.mask_phone(str(value))
                    elif 'id' in key.lower() or 'card' in key.lower():
                        masked_data[key] = DataMasking.mask_id_card(str(value))
                    else:
                        masked_data[key] = '*' * len(str(value))
                else:
                    masked_data[key] = DataMasking.mask_sensitive_data(value, sensitive_fields)
            return masked_data
        elif isinstance(data, list):
            return [DataMasking.mask_sensitive_data(item, sensitive_fields) for item in data]
        else:
            return data

身份认证系统

1. 认证管理器

import time
from typing import Optional, Dict, Any
from abc import ABC, abstractmethod

class AuthenticationProvider(ABC):
    """认证提供者接口"""
    
    @abstractmethod
    async def authenticate(self, credentials: Dict[str, Any]) -> Optional[SecurityContext]:
        """认证用户"""
        pass
    
    @abstractmethod
    async def validate_token(self, token: str) -> Optional[SecurityContext]:
        """验证令牌"""
        pass
    
    @abstractmethod
    async def refresh_token(self, refresh_token: str) -> Optional[str]:
        """刷新令牌"""
        pass

class APIKeyAuthProvider(AuthenticationProvider):
    """API密钥认证提供者"""
    
    def __init__(self, crypto_manager: CryptographyManager):
        self.crypto_manager = crypto_manager
        self.api_keys: Dict[str, Dict[str, Any]] = {}  # api_key -> user_info
        self.user_keys: Dict[str, Set[str]] = {}  # user_id -> api_keys
    
    def create_api_key(self, user_id: str, roles: Set[str], 
                      permissions: Set[Permission],
                      expires_in: Optional[timedelta] = None,
                      description: str = "") -> str:
        """创建API密钥"""
        api_key = self.crypto_manager.generate_api_key()
        
        expires_at = None
        if expires_in:
            expires_at = datetime.now() + expires_in
        
        self.api_keys[api_key] = {
            "user_id": user_id,
            "roles": roles,
            "permissions": permissions,
            "created_at": datetime.now(),
            "expires_at": expires_at,
            "description": description,
            "last_used": None,
            "usage_count": 0
        }
        
        if user_id not in self.user_keys:
            self.user_keys[user_id] = set()
        self.user_keys[user_id].add(api_key)
        
        return api_key
    
    def revoke_api_key(self, api_key: str) -> bool:
        """撤销API密钥"""
        if api_key in self.api_keys:
            user_id = self.api_keys[api_key]["user_id"]
            del self.api_keys[api_key]
            
            if user_id in self.user_keys:
                self.user_keys[user_id].discard(api_key)
            
            return True
        return False
    
    def list_user_api_keys(self, user_id: str) -> List[Dict[str, Any]]:
        """列出用户的API密钥"""
        user_api_keys = self.user_keys.get(user_id, set())
        
        result = []
        for api_key in user_api_keys:
            if api_key in self.api_keys:
                key_info = self.api_keys[api_key].copy()
                key_info["api_key"] = api_key[:8] + "..."  # 只显示前8位
                result.append(key_info)
        
        return result
    
    async def authenticate(self, credentials: Dict[str, Any]) -> Optional[SecurityContext]:
        """API密钥认证"""
        api_key = credentials.get("api_key")
        if not api_key:
            return None
        
        return await self.validate_token(api_key)
    
    async def validate_token(self, token: str) -> Optional[SecurityContext]:
        """验证API密钥"""
        if token not in self.api_keys:
            return None
        
        key_info = self.api_keys[token]
        
        # 检查过期时间
        if key_info["expires_at"] and datetime.now() > key_info["expires_at"]:
            return None
        
        # 更新使用统计
        key_info["last_used"] = datetime.now()
        key_info["usage_count"] += 1
        
        return SecurityContext(
            user_id=key_info["user_id"],
            session_id=f"api_{token[:8]}",
            roles=key_info["roles"],
            permissions=key_info["permissions"],
            authentication_method=AuthenticationMethod.API_KEY,
            authenticated_at=datetime.now()
        )
    
    async def refresh_token(self, refresh_token: str) -> Optional[str]:
        """API密钥不支持刷新"""
        return None

class JWTAuthProvider(AuthenticationProvider):
    """JWT认证提供者"""
    
    def __init__(self, secret_key: str, algorithm: str = "HS256"):
        self.secret_key = secret_key
        self.algorithm = algorithm
        self.blacklisted_tokens: Set[str] = set()
    
    def create_jwt_token(self, user_id: str, roles: Set[str], 
                        permissions: Set[Permission],
                        expires_in: timedelta = timedelta(hours=24)) -> Dict[str, str]:
        """创建JWT令牌"""
        now = datetime.now()
        expires_at = now + expires_in
        
        # 访问令牌载荷
        access_payload = {
            "user_id": user_id,
            "roles": list(roles),
            "permissions": [p.value for p in permissions],
            "iat": int(now.timestamp()),
            "exp": int(expires_at.timestamp()),
            "type": "access"
        }
        
        # 刷新令牌载荷
        refresh_expires_at = now + timedelta(days=30)
        refresh_payload = {
            "user_id": user_id,
            "iat": int(now.timestamp()),
            "exp": int(refresh_expires_at.timestamp()),
            "type": "refresh"
        }
        
        access_token = jwt.encode(access_payload, self.secret_key, algorithm=self.algorithm)
        refresh_token = jwt.encode(refresh_payload, self.secret_key, algorithm=self.algorithm)
        
        return {
            "access_token": access_token,
            "refresh_token": refresh_token,
            "token_type": "Bearer",
            "expires_in": int(expires_in.total_seconds())
        }
    
    def blacklist_token(self, token: str):
        """将令牌加入黑名单"""
        self.blacklisted_tokens.add(token)
    
    async def authenticate(self, credentials: Dict[str, Any]) -> Optional[SecurityContext]:
        """JWT认证"""
        username = credentials.get("username")
        password = credentials.get("password")
        
        # 这里应该验证用户名和密码
        # 为了示例,我们假设验证成功
        if username and password:
            # 在实际应用中,这里应该查询数据库验证用户
            return SecurityContext(
                user_id=username,
                session_id=f"jwt_{int(time.time())}",
                roles={"user"},
                permissions={Permission.READ, Permission.WRITE},
                authentication_method=AuthenticationMethod.JWT_TOKEN
            )
        
        return None
    
    async def validate_token(self, token: str) -> Optional[SecurityContext]:
        """验证JWT令牌"""
        if token in self.blacklisted_tokens:
            return None
        
        try:
            payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
            
            if payload.get("type") != "access":
                return None
            
            roles = set(payload.get("roles", []))
            permissions = {Permission(p) for p in payload.get("permissions", [])}
            
            return SecurityContext(
                user_id=payload["user_id"],
                session_id=f"jwt_{payload['iat']}",
                roles=roles,
                permissions=permissions,
                authentication_method=AuthenticationMethod.JWT_TOKEN,
                authenticated_at=datetime.fromtimestamp(payload["iat"]),
                expires_at=datetime.fromtimestamp(payload["exp"])
            )
        
        except jwt.ExpiredSignatureError:
            raise AuthenticationException("令牌已过期")
        except jwt.InvalidTokenError:
            raise AuthenticationException("无效的令牌")
    
    async def refresh_token(self, refresh_token: str) -> Optional[str]:
        """刷新JWT令牌"""
        try:
            payload = jwt.decode(refresh_token, self.secret_key, algorithms=[self.algorithm])
            
            if payload.get("type") != "refresh":
                return None
            
            user_id = payload["user_id"]
            
            # 创建新的访问令牌
            # 这里应该从数据库获取用户的最新角色和权限
            new_tokens = self.create_jwt_token(
                user_id=user_id,
                roles={"user"},  # 应该从数据库获取
                permissions={Permission.READ, Permission.WRITE}  # 应该从数据库获取
            )
            
            return new_tokens["access_token"]
        
        except jwt.ExpiredSignatureError:
            raise AuthenticationException("刷新令牌已过期")
        except jwt.InvalidTokenError:
            raise AuthenticationException("无效的刷新令牌")

class AuthenticationManager:
    """认证管理器"""
    
    def __init__(self):
        self.providers: Dict[AuthenticationMethod, AuthenticationProvider] = {}
        self.active_sessions: Dict[str, SecurityContext] = {}
        self.failed_attempts: Dict[str, List[datetime]] = {}  # ip -> attempt_times
        self.max_failed_attempts = 5
        self.lockout_duration = timedelta(minutes=15)
    
    def register_provider(self, method: AuthenticationMethod, provider: AuthenticationProvider):
        """注册认证提供者"""
        self.providers[method] = provider
    
    def is_ip_locked(self, ip_address: str) -> bool:
        """检查IP是否被锁定"""
        if ip_address not in self.failed_attempts:
            return False
        
        attempts = self.failed_attempts[ip_address]
        recent_attempts = [
            attempt for attempt in attempts
            if datetime.now() - attempt < self.lockout_duration
        ]
        
        self.failed_attempts[ip_address] = recent_attempts
        return len(recent_attempts) >= self.max_failed_attempts
    
    def record_failed_attempt(self, ip_address: str):
        """记录失败尝试"""
        if ip_address not in self.failed_attempts:
            self.failed_attempts[ip_address] = []
        
        self.failed_attempts[ip_address].append(datetime.now())
    
    def clear_failed_attempts(self, ip_address: str):
        """清除失败尝试记录"""
        self.failed_attempts.pop(ip_address, None)
    
    async def authenticate(self, method: AuthenticationMethod, 
                         credentials: Dict[str, Any],
                         ip_address: Optional[str] = None) -> SecurityContext:
        """认证用户"""
        # 检查IP锁定
        if ip_address and self.is_ip_locked(ip_address):
            raise AuthenticationException(f"IP地址 {ip_address} 已被锁定")
        
        # 获取认证提供者
        provider = self.providers.get(method)
        if not provider:
            raise AuthenticationException(f"不支持的认证方法: {method}")
        
        try:
            # 执行认证
            context = await provider.authenticate(credentials)
            if not context:
                if ip_address:
                    self.record_failed_attempt(ip_address)
                raise AuthenticationException("认证失败")
            
            # 设置IP地址
            if ip_address:
                context.ip_address = ip_address
                self.clear_failed_attempts(ip_address)
            
            # 保存会话
            self.active_sessions[context.session_id] = context
            
            return context
        
        except AuthenticationException:
            if ip_address:
                self.record_failed_attempt(ip_address)
            raise
    
    async def validate_session(self, session_id: str) -> Optional[SecurityContext]:
        """验证会话"""
        context = self.active_sessions.get(session_id)
        if not context:
            return None
        
        # 检查过期时间
        if context.expires_at and datetime.now() > context.expires_at:
            del self.active_sessions[session_id]
            return None
        
        return context
    
    async def validate_token(self, method: AuthenticationMethod, token: str) -> SecurityContext:
        """验证令牌"""
        provider = self.providers.get(method)
        if not provider:
            raise AuthenticationException(f"不支持的认证方法: {method}")
        
        context = await provider.validate_token(token)
        if not context:
            raise AuthenticationException("令牌验证失败")
        
        return context
    
    def logout(self, session_id: str) -> bool:
        """登出"""
        if session_id in self.active_sessions:
            del self.active_sessions[session_id]
            return True
        return False
    
    def get_active_sessions(self, user_id: Optional[str] = None) -> List[SecurityContext]:
        """获取活跃会话"""
        sessions = list(self.active_sessions.values())
        
        if user_id:
            sessions = [s for s in sessions if s.user_id == user_id]
        
        return sessions
    
    def cleanup_expired_sessions(self) -> int:
        """清理过期会话"""
        now = datetime.now()
        expired_sessions = []
        
        for session_id, context in self.active_sessions.items():
            if context.expires_at and now > context.expires_at:
                expired_sessions.append(session_id)
        
        for session_id in expired_sessions:
            del self.active_sessions[session_id]
        
        return len(expired_sessions)

授权控制系统

1. 基于角色的访问控制(RBAC)

from typing import Set, Dict, List, Optional, Any
from dataclasses import dataclass, field
from abc import ABC, abstractmethod

@dataclass
class Role:
    """角色定义"""
    name: str
    description: str
    permissions: Set[Permission] = field(default_factory=set)
    parent_roles: Set[str] = field(default_factory=set)  # 继承的父角色
    metadata: Dict[str, Any] = field(default_factory=dict)
    created_at: datetime = field(default_factory=datetime.now)
    updated_at: datetime = field(default_factory=datetime.now)

@dataclass
class Resource:
    """资源定义"""
    uri: str
    name: str
    resource_type: str
    security_level: SecurityLevel
    owner: Optional[str] = None
    required_permissions: Set[Permission] = field(default_factory=set)
    metadata: Dict[str, Any] = field(default_factory=dict)

class AuthorizationProvider(ABC):
    """授权提供者接口"""
    
    @abstractmethod
    async def check_permission(self, context: SecurityContext, 
                             resource: str, permission: Permission) -> bool:
        """检查权限"""
        pass
    
    @abstractmethod
    async def get_user_permissions(self, user_id: str) -> Set[Permission]:
        """获取用户权限"""
        pass
    
    @abstractmethod
    async def get_resource_permissions(self, resource_uri: str) -> Set[Permission]:
        """获取资源所需权限"""
        pass

class RBACAuthorizationProvider(AuthorizationProvider):
    """RBAC授权提供者"""
    
    def __init__(self):
        self.roles: Dict[str, Role] = {}
        self.user_roles: Dict[str, Set[str]] = {}  # user_id -> role_names
        self.resources: Dict[str, Resource] = {}
        self.role_hierarchy_cache: Dict[str, Set[str]] = {}  # 角色层次缓存
    
    def create_role(self, name: str, description: str, 
                   permissions: Set[Permission],
                   parent_roles: Optional[Set[str]] = None) -> Role:
        """创建角色"""
        if name in self.roles:
            raise ValueError(f"角色已存在: {name}")
        
        role = Role(
            name=name,
            description=description,
            permissions=permissions,
            parent_roles=parent_roles or set()
        )
        
        self.roles[name] = role
        self._invalidate_hierarchy_cache()
        
        return role
    
    def update_role(self, name: str, permissions: Optional[Set[Permission]] = None,
                   parent_roles: Optional[Set[str]] = None) -> bool:
        """更新角色"""
        if name not in self.roles:
            return False
        
        role = self.roles[name]
        
        if permissions is not None:
            role.permissions = permissions
        
        if parent_roles is not None:
            role.parent_roles = parent_roles
        
        role.updated_at = datetime.now()
        self._invalidate_hierarchy_cache()
        
        return True
    
    def delete_role(self, name: str) -> bool:
        """删除角色"""
        if name not in self.roles:
            return False
        
        # 检查是否有用户使用此角色
        for user_roles in self.user_roles.values():
            if name in user_roles:
                raise ValueError(f"角色 {name} 正在被用户使用,无法删除")
        
        # 检查是否有其他角色继承此角色
        for role in self.roles.values():
            if name in role.parent_roles:
                raise ValueError(f"角色 {name} 被其他角色继承,无法删除")
        
        del self.roles[name]
        self._invalidate_hierarchy_cache()
        
        return True
    
    def assign_role_to_user(self, user_id: str, role_name: str) -> bool:
        """为用户分配角色"""
        if role_name not in self.roles:
            return False
        
        if user_id not in self.user_roles:
            self.user_roles[user_id] = set()
        
        self.user_roles[user_id].add(role_name)
        return True
    
    def revoke_role_from_user(self, user_id: str, role_name: str) -> bool:
        """撤销用户角色"""
        if user_id not in self.user_roles:
            return False
        
        self.user_roles[user_id].discard(role_name)
        
        if not self.user_roles[user_id]:
            del self.user_roles[user_id]
        
        return True
    
    def register_resource(self, resource: Resource):
        """注册资源"""
        self.resources[resource.uri] = resource
    
    def unregister_resource(self, resource_uri: str) -> bool:
        """注销资源"""
        if resource_uri in self.resources:
            del self.resources[resource_uri]
            return True
        return False
    
    def _get_all_role_permissions(self, role_name: str) -> Set[Permission]:
        """获取角色的所有权限(包括继承的)"""
        if role_name in self.role_hierarchy_cache:
            all_roles = self.role_hierarchy_cache[role_name]
        else:
            all_roles = self._compute_role_hierarchy(role_name)
            self.role_hierarchy_cache[role_name] = all_roles
        
        permissions = set()
        for role_name in all_roles:
            if role_name in self.roles:
                permissions.update(self.roles[role_name].permissions)
        
        return permissions
    
    def _compute_role_hierarchy(self, role_name: str, visited: Optional[Set[str]] = None) -> Set[str]:
        """计算角色层次结构"""
        if visited is None:
            visited = set()
        
        if role_name in visited:
            raise ValueError(f"检测到角色循环依赖: {role_name}")
        
        if role_name not in self.roles:
            return set()
        
        visited.add(role_name)
        all_roles = {role_name}
        
        role = self.roles[role_name]
        for parent_role in role.parent_roles:
            parent_hierarchy = self._compute_role_hierarchy(parent_role, visited.copy())
            all_roles.update(parent_hierarchy)
        
        return all_roles
    
    def _invalidate_hierarchy_cache(self):
        """使角色层次缓存失效"""
        self.role_hierarchy_cache.clear()
    
    async def check_permission(self, context: SecurityContext, 
                             resource_uri: str, permission: Permission) -> bool:
        """检查权限"""
        # 检查用户是否有直接权限
        if permission in context.permissions:
            return True
        
        # 检查角色权限
        user_permissions = await self.get_user_permissions(context.user_id)
        if permission in user_permissions:
            return True
        
        # 检查资源特定权限
        resource = self.resources.get(resource_uri)
        if resource:
            # 检查资源所有者权限
            if resource.owner == context.user_id:
                return True
            
            # 检查安全级别
            if context.security_level.value < resource.security_level.value:
                return False
            
            # 检查资源所需权限
            if permission in resource.required_permissions:
                return permission in user_permissions
        
        return False
    
    async def get_user_permissions(self, user_id: str) -> Set[Permission]:
        """获取用户权限"""
        user_roles = self.user_roles.get(user_id, set())
        permissions = set()
        
        for role_name in user_roles:
            role_permissions = self._get_all_role_permissions(role_name)
            permissions.update(role_permissions)
        
        return permissions
    
    async def get_resource_permissions(self, resource_uri: str) -> Set[Permission]:
        """获取资源所需权限"""
        resource = self.resources.get(resource_uri)
        if resource:
            return resource.required_permissions
        return set()
    
    def get_user_roles(self, user_id: str) -> Set[str]:
        """获取用户角色"""
        return self.user_roles.get(user_id, set()).copy()
    
    def get_role_users(self, role_name: str) -> Set[str]:
        """获取拥有指定角色的用户"""
        users = set()
        for user_id, roles in self.user_roles.items():
            if role_name in roles:
                users.add(user_id)
        return users
    
    def export_rbac_config(self) -> Dict[str, Any]:
        """导出RBAC配置"""
        return {
            "roles": {
                name: {
                    "description": role.description,
                    "permissions": [p.value for p in role.permissions],
                    "parent_roles": list(role.parent_roles),
                    "created_at": role.created_at.isoformat(),
                    "updated_at": role.updated_at.isoformat()
                }
                for name, role in self.roles.items()
            },
            "user_roles": {
                user_id: list(roles)
                for user_id, roles in self.user_roles.items()
            },
            "resources": {
                uri: {
                    "name": resource.name,
                    "resource_type": resource.resource_type,
                    "security_level": resource.security_level.value,
                    "owner": resource.owner,
                    "required_permissions": [p.value for p in resource.required_permissions]
                }
                for uri, resource in self.resources.items()
            }
        }
    
    def import_rbac_config(self, config: Dict[str, Any]):
        """导入RBAC配置"""
        # 导入角色
        for name, role_data in config.get("roles", {}).items():
            permissions = {Permission(p) for p in role_data.get("permissions", [])}
            parent_roles = set(role_data.get("parent_roles", []))
            
            role = Role(
                name=name,
                description=role_data.get("description", ""),
                permissions=permissions,
                parent_roles=parent_roles,
                created_at=datetime.fromisoformat(role_data.get("created_at", datetime.now().isoformat())),
                updated_at=datetime.fromisoformat(role_data.get("updated_at", datetime.now().isoformat()))
            )
            
            self.roles[name] = role
        
        # 导入用户角色
        for user_id, roles in config.get("user_roles", {}).items():
            self.user_roles[user_id] = set(roles)
        
        # 导入资源
        for uri, resource_data in config.get("resources", {}).items():
            required_permissions = {Permission(p) for p in resource_data.get("required_permissions", [])}
            
            resource = Resource(
                uri=uri,
                name=resource_data.get("name", ""),
                resource_type=resource_data.get("resource_type", ""),
                security_level=SecurityLevel(resource_data.get("security_level", "public")),
                owner=resource_data.get("owner"),
                required_permissions=required_permissions
            )
            
            self.resources[uri] = resource
        
        self._invalidate_hierarchy_cache()

2. 授权管理器

class AuthorizationManager:
    """授权管理器"""
    
    def __init__(self):
        self.providers: Dict[str, AuthorizationProvider] = {}
        self.policies: Dict[str, SecurityPolicy] = {}
        self.access_logs: List[Dict[str, Any]] = []
        self.max_log_entries = 10000
    
    def register_provider(self, name: str, provider: AuthorizationProvider):
        """注册授权提供者"""
        self.providers[name] = provider
    
    def register_policy(self, policy: SecurityPolicy):
        """注册安全策略"""
        self.policies[policy.name] = policy
    
    def get_policy(self, policy_name: str) -> Optional[SecurityPolicy]:
        """获取安全策略"""
        return self.policies.get(policy_name)
    
    async def check_access(self, context: SecurityContext, resource_uri: str, 
                          permission: Permission, policy_name: Optional[str] = None) -> bool:
        """检查访问权限"""
        try:
            # 应用安全策略
            if policy_name:
                policy = self.get_policy(policy_name)
                if policy and not await self._check_policy(context, policy):
                    self._log_access(context, resource_uri, permission, False, "策略检查失败")
                    return False
            
            # 检查基本权限
            access_granted = False
            
            for provider in self.providers.values():
                if await provider.check_permission(context, resource_uri, permission):
                    access_granted = True
                    break
            
            self._log_access(context, resource_uri, permission, access_granted)
            return access_granted
        
        except Exception as e:
            self._log_access(context, resource_uri, permission, False, f"检查失败: {e}")
            return False
    
    async def _check_policy(self, context: SecurityContext, policy: SecurityPolicy) -> bool:
        """检查安全策略"""
        # 检查安全级别
        if context.security_level.value < policy.required_security_level.value:
            return False
        
        # 检查认证方法
        if context.authentication_method not in policy.allowed_authentication_methods:
            return False
        
        # 检查权限
        if not policy.required_permissions.issubset(context.permissions):
            return False
        
        # 检查IP白名单
        if policy.ip_whitelist and context.ip_address:
            if context.ip_address not in policy.ip_whitelist:
                return False
        
        # 检查IP黑名单
        if policy.ip_blacklist and context.ip_address:
            if context.ip_address in policy.ip_blacklist:
                return False
        
        # 检查会话超时
        if policy.session_timeout and context.authenticated_at:
            if datetime.now() - context.authenticated_at > policy.session_timeout:
                return False
        
        # 检查多因子认证
        if policy.require_mfa:
            # 这里应该检查MFA状态
            pass
        
        # 执行自定义规则
        for rule in policy.custom_rules:
            try:
                if not rule(context):
                    return False
            except Exception as e:
                print(f"自定义规则执行失败: {e}")
                return False
        
        return True
    
    def _log_access(self, context: SecurityContext, resource_uri: str, 
                   permission: Permission, granted: bool, reason: str = ""):
        """记录访问日志"""
        log_entry = {
            "timestamp": datetime.now().isoformat(),
            "user_id": context.user_id,
            "session_id": context.session_id,
            "resource_uri": resource_uri,
            "permission": permission.value,
            "granted": granted,
            "reason": reason,
            "ip_address": context.ip_address,
            "user_agent": context.user_agent
        }
        
        self.access_logs.append(log_entry)
        
        # 限制日志条目数量
        if len(self.access_logs) > self.max_log_entries:
            self.access_logs = self.access_logs[-self.max_log_entries//2:]
    
    def get_access_logs(self, user_id: Optional[str] = None, 
                       resource_uri: Optional[str] = None,
                       start_time: Optional[datetime] = None,
                       end_time: Optional[datetime] = None,
                       granted_only: Optional[bool] = None) -> List[Dict[str, Any]]:
        """获取访问日志"""
        filtered_logs = self.access_logs.copy()
        
        if user_id:
            filtered_logs = [log for log in filtered_logs if log["user_id"] == user_id]
        
        if resource_uri:
            filtered_logs = [log for log in filtered_logs if log["resource_uri"] == resource_uri]
        
        if start_time:
            start_iso = start_time.isoformat()
            filtered_logs = [log for log in filtered_logs if log["timestamp"] >= start_iso]
        
        if end_time:
            end_iso = end_time.isoformat()
            filtered_logs = [log for log in filtered_logs if log["timestamp"] <= end_iso]
        
        if granted_only is not None:
            filtered_logs = [log for log in filtered_logs if log["granted"] == granted_only]
        
        return filtered_logs
    
    def get_access_statistics(self) -> Dict[str, Any]:
        """获取访问统计"""
        if not self.access_logs:
            return {"total_requests": 0}
        
        total_requests = len(self.access_logs)
        granted_requests = sum(1 for log in self.access_logs if log["granted"])
        denied_requests = total_requests - granted_requests
        
        # 按用户统计
        user_stats = {}
        for log in self.access_logs:
            user_id = log["user_id"]
            if user_id not in user_stats:
                user_stats[user_id] = {"total": 0, "granted": 0, "denied": 0}
            
            user_stats[user_id]["total"] += 1
            if log["granted"]:
                user_stats[user_id]["granted"] += 1
            else:
                user_stats[user_id]["denied"] += 1
        
        # 按资源统计
        resource_stats = {}
        for log in self.access_logs:
            resource_uri = log["resource_uri"]
            if resource_uri not in resource_stats:
                resource_stats[resource_uri] = {"total": 0, "granted": 0, "denied": 0}
            
            resource_stats[resource_uri]["total"] += 1
            if log["granted"]:
                resource_stats[resource_uri]["granted"] += 1
            else:
                resource_stats[resource_uri]["denied"] += 1
        
        return {
            "total_requests": total_requests,
            "granted_requests": granted_requests,
            "denied_requests": denied_requests,
            "success_rate": granted_requests / total_requests if total_requests > 0 else 0,
            "user_statistics": user_stats,
            "resource_statistics": resource_stats
        }

本章总结

本章详细介绍了MCP协议中的安全性与权限控制机制,包括:

核心内容

  1. 安全架构概述

    • 多层防护策略
    • 安全模型和上下文定义
    • 加密工具和数据脱敏
  2. 身份认证系统

    • 多种认证方法支持(API密钥、JWT、OAuth2等)
    • 认证提供者架构
    • 会话管理和安全控制
  3. 授权控制系统

    • 基于角色的访问控制(RBAC)
    • 资源权限管理
    • 安全策略引擎
  4. 审计和监控

    • 访问日志记录
    • 安全事件监控
    • 统计分析功能

最佳实践

  1. 安全设计原则

    • 最小权限原则
    • 深度防御策略
    • 零信任架构
  2. 认证安全

    • 强密码策略
    • 多因子认证
    • 令牌安全管理
  3. 授权控制

    • 细粒度权限控制
    • 动态权限评估
    • 权限继承和委托
  4. 监控和审计

    • 实时安全监控
    • 完整审计日志
    • 异常行为检测

下一章我们将学习性能优化与监控,了解如何优化MCP协议的性能并实现有效的监控机制。