工具设计原则

1. 单一职责原则

每个工具应该只负责一个明确的功能,这样可以提高可维护性和可测试性。

好的设计

# 分离的工具设计
@server.call_tool()
async def handle_read_file(arguments):
    """只负责读取文件"""
    pass

@server.call_tool()
async def handle_write_file(arguments):
    """只负责写入文件"""
    pass

@server.call_tool()
async def handle_delete_file(arguments):
    """只负责删除文件"""
    pass

不好的设计

# 职责混乱的工具设计
@server.call_tool()
async def handle_file_operations(arguments):
    """一个工具处理所有文件操作"""
    operation = arguments.get('operation')
    if operation == 'read':
        # 读取逻辑
        pass
    elif operation == 'write':
        # 写入逻辑
        pass
    elif operation == 'delete':
        # 删除逻辑
        pass
    # 这种设计难以维护和测试

2. 接口一致性原则

所有工具应该遵循一致的接口设计模式。

统一的参数结构

from pydantic import BaseModel, Field
from typing import Optional, Dict, Any

class BaseToolInput(BaseModel):
    """工具输入基类"""
    pass

class FileOperationInput(BaseToolInput):
    """文件操作输入"""
    path: str = Field(description="文件路径")
    encoding: Optional[str] = Field(default="utf-8", description="文件编码")

class NetworkRequestInput(BaseToolInput):
    """网络请求输入"""
    url: str = Field(description="请求URL")
    method: Optional[str] = Field(default="GET", description="HTTP方法")
    headers: Optional[Dict[str, str]] = Field(default=None, description="请求头")
    timeout: Optional[int] = Field(default=30, description="超时时间(秒)")

class DatabaseQueryInput(BaseToolInput):
    """数据库查询输入"""
    query: str = Field(description="SQL查询语句")
    parameters: Optional[Dict[str, Any]] = Field(default=None, description="查询参数")
    limit: Optional[int] = Field(default=100, description="结果限制")

统一的返回结构

from mcp.types import CallToolResult, TextContent, ImageContent
from typing import List, Union, Optional

class ToolResponse:
    """工具响应封装类"""
    
    @staticmethod
    def success(message: str, data: Optional[Dict[str, Any]] = None) -> CallToolResult:
        """成功响应"""
        content = [TextContent(type="text", text=message)]
        if data:
            content.append(TextContent(
                type="text", 
                text=f"数据: {json.dumps(data, ensure_ascii=False, indent=2)}"
            ))
        return CallToolResult(content=content, isError=False)
    
    @staticmethod
    def error(message: str, error_code: Optional[str] = None) -> CallToolResult:
        """错误响应"""
        error_text = f"错误: {message}"
        if error_code:
            error_text += f" (错误码: {error_code})"
        
        return CallToolResult(
            content=[TextContent(type="text", text=error_text)],
            isError=True
        )
    
    @staticmethod
    def with_file_content(file_path: str, content: str, 
                         content_type: str = "text") -> CallToolResult:
        """文件内容响应"""
        if content_type == "text":
            return CallToolResult(
                content=[
                    TextContent(
                        type="text",
                        text=f"文件: {file_path}\n\n{content}"
                    )
                ]
            )
        # 可以扩展支持其他内容类型
        return ToolResponse.success(f"文件读取成功: {file_path}")

3. 错误处理原则

分层错误处理

import logging
from enum import Enum
from typing import Optional

class ErrorLevel(Enum):
    """错误级别"""
    WARNING = "warning"
    ERROR = "error"
    CRITICAL = "critical"

class ToolError(Exception):
    """工具错误基类"""
    
    def __init__(self, message: str, error_code: str = None, 
                 level: ErrorLevel = ErrorLevel.ERROR, 
                 details: Optional[Dict[str, Any]] = None):
        super().__init__(message)
        self.message = message
        self.error_code = error_code
        self.level = level
        self.details = details or {}

class ValidationError(ToolError):
    """参数验证错误"""
    def __init__(self, message: str, field: str = None):
        super().__init__(message, "VALIDATION_ERROR", ErrorLevel.WARNING)
        if field:
            self.details["field"] = field

class ResourceNotFoundError(ToolError):
    """资源未找到错误"""
    def __init__(self, resource_type: str, resource_id: str):
        message = f"{resource_type} '{resource_id}' 未找到"
        super().__init__(message, "RESOURCE_NOT_FOUND", ErrorLevel.ERROR)
        self.details.update({
            "resource_type": resource_type,
            "resource_id": resource_id
        })

class PermissionError(ToolError):
    """权限错误"""
    def __init__(self, operation: str, resource: str):
        message = f"没有权限执行操作 '{operation}' 在资源 '{resource}' 上"
        super().__init__(message, "PERMISSION_DENIED", ErrorLevel.ERROR)
        self.details.update({
            "operation": operation,
            "resource": resource
        })

# 错误处理装饰器
def handle_tool_errors(func):
    """工具错误处理装饰器"""
    @wraps(func)
    async def wrapper(*args, **kwargs):
        try:
            return await func(*args, **kwargs)
        except ToolError as e:
            logging.error(f"工具错误: {e.message}", extra={
                "error_code": e.error_code,
                "level": e.level.value,
                "details": e.details
            })
            return ToolResponse.error(e.message, e.error_code)
        except Exception as e:
            logging.exception(f"未处理的工具错误: {str(e)}")
            return ToolResponse.error(
                f"内部错误: {str(e)}", 
                "INTERNAL_ERROR"
            )
    return wrapper

高级工具开发模式

1. 工具工厂模式

from abc import ABC, abstractmethod
from typing import Type, Dict, Any

class ToolInterface(ABC):
    """工具接口"""
    
    @abstractmethod
    async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
        """执行工具"""
        pass
    
    @abstractmethod
    def get_schema(self) -> Dict[str, Any]:
        """获取工具模式"""
        pass
    
    @property
    @abstractmethod
    def name(self) -> str:
        """工具名称"""
        pass
    
    @property
    @abstractmethod
    def description(self) -> str:
        """工具描述"""
        pass

class FileReadTool(ToolInterface):
    """文件读取工具"""
    
    @property
    def name(self) -> str:
        return "read_file"
    
    @property
    def description(self) -> str:
        return "读取文件内容"
    
    def get_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "path": {
                    "type": "string",
                    "description": "文件路径"
                },
                "encoding": {
                    "type": "string",
                    "description": "文件编码",
                    "default": "utf-8"
                }
            },
            "required": ["path"]
        }
    
    @handle_tool_errors
    async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
        input_data = FileOperationInput(**arguments)
        
        try:
            async with aiofiles.open(input_data.path, 'r', 
                                   encoding=input_data.encoding) as f:
                content = await f.read()
            
            return ToolResponse.with_file_content(
                input_data.path, content
            )
        except FileNotFoundError:
            raise ResourceNotFoundError("文件", input_data.path)
        except PermissionError:
            raise PermissionError("读取", input_data.path)

class HttpRequestTool(ToolInterface):
    """HTTP请求工具"""
    
    @property
    def name(self) -> str:
        return "http_request"
    
    @property
    def description(self) -> str:
        return "发送HTTP请求"
    
    def get_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "url": {
                    "type": "string",
                    "description": "请求URL"
                },
                "method": {
                    "type": "string",
                    "description": "HTTP方法",
                    "enum": ["GET", "POST", "PUT", "DELETE"],
                    "default": "GET"
                },
                "headers": {
                    "type": "object",
                    "description": "请求头"
                },
                "data": {
                    "type": "object",
                    "description": "请求数据"
                },
                "timeout": {
                    "type": "integer",
                    "description": "超时时间(秒)",
                    "default": 30
                }
            },
            "required": ["url"]
        }
    
    @handle_tool_errors
    async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
        input_data = NetworkRequestInput(**arguments)
        
        import aiohttp
        
        try:
            async with aiohttp.ClientSession() as session:
                async with session.request(
                    method=input_data.method,
                    url=input_data.url,
                    headers=input_data.headers,
                    json=arguments.get('data'),
                    timeout=aiohttp.ClientTimeout(total=input_data.timeout)
                ) as response:
                    content = await response.text()
                    
                    result_data = {
                        "status_code": response.status,
                        "headers": dict(response.headers),
                        "content": content
                    }
                    
                    return ToolResponse.success(
                        f"HTTP请求成功: {input_data.method} {input_data.url}",
                        result_data
                    )
        except aiohttp.ClientTimeout:
            raise ToolError(f"请求超时: {input_data.url}", "REQUEST_TIMEOUT")
        except aiohttp.ClientError as e:
            raise ToolError(f"请求失败: {str(e)}", "REQUEST_FAILED")

class ToolFactory:
    """工具工厂"""
    
    def __init__(self):
        self._tools: Dict[str, Type[ToolInterface]] = {}
    
    def register(self, tool_class: Type[ToolInterface]):
        """注册工具"""
        tool_instance = tool_class()
        self._tools[tool_instance.name] = tool_class
    
    def create(self, tool_name: str) -> ToolInterface:
        """创建工具实例"""
        if tool_name not in self._tools:
            raise ValueError(f"未知工具: {tool_name}")
        return self._tools[tool_name]()
    
    def get_all_tools(self) -> List[Dict[str, Any]]:
        """获取所有工具信息"""
        tools = []
        for tool_class in self._tools.values():
            tool_instance = tool_class()
            tools.append({
                "name": tool_instance.name,
                "description": tool_instance.description,
                "inputSchema": tool_instance.get_schema()
            })
        return tools

# 使用工具工厂
tool_factory = ToolFactory()
tool_factory.register(FileReadTool)
tool_factory.register(HttpRequestTool)

@server.list_tools()
async def handle_list_tools() -> ListToolsResult:
    """使用工厂获取工具列表"""
    return ListToolsResult(tools=tool_factory.get_all_tools())

@server.call_tool()
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> CallToolResult:
    """使用工厂调用工具"""
    try:
        tool = tool_factory.create(name)
        return await tool.execute(arguments)
    except ValueError as e:
        return ToolResponse.error(str(e), "TOOL_NOT_FOUND")

2. 工具组合模式

class CompositeToolInterface(ToolInterface):
    """组合工具接口"""
    
    def __init__(self):
        self.sub_tools: List[ToolInterface] = []
    
    def add_tool(self, tool: ToolInterface):
        """添加子工具"""
        self.sub_tools.append(tool)
    
    def remove_tool(self, tool: ToolInterface):
        """移除子工具"""
        if tool in self.sub_tools:
            self.sub_tools.remove(tool)

class FileProcessingPipeline(CompositeToolInterface):
    """文件处理管道"""
    
    @property
    def name(self) -> str:
        return "file_processing_pipeline"
    
    @property
    def description(self) -> str:
        return "文件处理管道,支持读取、处理、写入文件"
    
    def get_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "input_path": {
                    "type": "string",
                    "description": "输入文件路径"
                },
                "output_path": {
                    "type": "string",
                    "description": "输出文件路径"
                },
                "operations": {
                    "type": "array",
                    "description": "处理操作列表",
                    "items": {
                        "type": "object",
                        "properties": {
                            "type": {
                                "type": "string",
                                "enum": ["uppercase", "lowercase", "replace", "filter"]
                            },
                            "params": {
                                "type": "object",
                                "description": "操作参数"
                            }
                        }
                    }
                }
            },
            "required": ["input_path", "output_path"]
        }
    
    @handle_tool_errors
    async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
        input_path = arguments["input_path"]
        output_path = arguments["output_path"]
        operations = arguments.get("operations", [])
        
        # 1. 读取文件
        read_tool = FileReadTool()
        read_result = await read_tool.execute({"path": input_path})
        
        if read_result.isError:
            return read_result
        
        # 提取文件内容
        content = read_result.content[0].text.split("\n\n", 1)[1]  # 去掉文件路径行
        
        # 2. 处理内容
        for operation in operations:
            content = await self._apply_operation(content, operation)
        
        # 3. 写入文件
        write_tool = FileWriteTool()
        write_result = await write_tool.execute({
            "path": output_path,
            "content": content
        })
        
        if write_result.isError:
            return write_result
        
        return ToolResponse.success(
            f"文件处理完成: {input_path} -> {output_path}",
            {
                "operations_applied": len(operations),
                "output_size": len(content)
            }
        )
    
    async def _apply_operation(self, content: str, operation: Dict[str, Any]) -> str:
        """应用处理操作"""
        op_type = operation["type"]
        params = operation.get("params", {})
        
        if op_type == "uppercase":
            return content.upper()
        elif op_type == "lowercase":
            return content.lower()
        elif op_type == "replace":
            old = params.get("old", "")
            new = params.get("new", "")
            return content.replace(old, new)
        elif op_type == "filter":
            pattern = params.get("pattern", "")
            import re
            lines = content.split("\n")
            filtered_lines = [line for line in lines if re.search(pattern, line)]
            return "\n".join(filtered_lines)
        else:
            return content

class FileWriteTool(ToolInterface):
    """文件写入工具"""
    
    @property
    def name(self) -> str:
        return "write_file"
    
    @property
    def description(self) -> str:
        return "写入文件内容"
    
    def get_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "path": {
                    "type": "string",
                    "description": "文件路径"
                },
                "content": {
                    "type": "string",
                    "description": "文件内容"
                },
                "encoding": {
                    "type": "string",
                    "description": "文件编码",
                    "default": "utf-8"
                }
            },
            "required": ["path", "content"]
        }
    
    @handle_tool_errors
    async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
        input_data = FileOperationInput(**arguments)
        content = arguments["content"]
        
        try:
            # 确保目录存在
            import os
            os.makedirs(os.path.dirname(input_data.path), exist_ok=True)
            
            async with aiofiles.open(input_data.path, 'w', 
                                   encoding=input_data.encoding) as f:
                await f.write(content)
            
            return ToolResponse.success(
                f"文件写入成功: {input_data.path}",
                {"bytes_written": len(content.encode(input_data.encoding))}
            )
        except PermissionError:
            raise PermissionError("写入", input_data.path)

3. 工具链模式

from typing import List, Callable, Any

class ToolChain:
    """工具链"""
    
    def __init__(self, name: str, description: str):
        self.name = name
        self.description = description
        self.steps: List[Dict[str, Any]] = []
    
    def add_step(self, tool_name: str, arguments: Dict[str, Any], 
                 condition: Optional[Callable[[Any], bool]] = None):
        """添加步骤"""
        self.steps.append({
            "tool_name": tool_name,
            "arguments": arguments,
            "condition": condition
        })
        return self
    
    def add_conditional_step(self, condition: Callable[[Any], bool], 
                           tool_name: str, arguments: Dict[str, Any]):
        """添加条件步骤"""
        return self.add_step(tool_name, arguments, condition)
    
    async def execute(self, initial_context: Dict[str, Any] = None) -> Dict[str, Any]:
        """执行工具链"""
        context = initial_context or {}
        results = []
        
        for i, step in enumerate(self.steps):
            # 检查条件
            if step["condition"] and not step["condition"](context):
                continue
            
            # 解析参数中的变量
            resolved_args = self._resolve_arguments(step["arguments"], context)
            
            # 执行工具
            tool = tool_factory.create(step["tool_name"])
            result = await tool.execute(resolved_args)
            
            # 更新上下文
            step_result = {
                "step": i,
                "tool_name": step["tool_name"],
                "arguments": resolved_args,
                "result": result,
                "success": not result.isError
            }
            
            results.append(step_result)
            context[f"step_{i}_result"] = step_result
            
            # 如果步骤失败且没有错误处理,停止执行
            if result.isError:
                break
        
        return {
            "chain_name": self.name,
            "steps_executed": len(results),
            "success": all(r["success"] for r in results),
            "results": results,
            "final_context": context
        }
    
    def _resolve_arguments(self, arguments: Dict[str, Any], 
                          context: Dict[str, Any]) -> Dict[str, Any]:
        """解析参数中的变量引用"""
        resolved = {}
        
        for key, value in arguments.items():
            if isinstance(value, str) and value.startswith("${"): 
                # 变量引用,如 ${step_0_result.content}
                var_path = value[2:-1]  # 去掉 ${ 和 }
                resolved[key] = self._get_nested_value(context, var_path)
            else:
                resolved[key] = value
        
        return resolved
    
    def _get_nested_value(self, obj: Any, path: str) -> Any:
        """获取嵌套值"""
        parts = path.split(".")
        current = obj
        
        for part in parts:
            if isinstance(current, dict):
                current = current.get(part)
            elif hasattr(current, part):
                current = getattr(current, part)
            else:
                return None
        
        return current

# 工具链构建器
class ToolChainBuilder:
    """工具链构建器"""
    
    @staticmethod
    def create_web_scraping_chain() -> ToolChain:
        """创建网页抓取工具链"""
        chain = ToolChain(
            "web_scraping",
            "网页内容抓取和处理工具链"
        )
        
        return (chain
                .add_step("http_request", {
                    "url": "${target_url}",
                    "method": "GET"
                })
                .add_conditional_step(
                    lambda ctx: ctx.get("step_0_result", {}).get("success", False),
                    "extract_text",
                    {"html": "${step_0_result.result.content}"}
                )
                .add_step("write_file", {
                    "path": "${output_file}",
                    "content": "${step_1_result.result.text}"
                }))
    
    @staticmethod
    def create_data_processing_chain() -> ToolChain:
        """创建数据处理工具链"""
        chain = ToolChain(
            "data_processing",
            "数据读取、处理和保存工具链"
        )
        
        return (chain
                .add_step("read_file", {
                    "path": "${input_file}"
                })
                .add_step("process_data", {
                    "data": "${step_0_result.result.content}",
                    "operations": "${operations}"
                })
                .add_step("write_file", {
                    "path": "${output_file}",
                    "content": "${step_1_result.result.processed_data}"
                }))

# 工具链执行工具
class ToolChainExecutor(ToolInterface):
    """工具链执行器"""
    
    def __init__(self):
        self.chains = {
            "web_scraping": ToolChainBuilder.create_web_scraping_chain(),
            "data_processing": ToolChainBuilder.create_data_processing_chain()
        }
    
    @property
    def name(self) -> str:
        return "execute_chain"
    
    @property
    def description(self) -> str:
        return "执行预定义的工具链"
    
    def get_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "chain_name": {
                    "type": "string",
                    "description": "工具链名称",
                    "enum": list(self.chains.keys())
                },
                "context": {
                    "type": "object",
                    "description": "执行上下文"
                }
            },
            "required": ["chain_name"]
        }
    
    @handle_tool_errors
    async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
        chain_name = arguments["chain_name"]
        context = arguments.get("context", {})
        
        if chain_name not in self.chains:
            raise ValidationError(f"未知工具链: {chain_name}", "chain_name")
        
        chain = self.chains[chain_name]
        result = await chain.execute(context)
        
        return ToolResponse.success(
            f"工具链执行完成: {chain_name}",
            result
        )

工具参数验证和类型安全

1. 高级参数验证

from pydantic import BaseModel, Field, validator, root_validator
from typing import Union, List, Optional
import re
from pathlib import Path

class AdvancedFileInput(BaseModel):
    """高级文件输入验证"""
    
    path: str = Field(description="文件路径")
    encoding: str = Field(default="utf-8", description="文件编码")
    max_size: Optional[int] = Field(default=None, description="最大文件大小(字节)")
    allowed_extensions: Optional[List[str]] = Field(
        default=None, 
        description="允许的文件扩展名"
    )
    
    @validator('path')
    def validate_path(cls, v):
        """验证路径格式"""
        if not v or not v.strip():
            raise ValueError("路径不能为空")
        
        # 检查路径安全性
        if ".." in v or v.startswith("/"):
            raise ValueError("路径包含不安全字符")
        
        return v.strip()
    
    @validator('encoding')
    def validate_encoding(cls, v):
        """验证编码格式"""
        import codecs
        try:
            codecs.lookup(v)
        except LookupError:
            raise ValueError(f"不支持的编码格式: {v}")
        return v
    
    @validator('max_size')
    def validate_max_size(cls, v):
        """验证最大文件大小"""
        if v is not None and v <= 0:
            raise ValueError("最大文件大小必须大于0")
        return v
    
    @validator('allowed_extensions')
    def validate_extensions(cls, v):
        """验证文件扩展名"""
        if v is not None:
            # 确保扩展名以点开头
            normalized = []
            for ext in v:
                if not ext.startswith('.'):
                    ext = '.' + ext
                normalized.append(ext.lower())
            return normalized
        return v
    
    @root_validator
    def validate_file_constraints(cls, values):
        """验证文件约束"""
        path = values.get('path')
        allowed_extensions = values.get('allowed_extensions')
        
        if path and allowed_extensions:
            file_ext = Path(path).suffix.lower()
            if file_ext not in allowed_extensions:
                raise ValueError(
                    f"文件扩展名 '{file_ext}' 不在允许列表中: {allowed_extensions}"
                )
        
        return values

class NetworkRequestInput(BaseModel):
    """网络请求输入验证"""
    
    url: str = Field(description="请求URL")
    method: str = Field(default="GET", description="HTTP方法")
    headers: Optional[Dict[str, str]] = Field(default=None, description="请求头")
    timeout: int = Field(default=30, ge=1, le=300, description="超时时间(秒)")
    follow_redirects: bool = Field(default=True, description="是否跟随重定向")
    max_redirects: int = Field(default=5, ge=0, le=20, description="最大重定向次数")
    
    @validator('url')
    def validate_url(cls, v):
        """验证URL格式"""
        import urllib.parse
        
        try:
            result = urllib.parse.urlparse(v)
            if not all([result.scheme, result.netloc]):
                raise ValueError("URL格式无效")
            
            # 只允许HTTP和HTTPS
            if result.scheme not in ['http', 'https']:
                raise ValueError("只支持HTTP和HTTPS协议")
            
            return v
        except Exception:
            raise ValueError("URL格式无效")
    
    @validator('method')
    def validate_method(cls, v):
        """验证HTTP方法"""
        allowed_methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS']
        if v.upper() not in allowed_methods:
            raise ValueError(f"不支持的HTTP方法: {v}")
        return v.upper()
    
    @validator('headers')
    def validate_headers(cls, v):
        """验证请求头"""
        if v is not None:
            # 检查危险的请求头
            dangerous_headers = ['authorization', 'cookie', 'x-api-key']
            for header in v.keys():
                if header.lower() in dangerous_headers:
                    raise ValueError(f"不允许设置敏感请求头: {header}")
        return v

class DatabaseQueryInput(BaseModel):
    """数据库查询输入验证"""
    
    query: str = Field(description="SQL查询语句")
    parameters: Optional[Dict[str, Any]] = Field(default=None, description="查询参数")
    limit: int = Field(default=100, ge=1, le=10000, description="结果限制")
    timeout: int = Field(default=30, ge=1, le=300, description="查询超时(秒)")
    
    @validator('query')
    def validate_query(cls, v):
        """验证SQL查询"""
        if not v or not v.strip():
            raise ValueError("查询语句不能为空")
        
        # 基本的SQL注入防护
        dangerous_keywords = [
            'drop', 'delete', 'truncate', 'alter', 'create', 
            'insert', 'update', 'exec', 'execute'
        ]
        
        query_lower = v.lower()
        for keyword in dangerous_keywords:
            if keyword in query_lower:
                raise ValueError(f"查询包含危险关键字: {keyword}")
        
        # 只允许SELECT语句
        if not query_lower.strip().startswith('select'):
            raise ValueError("只允许SELECT查询")
        
        return v.strip()

2. 类型安全的工具接口

from typing import TypeVar, Generic, Type
from pydantic import BaseModel

T = TypeVar('T', bound=BaseModel)
R = TypeVar('R')

class TypedToolInterface(Generic[T, R]):
    """类型安全的工具接口"""
    
    def __init__(self, input_model: Type[T]):
        self.input_model = input_model
    
    @abstractmethod
    async def execute_typed(self, input_data: T) -> R:
        """类型安全的执行方法"""
        pass
    
    async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
        """通用执行方法"""
        try:
            # 验证和转换输入
            input_data = self.input_model(**arguments)
            
            # 执行类型安全的方法
            result = await self.execute_typed(input_data)
            
            # 转换结果
            return self._convert_result(result)
        except ValidationError as e:
            return ToolResponse.error(f"参数验证失败: {str(e)}")
        except Exception as e:
            return ToolResponse.error(f"执行失败: {str(e)}")
    
    def _convert_result(self, result: R) -> CallToolResult:
        """转换结果为CallToolResult"""
        if isinstance(result, CallToolResult):
            return result
        elif isinstance(result, str):
            return ToolResponse.success(result)
        elif isinstance(result, dict):
            return ToolResponse.success("操作成功", result)
        else:
            return ToolResponse.success(str(result))

class FileReadResult(BaseModel):
    """文件读取结果"""
    path: str
    content: str
    size: int
    encoding: str
    lines: int

class TypedFileReadTool(TypedToolInterface[AdvancedFileInput, FileReadResult]):
    """类型安全的文件读取工具"""
    
    def __init__(self):
        super().__init__(AdvancedFileInput)
    
    @property
    def name(self) -> str:
        return "typed_read_file"
    
    @property
    def description(self) -> str:
        return "类型安全的文件读取工具"
    
    def get_schema(self) -> Dict[str, Any]:
        return self.input_model.schema()
    
    async def execute_typed(self, input_data: AdvancedFileInput) -> FileReadResult:
        """类型安全的文件读取"""
        file_path = Path(input_data.path)
        
        # 检查文件是否存在
        if not file_path.exists():
            raise ResourceNotFoundError("文件", str(file_path))
        
        # 检查文件大小
        file_size = file_path.stat().st_size
        if input_data.max_size and file_size > input_data.max_size:
            raise ValidationError(
                f"文件大小 {file_size} 超过限制 {input_data.max_size}"
            )
        
        # 读取文件
        try:
            with open(file_path, 'r', encoding=input_data.encoding) as f:
                content = f.read()
            
            lines = len(content.splitlines())
            
            return FileReadResult(
                path=str(file_path),
                content=content,
                size=file_size,
                encoding=input_data.encoding,
                lines=lines
            )
        except UnicodeDecodeError:
            raise ValidationError(f"无法使用编码 {input_data.encoding} 读取文件")

工具性能优化

1. 异步并发优化

import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import List, Callable, Any

class ConcurrentToolExecutor:
    """并发工具执行器"""
    
    def __init__(self, max_workers: int = 10):
        self.max_workers = max_workers
        self.thread_pool = ThreadPoolExecutor(max_workers=max_workers)
    
    async def execute_parallel(self, 
                             tools: List[Tuple[str, Dict[str, Any]]], 
                             max_concurrent: int = 5) -> List[CallToolResult]:
        """并行执行多个工具"""
        semaphore = asyncio.Semaphore(max_concurrent)
        
        async def execute_single(tool_name: str, arguments: Dict[str, Any]):
            async with semaphore:
                tool = tool_factory.create(tool_name)
                return await tool.execute(arguments)
        
        tasks = [
            execute_single(tool_name, arguments) 
            for tool_name, arguments in tools
        ]
        
        return await asyncio.gather(*tasks, return_exceptions=True)
    
    async def execute_cpu_bound(self, 
                              func: Callable, 
                              *args, **kwargs) -> Any:
        """执行CPU密集型任务"""
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(
            self.thread_pool, 
            func, 
            *args, **kwargs
        )
    
    async def execute_with_timeout(self, 
                                 tool_name: str, 
                                 arguments: Dict[str, Any], 
                                 timeout: float) -> CallToolResult:
        """带超时的工具执行"""
        try:
            tool = tool_factory.create(tool_name)
            return await asyncio.wait_for(
                tool.execute(arguments), 
                timeout=timeout
            )
        except asyncio.TimeoutError:
            return ToolResponse.error(
                f"工具执行超时: {tool_name} (超时时间: {timeout}秒)",
                "EXECUTION_TIMEOUT"
            )

# 批量处理工具
class BatchProcessingTool(ToolInterface):
    """批量处理工具"""
    
    def __init__(self):
        self.executor = ConcurrentToolExecutor()
    
    @property
    def name(self) -> str:
        return "batch_process"
    
    @property
    def description(self) -> str:
        return "批量执行工具操作"
    
    def get_schema(self) -> Dict[str, Any]:
        return {
            "type": "object",
            "properties": {
                "operations": {
                    "type": "array",
                    "description": "操作列表",
                    "items": {
                        "type": "object",
                        "properties": {
                            "tool_name": {
                                "type": "string",
                                "description": "工具名称"
                            },
                            "arguments": {
                                "type": "object",
                                "description": "工具参数"
                            }
                        },
                        "required": ["tool_name", "arguments"]
                    }
                },
                "max_concurrent": {
                    "type": "integer",
                    "description": "最大并发数",
                    "default": 5,
                    "minimum": 1,
                    "maximum": 20
                },
                "fail_fast": {
                    "type": "boolean",
                    "description": "遇到错误时是否立即停止",
                    "default": False
                }
            },
            "required": ["operations"]
        }
    
    @handle_tool_errors
    async def execute(self, arguments: Dict[str, Any]) -> CallToolResult:
        operations = arguments["operations"]
        max_concurrent = arguments.get("max_concurrent", 5)
        fail_fast = arguments.get("fail_fast", False)
        
        if not operations:
            return ToolResponse.error("操作列表不能为空")
        
        # 准备工具调用
        tool_calls = [
            (op["tool_name"], op["arguments"]) 
            for op in operations
        ]
        
        # 并行执行
        results = await self.executor.execute_parallel(
            tool_calls, 
            max_concurrent
        )
        
        # 处理结果
        success_count = 0
        error_count = 0
        detailed_results = []
        
        for i, (operation, result) in enumerate(zip(operations, results)):
            if isinstance(result, Exception):
                error_count += 1
                detailed_results.append({
                    "index": i,
                    "operation": operation,
                    "success": False,
                    "error": str(result)
                })
                
                if fail_fast:
                    break
            else:
                if result.isError:
                    error_count += 1
                else:
                    success_count += 1
                
                detailed_results.append({
                    "index": i,
                    "operation": operation,
                    "success": not result.isError,
                    "result": result.content[0].text if result.content else None
                })
        
        summary = {
            "total_operations": len(operations),
            "successful": success_count,
            "failed": error_count,
            "results": detailed_results
        }
        
        if error_count > 0 and fail_fast:
            return ToolResponse.error(
                f"批量处理失败: {error_count} 个操作出错",
                "BATCH_PROCESSING_FAILED"
            )
        
        return ToolResponse.success(
            f"批量处理完成: {success_count} 成功, {error_count} 失败",
            summary
        )

2. 缓存和记忆化

import hashlib
import pickle
from typing import Optional, Any, Dict
from datetime import datetime, timedelta

class ToolResultCache:
    """工具结果缓存"""
    
    def __init__(self, 
                 max_size: int = 1000, 
                 default_ttl: int = 3600):
        self.max_size = max_size
        self.default_ttl = default_ttl
        self.cache: Dict[str, Dict[str, Any]] = {}
        self.access_times: Dict[str, datetime] = {}
    
    def _generate_key(self, tool_name: str, arguments: Dict[str, Any]) -> str:
        """生成缓存键"""
        # 创建稳定的参数字符串
        args_str = json.dumps(arguments, sort_keys=True, ensure_ascii=False)
        combined = f"{tool_name}:{args_str}"
        return hashlib.md5(combined.encode()).hexdigest()
    
    def get(self, tool_name: str, arguments: Dict[str, Any]) -> Optional[CallToolResult]:
        """获取缓存结果"""
        key = self._generate_key(tool_name, arguments)
        
        if key not in self.cache:
            return None
        
        entry = self.cache[key]
        
        # 检查是否过期
        if datetime.now() > entry["expires_at"]:
            self._remove(key)
            return None
        
        # 更新访问时间
        self.access_times[key] = datetime.now()
        
        return entry["result"]
    
    def set(self, tool_name: str, arguments: Dict[str, Any], 
           result: CallToolResult, ttl: Optional[int] = None) -> None:
        """设置缓存结果"""
        key = self._generate_key(tool_name, arguments)
        
        # 如果缓存已满,移除最久未访问的条目
        if len(self.cache) >= self.max_size and key not in self.cache:
            self._evict_lru()
        
        ttl = ttl or self.default_ttl
        expires_at = datetime.now() + timedelta(seconds=ttl)
        
        self.cache[key] = {
            "result": result,
            "created_at": datetime.now(),
            "expires_at": expires_at,
            "tool_name": tool_name,
            "arguments": arguments
        }
        
        self.access_times[key] = datetime.now()
    
    def _remove(self, key: str) -> None:
        """移除缓存条目"""
        self.cache.pop(key, None)
        self.access_times.pop(key, None)
    
    def _evict_lru(self) -> None:
        """移除最久未访问的条目"""
        if not self.access_times:
            return
        
        lru_key = min(self.access_times.keys(), 
                     key=lambda k: self.access_times[k])
        self._remove(lru_key)
    
    def clear(self) -> None:
        """清空缓存"""
        self.cache.clear()
        self.access_times.clear()
    
    def get_stats(self) -> Dict[str, Any]:
        """获取缓存统计信息"""
        now = datetime.now()
        expired_count = sum(
            1 for entry in self.cache.values() 
            if now > entry["expires_at"]
        )
        
        return {
            "total_entries": len(self.cache),
            "expired_entries": expired_count,
            "active_entries": len(self.cache) - expired_count,
            "max_size": self.max_size,
            "usage_ratio": len(self.cache) / self.max_size
        }

# 带缓存的工具装饰器
def cached_tool(ttl: int = 3600, 
               cache_errors: bool = False):
    """缓存工具结果的装饰器"""
    def decorator(tool_class: Type[ToolInterface]):
        original_execute = tool_class.execute
        
        async def cached_execute(self, arguments: Dict[str, Any]) -> CallToolResult:
            # 检查缓存
            cached_result = tool_cache.get(self.name, arguments)
            if cached_result is not None:
                return cached_result
            
            # 执行工具
            result = await original_execute(self, arguments)
            
            # 缓存结果(根据配置决定是否缓存错误)
            if not result.isError or cache_errors:
                tool_cache.set(self.name, arguments, result, ttl)
            
            return result
        
        tool_class.execute = cached_execute
        return tool_class
    
    return decorator

# 全局缓存实例
tool_cache = ToolResultCache(max_size=1000, default_ttl=3600)

# 使用缓存装饰器
@cached_tool(ttl=1800)  # 缓存30分钟
class CachedHttpRequestTool(HttpRequestTool):
    """带缓存的HTTP请求工具"""
    pass

@cached_tool(ttl=300, cache_errors=True)  # 缓存5分钟,包括错误
class CachedFileReadTool(FileReadTool):
    """带缓存的文件读取工具"""
    pass

工具安全性

1. 权限控制

from enum import Enum
from typing import Set, List

class Permission(Enum):
    """权限枚举"""
    READ_FILE = "read_file"
    WRITE_FILE = "write_file"
    DELETE_FILE = "delete_file"
    NETWORK_REQUEST = "network_request"
    EXECUTE_COMMAND = "execute_command"
    DATABASE_READ = "database_read"
    DATABASE_WRITE = "database_write"

class Role(Enum):
    """角色枚举"""
    GUEST = "guest"
    USER = "user"
    ADMIN = "admin"
    SYSTEM = "system"

class SecurityContext:
    """安全上下文"""
    
    def __init__(self, 
                 user_id: str, 
                 role: Role, 
                 permissions: Set[Permission],
                 allowed_paths: Optional[List[str]] = None,
                 allowed_domains: Optional[List[str]] = None):
        self.user_id = user_id
        self.role = role
        self.permissions = permissions
        self.allowed_paths = allowed_paths or []
        self.allowed_domains = allowed_domains or []
    
    def has_permission(self, permission: Permission) -> bool:
        """检查是否有指定权限"""
        return permission in self.permissions
    
    def can_access_path(self, path: str) -> bool:
        """检查是否可以访问指定路径"""
        if not self.allowed_paths:
            return True  # 没有路径限制
        
        import os
        abs_path = os.path.abspath(path)
        
        for allowed_path in self.allowed_paths:
            allowed_abs = os.path.abspath(allowed_path)
            if abs_path.startswith(allowed_abs):
                return True
        
        return False
    
    def can_access_domain(self, url: str) -> bool:
        """检查是否可以访问指定域名"""
        if not self.allowed_domains:
            return True  # 没有域名限制
        
        from urllib.parse import urlparse
        domain = urlparse(url).netloc.lower()
        
        for allowed_domain in self.allowed_domains:
            if domain == allowed_domain.lower() or domain.endswith(f".{allowed_domain.lower()}"):
                return True
        
        return False

class SecurityManager:
    """安全管理器"""
    
    def __init__(self):
        self.role_permissions = {
            Role.GUEST: {
                Permission.READ_FILE,
            },
            Role.USER: {
                Permission.READ_FILE,
                Permission.WRITE_FILE,
                Permission.NETWORK_REQUEST,
                Permission.DATABASE_READ,
            },
            Role.ADMIN: {
                Permission.READ_FILE,
                Permission.WRITE_FILE,
                Permission.DELETE_FILE,
                Permission.NETWORK_REQUEST,
                Permission.DATABASE_READ,
                Permission.DATABASE_WRITE,
            },
            Role.SYSTEM: set(Permission),  # 所有权限
        }
    
    def create_context(self, 
                      user_id: str, 
                      role: Role,
                      custom_permissions: Optional[Set[Permission]] = None,
                      **kwargs) -> SecurityContext:
        """创建安全上下文"""
        permissions = custom_permissions or self.role_permissions.get(role, set())
        return SecurityContext(user_id, role, permissions, **kwargs)
    
    def check_tool_permission(self, 
                            context: SecurityContext, 
                            tool_name: str) -> bool:
        """检查工具权限"""
        tool_permissions = {
            "read_file": Permission.READ_FILE,
            "write_file": Permission.WRITE_FILE,
            "delete_file": Permission.DELETE_FILE,
            "http_request": Permission.NETWORK_REQUEST,
            "execute_command": Permission.EXECUTE_COMMAND,
            "database_query": Permission.DATABASE_READ,
        }
        
        required_permission = tool_permissions.get(tool_name)
        if required_permission is None:
            return True  # 未知工具,允许执行
        
        return context.has_permission(required_permission)

# 安全工具装饰器
def require_permission(permission: Permission):
    """权限检查装饰器"""
    def decorator(func):
        @wraps(func)
        async def wrapper(self, arguments: Dict[str, Any], 
                         context: Optional[SecurityContext] = None) -> CallToolResult:
            if context is None:
                return ToolResponse.error(
                    "缺少安全上下文", 
                    "SECURITY_CONTEXT_MISSING"
                )
            
            if not context.has_permission(permission):
                return ToolResponse.error(
                    f"权限不足: 需要 {permission.value} 权限",
                    "PERMISSION_DENIED"
                )
            
            return await func(self, arguments, context)
        
        return wrapper
    return decorator

# 安全的文件操作工具
class SecureFileReadTool(ToolInterface):
    """安全的文件读取工具"""
    
    @property
    def name(self) -> str:
        return "secure_read_file"
    
    @property
    def description(self) -> str:
        return "安全的文件读取工具"
    
    def get_schema(self) -> Dict[str, Any]:
        return AdvancedFileInput.schema()
    
    @require_permission(Permission.READ_FILE)
    async def execute(self, arguments: Dict[str, Any], 
                     context: SecurityContext) -> CallToolResult:
        input_data = AdvancedFileInput(**arguments)
        
        # 检查路径权限
        if not context.can_access_path(input_data.path):
            return ToolResponse.error(
                f"无权访问路径: {input_data.path}",
                "PATH_ACCESS_DENIED"
            )
        
        # 执行文件读取
        try:
            async with aiofiles.open(input_data.path, 'r', 
                                   encoding=input_data.encoding) as f:
                content = await f.read()
            
            return ToolResponse.with_file_content(
                input_data.path, content
            )
        except FileNotFoundError:
            raise ResourceNotFoundError("文件", input_data.path)
        except PermissionError:
            raise PermissionError("读取", input_data.path)

class SecureHttpRequestTool(ToolInterface):
    """安全的HTTP请求工具"""
    
    @property
    def name(self) -> str:
        return "secure_http_request"
    
    @property
    def description(self) -> str:
        return "安全的HTTP请求工具"
    
    def get_schema(self) -> Dict[str, Any]:
        return NetworkRequestInput.schema()
    
    @require_permission(Permission.NETWORK_REQUEST)
    async def execute(self, arguments: Dict[str, Any], 
                     context: SecurityContext) -> CallToolResult:
        input_data = NetworkRequestInput(**arguments)
        
        # 检查域名权限
        if not context.can_access_domain(input_data.url):
            return ToolResponse.error(
                f"无权访问域名: {input_data.url}",
                "DOMAIN_ACCESS_DENIED"
            )
        
        # 执行HTTP请求
        import aiohttp
        
        try:
            async with aiohttp.ClientSession() as session:
                async with session.request(
                    method=input_data.method,
                    url=input_data.url,
                    headers=input_data.headers,
                    timeout=aiohttp.ClientTimeout(total=input_data.timeout)
                ) as response:
                    content = await response.text()
                    
                    result_data = {
                        "status_code": response.status,
                        "headers": dict(response.headers),
                        "content": content
                    }
                    
                    return ToolResponse.success(
                        f"HTTP请求成功: {input_data.method} {input_data.url}",
                        result_data
                    )
        except aiohttp.ClientTimeout:
            raise ToolError(f"请求超时: {input_data.url}", "REQUEST_TIMEOUT")
        except aiohttp.ClientError as e:
            raise ToolError(f"请求失败: {str(e)}", "REQUEST_FAILED")

2. 输入验证和清理

import re
from typing import Any, Dict, List

class InputSanitizer:
    """输入清理器"""
    
    @staticmethod
    def sanitize_path(path: str) -> str:
        """清理文件路径"""
        # 移除危险字符
        path = re.sub(r'[<>:"|?*]', '', path)
        
        # 规范化路径分隔符
        path = path.replace('\\', '/')
        
        # 移除相对路径引用
        path = re.sub(r'\.\./', '', path)
        path = re.sub(r'/\.\./|^\.\./|\.\.$', '', path)
        
        return path.strip()
    
    @staticmethod
    def sanitize_sql(query: str) -> str:
        """清理SQL查询"""
        # 移除注释
        query = re.sub(r'--.*$', '', query, flags=re.MULTILINE)
        query = re.sub(r'/\*.*?\*/', '', query, flags=re.DOTALL)
        
        # 移除多余空白
        query = re.sub(r'\s+', ' ', query)
        
        return query.strip()
    
    @staticmethod
    def sanitize_html(text: str) -> str:
        """清理HTML内容"""
        # 移除脚本标签
        text = re.sub(r'<script[^>]*>.*?</script>', '', text, flags=re.DOTALL | re.IGNORECASE)
        
        # 移除危险属性
        text = re.sub(r'\s*on\w+\s*=\s*["\'][^"\'>]*["\']', '', text, flags=re.IGNORECASE)
        
        # 移除javascript: 协议
        text = re.sub(r'javascript:', '', text, flags=re.IGNORECASE)
        
        return text
    
    @staticmethod
    def validate_email(email: str) -> bool:
        """验证邮箱格式"""
        pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
        return bool(re.match(pattern, email))
    
    @staticmethod
    def validate_url(url: str) -> bool:
        """验证URL格式"""
        pattern = r'^https?://[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}(/.*)?$'
        return bool(re.match(pattern, url))

class RateLimiter:
    """速率限制器"""
    
    def __init__(self):
        self.requests: Dict[str, List[datetime]] = {}
        self.limits = {
            "default": (100, 3600),  # 每小时100次
            "file_operations": (50, 3600),  # 每小时50次
            "network_requests": (200, 3600),  # 每小时200次
        }
    
    def is_allowed(self, user_id: str, operation_type: str = "default") -> bool:
        """检查是否允许操作"""
        key = f"{user_id}:{operation_type}"
        now = datetime.now()
        
        # 获取限制配置
        max_requests, window_seconds = self.limits.get(operation_type, self.limits["default"])
        
        # 初始化用户记录
        if key not in self.requests:
            self.requests[key] = []
        
        # 清理过期记录
        cutoff_time = now - timedelta(seconds=window_seconds)
        self.requests[key] = [
            req_time for req_time in self.requests[key] 
            if req_time > cutoff_time
        ]
        
        # 检查是否超过限制
        if len(self.requests[key]) >= max_requests:
            return False
        
        # 记录本次请求
        self.requests[key].append(now)
        return True
    
    def get_remaining(self, user_id: str, operation_type: str = "default") -> int:
        """获取剩余请求次数"""
        key = f"{user_id}:{operation_type}"
        max_requests, _ = self.limits.get(operation_type, self.limits["default"])
        
        if key not in self.requests:
            return max_requests
        
        return max(0, max_requests - len(self.requests[key]))

# 全局速率限制器
rate_limiter = RateLimiter()

# 速率限制装饰器
def rate_limit(operation_type: str = "default"):
    """速率限制装饰器"""
    def decorator(func):
        @wraps(func)
        async def wrapper(self, arguments: Dict[str, Any], 
                         context: Optional[SecurityContext] = None) -> CallToolResult:
            if context is None:
                return ToolResponse.error(
                    "缺少安全上下文", 
                    "SECURITY_CONTEXT_MISSING"
                )
            
            if not rate_limiter.is_allowed(context.user_id, operation_type):
                remaining = rate_limiter.get_remaining(context.user_id, operation_type)
                return ToolResponse.error(
                    f"请求频率过高,剩余次数: {remaining}",
                    "RATE_LIMIT_EXCEEDED"
                )
            
            return await func(self, arguments, context)
        
        return wrapper
    return decorator

工具监控和日志

1. 工具执行监控

import time
from dataclasses import dataclass
from typing import Optional, List, Dict, Any
from datetime import datetime
import json

@dataclass
class ToolExecutionMetrics:
    """工具执行指标"""
    tool_name: str
    user_id: str
    start_time: datetime
    end_time: Optional[datetime] = None
    duration_ms: Optional[float] = None
    success: Optional[bool] = None
    error_message: Optional[str] = None
    input_size: int = 0
    output_size: int = 0
    memory_usage: Optional[float] = None
    cpu_usage: Optional[float] = None

class ToolMonitor:
    """工具监控器"""
    
    def __init__(self):
        self.metrics: List[ToolExecutionMetrics] = []
        self.active_executions: Dict[str, ToolExecutionMetrics] = {}
    
    def start_execution(self, tool_name: str, user_id: str, 
                       arguments: Dict[str, Any]) -> str:
        """开始执行监控"""
        execution_id = f"{tool_name}_{user_id}_{int(time.time() * 1000)}"
        
        metrics = ToolExecutionMetrics(
            tool_name=tool_name,
            user_id=user_id,
            start_time=datetime.now(),
            input_size=len(json.dumps(arguments, ensure_ascii=False))
        )
        
        self.active_executions[execution_id] = metrics
        return execution_id
    
    def end_execution(self, execution_id: str, success: bool, 
                     result: Optional[CallToolResult] = None,
                     error_message: Optional[str] = None) -> None:
        """结束执行监控"""
        if execution_id not in self.active_executions:
            return
        
        metrics = self.active_executions.pop(execution_id)
        metrics.end_time = datetime.now()
        metrics.duration_ms = (
            metrics.end_time - metrics.start_time
        ).total_seconds() * 1000
        metrics.success = success
        metrics.error_message = error_message
        
        if result and result.content:
            metrics.output_size = sum(
                len(content.text) if hasattr(content, 'text') else 0
                for content in result.content
            )
        
        self.metrics.append(metrics)
        
        # 保持最近1000条记录
        if len(self.metrics) > 1000:
            self.metrics = self.metrics[-1000:]
    
    def get_statistics(self, tool_name: Optional[str] = None, 
                      user_id: Optional[str] = None,
                      hours: int = 24) -> Dict[str, Any]:
        """获取统计信息"""
        cutoff_time = datetime.now() - timedelta(hours=hours)
        
        # 过滤指标
        filtered_metrics = [
            m for m in self.metrics
            if m.start_time > cutoff_time
            and (tool_name is None or m.tool_name == tool_name)
            and (user_id is None or m.user_id == user_id)
        ]
        
        if not filtered_metrics:
            return {"total_executions": 0}
        
        # 计算统计信息
        total_executions = len(filtered_metrics)
        successful_executions = sum(1 for m in filtered_metrics if m.success)
        failed_executions = total_executions - successful_executions
        
        durations = [m.duration_ms for m in filtered_metrics if m.duration_ms is not None]
        avg_duration = sum(durations) / len(durations) if durations else 0
        max_duration = max(durations) if durations else 0
        min_duration = min(durations) if durations else 0
        
        # 按工具分组统计
        tool_stats = {}
        for metrics in filtered_metrics:
            tool = metrics.tool_name
            if tool not in tool_stats:
                tool_stats[tool] = {"total": 0, "success": 0, "failed": 0}
            
            tool_stats[tool]["total"] += 1
            if metrics.success:
                tool_stats[tool]["success"] += 1
            else:
                tool_stats[tool]["failed"] += 1
        
        return {
            "total_executions": total_executions,
            "successful_executions": successful_executions,
            "failed_executions": failed_executions,
            "success_rate": successful_executions / total_executions if total_executions > 0 else 0,
            "average_duration_ms": avg_duration,
            "max_duration_ms": max_duration,
            "min_duration_ms": min_duration,
            "tool_statistics": tool_stats
        }
    
    def get_recent_errors(self, limit: int = 10) -> List[Dict[str, Any]]:
        """获取最近的错误"""
        error_metrics = [
            m for m in reversed(self.metrics)
            if not m.success and m.error_message
        ][:limit]
        
        return [
            {
                "tool_name": m.tool_name,
                "user_id": m.user_id,
                "timestamp": m.start_time.isoformat(),
                "error_message": m.error_message,
                "duration_ms": m.duration_ms
            }
            for m in error_metrics
        ]

# 全局监控器
tool_monitor = ToolMonitor()

# 监控装饰器
def monitor_execution(func):
    """执行监控装饰器"""
    @wraps(func)
    async def wrapper(self, arguments: Dict[str, Any], 
                     context: Optional[SecurityContext] = None) -> CallToolResult:
        user_id = context.user_id if context else "anonymous"
        execution_id = tool_monitor.start_execution(self.name, user_id, arguments)
        
        try:
            result = await func(self, arguments, context)
            tool_monitor.end_execution(
                execution_id, 
                not result.isError, 
                result
            )
            return result
        except Exception as e:
            tool_monitor.end_execution(
                execution_id, 
                False, 
                error_message=str(e)
            )
            raise
    
    return wrapper

2. 结构化日志

import logging
import json
from typing import Any, Dict, Optional
from datetime import datetime

class StructuredLogger:
    """结构化日志记录器"""
    
    def __init__(self, name: str = "mcp_tools"):
        self.logger = logging.getLogger(name)
        self.logger.setLevel(logging.INFO)
        
        # 创建格式化器
        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        
        # 创建处理器
        handler = logging.StreamHandler()
        handler.setFormatter(formatter)
        self.logger.addHandler(handler)
    
    def _log_structured(self, level: str, message: str, **kwargs):
        """记录结构化日志"""
        log_data = {
            "timestamp": datetime.now().isoformat(),
            "level": level,
            "message": message,
            **kwargs
        }
        
        log_message = json.dumps(log_data, ensure_ascii=False, default=str)
        getattr(self.logger, level.lower())(log_message)
    
    def info(self, message: str, **kwargs):
        """记录信息日志"""
        self._log_structured("INFO", message, **kwargs)
    
    def warning(self, message: str, **kwargs):
        """记录警告日志"""
        self._log_structured("WARNING", message, **kwargs)
    
    def error(self, message: str, **kwargs):
        """记录错误日志"""
        self._log_structured("ERROR", message, **kwargs)
    
    def debug(self, message: str, **kwargs):
        """记录调试日志"""
        self._log_structured("DEBUG", message, **kwargs)
    
    def tool_execution(self, tool_name: str, user_id: str, 
                      arguments: Dict[str, Any], 
                      result: Optional[CallToolResult] = None,
                      duration_ms: Optional[float] = None,
                      error: Optional[str] = None):
        """记录工具执行日志"""
        self._log_structured(
            "INFO" if error is None else "ERROR",
            f"Tool execution: {tool_name}",
            tool_name=tool_name,
            user_id=user_id,
            arguments=arguments,
            success=error is None,
            duration_ms=duration_ms,
            error=error,
            result_size=len(result.content) if result and result.content else 0
        )
    
    def security_event(self, event_type: str, user_id: str, 
                      details: Dict[str, Any]):
        """记录安全事件日志"""
        self._log_structured(
            "WARNING",
            f"Security event: {event_type}",
            event_type=event_type,
            user_id=user_id,
            details=details
        )
    
    def performance_warning(self, tool_name: str, duration_ms: float, 
                          threshold_ms: float = 5000):
        """记录性能警告"""
        if duration_ms > threshold_ms:
            self._log_structured(
                "WARNING",
                f"Slow tool execution: {tool_name}",
                tool_name=tool_name,
                duration_ms=duration_ms,
                threshold_ms=threshold_ms
            )

# 全局日志记录器
structured_logger = StructuredLogger()

# 日志记录装饰器
def log_execution(func):
    """执行日志装饰器"""
    @wraps(func)
    async def wrapper(self, arguments: Dict[str, Any], 
                     context: Optional[SecurityContext] = None) -> CallToolResult:
        user_id = context.user_id if context else "anonymous"
        start_time = time.time()
        
        try:
            result = await func(self, arguments, context)
            duration_ms = (time.time() - start_time) * 1000
            
            structured_logger.tool_execution(
                self.name, user_id, arguments, result, duration_ms
            )
            
            # 性能警告
            structured_logger.performance_warning(self.name, duration_ms)
            
            return result
        except Exception as e:
            duration_ms = (time.time() - start_time) * 1000
            
            structured_logger.tool_execution(
                self.name, user_id, arguments, None, duration_ms, str(e)
            )
            
            raise
    
    return wrapper

本章总结

本章深入介绍了MCP协议中工具开发与管理的高级技术,包括:

核心内容

  1. 工具设计原则

    • 单一职责原则
    • 接口一致性原则
    • 分层错误处理
  2. 高级开发模式

    • 工具工厂模式
    • 工具组合模式
    • 工具链模式
  3. 参数验证和类型安全

    • 高级参数验证
    • 类型安全的工具接口
    • 输入清理和验证
  4. 性能优化

    • 异步并发优化
    • 缓存和记忆化
    • 批量处理
  5. 安全性

    • 权限控制系统
    • 输入验证和清理
    • 速率限制
  6. 监控和日志

    • 工具执行监控
    • 结构化日志记录
    • 性能指标收集

最佳实践

  1. 设计阶段

    • 明确工具职责边界
    • 设计一致的接口
    • 考虑安全性和性能
  2. 开发阶段

    • 使用类型安全的接口
    • 实现完善的错误处理
    • 添加适当的缓存机制
  3. 部署阶段

    • 配置权限控制
    • 启用监控和日志
    • 设置合理的限制
  4. 维护阶段

    • 定期检查性能指标
    • 分析错误日志
    • 优化热点工具

下一章我们将学习资源管理与订阅机制,了解如何在MCP协议中管理动态资源和实现实时更新。