提示模板基础

1. 提示模板概念

在MCP协议中,提示(Prompt)是预定义的文本模板,可以包含变量占位符,用于生成动态内容。提示模板具有以下特征:

  • 参数化:支持变量替换和动态内容生成
  • 可重用:同一模板可用于不同场景
  • 结构化:支持复杂的模板结构和逻辑
  • 类型安全:参数类型验证和约束
from typing import Dict, List, Any, Optional, Union
from pydantic import BaseModel, Field, validator
from enum import Enum
from datetime import datetime
import re
import json

class PromptParameterType(Enum):
    """提示参数类型"""
    STRING = "string"
    INTEGER = "integer"
    FLOAT = "float"
    BOOLEAN = "boolean"
    ARRAY = "array"
    OBJECT = "object"
    DATE = "date"
    DATETIME = "datetime"
    EMAIL = "email"
    URL = "url"
    ENUM = "enum"

class PromptParameter(BaseModel):
    """提示参数定义"""
    name: str = Field(description="参数名称")
    type: PromptParameterType = Field(description="参数类型")
    description: Optional[str] = Field(default=None, description="参数描述")
    required: bool = Field(default=True, description="是否必需")
    default_value: Optional[Any] = Field(default=None, description="默认值")
    
    # 约束条件
    min_value: Optional[Union[int, float]] = Field(default=None, description="最小值")
    max_value: Optional[Union[int, float]] = Field(default=None, description="最大值")
    min_length: Optional[int] = Field(default=None, description="最小长度")
    max_length: Optional[int] = Field(default=None, description="最大长度")
    pattern: Optional[str] = Field(default=None, description="正则表达式模式")
    enum_values: Optional[List[str]] = Field(default=None, description="枚举值")
    
    @validator('default_value')
    def validate_default_value(cls, v, values):
        """验证默认值类型"""
        if v is None:
            return v
        
        param_type = values.get('type')
        if param_type == PromptParameterType.STRING and not isinstance(v, str):
            raise ValueError("字符串类型的默认值必须是字符串")
        elif param_type == PromptParameterType.INTEGER and not isinstance(v, int):
            raise ValueError("整数类型的默认值必须是整数")
        elif param_type == PromptParameterType.FLOAT and not isinstance(v, (int, float)):
            raise ValueError("浮点数类型的默认值必须是数字")
        elif param_type == PromptParameterType.BOOLEAN and not isinstance(v, bool):
            raise ValueError("布尔类型的默认值必须是布尔值")
        
        return v

class PromptTemplate(BaseModel):
    """提示模板"""
    name: str = Field(description="模板名称")
    description: Optional[str] = Field(default=None, description="模板描述")
    template: str = Field(description="模板内容")
    parameters: List[PromptParameter] = Field(default_factory=list, description="参数定义")
    tags: List[str] = Field(default_factory=list, description="标签")
    version: str = Field(default="1.0.0", description="版本号")
    created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
    updated_at: datetime = Field(default_factory=datetime.now, description="更新时间")
    
    # 模板配置
    escape_html: bool = Field(default=False, description="是否转义HTML")
    trim_whitespace: bool = Field(default=True, description="是否去除空白字符")
    allow_undefined: bool = Field(default=False, description="是否允许未定义变量")
    
    class Config:
        json_encoders = {
            datetime: lambda v: v.isoformat()
        }

class PromptResult(BaseModel):
    """提示生成结果"""
    content: str = Field(description="生成的内容")
    parameters_used: Dict[str, Any] = Field(description="使用的参数")
    template_name: str = Field(description="模板名称")
    generated_at: datetime = Field(default_factory=datetime.now, description="生成时间")
    metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")

2. 模板语法设计

import re
from typing import Pattern
from string import Template
from jinja2 import Environment, BaseLoader, TemplateError, select_autoescape

class TemplateEngine:
    """模板引擎基类"""
    
    def render(self, template: str, parameters: Dict[str, Any]) -> str:
        """渲染模板"""
        raise NotImplementedError
    
    def validate_template(self, template: str) -> List[str]:
        """验证模板语法,返回错误列表"""
        raise NotImplementedError
    
    def extract_variables(self, template: str) -> List[str]:
        """提取模板中的变量"""
        raise NotImplementedError

class SimpleTemplateEngine(TemplateEngine):
    """简单模板引擎(使用 ${variable} 语法)"""
    
    def __init__(self):
        self.variable_pattern: Pattern = re.compile(r'\$\{([^}]+)\}')
    
    def render(self, template: str, parameters: Dict[str, Any]) -> str:
        """渲染模板"""
        try:
            # 使用Python的string.Template
            template_obj = Template(template)
            return template_obj.safe_substitute(parameters)
        except Exception as e:
            raise ValueError(f"模板渲染失败: {e}")
    
    def validate_template(self, template: str) -> List[str]:
        """验证模板语法"""
        errors = []
        
        try:
            # 检查括号匹配
            open_count = template.count('${')
            close_count = template.count('}')
            if open_count != close_count:
                errors.append(f"括号不匹配: 找到 {open_count} 个 '${{' 和 {close_count} 个 '}}'")
            
            # 检查变量名格式
            variables = self.extract_variables(template)
            for var in variables:
                if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', var):
                    errors.append(f"无效的变量名: {var}")
            
        except Exception as e:
            errors.append(f"模板验证失败: {e}")
        
        return errors
    
    def extract_variables(self, template: str) -> List[str]:
        """提取模板中的变量"""
        matches = self.variable_pattern.findall(template)
        return list(set(matches))  # 去重

class Jinja2TemplateEngine(TemplateEngine):
    """Jinja2模板引擎"""
    
    def __init__(self, auto_escape: bool = True):
        self.env = Environment(
            loader=BaseLoader(),
            autoescape=select_autoescape() if auto_escape else False,
            trim_blocks=True,
            lstrip_blocks=True
        )
        
        # 添加自定义过滤器
        self.env.filters['json'] = json.dumps
        self.env.filters['upper'] = str.upper
        self.env.filters['lower'] = str.lower
        self.env.filters['title'] = str.title
        self.env.filters['length'] = len
    
    def render(self, template: str, parameters: Dict[str, Any]) -> str:
        """渲染模板"""
        try:
            template_obj = self.env.from_string(template)
            return template_obj.render(**parameters)
        except TemplateError as e:
            raise ValueError(f"Jinja2模板渲染失败: {e}")
        except Exception as e:
            raise ValueError(f"模板渲染失败: {e}")
    
    def validate_template(self, template: str) -> List[str]:
        """验证模板语法"""
        errors = []
        
        try:
            # 尝试解析模板
            self.env.from_string(template)
        except TemplateError as e:
            errors.append(f"Jinja2语法错误: {e}")
        except Exception as e:
            errors.append(f"模板验证失败: {e}")
        
        return errors
    
    def extract_variables(self, template: str) -> List[str]:
        """提取模板中的变量"""
        try:
            template_obj = self.env.from_string(template)
            return list(template_obj.get_corresponding_lineno.__globals__.get('meta', {}).get('undeclared_variables', set()))
        except Exception:
            # 如果无法通过AST提取,使用正则表达式
            variable_pattern = re.compile(r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_.]*)\s*\}\}')
            matches = variable_pattern.findall(template)
            return list(set(matches))

class AdvancedTemplateEngine(TemplateEngine):
    """高级模板引擎(支持多种语法)"""
    
    def __init__(self):
        self.simple_engine = SimpleTemplateEngine()
        self.jinja2_engine = Jinja2TemplateEngine()
    
    def detect_template_type(self, template: str) -> str:
        """检测模板类型"""
        # 检查是否包含Jinja2语法
        jinja2_patterns = [
            r'\{\{.*?\}\}',  # 变量
            r'\{%.*?%\}',   # 控制结构
            r'\{#.*?#\}',   # 注释
        ]
        
        for pattern in jinja2_patterns:
            if re.search(pattern, template):
                return 'jinja2'
        
        # 检查是否包含简单模板语法
        if re.search(r'\$\{.*?\}', template):
            return 'simple'
        
        return 'simple'  # 默认使用简单模板
    
    def render(self, template: str, parameters: Dict[str, Any]) -> str:
        """渲染模板"""
        template_type = self.detect_template_type(template)
        
        if template_type == 'jinja2':
            return self.jinja2_engine.render(template, parameters)
        else:
            return self.simple_engine.render(template, parameters)
    
    def validate_template(self, template: str) -> List[str]:
        """验证模板语法"""
        template_type = self.detect_template_type(template)
        
        if template_type == 'jinja2':
            return self.jinja2_engine.validate_template(template)
        else:
            return self.simple_engine.validate_template(template)
    
    def extract_variables(self, template: str) -> List[str]:
        """提取模板中的变量"""
        template_type = self.detect_template_type(template)
        
        if template_type == 'jinja2':
            return self.jinja2_engine.extract_variables(template)
        else:
            return self.simple_engine.extract_variables(template)

提示管理器实现

1. 核心提示管理器

import asyncio
from typing import Dict, List, Optional, Any, Callable
from pathlib import Path
import yaml
import json

class PromptManager:
    """提示管理器"""
    
    def __init__(self, template_engine: Optional[TemplateEngine] = None):
        self.templates: Dict[str, PromptTemplate] = {}
        self.template_engine = template_engine or AdvancedTemplateEngine()
        self.parameter_validators: Dict[str, Callable] = {}
        self.pre_render_hooks: List[Callable] = []
        self.post_render_hooks: List[Callable] = []
    
    def register_template(self, template: PromptTemplate) -> None:
        """注册模板"""
        # 验证模板语法
        errors = self.template_engine.validate_template(template.template)
        if errors:
            raise ValueError(f"模板语法错误: {', '.join(errors)}")
        
        # 验证参数定义与模板变量的一致性
        template_variables = set(self.template_engine.extract_variables(template.template))
        defined_parameters = set(param.name for param in template.parameters)
        
        # 检查未定义的变量
        undefined_variables = template_variables - defined_parameters
        if undefined_variables and not template.allow_undefined:
            raise ValueError(f"模板中包含未定义的变量: {', '.join(undefined_variables)}")
        
        # 检查多余的参数定义
        unused_parameters = defined_parameters - template_variables
        if unused_parameters:
            print(f"警告: 参数定义中包含未使用的参数: {', '.join(unused_parameters)}")
        
        self.templates[template.name] = template
    
    def unregister_template(self, template_name: str) -> bool:
        """注销模板"""
        if template_name in self.templates:
            del self.templates[template_name]
            return True
        return False
    
    def get_template(self, template_name: str) -> Optional[PromptTemplate]:
        """获取模板"""
        return self.templates.get(template_name)
    
    def list_templates(self, tag_filter: Optional[str] = None) -> List[PromptTemplate]:
        """列出模板"""
        templates = list(self.templates.values())
        
        if tag_filter:
            templates = [t for t in templates if tag_filter in t.tags]
        
        return templates
    
    def validate_parameters(self, template_name: str, parameters: Dict[str, Any]) -> List[str]:
        """验证参数"""
        template = self.get_template(template_name)
        if not template:
            return [f"模板不存在: {template_name}"]
        
        errors = []
        
        for param_def in template.parameters:
            param_name = param_def.name
            param_value = parameters.get(param_name)
            
            # 检查必需参数
            if param_def.required and param_value is None:
                if param_def.default_value is not None:
                    parameters[param_name] = param_def.default_value
                else:
                    errors.append(f"缺少必需参数: {param_name}")
                    continue
            
            # 如果参数值为None且有默认值,使用默认值
            if param_value is None and param_def.default_value is not None:
                parameters[param_name] = param_def.default_value
                param_value = param_def.default_value
            
            # 跳过None值的可选参数
            if param_value is None:
                continue
            
            # 类型验证
            type_errors = self._validate_parameter_type(param_def, param_value)
            errors.extend(type_errors)
            
            # 约束验证
            constraint_errors = self._validate_parameter_constraints(param_def, param_value)
            errors.extend(constraint_errors)
            
            # 自定义验证器
            if param_name in self.parameter_validators:
                try:
                    custom_errors = self.parameter_validators[param_name](param_value)
                    if custom_errors:
                        errors.extend(custom_errors)
                except Exception as e:
                    errors.append(f"参数 {param_name} 自定义验证失败: {e}")
        
        return errors
    
    def _validate_parameter_type(self, param_def: PromptParameter, value: Any) -> List[str]:
        """验证参数类型"""
        errors = []
        param_type = param_def.type
        param_name = param_def.name
        
        if param_type == PromptParameterType.STRING:
            if not isinstance(value, str):
                errors.append(f"参数 {param_name} 必须是字符串类型")
        elif param_type == PromptParameterType.INTEGER:
            if not isinstance(value, int):
                errors.append(f"参数 {param_name} 必须是整数类型")
        elif param_type == PromptParameterType.FLOAT:
            if not isinstance(value, (int, float)):
                errors.append(f"参数 {param_name} 必须是数字类型")
        elif param_type == PromptParameterType.BOOLEAN:
            if not isinstance(value, bool):
                errors.append(f"参数 {param_name} 必须是布尔类型")
        elif param_type == PromptParameterType.ARRAY:
            if not isinstance(value, list):
                errors.append(f"参数 {param_name} 必须是数组类型")
        elif param_type == PromptParameterType.OBJECT:
            if not isinstance(value, dict):
                errors.append(f"参数 {param_name} 必须是对象类型")
        elif param_type == PromptParameterType.EMAIL:
            if not isinstance(value, str) or not re.match(r'^[^@]+@[^@]+\.[^@]+$', value):
                errors.append(f"参数 {param_name} 必须是有效的邮箱地址")
        elif param_type == PromptParameterType.URL:
            if not isinstance(value, str) or not re.match(r'^https?://', value):
                errors.append(f"参数 {param_name} 必须是有效的URL")
        elif param_type == PromptParameterType.ENUM:
            if param_def.enum_values and value not in param_def.enum_values:
                errors.append(f"参数 {param_name} 必须是以下值之一: {', '.join(param_def.enum_values)}")
        
        return errors
    
    def _validate_parameter_constraints(self, param_def: PromptParameter, value: Any) -> List[str]:
        """验证参数约束"""
        errors = []
        param_name = param_def.name
        
        # 数值范围约束
        if isinstance(value, (int, float)):
            if param_def.min_value is not None and value < param_def.min_value:
                errors.append(f"参数 {param_name} 不能小于 {param_def.min_value}")
            if param_def.max_value is not None and value > param_def.max_value:
                errors.append(f"参数 {param_name} 不能大于 {param_def.max_value}")
        
        # 长度约束
        if isinstance(value, (str, list)):
            length = len(value)
            if param_def.min_length is not None and length < param_def.min_length:
                errors.append(f"参数 {param_name} 长度不能小于 {param_def.min_length}")
            if param_def.max_length is not None and length > param_def.max_length:
                errors.append(f"参数 {param_name} 长度不能大于 {param_def.max_length}")
        
        # 正则表达式约束
        if isinstance(value, str) and param_def.pattern:
            if not re.match(param_def.pattern, value):
                errors.append(f"参数 {param_name} 不匹配模式: {param_def.pattern}")
        
        return errors
    
    async def render_prompt(self, template_name: str, parameters: Dict[str, Any]) -> PromptResult:
        """渲染提示"""
        template = self.get_template(template_name)
        if not template:
            raise ValueError(f"模板不存在: {template_name}")
        
        # 验证参数
        validation_errors = self.validate_parameters(template_name, parameters)
        if validation_errors:
            raise ValueError(f"参数验证失败: {', '.join(validation_errors)}")
        
        # 执行预渲染钩子
        for hook in self.pre_render_hooks:
            try:
                if asyncio.iscoroutinefunction(hook):
                    await hook(template, parameters)
                else:
                    hook(template, parameters)
            except Exception as e:
                print(f"预渲染钩子执行失败: {e}")
        
        # 渲染模板
        try:
            content = self.template_engine.render(template.template, parameters)
            
            # 后处理
            if template.trim_whitespace:
                content = content.strip()
            
            if template.escape_html:
                import html
                content = html.escape(content)
            
        except Exception as e:
            raise ValueError(f"模板渲染失败: {e}")
        
        # 创建结果
        result = PromptResult(
            content=content,
            parameters_used=parameters.copy(),
            template_name=template_name,
            metadata={
                "template_version": template.version,
                "template_tags": template.tags,
                "render_engine": type(self.template_engine).__name__
            }
        )
        
        # 执行后渲染钩子
        for hook in self.post_render_hooks:
            try:
                if asyncio.iscoroutinefunction(hook):
                    await hook(template, parameters, result)
                else:
                    hook(template, parameters, result)
            except Exception as e:
                print(f"后渲染钩子执行失败: {e}")
        
        return result
    
    def register_parameter_validator(self, parameter_name: str, validator: Callable[[Any], Optional[List[str]]]) -> None:
        """注册参数验证器"""
        self.parameter_validators[parameter_name] = validator
    
    def add_pre_render_hook(self, hook: Callable) -> None:
        """添加预渲染钩子"""
        self.pre_render_hooks.append(hook)
    
    def add_post_render_hook(self, hook: Callable) -> None:
        """添加后渲染钩子"""
        self.post_render_hooks.append(hook)
    
    def export_templates(self, file_path: str, format: str = 'yaml') -> None:
        """导出模板"""
        templates_data = []
        
        for template in self.templates.values():
            template_dict = template.dict()
            # 转换datetime为字符串
            template_dict['created_at'] = template.created_at.isoformat()
            template_dict['updated_at'] = template.updated_at.isoformat()
            templates_data.append(template_dict)
        
        with open(file_path, 'w', encoding='utf-8') as f:
            if format.lower() == 'yaml':
                yaml.dump(templates_data, f, default_flow_style=False, allow_unicode=True)
            elif format.lower() == 'json':
                json.dump(templates_data, f, ensure_ascii=False, indent=2)
            else:
                raise ValueError(f"不支持的格式: {format}")
    
    def import_templates(self, file_path: str, format: str = 'yaml') -> int:
        """导入模板"""
        with open(file_path, 'r', encoding='utf-8') as f:
            if format.lower() == 'yaml':
                templates_data = yaml.safe_load(f)
            elif format.lower() == 'json':
                templates_data = json.load(f)
            else:
                raise ValueError(f"不支持的格式: {format}")
        
        imported_count = 0
        
        for template_dict in templates_data:
            try:
                # 转换datetime字符串
                if 'created_at' in template_dict:
                    template_dict['created_at'] = datetime.fromisoformat(template_dict['created_at'])
                if 'updated_at' in template_dict:
                    template_dict['updated_at'] = datetime.fromisoformat(template_dict['updated_at'])
                
                template = PromptTemplate(**template_dict)
                self.register_template(template)
                imported_count += 1
            except Exception as e:
                print(f"导入模板失败 {template_dict.get('name', 'unknown')}: {e}")
        
        return imported_count

2. 模板库管理

from pathlib import Path
import os
from typing import Dict, List, Set

class TemplateLibrary:
    """模板库管理器"""
    
    def __init__(self, library_path: str):
        self.library_path = Path(library_path)
        self.library_path.mkdir(parents=True, exist_ok=True)
        self.categories: Dict[str, Set[str]] = {}  # category -> template_names
        self.template_files: Dict[str, Path] = {}  # template_name -> file_path
    
    def scan_library(self) -> Dict[str, List[str]]:
        """扫描模板库"""
        self.categories.clear()
        self.template_files.clear()
        
        for file_path in self.library_path.rglob('*.yaml'):
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    template_data = yaml.safe_load(f)
                
                if isinstance(template_data, list):
                    # 多个模板在一个文件中
                    for template_dict in template_data:
                        self._process_template_data(template_dict, file_path)
                else:
                    # 单个模板
                    self._process_template_data(template_data, file_path)
            
            except Exception as e:
                print(f"扫描模板文件失败 {file_path}: {e}")
        
        return {category: list(templates) for category, templates in self.categories.items()}
    
    def _process_template_data(self, template_dict: Dict[str, Any], file_path: Path):
        """处理模板数据"""
        template_name = template_dict.get('name')
        if not template_name:
            return
        
        self.template_files[template_name] = file_path
        
        # 根据文件路径确定分类
        relative_path = file_path.relative_to(self.library_path)
        category = relative_path.parts[0] if len(relative_path.parts) > 1 else 'default'
        
        if category not in self.categories:
            self.categories[category] = set()
        self.categories[category].add(template_name)
    
    def load_template(self, template_name: str) -> Optional[PromptTemplate]:
        """加载模板"""
        file_path = self.template_files.get(template_name)
        if not file_path or not file_path.exists():
            return None
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                template_data = yaml.safe_load(f)
            
            if isinstance(template_data, list):
                # 在多模板文件中查找
                for template_dict in template_data:
                    if template_dict.get('name') == template_name:
                        return self._create_template_from_dict(template_dict)
            else:
                # 单模板文件
                if template_data.get('name') == template_name:
                    return self._create_template_from_dict(template_data)
            
            return None
        
        except Exception as e:
            print(f"加载模板失败 {template_name}: {e}")
            return None
    
    def _create_template_from_dict(self, template_dict: Dict[str, Any]) -> PromptTemplate:
        """从字典创建模板对象"""
        # 转换参数定义
        parameters = []
        for param_dict in template_dict.get('parameters', []):
            param_dict['type'] = PromptParameterType(param_dict['type'])
            parameters.append(PromptParameter(**param_dict))
        
        template_dict['parameters'] = parameters
        
        # 转换datetime字符串
        if 'created_at' in template_dict:
            template_dict['created_at'] = datetime.fromisoformat(template_dict['created_at'])
        if 'updated_at' in template_dict:
            template_dict['updated_at'] = datetime.fromisoformat(template_dict['updated_at'])
        
        return PromptTemplate(**template_dict)
    
    def save_template(self, template: PromptTemplate, category: str = 'default') -> Path:
        """保存模板到库"""
        category_path = self.library_path / category
        category_path.mkdir(parents=True, exist_ok=True)
        
        file_path = category_path / f"{template.name}.yaml"
        
        template_dict = template.dict()
        template_dict['created_at'] = template.created_at.isoformat()
        template_dict['updated_at'] = template.updated_at.isoformat()
        
        with open(file_path, 'w', encoding='utf-8') as f:
            yaml.dump(template_dict, f, default_flow_style=False, allow_unicode=True)
        
        # 更新索引
        self.template_files[template.name] = file_path
        if category not in self.categories:
            self.categories[category] = set()
        self.categories[category].add(template.name)
        
        return file_path
    
    def delete_template(self, template_name: str) -> bool:
        """删除模板"""
        file_path = self.template_files.get(template_name)
        if not file_path or not file_path.exists():
            return False
        
        try:
            file_path.unlink()
            
            # 更新索引
            del self.template_files[template_name]
            for category, templates in self.categories.items():
                templates.discard(template_name)
            
            return True
        
        except Exception as e:
            print(f"删除模板失败 {template_name}: {e}")
            return False
    
    def list_templates_by_category(self, category: str) -> List[str]:
        """按分类列出模板"""
        return list(self.categories.get(category, set()))
    
    def search_templates(self, query: str) -> List[str]:
        """搜索模板"""
        results = []
        query_lower = query.lower()
        
        for template_name in self.template_files.keys():
            if query_lower in template_name.lower():
                results.append(template_name)
        
        return results
    
    def get_template_info(self, template_name: str) -> Optional[Dict[str, Any]]:
        """获取模板信息"""
        template = self.load_template(template_name)
        if not template:
            return None
        
        file_path = self.template_files.get(template_name)
        category = None
        
        for cat, templates in self.categories.items():
            if template_name in templates:
                category = cat
                break
        
        return {
            "name": template.name,
            "description": template.description,
            "version": template.version,
            "tags": template.tags,
            "parameter_count": len(template.parameters),
            "created_at": template.created_at.isoformat(),
            "updated_at": template.updated_at.isoformat(),
            "file_path": str(file_path) if file_path else None,
            "category": category
        }

动态提示生成

1. 上下文感知提示生成

from typing import Protocol, runtime_checkable
from dataclasses import dataclass
from abc import ABC, abstractmethod

@runtime_checkable
class ContextProvider(Protocol):
    """上下文提供者协议"""
    
    def get_context(self) -> Dict[str, Any]:
        """获取上下文信息"""
        ...

@dataclass
class UserContext:
    """用户上下文"""
    user_id: str
    username: str
    role: str
    preferences: Dict[str, Any]
    session_data: Dict[str, Any]

@dataclass
class SystemContext:
    """系统上下文"""
    timestamp: datetime
    system_version: str
    environment: str
    resource_usage: Dict[str, Any]

@dataclass
class ConversationContext:
    """对话上下文"""
    conversation_id: str
    message_history: List[Dict[str, Any]]
    current_topic: Optional[str]
    intent: Optional[str]
    entities: Dict[str, Any]

class DynamicPromptGenerator:
    """动态提示生成器"""
    
    def __init__(self, prompt_manager: PromptManager):
        self.prompt_manager = prompt_manager
        self.context_providers: Dict[str, ContextProvider] = {}
        self.template_selectors: List[Callable] = []
        self.parameter_enrichers: List[Callable] = []
    
    def register_context_provider(self, name: str, provider: ContextProvider):
        """注册上下文提供者"""
        self.context_providers[name] = provider
    
    def add_template_selector(self, selector: Callable[[Dict[str, Any]], Optional[str]]):
        """添加模板选择器"""
        self.template_selectors.append(selector)
    
    def add_parameter_enricher(self, enricher: Callable[[Dict[str, Any], Dict[str, Any]], None]):
        """添加参数丰富器"""
        self.parameter_enrichers.append(enricher)
    
    async def generate_prompt(self, 
                            base_parameters: Dict[str, Any],
                            template_hint: Optional[str] = None) -> PromptResult:
        """生成动态提示"""
        # 收集上下文
        context = await self._collect_context()
        
        # 选择模板
        template_name = await self._select_template(context, template_hint)
        if not template_name:
            raise ValueError("无法选择合适的模板")
        
        # 丰富参数
        enriched_parameters = await self._enrich_parameters(base_parameters, context)
        
        # 渲染提示
        return await self.prompt_manager.render_prompt(template_name, enriched_parameters)
    
    async def _collect_context(self) -> Dict[str, Any]:
        """收集上下文信息"""
        context = {}
        
        for name, provider in self.context_providers.items():
            try:
                provider_context = provider.get_context()
                context[name] = provider_context
            except Exception as e:
                print(f"获取上下文失败 {name}: {e}")
                context[name] = {}
        
        return context
    
    async def _select_template(self, context: Dict[str, Any], hint: Optional[str]) -> Optional[str]:
        """选择模板"""
        # 如果有提示,优先使用
        if hint and self.prompt_manager.get_template(hint):
            return hint
        
        # 使用模板选择器
        for selector in self.template_selectors:
            try:
                selected = selector(context)
                if selected and self.prompt_manager.get_template(selected):
                    return selected
            except Exception as e:
                print(f"模板选择器执行失败: {e}")
        
        # 默认选择
        templates = self.prompt_manager.list_templates()
        if templates:
            return templates[0].name
        
        return None
    
    async def _enrich_parameters(self, base_parameters: Dict[str, Any], 
                               context: Dict[str, Any]) -> Dict[str, Any]:
        """丰富参数"""
        enriched = base_parameters.copy()
        
        # 添加上下文信息
        enriched['_context'] = context
        enriched['_timestamp'] = datetime.now().isoformat()
        
        # 使用参数丰富器
        for enricher in self.parameter_enrichers:
            try:
                enricher(enriched, context)
            except Exception as e:
                print(f"参数丰富器执行失败: {e}")
        
        return enriched

class UserContextProvider:
    """用户上下文提供者"""
    
    def __init__(self, user_service):
        self.user_service = user_service
    
    def get_context(self) -> Dict[str, Any]:
        """获取用户上下文"""
        # 这里应该从用户服务获取实际数据
        return {
            "user_id": "user123",
            "username": "john_doe",
            "role": "admin",
            "preferences": {
                "language": "zh-CN",
                "theme": "dark",
                "timezone": "Asia/Shanghai"
            },
            "session_data": {
                "login_time": datetime.now().isoformat(),
                "last_activity": datetime.now().isoformat()
            }
        }

class SystemContextProvider:
    """系统上下文提供者"""
    
    def get_context(self) -> Dict[str, Any]:
        """获取系统上下文"""
        import psutil
        
        return {
            "timestamp": datetime.now().isoformat(),
            "system_version": "1.0.0",
            "environment": "production",
            "resource_usage": {
                "cpu_percent": psutil.cpu_percent(),
                "memory_percent": psutil.virtual_memory().percent,
                "disk_usage": psutil.disk_usage('/').percent
            }
        }

class ConversationContextProvider:
    """对话上下文提供者"""
    
    def __init__(self, conversation_service):
        self.conversation_service = conversation_service
    
    def get_context(self) -> Dict[str, Any]:
        """获取对话上下文"""
        # 这里应该从对话服务获取实际数据
        return {
            "conversation_id": "conv123",
            "message_history": [
                {"role": "user", "content": "你好", "timestamp": "2024-01-01T10:00:00"},
                {"role": "assistant", "content": "你好!有什么可以帮助你的吗?", "timestamp": "2024-01-01T10:00:01"}
            ],
            "current_topic": "greeting",
            "intent": "greeting",
            "entities": {}
        }

2. 智能模板选择

import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Tuple

class IntelligentTemplateSelector:
    """智能模板选择器"""
    
    def __init__(self, prompt_manager: PromptManager):
        self.prompt_manager = prompt_manager
        self.vectorizer = TfidfVectorizer(stop_words='english')
        self.template_vectors = None
        self.template_names = []
        self.selection_history: List[Tuple[str, str, float]] = []  # (context, template, score)
    
    def build_template_index(self):
        """构建模板索引"""
        templates = self.prompt_manager.list_templates()
        
        if not templates:
            return
        
        # 提取模板文本特征
        template_texts = []
        self.template_names = []
        
        for template in templates:
            # 组合模板名称、描述和标签作为特征文本
            feature_text = f"{template.name} {template.description or ''} {' '.join(template.tags)}"
            template_texts.append(feature_text)
            self.template_names.append(template.name)
        
        # 构建TF-IDF向量
        if template_texts:
            self.template_vectors = self.vectorizer.fit_transform(template_texts)
    
    def select_template_by_similarity(self, query: str, top_k: int = 3) -> List[Tuple[str, float]]:
        """基于相似度选择模板"""
        if self.template_vectors is None:
            self.build_template_index()
        
        if self.template_vectors is None or not query.strip():
            return []
        
        # 将查询转换为向量
        query_vector = self.vectorizer.transform([query])
        
        # 计算相似度
        similarities = cosine_similarity(query_vector, self.template_vectors).flatten()
        
        # 获取top-k结果
        top_indices = np.argsort(similarities)[::-1][:top_k]
        
        results = []
        for idx in top_indices:
            if similarities[idx] > 0:  # 只返回有相似度的结果
                results.append((self.template_names[idx], float(similarities[idx])))
        
        return results
    
    def select_template_by_context(self, context: Dict[str, Any]) -> Optional[str]:
        """基于上下文选择模板"""
        # 提取上下文特征
        context_features = self._extract_context_features(context)
        
        # 基于规则的选择
        rule_based_template = self._rule_based_selection(context_features)
        if rule_based_template:
            return rule_based_template
        
        # 基于历史的选择
        history_based_template = self._history_based_selection(context_features)
        if history_based_template:
            return history_based_template
        
        # 基于相似度的选择
        query = self._context_to_query(context_features)
        similarity_results = self.select_template_by_similarity(query, top_k=1)
        
        if similarity_results:
            return similarity_results[0][0]
        
        return None
    
    def _extract_context_features(self, context: Dict[str, Any]) -> Dict[str, Any]:
        """提取上下文特征"""
        features = {}
        
        # 用户特征
        if 'user' in context:
            user_context = context['user']
            features['user_role'] = user_context.get('role', 'unknown')
            features['user_language'] = user_context.get('preferences', {}).get('language', 'en')
        
        # 系统特征
        if 'system' in context:
            system_context = context['system']
            features['environment'] = system_context.get('environment', 'unknown')
            features['system_load'] = 'high' if system_context.get('resource_usage', {}).get('cpu_percent', 0) > 80 else 'normal'
        
        # 对话特征
        if 'conversation' in context:
            conv_context = context['conversation']
            features['conversation_length'] = len(conv_context.get('message_history', []))
            features['current_intent'] = conv_context.get('intent', 'unknown')
            features['current_topic'] = conv_context.get('current_topic', 'general')
        
        return features
    
    def _rule_based_selection(self, features: Dict[str, Any]) -> Optional[str]:
        """基于规则的模板选择"""
        # 示例规则
        if features.get('current_intent') == 'greeting':
            return self._find_template_by_tag('greeting')
        
        if features.get('user_role') == 'admin':
            admin_template = self._find_template_by_tag('admin')
            if admin_template:
                return admin_template
        
        if features.get('system_load') == 'high':
            return self._find_template_by_tag('simple')
        
        return None
    
    def _find_template_by_tag(self, tag: str) -> Optional[str]:
        """根据标签查找模板"""
        templates = self.prompt_manager.list_templates(tag_filter=tag)
        return templates[0].name if templates else None
    
    def _history_based_selection(self, features: Dict[str, Any]) -> Optional[str]:
        """基于历史的模板选择"""
        if not self.selection_history:
            return None
        
        # 简单的基于频率的选择
        feature_str = str(sorted(features.items()))
        
        # 统计历史选择
        template_scores = {}
        for hist_context, hist_template, hist_score in self.selection_history:
            if hist_context == feature_str:
                template_scores[hist_template] = template_scores.get(hist_template, 0) + hist_score
        
        if template_scores:
            best_template = max(template_scores.items(), key=lambda x: x[1])[0]
            return best_template
        
        return None
    
    def _context_to_query(self, features: Dict[str, Any]) -> str:
        """将上下文特征转换为查询字符串"""
        query_parts = []
        
        for key, value in features.items():
            if isinstance(value, str):
                query_parts.append(f"{key} {value}")
            elif isinstance(value, (int, float)):
                query_parts.append(f"{key} {value}")
        
        return ' '.join(query_parts)
    
    def record_selection(self, context: Dict[str, Any], template_name: str, score: float = 1.0):
        """记录模板选择历史"""
        features = self._extract_context_features(context)
        feature_str = str(sorted(features.items()))
        
        self.selection_history.append((feature_str, template_name, score))
        
        # 限制历史记录数量
        if len(self.selection_history) > 1000:
            self.selection_history = self.selection_history[-800:]
    
    def get_selection_stats(self) -> Dict[str, Any]:
        """获取选择统计信息"""
        if not self.selection_history:
            return {"total_selections": 0}
        
        template_counts = {}
        total_score = 0
        
        for _, template, score in self.selection_history:
            template_counts[template] = template_counts.get(template, 0) + 1
            total_score += score
        
        most_used = max(template_counts.items(), key=lambda x: x[1]) if template_counts else None
        
        return {
            "total_selections": len(self.selection_history),
            "unique_templates": len(template_counts),
            "average_score": total_score / len(self.selection_history),
            "most_used_template": most_used[0] if most_used else None,
            "most_used_count": most_used[1] if most_used else 0,
            "template_usage": template_counts
        }

本章总结

本章详细介绍了MCP协议中的提示模板与动态生成机制,包括:

核心内容

  1. 提示模板基础

    • 模板概念和参数类型定义
    • 模板语法设计(简单模板、Jinja2、高级模板引擎)
    • 参数验证和约束机制
  2. 提示管理器实现

    • 核心提示管理器功能
    • 模板注册、验证和渲染
    • 钩子机制和扩展性
    • 模板导入导出功能
  3. 模板库管理

    • 文件系统模板库
    • 分类管理和搜索功能
    • 模板版本控制
  4. 动态提示生成

    • 上下文感知的提示生成
    • 智能模板选择算法
    • 基于机器学习的模板推荐

最佳实践

  1. 模板设计

    • 使用清晰的参数命名
    • 提供详细的参数描述和约束
    • 合理使用默认值
    • 考虑模板的可重用性
  2. 性能优化

    • 缓存编译后的模板
    • 使用异步渲染
    • 优化模板选择算法
  3. 错误处理

    • 完善的参数验证
    • 优雅的模板错误处理
    • 提供有意义的错误信息
  4. 扩展性

    • 支持多种模板引擎
    • 提供钩子机制
    • 支持自定义验证器

下一章我们将学习安全性与权限控制,了解如何在MCP协议中实现安全的通信和访问控制。