本章概述

检索与生成是RAG系统的核心环节,决定了系统的最终效果。本章将深入介绍检索策略优化、重排序技术、生成模型集成、以及如何构建端到端的RAG管道。

学习目标

  • 掌握多种检索策略和优化技术
  • 学习重排序算法提升检索精度
  • 了解生成模型集成和提示工程
  • 熟悉RAG管道的构建和优化
  • 掌握检索增强的生成技术

1. 高级检索策略

1.1 混合检索系统

# src/retrieval/hybrid_retriever.py - 混合检索器
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Union, Tuple
import numpy as np
from dataclasses import dataclass
import logging
from ..vectorstore.vector_manager import VectorSearchResult, BaseVectorStore
from ..chunking.base_chunker import Chunk

@dataclass
class RetrievalConfig:
    """检索配置"""
    top_k: int = 10
    similarity_threshold: float = 0.7
    enable_reranking: bool = True
    rerank_top_k: int = 50
    hybrid_weights: Dict[str, float] = None
    filter_duplicates: bool = True
    max_content_length: int = 4000
    
    def __post_init__(self):
        if self.hybrid_weights is None:
            self.hybrid_weights = {
                "semantic": 0.7,
                "keyword": 0.3
            }

class BaseRetriever(ABC):
    """检索器基类"""
    
    def __init__(self, config: RetrievalConfig):
        self.config = config
    
    @abstractmethod
    def retrieve(self, query: str, **kwargs) -> List[VectorSearchResult]:
        """检索相关文档"""
        pass
    
    def _filter_results(self, results: List[VectorSearchResult]) -> List[VectorSearchResult]:
        """过滤检索结果"""
        # 按相似度阈值过滤
        filtered = [r for r in results if r.score >= self.config.similarity_threshold]
        
        # 去重
        if self.config.filter_duplicates:
            filtered = self._remove_duplicates(filtered)
        
        # 限制内容长度
        for result in filtered:
            if len(result.content) > self.config.max_content_length:
                result.content = result.content[:self.config.max_content_length] + "..."
        
        return filtered[:self.config.top_k]
    
    def _remove_duplicates(self, results: List[VectorSearchResult]) -> List[VectorSearchResult]:
        """移除重复结果"""
        seen_content = set()
        unique_results = []
        
        for result in results:
            content_hash = hash(result.content)
            if content_hash not in seen_content:
                seen_content.add(content_hash)
                unique_results.append(result)
        
        return unique_results

class SemanticRetriever(BaseRetriever):
    """语义检索器"""
    
    def __init__(self, config: RetrievalConfig, vector_store: BaseVectorStore, embedding_model):
        super().__init__(config)
        self.vector_store = vector_store
        self.embedding_model = embedding_model
    
    def retrieve(self, query: str, **kwargs) -> List[VectorSearchResult]:
        """语义检索"""
        try:
            # 生成查询向量
            query_embedding = self.embedding_model.encode([query])[0]
            
            # 向量搜索
            results = self.vector_store.search(
                query_vector=query_embedding,
                top_k=self.config.rerank_top_k if self.config.enable_reranking else self.config.top_k,
                filter_dict=kwargs.get('filter_dict')
            )
            
            return self._filter_results(results)
            
        except Exception as e:
            logging.error(f"语义检索失败: {e}")
            return []

class KeywordRetriever(BaseRetriever):
    """关键词检索器"""
    
    def __init__(self, config: RetrievalConfig, documents: List[Chunk]):
        super().__init__(config)
        self.documents = documents
        self.tfidf_vectorizer = None
        self.tfidf_matrix = None
        self._build_index()
    
    def _build_index(self):
        """构建TF-IDF索引"""
        try:
            from sklearn.feature_extraction.text import TfidfVectorizer
            from sklearn.metrics.pairwise import cosine_similarity
            
            # 提取文档内容
            doc_contents = [doc.content for doc in self.documents]
            
            # 构建TF-IDF向量化器
            self.tfidf_vectorizer = TfidfVectorizer(
                max_features=10000,
                stop_words='english',  # 可根据语言调整
                ngram_range=(1, 2),
                min_df=2,
                max_df=0.8
            )
            
            # 构建TF-IDF矩阵
            self.tfidf_matrix = self.tfidf_vectorizer.fit_transform(doc_contents)
            
            logging.info(f"已构建TF-IDF索引,文档数: {len(doc_contents)}")
            
        except ImportError:
            raise RuntimeError("请安装scikit-learn库: pip install scikit-learn")
        except Exception as e:
            logging.error(f"构建TF-IDF索引失败: {e}")
    
    def retrieve(self, query: str, **kwargs) -> List[VectorSearchResult]:
        """关键词检索"""
        if self.tfidf_vectorizer is None or self.tfidf_matrix is None:
            return []
        
        try:
            from sklearn.metrics.pairwise import cosine_similarity
            
            # 向量化查询
            query_vector = self.tfidf_vectorizer.transform([query])
            
            # 计算相似度
            similarities = cosine_similarity(query_vector, self.tfidf_matrix).flatten()
            
            # 获取top-k结果
            top_indices = np.argsort(similarities)[::-1][:self.config.top_k * 2]  # 多取一些用于过滤
            
            results = []
            for idx in top_indices:
                if similarities[idx] > 0:  # 只保留有相似度的结果
                    doc = self.documents[idx]
                    result = VectorSearchResult(
                        chunk_id=doc.chunk_id,
                        content=doc.content,
                        metadata=doc.metadata,
                        score=float(similarities[idx])
                    )
                    results.append(result)
            
            return self._filter_results(results)
            
        except Exception as e:
            logging.error(f"关键词检索失败: {e}")
            return []

class HybridRetriever(BaseRetriever):
    """混合检索器"""
    
    def __init__(self, 
                 config: RetrievalConfig,
                 semantic_retriever: SemanticRetriever,
                 keyword_retriever: KeywordRetriever):
        super().__init__(config)
        self.semantic_retriever = semantic_retriever
        self.keyword_retriever = keyword_retriever
    
    def retrieve(self, query: str, **kwargs) -> List[VectorSearchResult]:
        """混合检索"""
        # 语义检索
        semantic_results = self.semantic_retriever.retrieve(query, **kwargs)
        
        # 关键词检索
        keyword_results = self.keyword_retriever.retrieve(query, **kwargs)
        
        # 合并和重新评分
        combined_results = self._combine_results(semantic_results, keyword_results)
        
        return self._filter_results(combined_results)
    
    def _combine_results(self, 
                        semantic_results: List[VectorSearchResult],
                        keyword_results: List[VectorSearchResult]) -> List[VectorSearchResult]:
        """合并检索结果"""
        # 创建结果字典
        result_dict = {}
        
        # 添加语义检索结果
        semantic_weight = self.config.hybrid_weights.get("semantic", 0.7)
        for result in semantic_results:
            result_dict[result.chunk_id] = result
            result.score *= semantic_weight
        
        # 添加关键词检索结果
        keyword_weight = self.config.hybrid_weights.get("keyword", 0.3)
        for result in keyword_results:
            if result.chunk_id in result_dict:
                # 合并分数
                result_dict[result.chunk_id].score += result.score * keyword_weight
            else:
                result.score *= keyword_weight
                result_dict[result.chunk_id] = result
        
        # 按分数排序
        combined_results = list(result_dict.values())
        combined_results.sort(key=lambda x: x.score, reverse=True)
        
        return combined_results

class QueryExpansionRetriever(BaseRetriever):
    """查询扩展检索器"""
    
    def __init__(self, 
                 config: RetrievalConfig,
                 base_retriever: BaseRetriever,
                 expansion_model=None):
        super().__init__(config)
        self.base_retriever = base_retriever
        self.expansion_model = expansion_model
    
    def retrieve(self, query: str, **kwargs) -> List[VectorSearchResult]:
        """查询扩展检索"""
        # 扩展查询
        expanded_queries = self._expand_query(query)
        
        all_results = []
        
        # 对每个扩展查询进行检索
        for expanded_query in expanded_queries:
            results = self.base_retriever.retrieve(expanded_query, **kwargs)
            all_results.extend(results)
        
        # 去重和重新排序
        unique_results = self._remove_duplicates(all_results)
        unique_results.sort(key=lambda x: x.score, reverse=True)
        
        return self._filter_results(unique_results)
    
    def _expand_query(self, query: str) -> List[str]:
        """扩展查询"""
        expanded_queries = [query]  # 原始查询
        
        try:
            # 使用同义词扩展
            synonyms = self._get_synonyms(query)
            for synonym in synonyms[:3]:  # 限制同义词数量
                expanded_queries.append(f"{query} {synonym}")
            
            # 使用相关词扩展
            if self.expansion_model:
                related_terms = self._get_related_terms(query)
                for term in related_terms[:2]:
                    expanded_queries.append(f"{query} {term}")
        
        except Exception as e:
            logging.warning(f"查询扩展失败: {e}")
        
        return expanded_queries
    
    def _get_synonyms(self, query: str) -> List[str]:
        """获取同义词"""
        # 简单的同义词映射(实际应用中可使用WordNet等)
        synonym_dict = {
            "问题": ["疑问", "难题", "课题"],
            "方法": ["办法", "途径", "手段"],
            "技术": ["科技", "工艺", "技能"],
            "系统": ["体系", "制度", "机制"]
        }
        
        synonyms = []
        for word in query.split():
            if word in synonym_dict:
                synonyms.extend(synonym_dict[word])
        
        return synonyms
    
    def _get_related_terms(self, query: str) -> List[str]:
        """获取相关词(使用模型)"""
        # 这里可以集成词向量模型或语言模型来获取相关词
        return []

2. 重排序技术

2.1 重排序模型

# src/retrieval/reranker.py - 重排序器
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
import numpy as np
from dataclasses import dataclass
import logging
from ..vectorstore.vector_manager import VectorSearchResult

@dataclass
class RerankConfig:
    """重排序配置"""
    model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
    batch_size: int = 16
    max_length: int = 512
    device: str = "cpu"
    score_threshold: float = 0.0

class BaseReranker(ABC):
    """重排序器基类"""
    
    def __init__(self, config: RerankConfig):
        self.config = config
    
    @abstractmethod
    def rerank(self, query: str, results: List[VectorSearchResult]) -> List[VectorSearchResult]:
        """重排序结果"""
        pass

class CrossEncoderReranker(BaseReranker):
    """交叉编码器重排序"""
    
    def __init__(self, config: RerankConfig):
        super().__init__(config)
        self.model = None
        self._load_model()
    
    def _load_model(self):
        """加载交叉编码器模型"""
        try:
            from sentence_transformers import CrossEncoder
            
            self.model = CrossEncoder(
                self.config.model_name,
                device=self.config.device
            )
            
            logging.info(f"已加载交叉编码器: {self.config.model_name}")
            
        except ImportError:
            raise RuntimeError("请安装sentence-transformers库")
        except Exception as e:
            logging.error(f"加载交叉编码器失败: {e}")
            self.model = None
    
    def rerank(self, query: str, results: List[VectorSearchResult]) -> List[VectorSearchResult]:
        """使用交叉编码器重排序"""
        if not self.model or not results:
            return results
        
        try:
            # 准备输入对
            query_doc_pairs = [(query, result.content) for result in results]
            
            # 批量预测相关性分数
            scores = self.model.predict(
                query_doc_pairs,
                batch_size=self.config.batch_size,
                show_progress_bar=False
            )
            
            # 更新结果分数
            for result, score in zip(results, scores):
                result.score = float(score)
            
            # 过滤低分结果
            filtered_results = [
                result for result in results 
                if result.score >= self.config.score_threshold
            ]
            
            # 按新分数排序
            filtered_results.sort(key=lambda x: x.score, reverse=True)
            
            return filtered_results
            
        except Exception as e:
            logging.error(f"重排序失败: {e}")
            return results

class LLMReranker(BaseReranker):
    """LLM重排序器"""
    
    def __init__(self, config: RerankConfig, llm_client):
        super().__init__(config)
        self.llm_client = llm_client
    
    def rerank(self, query: str, results: List[VectorSearchResult]) -> List[VectorSearchResult]:
        """使用LLM重排序"""
        if not results:
            return results
        
        try:
            # 构建重排序提示
            prompt = self._build_rerank_prompt(query, results)
            
            # 调用LLM
            response = self.llm_client.generate(prompt)
            
            # 解析排序结果
            reranked_results = self._parse_rerank_response(response, results)
            
            return reranked_results
            
        except Exception as e:
            logging.error(f"LLM重排序失败: {e}")
            return results
    
    def _build_rerank_prompt(self, query: str, results: List[VectorSearchResult]) -> str:
        """构建重排序提示"""
        prompt = f"""请根据查询对以下文档片段进行相关性排序。

查询: {query}

文档片段:
"""
        
        for i, result in enumerate(results):
            prompt += f"{i+1}. {result.content[:200]}...\n\n"
        
        prompt += """请按相关性从高到低排序,只返回序号列表,用逗号分隔。例如: 3,1,5,2,4

排序结果:"""
        
        return prompt
    
    def _parse_rerank_response(self, response: str, results: List[VectorSearchResult]) -> List[VectorSearchResult]:
        """解析重排序响应"""
        try:
            # 提取序号
            indices = [int(x.strip()) - 1 for x in response.strip().split(',')]
            
            # 重新排序
            reranked_results = []
            for idx in indices:
                if 0 <= idx < len(results):
                    reranked_results.append(results[idx])
            
            # 添加未包含的结果
            included_indices = set(indices)
            for i, result in enumerate(results):
                if i not in included_indices:
                    reranked_results.append(result)
            
            return reranked_results
            
        except Exception as e:
            logging.error(f"解析重排序响应失败: {e}")
            return results

class FeatureBasedReranker(BaseReranker):
    """基于特征的重排序器"""
    
    def __init__(self, config: RerankConfig):
        super().__init__(config)
        self.feature_weights = {
            "semantic_similarity": 0.4,
            "keyword_overlap": 0.2,
            "length_penalty": 0.1,
            "position_bonus": 0.1,
            "freshness_bonus": 0.1,
            "source_authority": 0.1
        }
    
    def rerank(self, query: str, results: List[VectorSearchResult]) -> List[VectorSearchResult]:
        """基于多特征重排序"""
        if not results:
            return results
        
        try:
            # 计算各种特征分数
            for i, result in enumerate(results):
                features = self._extract_features(query, result, i, len(results))
                
                # 计算加权分数
                weighted_score = sum(
                    features[feature] * weight 
                    for feature, weight in self.feature_weights.items()
                    if feature in features
                )
                
                result.score = weighted_score
            
            # 按新分数排序
            results.sort(key=lambda x: x.score, reverse=True)
            
            return results
            
        except Exception as e:
            logging.error(f"特征重排序失败: {e}")
            return results
    
    def _extract_features(self, query: str, result: VectorSearchResult, position: int, total: int) -> Dict[str, float]:
        """提取特征"""
        features = {}
        
        # 语义相似度(原始分数)
        features["semantic_similarity"] = result.score
        
        # 关键词重叠度
        features["keyword_overlap"] = self._calculate_keyword_overlap(query, result.content)
        
        # 长度惩罚(过短或过长的文档)
        features["length_penalty"] = self._calculate_length_penalty(result.content)
        
        # 位置奖励(原始排序位置)
        features["position_bonus"] = 1.0 - (position / total)
        
        # 新鲜度奖励
        features["freshness_bonus"] = self._calculate_freshness_bonus(result.metadata)
        
        # 来源权威性
        features["source_authority"] = self._calculate_source_authority(result.metadata)
        
        return features
    
    def _calculate_keyword_overlap(self, query: str, content: str) -> float:
        """计算关键词重叠度"""
        query_words = set(query.lower().split())
        content_words = set(content.lower().split())
        
        if not query_words:
            return 0.0
        
        overlap = len(query_words.intersection(content_words))
        return overlap / len(query_words)
    
    def _calculate_length_penalty(self, content: str) -> float:
        """计算长度惩罚"""
        length = len(content)
        
        # 理想长度范围
        ideal_min, ideal_max = 100, 1000
        
        if ideal_min <= length <= ideal_max:
            return 1.0
        elif length < ideal_min:
            return length / ideal_min
        else:
            return ideal_max / length
    
    def _calculate_freshness_bonus(self, metadata: Dict[str, Any]) -> float:
        """计算新鲜度奖励"""
        # 根据文档创建时间计算新鲜度
        created_time = metadata.get("created_time")
        if not created_time:
            return 0.5  # 默认值
        
        try:
            from datetime import datetime, timedelta
            
            if isinstance(created_time, str):
                created_time = datetime.fromisoformat(created_time)
            
            age_days = (datetime.now() - created_time).days
            
            # 30天内的文档获得满分,之后逐渐衰减
            if age_days <= 30:
                return 1.0
            elif age_days <= 365:
                return 1.0 - (age_days - 30) / 335 * 0.5
            else:
                return 0.5
                
        except Exception:
            return 0.5
    
    def _calculate_source_authority(self, metadata: Dict[str, Any]) -> float:
        """计算来源权威性"""
        source = metadata.get("source", "")
        
        # 权威来源评分
        authority_scores = {
            "official": 1.0,
            "academic": 0.9,
            "news": 0.7,
            "blog": 0.5,
            "forum": 0.3
        }
        
        for source_type, score in authority_scores.items():
            if source_type in source.lower():
                return score
        
        return 0.5  # 默认值

class RerankingPipeline:
    """重排序管道"""
    
    def __init__(self, rerankers: List[BaseReranker]):
        self.rerankers = rerankers
    
    def rerank(self, query: str, results: List[VectorSearchResult]) -> List[VectorSearchResult]:
        """多阶段重排序"""
        current_results = results
        
        for reranker in self.rerankers:
            try:
                current_results = reranker.rerank(query, current_results)
                logging.info(f"完成 {reranker.__class__.__name__} 重排序")
            except Exception as e:
                logging.error(f"{reranker.__class__.__name__} 重排序失败: {e}")
        
        return current_results

3. 生成模型集成

3.1 生成器管理

# src/generation/generator.py - 生成器
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass
import logging
from ..vectorstore.vector_manager import VectorSearchResult

@dataclass
class GenerationConfig:
    """生成配置"""
    model_name: str = "gpt-3.5-turbo"
    max_tokens: int = 1000
    temperature: float = 0.7
    top_p: float = 0.9
    frequency_penalty: float = 0.0
    presence_penalty: float = 0.0
    stop_sequences: List[str] = None
    
    def __post_init__(self):
        if self.stop_sequences is None:
            self.stop_sequences = []

@dataclass
class GenerationResult:
    """生成结果"""
    content: str
    metadata: Dict[str, Any]
    sources: List[VectorSearchResult]
    confidence: float = 0.0
    reasoning: str = ""

class BaseGenerator(ABC):
    """生成器基类"""
    
    def __init__(self, config: GenerationConfig):
        self.config = config
    
    @abstractmethod
    def generate(self, 
                query: str, 
                context: List[VectorSearchResult],
                **kwargs) -> GenerationResult:
        """生成回答"""
        pass
    
    def _build_context(self, results: List[VectorSearchResult]) -> str:
        """构建上下文"""
        context_parts = []
        
        for i, result in enumerate(results):
            source_info = result.metadata.get("source", "未知来源")
            context_parts.append(
                f"[文档{i+1}] 来源: {source_info}\n{result.content}\n"
            )
        
        return "\n".join(context_parts)
    
    def _build_prompt(self, query: str, context: str) -> str:
        """构建提示"""
        prompt = f"""请基于以下上下文信息回答用户问题。如果上下文中没有相关信息,请明确说明。

上下文信息:
{context}

用户问题: {query}

请提供准确、有用的回答:"""
        
        return prompt

class OpenAIGenerator(BaseGenerator):
    """OpenAI生成器"""
    
    def __init__(self, config: GenerationConfig, api_key: str):
        super().__init__(config)
        self.api_key = api_key
        self.client = None
        self._initialize_client()
    
    def _initialize_client(self):
        """初始化OpenAI客户端"""
        try:
            import openai
            openai.api_key = self.api_key
            self.client = openai
            logging.info("已初始化OpenAI客户端")
        except ImportError:
            raise RuntimeError("请安装openai库: pip install openai")
    
    def generate(self, 
                query: str, 
                context: List[VectorSearchResult],
                **kwargs) -> GenerationResult:
        """使用OpenAI生成回答"""
        try:
            # 构建上下文和提示
            context_text = self._build_context(context)
            prompt = self._build_prompt(query, context_text)
            
            # 调用OpenAI API
            response = self.client.ChatCompletion.create(
                model=self.config.model_name,
                messages=[
                    {"role": "system", "content": "你是一个有用的AI助手,能够基于提供的上下文信息准确回答问题。"},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=self.config.max_tokens,
                temperature=self.config.temperature,
                top_p=self.config.top_p,
                frequency_penalty=self.config.frequency_penalty,
                presence_penalty=self.config.presence_penalty,
                stop=self.config.stop_sequences if self.config.stop_sequences else None
            )
            
            # 提取生成内容
            generated_content = response.choices[0].message.content.strip()
            
            # 计算置信度(基于token概率等)
            confidence = self._calculate_confidence(response)
            
            return GenerationResult(
                content=generated_content,
                metadata={
                    "model": self.config.model_name,
                    "tokens_used": response.usage.total_tokens,
                    "finish_reason": response.choices[0].finish_reason
                },
                sources=context,
                confidence=confidence
            )
            
        except Exception as e:
            logging.error(f"OpenAI生成失败: {e}")
            return GenerationResult(
                content=f"抱歉,生成回答时出现错误: {str(e)}",
                metadata={"error": str(e)},
                sources=context,
                confidence=0.0
            )
    
    def _calculate_confidence(self, response) -> float:
        """计算置信度"""
        # 简单的置信度计算(实际应用中可以更复杂)
        try:
            finish_reason = response.choices[0].finish_reason
            if finish_reason == "stop":
                return 0.8
            elif finish_reason == "length":
                return 0.6
            else:
                return 0.4
        except:
            return 0.5

class LocalLLMGenerator(BaseGenerator):
    """本地LLM生成器"""
    
    def __init__(self, config: GenerationConfig, model_path: str):
        super().__init__(config)
        self.model_path = model_path
        self.model = None
        self.tokenizer = None
        self._load_model()
    
    def _load_model(self):
        """加载本地模型"""
        try:
            from transformers import AutoTokenizer, AutoModelForCausalLM
            import torch
            
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                torch_dtype=torch.float16,
                device_map="auto"
            )
            
            logging.info(f"已加载本地模型: {self.model_path}")
            
        except ImportError:
            raise RuntimeError("请安装transformers库: pip install transformers")
        except Exception as e:
            logging.error(f"加载本地模型失败: {e}")
    
    def generate(self, 
                query: str, 
                context: List[VectorSearchResult],
                **kwargs) -> GenerationResult:
        """使用本地模型生成回答"""
        if not self.model or not self.tokenizer:
            return GenerationResult(
                content="模型未正确加载",
                metadata={"error": "model_not_loaded"},
                sources=context,
                confidence=0.0
            )
        
        try:
            # 构建提示
            context_text = self._build_context(context)
            prompt = self._build_prompt(query, context_text)
            
            # 编码输入
            inputs = self.tokenizer.encode(prompt, return_tensors="pt")
            
            # 生成回答
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs,
                    max_new_tokens=self.config.max_tokens,
                    temperature=self.config.temperature,
                    top_p=self.config.top_p,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            # 解码输出
            generated_text = self.tokenizer.decode(
                outputs[0][inputs.shape[1]:], 
                skip_special_tokens=True
            )
            
            return GenerationResult(
                content=generated_text.strip(),
                metadata={
                    "model": self.model_path,
                    "input_tokens": inputs.shape[1],
                    "output_tokens": outputs.shape[1] - inputs.shape[1]
                },
                sources=context,
                confidence=0.7  # 默认置信度
            )
            
        except Exception as e:
            logging.error(f"本地模型生成失败: {e}")
            return GenerationResult(
                content=f"生成回答时出现错误: {str(e)}",
                metadata={"error": str(e)},
                sources=context,
                confidence=0.0
            )

class TemplateGenerator(BaseGenerator):
    """模板生成器"""
    
    def __init__(self, config: GenerationConfig, base_generator: BaseGenerator):
        super().__init__(config)
        self.base_generator = base_generator
        self.templates = self._load_templates()
    
    def _load_templates(self) -> Dict[str, str]:
        """加载回答模板"""
        return {
            "factual": """基于提供的信息,{answer}

参考来源:
{sources}""",
            
            "analytical": """根据分析,{answer}

分析依据:
{reasoning}

参考资料:
{sources}""",
            
            "comparative": """通过比较可以看出,{answer}

对比要点:
{comparison_points}

数据来源:
{sources}""",
            
            "step_by_step": """解决步骤如下:

{steps}

详细说明:
{explanation}

参考文档:
{sources}"""
        }
    
    def generate(self, 
                query: str, 
                context: List[VectorSearchResult],
                template_type: str = "factual",
                **kwargs) -> GenerationResult:
        """使用模板生成回答"""
        # 先用基础生成器生成原始回答
        base_result = self.base_generator.generate(query, context, **kwargs)
        
        # 应用模板
        if template_type in self.templates:
            template = self.templates[template_type]
            
            # 准备模板变量
            template_vars = {
                "answer": base_result.content,
                "sources": self._format_sources(context),
                "reasoning": base_result.reasoning,
                "comparison_points": self._extract_comparison_points(base_result.content),
                "steps": self._extract_steps(base_result.content),
                "explanation": base_result.content
            }
            
            # 格式化模板
            try:
                formatted_content = template.format(**template_vars)
                base_result.content = formatted_content
                base_result.metadata["template_type"] = template_type
            except Exception as e:
                logging.warning(f"模板格式化失败: {e}")
        
        return base_result
    
    def _format_sources(self, context: List[VectorSearchResult]) -> str:
        """格式化来源信息"""
        sources = []
        for i, result in enumerate(context):
            source = result.metadata.get("source", "未知来源")
            sources.append(f"{i+1}. {source}")
        return "\n".join(sources)
    
    def _extract_comparison_points(self, content: str) -> str:
        """提取对比要点"""
        # 简单的关键词提取(实际应用中可以更复杂)
        comparison_keywords = ["相比", "对比", "区别", "差异", "优势", "劣势"]
        points = []
        
        for line in content.split("\n"):
            if any(keyword in line for keyword in comparison_keywords):
                points.append(f"• {line.strip()}")
        
        return "\n".join(points) if points else "无明显对比要点"
    
    def _extract_steps(self, content: str) -> str:
        """提取步骤信息"""
        # 查找数字编号的步骤
        import re
        step_pattern = r'(\d+[.、].*?)(?=\d+[.、]|$)'
        steps = re.findall(step_pattern, content, re.DOTALL)
        
        if steps:
            return "\n".join(f"步骤 {step.strip()}" for step in steps)
        else:
            return content  # 如果没有明确步骤,返回原内容

本章总结

本章深入介绍了RAG系统的检索与生成技术:

核心要点

  1. 高级检索策略

    • 混合检索:结合语义和关键词检索
    • 查询扩展:提高检索召回率
    • 多阶段检索:逐步精化结果
  2. 重排序技术

    • 交叉编码器:精确的相关性评分
    • 特征重排序:多维度评估
    • LLM重排序:利用大模型理解能力
  3. 生成模型集成

    • 多种生成器支持(OpenAI、本地模型等)
    • 模板化生成:结构化输出
    • 置信度评估:质量控制

最佳实践

  1. 检索优化

    • 根据查询类型选择检索策略
    • 合理设置检索参数
    • 实施多阶段过滤
  2. 重排序策略

    • 结合多种重排序方法
    • 考虑计算成本和效果平衡
    • 针对特定领域调优
  3. 生成质量

    • 精心设计提示模板
    • 实施输出质量检查
    • 提供可追溯的来源信息

下一章我们将学习RAG系统的评估与优化,包括评估指标、A/B测试、性能优化等内容。