工具设计原则
1. 单一职责原则
每个工具应该只负责一个明确的功能,这样可以提高可维护性和可测试性。
好的设计:
# 分离的工具设计
@server.call_tool()
async def handle_read_file(arguments):
"""只负责读取文件"""
pass
@server.call_tool()
async def handle_write_file(arguments):
"""只负责写入文件"""
pass
@server.call_tool()
async def handle_delete_file(arguments):
"""只负责删除文件"""
pass
不好的设计:
# 职责混乱的工具设计
@server.call_tool()
async def handle_file_operations(arguments):
"""一个工具处理所有文件操作"""
operation = arguments.get('operation')
if operation == 'read':
# 读取逻辑
pass
elif operation == 'write':
# 写入逻辑
pass
elif operation == 'delete':
# 删除逻辑
pass
# 这种设计难以维护和测试
2. 接口一致性原则
所有工具应该遵循一致的接口设计模式。
统一的参数结构:
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any
class BaseToolInput(BaseModel):
"""工具输入基类"""
pass
class FileOperationInput(BaseToolInput):
"""文件操作输入"""
path: str = Field(description="文件路径")
encoding: Optional[str] = Field(default="utf-8", description="文件编码")
class NetworkRequestInput(BaseToolInput):
"""网络请求输入"""
url: str = Field(description="请求URL")
method: Optional[str] = Field(default="GET", description="HTTP方法")
headers: Optional[Dict[str, str]] = Field(default=None, description="请求头")
timeout: Optional[int] = Field(default=30, description="超时时间(秒)")
class DatabaseQueryInput(BaseToolInput):
"""数据库查询输入"""
query: str = Field(description="SQL查询语句")
parameters: Optional[Dict[str, Any]] = Field(default=None, description="查询参数")
limit: Optional[int] = Field(default=100, description="结果限制")
统一的返回结构:
from mcp.types import CallToolResult, TextContent, ImageContent
from typing import List, Union, Optional
class ToolResponse:
"""工具响应封装类"""
@staticmethod
def success(message: str, data: Optional[Dict[str, Any]] = None) -> CallToolResult:
"""成功响应"""
content = [TextContent(type="text", text=message)]
if data:
content.append(TextContent(
type="text",
text=f"数据: {json.dumps(data, ensure_ascii=False, indent=2)}"
))
return CallToolResult(content=content, isError=False)
@staticmethod
def error(message: str, error_code: Optional[str] = None) -> CallToolResult:
"""错误响应"""
error_text = f"错误: {message}"
if error_code:
error_text += f" (错误码: {error_code})"
return CallToolResult(
content=[TextContent(type="text", text=error_text)],
isError=True
)
@staticmethod
def with_file_content(file_path: str, content: str,
content_type: str = "text") -> CallToolResult:
"""文件内容响应"""
if content_type == "text":
return CallToolResult(
content=[
TextContent(
type="text",
text=f"文件: {file_path}\n\n{content}"
)
]
)
# 可以扩展支持其他内容类型
return ToolResponse.success(f"文件读取成功: {file_path}")
3. 错误处理原则
分层错误处理:
import logging
from enum import Enum
from typing import Optional
class ErrorLevel(Enum):
"""错误级别"""
WARNING = "warning"
ERROR = "error"
CRITICAL = "critical"
class ToolError(Exception):
"""工具错误基类"""
def __init__(self, message: str, error_code: str = None,
level: ErrorLevel = ErrorLevel.ERROR,
details: Optional[Dict[str, Any]] = None):
super().__init__(message)
self.message = message
self.error_code = error_code
self.level = level
self.details = details or {}
class ValidationError(ToolError):
"""参数验证错误"""
def __init__(self, message: str, field: str = None):
super().__init__(message, "VALIDATION_ERROR", ErrorLevel.WARNING)
if field:
self.details["field"] = field
class ResourceNotFoundError(ToolError):
"""资源未找到错误"""
def __init__(self, resource_type: str, resource_id: str):
message = f"{resource_type} '{resource_id}' 未找到"
super().__init__(message, "RESOURCE_NOT_FOUND", ErrorLevel.ERROR)
self.details.update({
"resource_type": resource_type,
"resource_id": resource_id
})
class PermissionError(ToolError):
"""权限错误"""
def __init__(self, operation: str, resource: str):
message = f"没有权限执行操作 '{operation}' 在资源 '{resource}' 上"
super().__init__(message, "PERMISSION_DENIED", ErrorLevel.ERROR)
self.details.update({
"operation": operation,
"resource": resource
})
# 错误处理装饰器
def handle_tool_errors(func):
"""工具错误处理装饰器"""
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except ToolError as e:
logging.error(f"工具错误: {e.message}", extra={
"error_code": e.error_code,
"level": e.level.value,
"details": e.details
})
return ToolResponse.error(e.message, e.error_code)
except Exception as e:
logging.exception(f"未处理的工具错误: {str(e)}")
return ToolResponse.error(
f"内部错误: {str(e)}",
"INTERNAL_ERROR"
)
return wrapper
高级工具开发模式
1. 工具工厂模式
from abc import ABC, abstractmethod
from typing import Type, Dict, Any
class ToolInterface(ABC):
"""工具接口"""
@abstractmethod
async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
"""执行工具"""
pass
@abstractmethod
def get_schema(self) -> Dict[str, Any]:
"""获取工具模式"""
pass
@property
@abstractmethod
def name(self) -> str:
"""工具名称"""
pass
@property
@abstractmethod
def description(self) -> str:
"""工具描述"""
pass
class FileReadTool(ToolInterface):
"""文件读取工具"""
@property
def name(self) -> str:
return "read_file"
@property
def description(self) -> str:
return "读取文件内容"
def get_schema(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "文件路径"
},
"encoding": {
"type": "string",
"description": "文件编码",
"default": "utf-8"
}
},
"required": ["path"]
}
@handle_tool_errors
async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
input_data = FileOperationInput(**arguments)
try:
async with aiofiles.open(input_data.path, 'r',
encoding=input_data.encoding) as f:
content = await f.read()
return ToolResponse.with_file_content(
input_data.path, content
)
except FileNotFoundError:
raise ResourceNotFoundError("文件", input_data.path)
except PermissionError:
raise PermissionError("读取", input_data.path)
class HttpRequestTool(ToolInterface):
"""HTTP请求工具"""
@property
def name(self) -> str:
return "http_request"
@property
def description(self) -> str:
return "发送HTTP请求"
def get_schema(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "请求URL"
},
"method": {
"type": "string",
"description": "HTTP方法",
"enum": ["GET", "POST", "PUT", "DELETE"],
"default": "GET"
},
"headers": {
"type": "object",
"description": "请求头"
},
"data": {
"type": "object",
"description": "请求数据"
},
"timeout": {
"type": "integer",
"description": "超时时间(秒)",
"default": 30
}
},
"required": ["url"]
}
@handle_tool_errors
async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
input_data = NetworkRequestInput(**arguments)
import aiohttp
try:
async with aiohttp.ClientSession() as session:
async with session.request(
method=input_data.method,
url=input_data.url,
headers=input_data.headers,
json=arguments.get('data'),
timeout=aiohttp.ClientTimeout(total=input_data.timeout)
) as response:
content = await response.text()
result_data = {
"status_code": response.status,
"headers": dict(response.headers),
"content": content
}
return ToolResponse.success(
f"HTTP请求成功: {input_data.method} {input_data.url}",
result_data
)
except aiohttp.ClientTimeout:
raise ToolError(f"请求超时: {input_data.url}", "REQUEST_TIMEOUT")
except aiohttp.ClientError as e:
raise ToolError(f"请求失败: {str(e)}", "REQUEST_FAILED")
class ToolFactory:
"""工具工厂"""
def __init__(self):
self._tools: Dict[str, Type[ToolInterface]] = {}
def register(self, tool_class: Type[ToolInterface]):
"""注册工具"""
tool_instance = tool_class()
self._tools[tool_instance.name] = tool_class
def create(self, tool_name: str) -> ToolInterface:
"""创建工具实例"""
if tool_name not in self._tools:
raise ValueError(f"未知工具: {tool_name}")
return self._tools[tool_name]()
def get_all_tools(self) -> List[Dict[str, Any]]:
"""获取所有工具信息"""
tools = []
for tool_class in self._tools.values():
tool_instance = tool_class()
tools.append({
"name": tool_instance.name,
"description": tool_instance.description,
"inputSchema": tool_instance.get_schema()
})
return tools
# 使用工具工厂
tool_factory = ToolFactory()
tool_factory.register(FileReadTool)
tool_factory.register(HttpRequestTool)
@server.list_tools()
async def handle_list_tools() -> ListToolsResult:
"""使用工厂获取工具列表"""
return ListToolsResult(tools=tool_factory.get_all_tools())
@server.call_tool()
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> CallToolResult:
"""使用工厂调用工具"""
try:
tool = tool_factory.create(name)
return await tool.execute(arguments)
except ValueError as e:
return ToolResponse.error(str(e), "TOOL_NOT_FOUND")
2. 工具组合模式
class CompositeToolInterface(ToolInterface):
"""组合工具接口"""
def __init__(self):
self.sub_tools: List[ToolInterface] = []
def add_tool(self, tool: ToolInterface):
"""添加子工具"""
self.sub_tools.append(tool)
def remove_tool(self, tool: ToolInterface):
"""移除子工具"""
if tool in self.sub_tools:
self.sub_tools.remove(tool)
class FileProcessingPipeline(CompositeToolInterface):
"""文件处理管道"""
@property
def name(self) -> str:
return "file_processing_pipeline"
@property
def description(self) -> str:
return "文件处理管道,支持读取、处理、写入文件"
def get_schema(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"input_path": {
"type": "string",
"description": "输入文件路径"
},
"output_path": {
"type": "string",
"description": "输出文件路径"
},
"operations": {
"type": "array",
"description": "处理操作列表",
"items": {
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["uppercase", "lowercase", "replace", "filter"]
},
"params": {
"type": "object",
"description": "操作参数"
}
}
}
}
},
"required": ["input_path", "output_path"]
}
@handle_tool_errors
async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
input_path = arguments["input_path"]
output_path = arguments["output_path"]
operations = arguments.get("operations", [])
# 1. 读取文件
read_tool = FileReadTool()
read_result = await read_tool.execute({"path": input_path})
if read_result.isError:
return read_result
# 提取文件内容
content = read_result.content[0].text.split("\n\n", 1)[1] # 去掉文件路径行
# 2. 处理内容
for operation in operations:
content = await self._apply_operation(content, operation)
# 3. 写入文件
write_tool = FileWriteTool()
write_result = await write_tool.execute({
"path": output_path,
"content": content
})
if write_result.isError:
return write_result
return ToolResponse.success(
f"文件处理完成: {input_path} -> {output_path}",
{
"operations_applied": len(operations),
"output_size": len(content)
}
)
async def _apply_operation(self, content: str, operation: Dict[str, Any]) -> str:
"""应用处理操作"""
op_type = operation["type"]
params = operation.get("params", {})
if op_type == "uppercase":
return content.upper()
elif op_type == "lowercase":
return content.lower()
elif op_type == "replace":
old = params.get("old", "")
new = params.get("new", "")
return content.replace(old, new)
elif op_type == "filter":
pattern = params.get("pattern", "")
import re
lines = content.split("\n")
filtered_lines = [line for line in lines if re.search(pattern, line)]
return "\n".join(filtered_lines)
else:
return content
class FileWriteTool(ToolInterface):
"""文件写入工具"""
@property
def name(self) -> str:
return "write_file"
@property
def description(self) -> str:
return "写入文件内容"
def get_schema(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "文件路径"
},
"content": {
"type": "string",
"description": "文件内容"
},
"encoding": {
"type": "string",
"description": "文件编码",
"default": "utf-8"
}
},
"required": ["path", "content"]
}
@handle_tool_errors
async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
input_data = FileOperationInput(**arguments)
content = arguments["content"]
try:
# 确保目录存在
import os
os.makedirs(os.path.dirname(input_data.path), exist_ok=True)
async with aiofiles.open(input_data.path, 'w',
encoding=input_data.encoding) as f:
await f.write(content)
return ToolResponse.success(
f"文件写入成功: {input_data.path}",
{"bytes_written": len(content.encode(input_data.encoding))}
)
except PermissionError:
raise PermissionError("写入", input_data.path)
3. 工具链模式
from typing import List, Callable, Any
class ToolChain:
"""工具链"""
def __init__(self, name: str, description: str):
self.name = name
self.description = description
self.steps: List[Dict[str, Any]] = []
def add_step(self, tool_name: str, arguments: Dict[str, Any],
condition: Optional[Callable[[Any], bool]] = None):
"""添加步骤"""
self.steps.append({
"tool_name": tool_name,
"arguments": arguments,
"condition": condition
})
return self
def add_conditional_step(self, condition: Callable[[Any], bool],
tool_name: str, arguments: Dict[str, Any]):
"""添加条件步骤"""
return self.add_step(tool_name, arguments, condition)
async def execute(self, initial_context: Dict[str, Any] = None) -> Dict[str, Any]:
"""执行工具链"""
context = initial_context or {}
results = []
for i, step in enumerate(self.steps):
# 检查条件
if step["condition"] and not step["condition"](context):
continue
# 解析参数中的变量
resolved_args = self._resolve_arguments(step["arguments"], context)
# 执行工具
tool = tool_factory.create(step["tool_name"])
result = await tool.execute(resolved_args)
# 更新上下文
step_result = {
"step": i,
"tool_name": step["tool_name"],
"arguments": resolved_args,
"result": result,
"success": not result.isError
}
results.append(step_result)
context[f"step_{i}_result"] = step_result
# 如果步骤失败且没有错误处理,停止执行
if result.isError:
break
return {
"chain_name": self.name,
"steps_executed": len(results),
"success": all(r["success"] for r in results),
"results": results,
"final_context": context
}
def _resolve_arguments(self, arguments: Dict[str, Any],
context: Dict[str, Any]) -> Dict[str, Any]:
"""解析参数中的变量引用"""
resolved = {}
for key, value in arguments.items():
if isinstance(value, str) and value.startswith("${"):
# 变量引用,如 ${step_0_result.content}
var_path = value[2:-1] # 去掉 ${ 和 }
resolved[key] = self._get_nested_value(context, var_path)
else:
resolved[key] = value
return resolved
def _get_nested_value(self, obj: Any, path: str) -> Any:
"""获取嵌套值"""
parts = path.split(".")
current = obj
for part in parts:
if isinstance(current, dict):
current = current.get(part)
elif hasattr(current, part):
current = getattr(current, part)
else:
return None
return current
# 工具链构建器
class ToolChainBuilder:
"""工具链构建器"""
@staticmethod
def create_web_scraping_chain() -> ToolChain:
"""创建网页抓取工具链"""
chain = ToolChain(
"web_scraping",
"网页内容抓取和处理工具链"
)
return (chain
.add_step("http_request", {
"url": "${target_url}",
"method": "GET"
})
.add_conditional_step(
lambda ctx: ctx.get("step_0_result", {}).get("success", False),
"extract_text",
{"html": "${step_0_result.result.content}"}
)
.add_step("write_file", {
"path": "${output_file}",
"content": "${step_1_result.result.text}"
}))
@staticmethod
def create_data_processing_chain() -> ToolChain:
"""创建数据处理工具链"""
chain = ToolChain(
"data_processing",
"数据读取、处理和保存工具链"
)
return (chain
.add_step("read_file", {
"path": "${input_file}"
})
.add_step("process_data", {
"data": "${step_0_result.result.content}",
"operations": "${operations}"
})
.add_step("write_file", {
"path": "${output_file}",
"content": "${step_1_result.result.processed_data}"
}))
# 工具链执行工具
class ToolChainExecutor(ToolInterface):
"""工具链执行器"""
def __init__(self):
self.chains = {
"web_scraping": ToolChainBuilder.create_web_scraping_chain(),
"data_processing": ToolChainBuilder.create_data_processing_chain()
}
@property
def name(self) -> str:
return "execute_chain"
@property
def description(self) -> str:
return "执行预定义的工具链"
def get_schema(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"chain_name": {
"type": "string",
"description": "工具链名称",
"enum": list(self.chains.keys())
},
"context": {
"type": "object",
"description": "执行上下文"
}
},
"required": ["chain_name"]
}
@handle_tool_errors
async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
chain_name = arguments["chain_name"]
context = arguments.get("context", {})
if chain_name not in self.chains:
raise ValidationError(f"未知工具链: {chain_name}", "chain_name")
chain = self.chains[chain_name]
result = await chain.execute(context)
return ToolResponse.success(
f"工具链执行完成: {chain_name}",
result
)
工具参数验证和类型安全
1. 高级参数验证
from pydantic import BaseModel, Field, validator, root_validator
from typing import Union, List, Optional
import re
from pathlib import Path
class AdvancedFileInput(BaseModel):
"""高级文件输入验证"""
path: str = Field(description="文件路径")
encoding: str = Field(default="utf-8", description="文件编码")
max_size: Optional[int] = Field(default=None, description="最大文件大小(字节)")
allowed_extensions: Optional[List[str]] = Field(
default=None,
description="允许的文件扩展名"
)
@validator('path')
def validate_path(cls, v):
"""验证路径格式"""
if not v or not v.strip():
raise ValueError("路径不能为空")
# 检查路径安全性
if ".." in v or v.startswith("/"):
raise ValueError("路径包含不安全字符")
return v.strip()
@validator('encoding')
def validate_encoding(cls, v):
"""验证编码格式"""
import codecs
try:
codecs.lookup(v)
except LookupError:
raise ValueError(f"不支持的编码格式: {v}")
return v
@validator('max_size')
def validate_max_size(cls, v):
"""验证最大文件大小"""
if v is not None and v <= 0:
raise ValueError("最大文件大小必须大于0")
return v
@validator('allowed_extensions')
def validate_extensions(cls, v):
"""验证文件扩展名"""
if v is not None:
# 确保扩展名以点开头
normalized = []
for ext in v:
if not ext.startswith('.'):
ext = '.' + ext
normalized.append(ext.lower())
return normalized
return v
@root_validator
def validate_file_constraints(cls, values):
"""验证文件约束"""
path = values.get('path')
allowed_extensions = values.get('allowed_extensions')
if path and allowed_extensions:
file_ext = Path(path).suffix.lower()
if file_ext not in allowed_extensions:
raise ValueError(
f"文件扩展名 '{file_ext}' 不在允许列表中: {allowed_extensions}"
)
return values
class NetworkRequestInput(BaseModel):
"""网络请求输入验证"""
url: str = Field(description="请求URL")
method: str = Field(default="GET", description="HTTP方法")
headers: Optional[Dict[str, str]] = Field(default=None, description="请求头")
timeout: int = Field(default=30, ge=1, le=300, description="超时时间(秒)")
follow_redirects: bool = Field(default=True, description="是否跟随重定向")
max_redirects: int = Field(default=5, ge=0, le=20, description="最大重定向次数")
@validator('url')
def validate_url(cls, v):
"""验证URL格式"""
import urllib.parse
try:
result = urllib.parse.urlparse(v)
if not all([result.scheme, result.netloc]):
raise ValueError("URL格式无效")
# 只允许HTTP和HTTPS
if result.scheme not in ['http', 'https']:
raise ValueError("只支持HTTP和HTTPS协议")
return v
except Exception:
raise ValueError("URL格式无效")
@validator('method')
def validate_method(cls, v):
"""验证HTTP方法"""
allowed_methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS']
if v.upper() not in allowed_methods:
raise ValueError(f"不支持的HTTP方法: {v}")
return v.upper()
@validator('headers')
def validate_headers(cls, v):
"""验证请求头"""
if v is not None:
# 检查危险的请求头
dangerous_headers = ['authorization', 'cookie', 'x-api-key']
for header in v.keys():
if header.lower() in dangerous_headers:
raise ValueError(f"不允许设置敏感请求头: {header}")
return v
class DatabaseQueryInput(BaseModel):
"""数据库查询输入验证"""
query: str = Field(description="SQL查询语句")
parameters: Optional[Dict[str, Any]] = Field(default=None, description="查询参数")
limit: int = Field(default=100, ge=1, le=10000, description="结果限制")
timeout: int = Field(default=30, ge=1, le=300, description="查询超时(秒)")
@validator('query')
def validate_query(cls, v):
"""验证SQL查询"""
if not v or not v.strip():
raise ValueError("查询语句不能为空")
# 基本的SQL注入防护
dangerous_keywords = [
'drop', 'delete', 'truncate', 'alter', 'create',
'insert', 'update', 'exec', 'execute'
]
query_lower = v.lower()
for keyword in dangerous_keywords:
if keyword in query_lower:
raise ValueError(f"查询包含危险关键字: {keyword}")
# 只允许SELECT语句
if not query_lower.strip().startswith('select'):
raise ValueError("只允许SELECT查询")
return v.strip()
2. 类型安全的工具接口
from typing import TypeVar, Generic, Type
from pydantic import BaseModel
T = TypeVar('T', bound=BaseModel)
R = TypeVar('R')
class TypedToolInterface(Generic[T, R]):
"""类型安全的工具接口"""
def __init__(self, input_model: Type[T]):
self.input_model = input_model
@abstractmethod
async def execute_typed(self, input_data: T) -> R:
"""类型安全的执行方法"""
pass
async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
"""通用执行方法"""
try:
# 验证和转换输入
input_data = self.input_model(**arguments)
# 执行类型安全的方法
result = await self.execute_typed(input_data)
# 转换结果
return self._convert_result(result)
except ValidationError as e:
return ToolResponse.error(f"参数验证失败: {str(e)}")
except Exception as e:
return ToolResponse.error(f"执行失败: {str(e)}")
def _convert_result(self, result: R) -> CallToolResult:
"""转换结果为CallToolResult"""
if isinstance(result, CallToolResult):
return result
elif isinstance(result, str):
return ToolResponse.success(result)
elif isinstance(result, dict):
return ToolResponse.success("操作成功", result)
else:
return ToolResponse.success(str(result))
class FileReadResult(BaseModel):
"""文件读取结果"""
path: str
content: str
size: int
encoding: str
lines: int
class TypedFileReadTool(TypedToolInterface[AdvancedFileInput, FileReadResult]):
"""类型安全的文件读取工具"""
def __init__(self):
super().__init__(AdvancedFileInput)
@property
def name(self) -> str:
return "typed_read_file"
@property
def description(self) -> str:
return "类型安全的文件读取工具"
def get_schema(self) -> Dict[str, Any]:
return self.input_model.schema()
async def execute_typed(self, input_data: AdvancedFileInput) -> FileReadResult:
"""类型安全的文件读取"""
file_path = Path(input_data.path)
# 检查文件是否存在
if not file_path.exists():
raise ResourceNotFoundError("文件", str(file_path))
# 检查文件大小
file_size = file_path.stat().st_size
if input_data.max_size and file_size > input_data.max_size:
raise ValidationError(
f"文件大小 {file_size} 超过限制 {input_data.max_size}"
)
# 读取文件
try:
with open(file_path, 'r', encoding=input_data.encoding) as f:
content = f.read()
lines = len(content.splitlines())
return FileReadResult(
path=str(file_path),
content=content,
size=file_size,
encoding=input_data.encoding,
lines=lines
)
except UnicodeDecodeError:
raise ValidationError(f"无法使用编码 {input_data.encoding} 读取文件")
工具性能优化
1. 异步并发优化
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import List, Callable, Any
class ConcurrentToolExecutor:
"""并发工具执行器"""
def __init__(self, max_workers: int = 10):
self.max_workers = max_workers
self.thread_pool = ThreadPoolExecutor(max_workers=max_workers)
async def execute_parallel(self,
tools: List[Tuple[str, Dict[str, Any]]],
max_concurrent: int = 5) -> List[CallToolResult]:
"""并行执行多个工具"""
semaphore = asyncio.Semaphore(max_concurrent)
async def execute_single(tool_name: str, arguments: Dict[str, Any]):
async with semaphore:
tool = tool_factory.create(tool_name)
return await tool.execute(arguments)
tasks = [
execute_single(tool_name, arguments)
for tool_name, arguments in tools
]
return await asyncio.gather(*tasks, return_exceptions=True)
async def execute_cpu_bound(self,
func: Callable,
*args, **kwargs) -> Any:
"""执行CPU密集型任务"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.thread_pool,
func,
*args, **kwargs
)
async def execute_with_timeout(self,
tool_name: str,
arguments: Dict[str, Any],
timeout: float) -> CallToolResult:
"""带超时的工具执行"""
try:
tool = tool_factory.create(tool_name)
return await asyncio.wait_for(
tool.execute(arguments),
timeout=timeout
)
except asyncio.TimeoutError:
return ToolResponse.error(
f"工具执行超时: {tool_name} (超时时间: {timeout}秒)",
"EXECUTION_TIMEOUT"
)
# 批量处理工具
class BatchProcessingTool(ToolInterface):
"""批量处理工具"""
def __init__(self):
self.executor = ConcurrentToolExecutor()
@property
def name(self) -> str:
return "batch_process"
@property
def description(self) -> str:
return "批量执行工具操作"
def get_schema(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"operations": {
"type": "array",
"description": "操作列表",
"items": {
"type": "object",
"properties": {
"tool_name": {
"type": "string",
"description": "工具名称"
},
"arguments": {
"type": "object",
"description": "工具参数"
}
},
"required": ["tool_name", "arguments"]
}
},
"max_concurrent": {
"type": "integer",
"description": "最大并发数",
"default": 5,
"minimum": 1,
"maximum": 20
},
"fail_fast": {
"type": "boolean",
"description": "遇到错误时是否立即停止",
"default": False
}
},
"required": ["operations"]
}
@handle_tool_errors
async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
operations = arguments["operations"]
max_concurrent = arguments.get("max_concurrent", 5)
fail_fast = arguments.get("fail_fast", False)
if not operations:
return ToolResponse.error("操作列表不能为空")
# 准备工具调用
tool_calls = [
(op["tool_name"], op["arguments"])
for op in operations
]
# 并行执行
results = await self.executor.execute_parallel(
tool_calls,
max_concurrent
)
# 处理结果
success_count = 0
error_count = 0
detailed_results = []
for i, (operation, result) in enumerate(zip(operations, results)):
if isinstance(result, Exception):
error_count += 1
detailed_results.append({
"index": i,
"operation": operation,
"success": False,
"error": str(result)
})
if fail_fast:
break
else:
if result.isError:
error_count += 1
else:
success_count += 1
detailed_results.append({
"index": i,
"operation": operation,
"success": not result.isError,
"result": result.content[0].text if result.content else None
})
summary = {
"total_operations": len(operations),
"successful": success_count,
"failed": error_count,
"results": detailed_results
}
if error_count > 0 and fail_fast:
return ToolResponse.error(
f"批量处理失败: {error_count} 个操作出错",
"BATCH_PROCESSING_FAILED"
)
return ToolResponse.success(
f"批量处理完成: {success_count} 成功, {error_count} 失败",
summary
)
2. 缓存和记忆化
import hashlib
import pickle
from typing import Optional, Any, Dict
from datetime import datetime, timedelta
class ToolResultCache:
"""工具结果缓存"""
def __init__(self,
max_size: int = 1000,
default_ttl: int = 3600):
self.max_size = max_size
self.default_ttl = default_ttl
self.cache: Dict[str, Dict[str, Any]] = {}
self.access_times: Dict[str, datetime] = {}
def _generate_key(self, tool_name: str, arguments: Dict[str, Any]) -> str:
"""生成缓存键"""
# 创建稳定的参数字符串
args_str = json.dumps(arguments, sort_keys=True, ensure_ascii=False)
combined = f"{tool_name}:{args_str}"
return hashlib.md5(combined.encode()).hexdigest()
def get(self, tool_name: str, arguments: Dict[str, Any]) -> Optional[CallToolResult]:
"""获取缓存结果"""
key = self._generate_key(tool_name, arguments)
if key not in self.cache:
return None
entry = self.cache[key]
# 检查是否过期
if datetime.now() > entry["expires_at"]:
self._remove(key)
return None
# 更新访问时间
self.access_times[key] = datetime.now()
return entry["result"]
def set(self, tool_name: str, arguments: Dict[str, Any],
result: CallToolResult, ttl: Optional[int] = None) -> None:
"""设置缓存结果"""
key = self._generate_key(tool_name, arguments)
# 如果缓存已满,移除最久未访问的条目
if len(self.cache) >= self.max_size and key not in self.cache:
self._evict_lru()
ttl = ttl or self.default_ttl
expires_at = datetime.now() + timedelta(seconds=ttl)
self.cache[key] = {
"result": result,
"created_at": datetime.now(),
"expires_at": expires_at,
"tool_name": tool_name,
"arguments": arguments
}
self.access_times[key] = datetime.now()
def _remove(self, key: str) -> None:
"""移除缓存条目"""
self.cache.pop(key, None)
self.access_times.pop(key, None)
def _evict_lru(self) -> None:
"""移除最久未访问的条目"""
if not self.access_times:
return
lru_key = min(self.access_times.keys(),
key=lambda k: self.access_times[k])
self._remove(lru_key)
def clear(self) -> None:
"""清空缓存"""
self.cache.clear()
self.access_times.clear()
def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
now = datetime.now()
expired_count = sum(
1 for entry in self.cache.values()
if now > entry["expires_at"]
)
return {
"total_entries": len(self.cache),
"expired_entries": expired_count,
"active_entries": len(self.cache) - expired_count,
"max_size": self.max_size,
"usage_ratio": len(self.cache) / self.max_size
}
# 带缓存的工具装饰器
def cached_tool(ttl: int = 3600,
cache_errors: bool = False):
"""缓存工具结果的装饰器"""
def decorator(tool_class: Type[ToolInterface]):
original_execute = tool_class.execute
async def cached_execute(self, arguments: Dict[str, Any]) -> CallToolResult:
# 检查缓存
cached_result = tool_cache.get(self.name, arguments)
if cached_result is not None:
return cached_result
# 执行工具
result = await original_execute(self, arguments)
# 缓存结果(根据配置决定是否缓存错误)
if not result.isError or cache_errors:
tool_cache.set(self.name, arguments, result, ttl)
return result
tool_class.execute = cached_execute
return tool_class
return decorator
# 全局缓存实例
tool_cache = ToolResultCache(max_size=1000, default_ttl=3600)
# 使用缓存装饰器
@cached_tool(ttl=1800) # 缓存30分钟
class CachedHttpRequestTool(HttpRequestTool):
"""带缓存的HTTP请求工具"""
pass
@cached_tool(ttl=300, cache_errors=True) # 缓存5分钟,包括错误
class CachedFileReadTool(FileReadTool):
"""带缓存的文件读取工具"""
pass
工具安全性
1. 权限控制
from enum import Enum
from typing import Set, List
class Permission(Enum):
"""权限枚举"""
READ_FILE = "read_file"
WRITE_FILE = "write_file"
DELETE_FILE = "delete_file"
NETWORK_REQUEST = "network_request"
EXECUTE_COMMAND = "execute_command"
DATABASE_READ = "database_read"
DATABASE_WRITE = "database_write"
class Role(Enum):
"""角色枚举"""
GUEST = "guest"
USER = "user"
ADMIN = "admin"
SYSTEM = "system"
class SecurityContext:
"""安全上下文"""
def __init__(self,
user_id: str,
role: Role,
permissions: Set[Permission],
allowed_paths: Optional[List[str]] = None,
allowed_domains: Optional[List[str]] = None):
self.user_id = user_id
self.role = role
self.permissions = permissions
self.allowed_paths = allowed_paths or []
self.allowed_domains = allowed_domains or []
def has_permission(self, permission: Permission) -> bool:
"""检查是否有指定权限"""
return permission in self.permissions
def can_access_path(self, path: str) -> bool:
"""检查是否可以访问指定路径"""
if not self.allowed_paths:
return True # 没有路径限制
import os
abs_path = os.path.abspath(path)
for allowed_path in self.allowed_paths:
allowed_abs = os.path.abspath(allowed_path)
if abs_path.startswith(allowed_abs):
return True
return False
def can_access_domain(self, url: str) -> bool:
"""检查是否可以访问指定域名"""
if not self.allowed_domains:
return True # 没有域名限制
from urllib.parse import urlparse
domain = urlparse(url).netloc.lower()
for allowed_domain in self.allowed_domains:
if domain == allowed_domain.lower() or domain.endswith(f".{allowed_domain.lower()}"):
return True
return False
class SecurityManager:
"""安全管理器"""
def __init__(self):
self.role_permissions = {
Role.GUEST: {
Permission.READ_FILE,
},
Role.USER: {
Permission.READ_FILE,
Permission.WRITE_FILE,
Permission.NETWORK_REQUEST,
Permission.DATABASE_READ,
},
Role.ADMIN: {
Permission.READ_FILE,
Permission.WRITE_FILE,
Permission.DELETE_FILE,
Permission.NETWORK_REQUEST,
Permission.DATABASE_READ,
Permission.DATABASE_WRITE,
},
Role.SYSTEM: set(Permission), # 所有权限
}
def create_context(self,
user_id: str,
role: Role,
custom_permissions: Optional[Set[Permission]] = None,
**kwargs) -> SecurityContext:
"""创建安全上下文"""
permissions = custom_permissions or self.role_permissions.get(role, set())
return SecurityContext(user_id, role, permissions, **kwargs)
def check_tool_permission(self,
context: SecurityContext,
tool_name: str) -> bool:
"""检查工具权限"""
tool_permissions = {
"read_file": Permission.READ_FILE,
"write_file": Permission.WRITE_FILE,
"delete_file": Permission.DELETE_FILE,
"http_request": Permission.NETWORK_REQUEST,
"execute_command": Permission.EXECUTE_COMMAND,
"database_query": Permission.DATABASE_READ,
}
required_permission = tool_permissions.get(tool_name)
if required_permission is None:
return True # 未知工具,允许执行
return context.has_permission(required_permission)
# 安全工具装饰器
def require_permission(permission: Permission):
"""权限检查装饰器"""
def decorator(func):
@wraps(func)
async def wrapper(self, arguments: Dict[str, Any],
context: Optional[SecurityContext] = None) -> CallToolResult:
if context is None:
return ToolResponse.error(
"缺少安全上下文",
"SECURITY_CONTEXT_MISSING"
)
if not context.has_permission(permission):
return ToolResponse.error(
f"权限不足: 需要 {permission.value} 权限",
"PERMISSION_DENIED"
)
return await func(self, arguments, context)
return wrapper
return decorator
# 安全的文件操作工具
class SecureFileReadTool(ToolInterface):
"""安全的文件读取工具"""
@property
def name(self) -> str:
return "secure_read_file"
@property
def description(self) -> str:
return "安全的文件读取工具"
def get_schema(self) -> Dict[str, Any]:
return AdvancedFileInput.schema()
@require_permission(Permission.READ_FILE)
async def execute(self, arguments: Dict[str, Any],
context: SecurityContext) -> CallToolResult:
input_data = AdvancedFileInput(**arguments)
# 检查路径权限
if not context.can_access_path(input_data.path):
return ToolResponse.error(
f"无权访问路径: {input_data.path}",
"PATH_ACCESS_DENIED"
)
# 执行文件读取
try:
async with aiofiles.open(input_data.path, 'r',
encoding=input_data.encoding) as f:
content = await f.read()
return ToolResponse.with_file_content(
input_data.path, content
)
except FileNotFoundError:
raise ResourceNotFoundError("文件", input_data.path)
except PermissionError:
raise PermissionError("读取", input_data.path)
class SecureHttpRequestTool(ToolInterface):
"""安全的HTTP请求工具"""
@property
def name(self) -> str:
return "secure_http_request"
@property
def description(self) -> str:
return "安全的HTTP请求工具"
def get_schema(self) -> Dict[str, Any]:
return NetworkRequestInput.schema()
@require_permission(Permission.NETWORK_REQUEST)
async def execute(self, arguments: Dict[str, Any],
context: SecurityContext) -> CallToolResult:
input_data = NetworkRequestInput(**arguments)
# 检查域名权限
if not context.can_access_domain(input_data.url):
return ToolResponse.error(
f"无权访问域名: {input_data.url}",
"DOMAIN_ACCESS_DENIED"
)
# 执行HTTP请求
import aiohttp
try:
async with aiohttp.ClientSession() as session:
async with session.request(
method=input_data.method,
url=input_data.url,
headers=input_data.headers,
timeout=aiohttp.ClientTimeout(total=input_data.timeout)
) as response:
content = await response.text()
result_data = {
"status_code": response.status,
"headers": dict(response.headers),
"content": content
}
return ToolResponse.success(
f"HTTP请求成功: {input_data.method} {input_data.url}",
result_data
)
except aiohttp.ClientTimeout:
raise ToolError(f"请求超时: {input_data.url}", "REQUEST_TIMEOUT")
except aiohttp.ClientError as e:
raise ToolError(f"请求失败: {str(e)}", "REQUEST_FAILED")
2. 输入验证和清理
import re
from typing import Any, Dict, List
class InputSanitizer:
"""输入清理器"""
@staticmethod
def sanitize_path(path: str) -> str:
"""清理文件路径"""
# 移除危险字符
path = re.sub(r'[<>:"|?*]', '', path)
# 规范化路径分隔符
path = path.replace('\\', '/')
# 移除相对路径引用
path = re.sub(r'\.\./', '', path)
path = re.sub(r'/\.\./|^\.\./|\.\.$', '', path)
return path.strip()
@staticmethod
def sanitize_sql(query: str) -> str:
"""清理SQL查询"""
# 移除注释
query = re.sub(r'--.*$', '', query, flags=re.MULTILINE)
query = re.sub(r'/\*.*?\*/', '', query, flags=re.DOTALL)
# 移除多余空白
query = re.sub(r'\s+', ' ', query)
return query.strip()
@staticmethod
def sanitize_html(text: str) -> str:
"""清理HTML内容"""
# 移除脚本标签
text = re.sub(r'<script[^>]*>.*?</script>', '', text, flags=re.DOTALL | re.IGNORECASE)
# 移除危险属性
text = re.sub(r'\s*on\w+\s*=\s*["\'][^"\'>]*["\']', '', text, flags=re.IGNORECASE)
# 移除javascript: 协议
text = re.sub(r'javascript:', '', text, flags=re.IGNORECASE)
return text
@staticmethod
def validate_email(email: str) -> bool:
"""验证邮箱格式"""
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
return bool(re.match(pattern, email))
@staticmethod
def validate_url(url: str) -> bool:
"""验证URL格式"""
pattern = r'^https?://[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}(/.*)?$'
return bool(re.match(pattern, url))
class RateLimiter:
"""速率限制器"""
def __init__(self):
self.requests: Dict[str, List[datetime]] = {}
self.limits = {
"default": (100, 3600), # 每小时100次
"file_operations": (50, 3600), # 每小时50次
"network_requests": (200, 3600), # 每小时200次
}
def is_allowed(self, user_id: str, operation_type: str = "default") -> bool:
"""检查是否允许操作"""
key = f"{user_id}:{operation_type}"
now = datetime.now()
# 获取限制配置
max_requests, window_seconds = self.limits.get(operation_type, self.limits["default"])
# 初始化用户记录
if key not in self.requests:
self.requests[key] = []
# 清理过期记录
cutoff_time = now - timedelta(seconds=window_seconds)
self.requests[key] = [
req_time for req_time in self.requests[key]
if req_time > cutoff_time
]
# 检查是否超过限制
if len(self.requests[key]) >= max_requests:
return False
# 记录本次请求
self.requests[key].append(now)
return True
def get_remaining(self, user_id: str, operation_type: str = "default") -> int:
"""获取剩余请求次数"""
key = f"{user_id}:{operation_type}"
max_requests, _ = self.limits.get(operation_type, self.limits["default"])
if key not in self.requests:
return max_requests
return max(0, max_requests - len(self.requests[key]))
# 全局速率限制器
rate_limiter = RateLimiter()
# 速率限制装饰器
def rate_limit(operation_type: str = "default"):
"""速率限制装饰器"""
def decorator(func):
@wraps(func)
async def wrapper(self, arguments: Dict[str, Any],
context: Optional[SecurityContext] = None) -> CallToolResult:
if context is None:
return ToolResponse.error(
"缺少安全上下文",
"SECURITY_CONTEXT_MISSING"
)
if not rate_limiter.is_allowed(context.user_id, operation_type):
remaining = rate_limiter.get_remaining(context.user_id, operation_type)
return ToolResponse.error(
f"请求频率过高,剩余次数: {remaining}",
"RATE_LIMIT_EXCEEDED"
)
return await func(self, arguments, context)
return wrapper
return decorator
工具监控和日志
1. 工具执行监控
import time
from dataclasses import dataclass
from typing import Optional, List, Dict, Any
from datetime import datetime
import json
@dataclass
class ToolExecutionMetrics:
"""工具执行指标"""
tool_name: str
user_id: str
start_time: datetime
end_time: Optional[datetime] = None
duration_ms: Optional[float] = None
success: Optional[bool] = None
error_message: Optional[str] = None
input_size: int = 0
output_size: int = 0
memory_usage: Optional[float] = None
cpu_usage: Optional[float] = None
class ToolMonitor:
"""工具监控器"""
def __init__(self):
self.metrics: List[ToolExecutionMetrics] = []
self.active_executions: Dict[str, ToolExecutionMetrics] = {}
def start_execution(self, tool_name: str, user_id: str,
arguments: Dict[str, Any]) -> str:
"""开始执行监控"""
execution_id = f"{tool_name}_{user_id}_{int(time.time() * 1000)}"
metrics = ToolExecutionMetrics(
tool_name=tool_name,
user_id=user_id,
start_time=datetime.now(),
input_size=len(json.dumps(arguments, ensure_ascii=False))
)
self.active_executions[execution_id] = metrics
return execution_id
def end_execution(self, execution_id: str, success: bool,
result: Optional[CallToolResult] = None,
error_message: Optional[str] = None) -> None:
"""结束执行监控"""
if execution_id not in self.active_executions:
return
metrics = self.active_executions.pop(execution_id)
metrics.end_time = datetime.now()
metrics.duration_ms = (
metrics.end_time - metrics.start_time
).total_seconds() * 1000
metrics.success = success
metrics.error_message = error_message
if result and result.content:
metrics.output_size = sum(
len(content.text) if hasattr(content, 'text') else 0
for content in result.content
)
self.metrics.append(metrics)
# 保持最近1000条记录
if len(self.metrics) > 1000:
self.metrics = self.metrics[-1000:]
def get_statistics(self, tool_name: Optional[str] = None,
user_id: Optional[str] = None,
hours: int = 24) -> Dict[str, Any]:
"""获取统计信息"""
cutoff_time = datetime.now() - timedelta(hours=hours)
# 过滤指标
filtered_metrics = [
m for m in self.metrics
if m.start_time > cutoff_time
and (tool_name is None or m.tool_name == tool_name)
and (user_id is None or m.user_id == user_id)
]
if not filtered_metrics:
return {"total_executions": 0}
# 计算统计信息
total_executions = len(filtered_metrics)
successful_executions = sum(1 for m in filtered_metrics if m.success)
failed_executions = total_executions - successful_executions
durations = [m.duration_ms for m in filtered_metrics if m.duration_ms is not None]
avg_duration = sum(durations) / len(durations) if durations else 0
max_duration = max(durations) if durations else 0
min_duration = min(durations) if durations else 0
# 按工具分组统计
tool_stats = {}
for metrics in filtered_metrics:
tool = metrics.tool_name
if tool not in tool_stats:
tool_stats[tool] = {"total": 0, "success": 0, "failed": 0}
tool_stats[tool]["total"] += 1
if metrics.success:
tool_stats[tool]["success"] += 1
else:
tool_stats[tool]["failed"] += 1
return {
"total_executions": total_executions,
"successful_executions": successful_executions,
"failed_executions": failed_executions,
"success_rate": successful_executions / total_executions if total_executions > 0 else 0,
"average_duration_ms": avg_duration,
"max_duration_ms": max_duration,
"min_duration_ms": min_duration,
"tool_statistics": tool_stats
}
def get_recent_errors(self, limit: int = 10) -> List[Dict[str, Any]]:
"""获取最近的错误"""
error_metrics = [
m for m in reversed(self.metrics)
if not m.success and m.error_message
][:limit]
return [
{
"tool_name": m.tool_name,
"user_id": m.user_id,
"timestamp": m.start_time.isoformat(),
"error_message": m.error_message,
"duration_ms": m.duration_ms
}
for m in error_metrics
]
# 全局监控器
tool_monitor = ToolMonitor()
# 监控装饰器
def monitor_execution(func):
"""执行监控装饰器"""
@wraps(func)
async def wrapper(self, arguments: Dict[str, Any],
context: Optional[SecurityContext] = None) -> CallToolResult:
user_id = context.user_id if context else "anonymous"
execution_id = tool_monitor.start_execution(self.name, user_id, arguments)
try:
result = await func(self, arguments, context)
tool_monitor.end_execution(
execution_id,
not result.isError,
result
)
return result
except Exception as e:
tool_monitor.end_execution(
execution_id,
False,
error_message=str(e)
)
raise
return wrapper
2. 结构化日志
import logging
import json
from typing import Any, Dict, Optional
from datetime import datetime
class StructuredLogger:
"""结构化日志记录器"""
def __init__(self, name: str = "mcp_tools"):
self.logger = logging.getLogger(name)
self.logger.setLevel(logging.INFO)
# 创建格式化器
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 创建处理器
handler = logging.StreamHandler()
handler.setFormatter(formatter)
self.logger.addHandler(handler)
def _log_structured(self, level: str, message: str, **kwargs):
"""记录结构化日志"""
log_data = {
"timestamp": datetime.now().isoformat(),
"level": level,
"message": message,
**kwargs
}
log_message = json.dumps(log_data, ensure_ascii=False, default=str)
getattr(self.logger, level.lower())(log_message)
def info(self, message: str, **kwargs):
"""记录信息日志"""
self._log_structured("INFO", message, **kwargs)
def warning(self, message: str, **kwargs):
"""记录警告日志"""
self._log_structured("WARNING", message, **kwargs)
def error(self, message: str, **kwargs):
"""记录错误日志"""
self._log_structured("ERROR", message, **kwargs)
def debug(self, message: str, **kwargs):
"""记录调试日志"""
self._log_structured("DEBUG", message, **kwargs)
def tool_execution(self, tool_name: str, user_id: str,
arguments: Dict[str, Any],
result: Optional[CallToolResult] = None,
duration_ms: Optional[float] = None,
error: Optional[str] = None):
"""记录工具执行日志"""
self._log_structured(
"INFO" if error is None else "ERROR",
f"Tool execution: {tool_name}",
tool_name=tool_name,
user_id=user_id,
arguments=arguments,
success=error is None,
duration_ms=duration_ms,
error=error,
result_size=len(result.content) if result and result.content else 0
)
def security_event(self, event_type: str, user_id: str,
details: Dict[str, Any]):
"""记录安全事件日志"""
self._log_structured(
"WARNING",
f"Security event: {event_type}",
event_type=event_type,
user_id=user_id,
details=details
)
def performance_warning(self, tool_name: str, duration_ms: float,
threshold_ms: float = 5000):
"""记录性能警告"""
if duration_ms > threshold_ms:
self._log_structured(
"WARNING",
f"Slow tool execution: {tool_name}",
tool_name=tool_name,
duration_ms=duration_ms,
threshold_ms=threshold_ms
)
# 全局日志记录器
structured_logger = StructuredLogger()
# 日志记录装饰器
def log_execution(func):
"""执行日志装饰器"""
@wraps(func)
async def wrapper(self, arguments: Dict[str, Any],
context: Optional[SecurityContext] = None) -> CallToolResult:
user_id = context.user_id if context else "anonymous"
start_time = time.time()
try:
result = await func(self, arguments, context)
duration_ms = (time.time() - start_time) * 1000
structured_logger.tool_execution(
self.name, user_id, arguments, result, duration_ms
)
# 性能警告
structured_logger.performance_warning(self.name, duration_ms)
return result
except Exception as e:
duration_ms = (time.time() - start_time) * 1000
structured_logger.tool_execution(
self.name, user_id, arguments, None, duration_ms, str(e)
)
raise
return wrapper
本章总结
本章深入介绍了MCP协议中工具开发与管理的高级技术,包括:
核心内容
工具设计原则
- 单一职责原则
- 接口一致性原则
- 分层错误处理
高级开发模式
- 工具工厂模式
- 工具组合模式
- 工具链模式
参数验证和类型安全
- 高级参数验证
- 类型安全的工具接口
- 输入清理和验证
性能优化
- 异步并发优化
- 缓存和记忆化
- 批量处理
安全性
- 权限控制系统
- 输入验证和清理
- 速率限制
监控和日志
- 工具执行监控
- 结构化日志记录
- 性能指标收集
最佳实践
设计阶段
- 明确工具职责边界
- 设计一致的接口
- 考虑安全性和性能
开发阶段
- 使用类型安全的接口
- 实现完善的错误处理
- 添加适当的缓存机制
部署阶段
- 配置权限控制
- 启用监控和日志
- 设置合理的限制
维护阶段
- 定期检查性能指标
- 分析错误日志
- 优化热点工具
下一章我们将学习资源管理与订阅机制,了解如何在MCP协议中管理动态资源和实现实时更新。