安全架构概述
1. MCP安全模型
MCP协议的安全模型基于多层防护策略,确保通信安全、数据保护和访问控制:
- 传输层安全:TLS/SSL加密通信
- 身份认证:多种认证机制支持
- 授权控制:基于角色的访问控制(RBAC)
- 数据保护:敏感数据加密和脱敏
- 审计日志:完整的操作审计追踪
from typing import Dict, List, Optional, Any, Set, Callable
from enum import Enum
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import hashlib
import secrets
import jwt
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
import base64
import os
class SecurityLevel(Enum):
"""安全级别"""
PUBLIC = "public"
INTERNAL = "internal"
CONFIDENTIAL = "confidential"
SECRET = "secret"
TOP_SECRET = "top_secret"
class AuthenticationMethod(Enum):
"""认证方法"""
API_KEY = "api_key"
JWT_TOKEN = "jwt_token"
OAUTH2 = "oauth2"
CERTIFICATE = "certificate"
MULTI_FACTOR = "multi_factor"
class Permission(Enum):
"""权限类型"""
READ = "read"
WRITE = "write"
EXECUTE = "execute"
DELETE = "delete"
ADMIN = "admin"
AUDIT = "audit"
@dataclass
class SecurityContext:
"""安全上下文"""
user_id: str
session_id: str
roles: Set[str] = field(default_factory=set)
permissions: Set[Permission] = field(default_factory=set)
security_level: SecurityLevel = SecurityLevel.PUBLIC
authentication_method: AuthenticationMethod = AuthenticationMethod.API_KEY
authenticated_at: datetime = field(default_factory=datetime.now)
expires_at: Optional[datetime] = None
ip_address: Optional[str] = None
user_agent: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class SecurityPolicy:
"""安全策略"""
name: str
description: str
required_security_level: SecurityLevel
allowed_authentication_methods: Set[AuthenticationMethod]
required_permissions: Set[Permission]
ip_whitelist: Optional[Set[str]] = None
ip_blacklist: Optional[Set[str]] = None
rate_limit: Optional[int] = None # requests per minute
session_timeout: Optional[timedelta] = None
require_mfa: bool = False
audit_required: bool = True
custom_rules: List[Callable[[SecurityContext], bool]] = field(default_factory=list)
class SecurityException(Exception):
"""安全异常"""
def __init__(self, message: str, error_code: str = "SECURITY_ERROR",
context: Optional[SecurityContext] = None):
super().__init__(message)
self.error_code = error_code
self.context = context
self.timestamp = datetime.now()
class AuthenticationException(SecurityException):
"""认证异常"""
def __init__(self, message: str, context: Optional[SecurityContext] = None):
super().__init__(message, "AUTHENTICATION_ERROR", context)
class AuthorizationException(SecurityException):
"""授权异常"""
def __init__(self, message: str, context: Optional[SecurityContext] = None):
super().__init__(message, "AUTHORIZATION_ERROR", context)
2. 加密工具类
import json
from typing import Union, Tuple
class CryptographyManager:
"""加密管理器"""
def __init__(self, master_key: Optional[str] = None):
if master_key:
self.master_key = master_key.encode()
else:
self.master_key = os.urandom(32)
# 生成Fernet密钥
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=b'mcp_salt_2024', # 在生产环境中应使用随机盐
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(self.master_key))
self.fernet = Fernet(key)
def encrypt_data(self, data: Union[str, dict, list]) -> str:
"""加密数据"""
if isinstance(data, (dict, list)):
data = json.dumps(data, ensure_ascii=False)
elif not isinstance(data, str):
data = str(data)
encrypted = self.fernet.encrypt(data.encode('utf-8'))
return base64.urlsafe_b64encode(encrypted).decode('utf-8')
def decrypt_data(self, encrypted_data: str) -> str:
"""解密数据"""
try:
encrypted_bytes = base64.urlsafe_b64decode(encrypted_data.encode('utf-8'))
decrypted = self.fernet.decrypt(encrypted_bytes)
return decrypted.decode('utf-8')
except Exception as e:
raise SecurityException(f"数据解密失败: {e}")
def decrypt_json(self, encrypted_data: str) -> Union[dict, list]:
"""解密JSON数据"""
decrypted_str = self.decrypt_data(encrypted_data)
try:
return json.loads(decrypted_str)
except json.JSONDecodeError as e:
raise SecurityException(f"JSON解密失败: {e}")
def hash_password(self, password: str, salt: Optional[str] = None) -> Tuple[str, str]:
"""哈希密码"""
if salt is None:
salt = secrets.token_hex(16)
# 使用PBKDF2进行密码哈希
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt.encode(),
iterations=100000,
)
password_hash = base64.urlsafe_b64encode(kdf.derive(password.encode())).decode()
return password_hash, salt
def verify_password(self, password: str, password_hash: str, salt: str) -> bool:
"""验证密码"""
try:
computed_hash, _ = self.hash_password(password, salt)
return secrets.compare_digest(password_hash, computed_hash)
except Exception:
return False
def generate_token(self, length: int = 32) -> str:
"""生成安全令牌"""
return secrets.token_urlsafe(length)
def generate_api_key(self) -> str:
"""生成API密钥"""
return f"mcp_{secrets.token_urlsafe(32)}"
class DataMasking:
"""数据脱敏"""
@staticmethod
def mask_email(email: str) -> str:
"""脱敏邮箱地址"""
if '@' not in email:
return email
local, domain = email.split('@', 1)
if len(local) <= 2:
masked_local = '*' * len(local)
else:
masked_local = local[0] + '*' * (len(local) - 2) + local[-1]
return f"{masked_local}@{domain}"
@staticmethod
def mask_phone(phone: str) -> str:
"""脱敏电话号码"""
if len(phone) <= 4:
return '*' * len(phone)
return phone[:3] + '*' * (len(phone) - 6) + phone[-3:]
@staticmethod
def mask_id_card(id_card: str) -> str:
"""脱敏身份证号"""
if len(id_card) <= 8:
return '*' * len(id_card)
return id_card[:4] + '*' * (len(id_card) - 8) + id_card[-4:]
@staticmethod
def mask_credit_card(card_number: str) -> str:
"""脱敏信用卡号"""
if len(card_number) <= 8:
return '*' * len(card_number)
return card_number[:4] + '*' * (len(card_number) - 8) + card_number[-4:]
@staticmethod
def mask_sensitive_data(data: Any, sensitive_fields: Set[str]) -> Any:
"""脱敏敏感数据"""
if isinstance(data, dict):
masked_data = {}
for key, value in data.items():
if key.lower() in {field.lower() for field in sensitive_fields}:
if 'email' in key.lower():
masked_data[key] = DataMasking.mask_email(str(value))
elif 'phone' in key.lower():
masked_data[key] = DataMasking.mask_phone(str(value))
elif 'id' in key.lower() or 'card' in key.lower():
masked_data[key] = DataMasking.mask_id_card(str(value))
else:
masked_data[key] = '*' * len(str(value))
else:
masked_data[key] = DataMasking.mask_sensitive_data(value, sensitive_fields)
return masked_data
elif isinstance(data, list):
return [DataMasking.mask_sensitive_data(item, sensitive_fields) for item in data]
else:
return data
身份认证系统
1. 认证管理器
import time
from typing import Optional, Dict, Any
from abc import ABC, abstractmethod
class AuthenticationProvider(ABC):
"""认证提供者接口"""
@abstractmethod
async def authenticate(self, credentials: Dict[str, Any]) -> Optional[SecurityContext]:
"""认证用户"""
pass
@abstractmethod
async def validate_token(self, token: str) -> Optional[SecurityContext]:
"""验证令牌"""
pass
@abstractmethod
async def refresh_token(self, refresh_token: str) -> Optional[str]:
"""刷新令牌"""
pass
class APIKeyAuthProvider(AuthenticationProvider):
"""API密钥认证提供者"""
def __init__(self, crypto_manager: CryptographyManager):
self.crypto_manager = crypto_manager
self.api_keys: Dict[str, Dict[str, Any]] = {} # api_key -> user_info
self.user_keys: Dict[str, Set[str]] = {} # user_id -> api_keys
def create_api_key(self, user_id: str, roles: Set[str],
permissions: Set[Permission],
expires_in: Optional[timedelta] = None,
description: str = "") -> str:
"""创建API密钥"""
api_key = self.crypto_manager.generate_api_key()
expires_at = None
if expires_in:
expires_at = datetime.now() + expires_in
self.api_keys[api_key] = {
"user_id": user_id,
"roles": roles,
"permissions": permissions,
"created_at": datetime.now(),
"expires_at": expires_at,
"description": description,
"last_used": None,
"usage_count": 0
}
if user_id not in self.user_keys:
self.user_keys[user_id] = set()
self.user_keys[user_id].add(api_key)
return api_key
def revoke_api_key(self, api_key: str) -> bool:
"""撤销API密钥"""
if api_key in self.api_keys:
user_id = self.api_keys[api_key]["user_id"]
del self.api_keys[api_key]
if user_id in self.user_keys:
self.user_keys[user_id].discard(api_key)
return True
return False
def list_user_api_keys(self, user_id: str) -> List[Dict[str, Any]]:
"""列出用户的API密钥"""
user_api_keys = self.user_keys.get(user_id, set())
result = []
for api_key in user_api_keys:
if api_key in self.api_keys:
key_info = self.api_keys[api_key].copy()
key_info["api_key"] = api_key[:8] + "..." # 只显示前8位
result.append(key_info)
return result
async def authenticate(self, credentials: Dict[str, Any]) -> Optional[SecurityContext]:
"""API密钥认证"""
api_key = credentials.get("api_key")
if not api_key:
return None
return await self.validate_token(api_key)
async def validate_token(self, token: str) -> Optional[SecurityContext]:
"""验证API密钥"""
if token not in self.api_keys:
return None
key_info = self.api_keys[token]
# 检查过期时间
if key_info["expires_at"] and datetime.now() > key_info["expires_at"]:
return None
# 更新使用统计
key_info["last_used"] = datetime.now()
key_info["usage_count"] += 1
return SecurityContext(
user_id=key_info["user_id"],
session_id=f"api_{token[:8]}",
roles=key_info["roles"],
permissions=key_info["permissions"],
authentication_method=AuthenticationMethod.API_KEY,
authenticated_at=datetime.now()
)
async def refresh_token(self, refresh_token: str) -> Optional[str]:
"""API密钥不支持刷新"""
return None
class JWTAuthProvider(AuthenticationProvider):
"""JWT认证提供者"""
def __init__(self, secret_key: str, algorithm: str = "HS256"):
self.secret_key = secret_key
self.algorithm = algorithm
self.blacklisted_tokens: Set[str] = set()
def create_jwt_token(self, user_id: str, roles: Set[str],
permissions: Set[Permission],
expires_in: timedelta = timedelta(hours=24)) -> Dict[str, str]:
"""创建JWT令牌"""
now = datetime.now()
expires_at = now + expires_in
# 访问令牌载荷
access_payload = {
"user_id": user_id,
"roles": list(roles),
"permissions": [p.value for p in permissions],
"iat": int(now.timestamp()),
"exp": int(expires_at.timestamp()),
"type": "access"
}
# 刷新令牌载荷
refresh_expires_at = now + timedelta(days=30)
refresh_payload = {
"user_id": user_id,
"iat": int(now.timestamp()),
"exp": int(refresh_expires_at.timestamp()),
"type": "refresh"
}
access_token = jwt.encode(access_payload, self.secret_key, algorithm=self.algorithm)
refresh_token = jwt.encode(refresh_payload, self.secret_key, algorithm=self.algorithm)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "Bearer",
"expires_in": int(expires_in.total_seconds())
}
def blacklist_token(self, token: str):
"""将令牌加入黑名单"""
self.blacklisted_tokens.add(token)
async def authenticate(self, credentials: Dict[str, Any]) -> Optional[SecurityContext]:
"""JWT认证"""
username = credentials.get("username")
password = credentials.get("password")
# 这里应该验证用户名和密码
# 为了示例,我们假设验证成功
if username and password:
# 在实际应用中,这里应该查询数据库验证用户
return SecurityContext(
user_id=username,
session_id=f"jwt_{int(time.time())}",
roles={"user"},
permissions={Permission.READ, Permission.WRITE},
authentication_method=AuthenticationMethod.JWT_TOKEN
)
return None
async def validate_token(self, token: str) -> Optional[SecurityContext]:
"""验证JWT令牌"""
if token in self.blacklisted_tokens:
return None
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
if payload.get("type") != "access":
return None
roles = set(payload.get("roles", []))
permissions = {Permission(p) for p in payload.get("permissions", [])}
return SecurityContext(
user_id=payload["user_id"],
session_id=f"jwt_{payload['iat']}",
roles=roles,
permissions=permissions,
authentication_method=AuthenticationMethod.JWT_TOKEN,
authenticated_at=datetime.fromtimestamp(payload["iat"]),
expires_at=datetime.fromtimestamp(payload["exp"])
)
except jwt.ExpiredSignatureError:
raise AuthenticationException("令牌已过期")
except jwt.InvalidTokenError:
raise AuthenticationException("无效的令牌")
async def refresh_token(self, refresh_token: str) -> Optional[str]:
"""刷新JWT令牌"""
try:
payload = jwt.decode(refresh_token, self.secret_key, algorithms=[self.algorithm])
if payload.get("type") != "refresh":
return None
user_id = payload["user_id"]
# 创建新的访问令牌
# 这里应该从数据库获取用户的最新角色和权限
new_tokens = self.create_jwt_token(
user_id=user_id,
roles={"user"}, # 应该从数据库获取
permissions={Permission.READ, Permission.WRITE} # 应该从数据库获取
)
return new_tokens["access_token"]
except jwt.ExpiredSignatureError:
raise AuthenticationException("刷新令牌已过期")
except jwt.InvalidTokenError:
raise AuthenticationException("无效的刷新令牌")
class AuthenticationManager:
"""认证管理器"""
def __init__(self):
self.providers: Dict[AuthenticationMethod, AuthenticationProvider] = {}
self.active_sessions: Dict[str, SecurityContext] = {}
self.failed_attempts: Dict[str, List[datetime]] = {} # ip -> attempt_times
self.max_failed_attempts = 5
self.lockout_duration = timedelta(minutes=15)
def register_provider(self, method: AuthenticationMethod, provider: AuthenticationProvider):
"""注册认证提供者"""
self.providers[method] = provider
def is_ip_locked(self, ip_address: str) -> bool:
"""检查IP是否被锁定"""
if ip_address not in self.failed_attempts:
return False
attempts = self.failed_attempts[ip_address]
recent_attempts = [
attempt for attempt in attempts
if datetime.now() - attempt < self.lockout_duration
]
self.failed_attempts[ip_address] = recent_attempts
return len(recent_attempts) >= self.max_failed_attempts
def record_failed_attempt(self, ip_address: str):
"""记录失败尝试"""
if ip_address not in self.failed_attempts:
self.failed_attempts[ip_address] = []
self.failed_attempts[ip_address].append(datetime.now())
def clear_failed_attempts(self, ip_address: str):
"""清除失败尝试记录"""
self.failed_attempts.pop(ip_address, None)
async def authenticate(self, method: AuthenticationMethod,
credentials: Dict[str, Any],
ip_address: Optional[str] = None) -> SecurityContext:
"""认证用户"""
# 检查IP锁定
if ip_address and self.is_ip_locked(ip_address):
raise AuthenticationException(f"IP地址 {ip_address} 已被锁定")
# 获取认证提供者
provider = self.providers.get(method)
if not provider:
raise AuthenticationException(f"不支持的认证方法: {method}")
try:
# 执行认证
context = await provider.authenticate(credentials)
if not context:
if ip_address:
self.record_failed_attempt(ip_address)
raise AuthenticationException("认证失败")
# 设置IP地址
if ip_address:
context.ip_address = ip_address
self.clear_failed_attempts(ip_address)
# 保存会话
self.active_sessions[context.session_id] = context
return context
except AuthenticationException:
if ip_address:
self.record_failed_attempt(ip_address)
raise
async def validate_session(self, session_id: str) -> Optional[SecurityContext]:
"""验证会话"""
context = self.active_sessions.get(session_id)
if not context:
return None
# 检查过期时间
if context.expires_at and datetime.now() > context.expires_at:
del self.active_sessions[session_id]
return None
return context
async def validate_token(self, method: AuthenticationMethod, token: str) -> SecurityContext:
"""验证令牌"""
provider = self.providers.get(method)
if not provider:
raise AuthenticationException(f"不支持的认证方法: {method}")
context = await provider.validate_token(token)
if not context:
raise AuthenticationException("令牌验证失败")
return context
def logout(self, session_id: str) -> bool:
"""登出"""
if session_id in self.active_sessions:
del self.active_sessions[session_id]
return True
return False
def get_active_sessions(self, user_id: Optional[str] = None) -> List[SecurityContext]:
"""获取活跃会话"""
sessions = list(self.active_sessions.values())
if user_id:
sessions = [s for s in sessions if s.user_id == user_id]
return sessions
def cleanup_expired_sessions(self) -> int:
"""清理过期会话"""
now = datetime.now()
expired_sessions = []
for session_id, context in self.active_sessions.items():
if context.expires_at and now > context.expires_at:
expired_sessions.append(session_id)
for session_id in expired_sessions:
del self.active_sessions[session_id]
return len(expired_sessions)
授权控制系统
1. 基于角色的访问控制(RBAC)
from typing import Set, Dict, List, Optional, Any
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
@dataclass
class Role:
"""角色定义"""
name: str
description: str
permissions: Set[Permission] = field(default_factory=set)
parent_roles: Set[str] = field(default_factory=set) # 继承的父角色
metadata: Dict[str, Any] = field(default_factory=dict)
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
@dataclass
class Resource:
"""资源定义"""
uri: str
name: str
resource_type: str
security_level: SecurityLevel
owner: Optional[str] = None
required_permissions: Set[Permission] = field(default_factory=set)
metadata: Dict[str, Any] = field(default_factory=dict)
class AuthorizationProvider(ABC):
"""授权提供者接口"""
@abstractmethod
async def check_permission(self, context: SecurityContext,
resource: str, permission: Permission) -> bool:
"""检查权限"""
pass
@abstractmethod
async def get_user_permissions(self, user_id: str) -> Set[Permission]:
"""获取用户权限"""
pass
@abstractmethod
async def get_resource_permissions(self, resource_uri: str) -> Set[Permission]:
"""获取资源所需权限"""
pass
class RBACAuthorizationProvider(AuthorizationProvider):
"""RBAC授权提供者"""
def __init__(self):
self.roles: Dict[str, Role] = {}
self.user_roles: Dict[str, Set[str]] = {} # user_id -> role_names
self.resources: Dict[str, Resource] = {}
self.role_hierarchy_cache: Dict[str, Set[str]] = {} # 角色层次缓存
def create_role(self, name: str, description: str,
permissions: Set[Permission],
parent_roles: Optional[Set[str]] = None) -> Role:
"""创建角色"""
if name in self.roles:
raise ValueError(f"角色已存在: {name}")
role = Role(
name=name,
description=description,
permissions=permissions,
parent_roles=parent_roles or set()
)
self.roles[name] = role
self._invalidate_hierarchy_cache()
return role
def update_role(self, name: str, permissions: Optional[Set[Permission]] = None,
parent_roles: Optional[Set[str]] = None) -> bool:
"""更新角色"""
if name not in self.roles:
return False
role = self.roles[name]
if permissions is not None:
role.permissions = permissions
if parent_roles is not None:
role.parent_roles = parent_roles
role.updated_at = datetime.now()
self._invalidate_hierarchy_cache()
return True
def delete_role(self, name: str) -> bool:
"""删除角色"""
if name not in self.roles:
return False
# 检查是否有用户使用此角色
for user_roles in self.user_roles.values():
if name in user_roles:
raise ValueError(f"角色 {name} 正在被用户使用,无法删除")
# 检查是否有其他角色继承此角色
for role in self.roles.values():
if name in role.parent_roles:
raise ValueError(f"角色 {name} 被其他角色继承,无法删除")
del self.roles[name]
self._invalidate_hierarchy_cache()
return True
def assign_role_to_user(self, user_id: str, role_name: str) -> bool:
"""为用户分配角色"""
if role_name not in self.roles:
return False
if user_id not in self.user_roles:
self.user_roles[user_id] = set()
self.user_roles[user_id].add(role_name)
return True
def revoke_role_from_user(self, user_id: str, role_name: str) -> bool:
"""撤销用户角色"""
if user_id not in self.user_roles:
return False
self.user_roles[user_id].discard(role_name)
if not self.user_roles[user_id]:
del self.user_roles[user_id]
return True
def register_resource(self, resource: Resource):
"""注册资源"""
self.resources[resource.uri] = resource
def unregister_resource(self, resource_uri: str) -> bool:
"""注销资源"""
if resource_uri in self.resources:
del self.resources[resource_uri]
return True
return False
def _get_all_role_permissions(self, role_name: str) -> Set[Permission]:
"""获取角色的所有权限(包括继承的)"""
if role_name in self.role_hierarchy_cache:
all_roles = self.role_hierarchy_cache[role_name]
else:
all_roles = self._compute_role_hierarchy(role_name)
self.role_hierarchy_cache[role_name] = all_roles
permissions = set()
for role_name in all_roles:
if role_name in self.roles:
permissions.update(self.roles[role_name].permissions)
return permissions
def _compute_role_hierarchy(self, role_name: str, visited: Optional[Set[str]] = None) -> Set[str]:
"""计算角色层次结构"""
if visited is None:
visited = set()
if role_name in visited:
raise ValueError(f"检测到角色循环依赖: {role_name}")
if role_name not in self.roles:
return set()
visited.add(role_name)
all_roles = {role_name}
role = self.roles[role_name]
for parent_role in role.parent_roles:
parent_hierarchy = self._compute_role_hierarchy(parent_role, visited.copy())
all_roles.update(parent_hierarchy)
return all_roles
def _invalidate_hierarchy_cache(self):
"""使角色层次缓存失效"""
self.role_hierarchy_cache.clear()
async def check_permission(self, context: SecurityContext,
resource_uri: str, permission: Permission) -> bool:
"""检查权限"""
# 检查用户是否有直接权限
if permission in context.permissions:
return True
# 检查角色权限
user_permissions = await self.get_user_permissions(context.user_id)
if permission in user_permissions:
return True
# 检查资源特定权限
resource = self.resources.get(resource_uri)
if resource:
# 检查资源所有者权限
if resource.owner == context.user_id:
return True
# 检查安全级别
if context.security_level.value < resource.security_level.value:
return False
# 检查资源所需权限
if permission in resource.required_permissions:
return permission in user_permissions
return False
async def get_user_permissions(self, user_id: str) -> Set[Permission]:
"""获取用户权限"""
user_roles = self.user_roles.get(user_id, set())
permissions = set()
for role_name in user_roles:
role_permissions = self._get_all_role_permissions(role_name)
permissions.update(role_permissions)
return permissions
async def get_resource_permissions(self, resource_uri: str) -> Set[Permission]:
"""获取资源所需权限"""
resource = self.resources.get(resource_uri)
if resource:
return resource.required_permissions
return set()
def get_user_roles(self, user_id: str) -> Set[str]:
"""获取用户角色"""
return self.user_roles.get(user_id, set()).copy()
def get_role_users(self, role_name: str) -> Set[str]:
"""获取拥有指定角色的用户"""
users = set()
for user_id, roles in self.user_roles.items():
if role_name in roles:
users.add(user_id)
return users
def export_rbac_config(self) -> Dict[str, Any]:
"""导出RBAC配置"""
return {
"roles": {
name: {
"description": role.description,
"permissions": [p.value for p in role.permissions],
"parent_roles": list(role.parent_roles),
"created_at": role.created_at.isoformat(),
"updated_at": role.updated_at.isoformat()
}
for name, role in self.roles.items()
},
"user_roles": {
user_id: list(roles)
for user_id, roles in self.user_roles.items()
},
"resources": {
uri: {
"name": resource.name,
"resource_type": resource.resource_type,
"security_level": resource.security_level.value,
"owner": resource.owner,
"required_permissions": [p.value for p in resource.required_permissions]
}
for uri, resource in self.resources.items()
}
}
def import_rbac_config(self, config: Dict[str, Any]):
"""导入RBAC配置"""
# 导入角色
for name, role_data in config.get("roles", {}).items():
permissions = {Permission(p) for p in role_data.get("permissions", [])}
parent_roles = set(role_data.get("parent_roles", []))
role = Role(
name=name,
description=role_data.get("description", ""),
permissions=permissions,
parent_roles=parent_roles,
created_at=datetime.fromisoformat(role_data.get("created_at", datetime.now().isoformat())),
updated_at=datetime.fromisoformat(role_data.get("updated_at", datetime.now().isoformat()))
)
self.roles[name] = role
# 导入用户角色
for user_id, roles in config.get("user_roles", {}).items():
self.user_roles[user_id] = set(roles)
# 导入资源
for uri, resource_data in config.get("resources", {}).items():
required_permissions = {Permission(p) for p in resource_data.get("required_permissions", [])}
resource = Resource(
uri=uri,
name=resource_data.get("name", ""),
resource_type=resource_data.get("resource_type", ""),
security_level=SecurityLevel(resource_data.get("security_level", "public")),
owner=resource_data.get("owner"),
required_permissions=required_permissions
)
self.resources[uri] = resource
self._invalidate_hierarchy_cache()
2. 授权管理器
class AuthorizationManager:
"""授权管理器"""
def __init__(self):
self.providers: Dict[str, AuthorizationProvider] = {}
self.policies: Dict[str, SecurityPolicy] = {}
self.access_logs: List[Dict[str, Any]] = []
self.max_log_entries = 10000
def register_provider(self, name: str, provider: AuthorizationProvider):
"""注册授权提供者"""
self.providers[name] = provider
def register_policy(self, policy: SecurityPolicy):
"""注册安全策略"""
self.policies[policy.name] = policy
def get_policy(self, policy_name: str) -> Optional[SecurityPolicy]:
"""获取安全策略"""
return self.policies.get(policy_name)
async def check_access(self, context: SecurityContext, resource_uri: str,
permission: Permission, policy_name: Optional[str] = None) -> bool:
"""检查访问权限"""
try:
# 应用安全策略
if policy_name:
policy = self.get_policy(policy_name)
if policy and not await self._check_policy(context, policy):
self._log_access(context, resource_uri, permission, False, "策略检查失败")
return False
# 检查基本权限
access_granted = False
for provider in self.providers.values():
if await provider.check_permission(context, resource_uri, permission):
access_granted = True
break
self._log_access(context, resource_uri, permission, access_granted)
return access_granted
except Exception as e:
self._log_access(context, resource_uri, permission, False, f"检查失败: {e}")
return False
async def _check_policy(self, context: SecurityContext, policy: SecurityPolicy) -> bool:
"""检查安全策略"""
# 检查安全级别
if context.security_level.value < policy.required_security_level.value:
return False
# 检查认证方法
if context.authentication_method not in policy.allowed_authentication_methods:
return False
# 检查权限
if not policy.required_permissions.issubset(context.permissions):
return False
# 检查IP白名单
if policy.ip_whitelist and context.ip_address:
if context.ip_address not in policy.ip_whitelist:
return False
# 检查IP黑名单
if policy.ip_blacklist and context.ip_address:
if context.ip_address in policy.ip_blacklist:
return False
# 检查会话超时
if policy.session_timeout and context.authenticated_at:
if datetime.now() - context.authenticated_at > policy.session_timeout:
return False
# 检查多因子认证
if policy.require_mfa:
# 这里应该检查MFA状态
pass
# 执行自定义规则
for rule in policy.custom_rules:
try:
if not rule(context):
return False
except Exception as e:
print(f"自定义规则执行失败: {e}")
return False
return True
def _log_access(self, context: SecurityContext, resource_uri: str,
permission: Permission, granted: bool, reason: str = ""):
"""记录访问日志"""
log_entry = {
"timestamp": datetime.now().isoformat(),
"user_id": context.user_id,
"session_id": context.session_id,
"resource_uri": resource_uri,
"permission": permission.value,
"granted": granted,
"reason": reason,
"ip_address": context.ip_address,
"user_agent": context.user_agent
}
self.access_logs.append(log_entry)
# 限制日志条目数量
if len(self.access_logs) > self.max_log_entries:
self.access_logs = self.access_logs[-self.max_log_entries//2:]
def get_access_logs(self, user_id: Optional[str] = None,
resource_uri: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
granted_only: Optional[bool] = None) -> List[Dict[str, Any]]:
"""获取访问日志"""
filtered_logs = self.access_logs.copy()
if user_id:
filtered_logs = [log for log in filtered_logs if log["user_id"] == user_id]
if resource_uri:
filtered_logs = [log for log in filtered_logs if log["resource_uri"] == resource_uri]
if start_time:
start_iso = start_time.isoformat()
filtered_logs = [log for log in filtered_logs if log["timestamp"] >= start_iso]
if end_time:
end_iso = end_time.isoformat()
filtered_logs = [log for log in filtered_logs if log["timestamp"] <= end_iso]
if granted_only is not None:
filtered_logs = [log for log in filtered_logs if log["granted"] == granted_only]
return filtered_logs
def get_access_statistics(self) -> Dict[str, Any]:
"""获取访问统计"""
if not self.access_logs:
return {"total_requests": 0}
total_requests = len(self.access_logs)
granted_requests = sum(1 for log in self.access_logs if log["granted"])
denied_requests = total_requests - granted_requests
# 按用户统计
user_stats = {}
for log in self.access_logs:
user_id = log["user_id"]
if user_id not in user_stats:
user_stats[user_id] = {"total": 0, "granted": 0, "denied": 0}
user_stats[user_id]["total"] += 1
if log["granted"]:
user_stats[user_id]["granted"] += 1
else:
user_stats[user_id]["denied"] += 1
# 按资源统计
resource_stats = {}
for log in self.access_logs:
resource_uri = log["resource_uri"]
if resource_uri not in resource_stats:
resource_stats[resource_uri] = {"total": 0, "granted": 0, "denied": 0}
resource_stats[resource_uri]["total"] += 1
if log["granted"]:
resource_stats[resource_uri]["granted"] += 1
else:
resource_stats[resource_uri]["denied"] += 1
return {
"total_requests": total_requests,
"granted_requests": granted_requests,
"denied_requests": denied_requests,
"success_rate": granted_requests / total_requests if total_requests > 0 else 0,
"user_statistics": user_stats,
"resource_statistics": resource_stats
}
本章总结
本章详细介绍了MCP协议中的安全性与权限控制机制,包括:
核心内容
安全架构概述
- 多层防护策略
- 安全模型和上下文定义
- 加密工具和数据脱敏
身份认证系统
- 多种认证方法支持(API密钥、JWT、OAuth2等)
- 认证提供者架构
- 会话管理和安全控制
授权控制系统
- 基于角色的访问控制(RBAC)
- 资源权限管理
- 安全策略引擎
审计和监控
- 访问日志记录
- 安全事件监控
- 统计分析功能
最佳实践
安全设计原则
- 最小权限原则
- 深度防御策略
- 零信任架构
认证安全
- 强密码策略
- 多因子认证
- 令牌安全管理
授权控制
- 细粒度权限控制
- 动态权限评估
- 权限继承和委托
监控和审计
- 实时安全监控
- 完整审计日志
- 异常行为检测
下一章我们将学习性能优化与监控,了解如何优化MCP协议的性能并实现有效的监控机制。