本章概述
部署与运维是RAG系统从开发到生产的关键环节,涉及系统架构设计、容器化部署、性能监控、故障处理等多个方面。本章将详细介绍如何构建可扩展、高可用的RAG系统部署方案。
学习目标
- 掌握RAG系统的架构设计原则
- 学习容器化部署和编排技术
- 了解负载均衡和高可用配置
- 熟悉监控告警和日志管理
- 掌握性能优化和故障处理方法
1. 系统架构设计
1.1 微服务架构
# src/deployment/service_manager.py - 服务管理器
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
import asyncio
import logging
from abc import ABC, abstractmethod
from datetime import datetime
import json
from pathlib import Path
@dataclass
class ServiceConfig:
"""服务配置"""
name: str
host: str
port: int
health_check_path: str = "/health"
timeout: int = 30
max_retries: int = 3
environment: str = "production"
resources: Dict[str, Any] = None
def __post_init__(self):
if self.resources is None:
self.resources = {
"cpu": "1000m",
"memory": "2Gi",
"storage": "10Gi"
}
@dataclass
class ServiceStatus:
"""服务状态"""
name: str
status: str # running, stopped, error, unknown
last_check: datetime
response_time: float
error_message: str = None
metadata: Dict[str, Any] = None
class BaseService(ABC):
"""服务基类"""
def __init__(self, config: ServiceConfig):
self.config = config
self.logger = logging.getLogger(f"service.{config.name}")
@abstractmethod
async def start(self) -> bool:
"""启动服务"""
pass
@abstractmethod
async def stop(self) -> bool:
"""停止服务"""
pass
@abstractmethod
async def health_check(self) -> ServiceStatus:
"""健康检查"""
pass
@abstractmethod
async def process_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""处理请求"""
pass
class DocumentLoaderService(BaseService):
"""文档加载服务"""
def __init__(self, config: ServiceConfig):
super().__init__(config)
self.is_running = False
self.document_loaders = {}
async def start(self) -> bool:
"""启动文档加载服务"""
try:
# 初始化文档加载器
from ..document.loader import DocumentLoaderFactory
self.document_loaders = {
'text': DocumentLoaderFactory.create_loader('text'),
'pdf': DocumentLoaderFactory.create_loader('pdf'),
'docx': DocumentLoaderFactory.create_loader('docx')
}
self.is_running = True
self.logger.info(f"文档加载服务已启动,端口: {self.config.port}")
return True
except Exception as e:
self.logger.error(f"文档加载服务启动失败: {e}")
return False
async def stop(self) -> bool:
"""停止文档加载服务"""
try:
self.is_running = False
self.document_loaders.clear()
self.logger.info("文档加载服务已停止")
return True
except Exception as e:
self.logger.error(f"文档加载服务停止失败: {e}")
return False
async def health_check(self) -> ServiceStatus:
"""健康检查"""
start_time = datetime.now()
try:
if not self.is_running:
return ServiceStatus(
name=self.config.name,
status="stopped",
last_check=start_time,
response_time=0.0
)
# 简单的健康检查
test_result = len(self.document_loaders) > 0
response_time = (datetime.now() - start_time).total_seconds()
return ServiceStatus(
name=self.config.name,
status="running" if test_result else "error",
last_check=start_time,
response_time=response_time,
metadata={"loaders_count": len(self.document_loaders)}
)
except Exception as e:
response_time = (datetime.now() - start_time).total_seconds()
return ServiceStatus(
name=self.config.name,
status="error",
last_check=start_time,
response_time=response_time,
error_message=str(e)
)
async def process_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""处理文档加载请求"""
try:
file_path = request.get('file_path')
file_type = request.get('file_type', 'text')
if not file_path:
return {"error": "缺少file_path参数"}
if file_type not in self.document_loaders:
return {"error": f"不支持的文件类型: {file_type}"}
loader = self.document_loaders[file_type]
documents = await asyncio.to_thread(loader.load, file_path)
return {
"success": True,
"documents": [{
"content": doc.content,
"metadata": doc.metadata
} for doc in documents]
}
except Exception as e:
self.logger.error(f"文档加载失败: {e}")
return {"error": str(e)}
class VectorSearchService(BaseService):
"""向量搜索服务"""
def __init__(self, config: ServiceConfig, vector_store_config: Dict[str, Any]):
super().__init__(config)
self.vector_store_config = vector_store_config
self.vector_store = None
self.is_running = False
async def start(self) -> bool:
"""启动向量搜索服务"""
try:
from ..vectorstore.vector_manager import VectorStoreFactory
self.vector_store = VectorStoreFactory.create_vector_store(
self.vector_store_config['type'],
self.vector_store_config
)
self.is_running = True
self.logger.info(f"向量搜索服务已启动,端口: {self.config.port}")
return True
except Exception as e:
self.logger.error(f"向量搜索服务启动失败: {e}")
return False
async def stop(self) -> bool:
"""停止向量搜索服务"""
try:
self.is_running = False
if self.vector_store:
# 清理资源
self.vector_store = None
self.logger.info("向量搜索服务已停止")
return True
except Exception as e:
self.logger.error(f"向量搜索服务停止失败: {e}")
return False
async def health_check(self) -> ServiceStatus:
"""健康检查"""
start_time = datetime.now()
try:
if not self.is_running or not self.vector_store:
return ServiceStatus(
name=self.config.name,
status="stopped",
last_check=start_time,
response_time=0.0
)
# 执行简单的搜索测试
test_query = "健康检查测试查询"
results = await asyncio.to_thread(
self.vector_store.search,
test_query,
k=1
)
response_time = (datetime.now() - start_time).total_seconds()
return ServiceStatus(
name=self.config.name,
status="running",
last_check=start_time,
response_time=response_time,
metadata={"test_results_count": len(results)}
)
except Exception as e:
response_time = (datetime.now() - start_time).total_seconds()
return ServiceStatus(
name=self.config.name,
status="error",
last_check=start_time,
response_time=response_time,
error_message=str(e)
)
async def process_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""处理向量搜索请求"""
try:
query = request.get('query')
k = request.get('k', 5)
filters = request.get('filters')
if not query:
return {"error": "缺少query参数"}
results = await asyncio.to_thread(
self.vector_store.search,
query,
k=k,
filters=filters
)
return {
"success": True,
"results": [{
"chunk_id": result.chunk_id,
"content": result.content,
"score": result.score,
"metadata": result.metadata
} for result in results]
}
except Exception as e:
self.logger.error(f"向量搜索失败: {e}")
return {"error": str(e)}
class GenerationService(BaseService):
"""生成服务"""
def __init__(self, config: ServiceConfig, generation_config: Dict[str, Any]):
super().__init__(config)
self.generation_config = generation_config
self.generator = None
self.is_running = False
async def start(self) -> bool:
"""启动生成服务"""
try:
from ..generation.generator import GeneratorFactory
self.generator = GeneratorFactory.create_generator(
self.generation_config['type'],
self.generation_config
)
self.is_running = True
self.logger.info(f"生成服务已启动,端口: {self.config.port}")
return True
except Exception as e:
self.logger.error(f"生成服务启动失败: {e}")
return False
async def stop(self) -> bool:
"""停止生成服务"""
try:
self.is_running = False
self.generator = None
self.logger.info("生成服务已停止")
return True
except Exception as e:
self.logger.error(f"生成服务停止失败: {e}")
return False
async def health_check(self) -> ServiceStatus:
"""健康检查"""
start_time = datetime.now()
try:
if not self.is_running or not self.generator:
return ServiceStatus(
name=self.config.name,
status="stopped",
last_check=start_time,
response_time=0.0
)
# 执行简单的生成测试
test_prompt = "这是一个健康检查测试。"
result = await asyncio.to_thread(
self.generator.generate,
test_prompt,
max_length=10
)
response_time = (datetime.now() - start_time).total_seconds()
return ServiceStatus(
name=self.config.name,
status="running" if result.text else "error",
last_check=start_time,
response_time=response_time,
metadata={"test_output_length": len(result.text)}
)
except Exception as e:
response_time = (datetime.now() - start_time).total_seconds()
return ServiceStatus(
name=self.config.name,
status="error",
last_check=start_time,
response_time=response_time,
error_message=str(e)
)
async def process_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""处理生成请求"""
try:
prompt = request.get('prompt')
max_length = request.get('max_length', 512)
temperature = request.get('temperature', 0.7)
if not prompt:
return {"error": "缺少prompt参数"}
result = await asyncio.to_thread(
self.generator.generate,
prompt,
max_length=max_length,
temperature=temperature
)
return {
"success": True,
"text": result.text,
"metadata": result.metadata
}
except Exception as e:
self.logger.error(f"文本生成失败: {e}")
return {"error": str(e)}
class ServiceManager:
"""服务管理器"""
def __init__(self):
self.services: Dict[str, BaseService] = {}
self.logger = logging.getLogger("service_manager")
def register_service(self, service: BaseService):
"""注册服务"""
self.services[service.config.name] = service
self.logger.info(f"已注册服务: {service.config.name}")
async def start_all_services(self) -> Dict[str, bool]:
"""启动所有服务"""
results = {}
for name, service in self.services.items():
try:
result = await service.start()
results[name] = result
if result:
self.logger.info(f"服务 {name} 启动成功")
else:
self.logger.error(f"服务 {name} 启动失败")
except Exception as e:
self.logger.error(f"服务 {name} 启动异常: {e}")
results[name] = False
return results
async def stop_all_services(self) -> Dict[str, bool]:
"""停止所有服务"""
results = {}
for name, service in self.services.items():
try:
result = await service.stop()
results[name] = result
if result:
self.logger.info(f"服务 {name} 停止成功")
else:
self.logger.error(f"服务 {name} 停止失败")
except Exception as e:
self.logger.error(f"服务 {name} 停止异常: {e}")
results[name] = False
return results
async def health_check_all(self) -> Dict[str, ServiceStatus]:
"""检查所有服务健康状态"""
results = {}
for name, service in self.services.items():
try:
status = await service.health_check()
results[name] = status
except Exception as e:
self.logger.error(f"服务 {name} 健康检查异常: {e}")
results[name] = ServiceStatus(
name=name,
status="error",
last_check=datetime.now(),
response_time=0.0,
error_message=str(e)
)
return results
async def get_service_status(self, service_name: str) -> Optional[ServiceStatus]:
"""获取指定服务状态"""
if service_name not in self.services:
return None
try:
return await self.services[service_name].health_check()
except Exception as e:
self.logger.error(f"获取服务 {service_name} 状态失败: {e}")
return ServiceStatus(
name=service_name,
status="error",
last_check=datetime.now(),
response_time=0.0,
error_message=str(e)
)
def get_service(self, service_name: str) -> Optional[BaseService]:
"""获取服务实例"""
return self.services.get(service_name)
1.2 API网关
# src/deployment/api_gateway.py - API网关
from typing import Dict, List, Any, Optional, Callable
from dataclasses import dataclass
import asyncio
import logging
import time
import json
from datetime import datetime, timedelta
from collections import defaultdict, deque
import hashlib
import jwt
from .service_manager import ServiceManager, ServiceStatus
@dataclass
class RateLimitConfig:
"""限流配置"""
requests_per_minute: int = 60
requests_per_hour: int = 1000
burst_size: int = 10
@dataclass
class AuthConfig:
"""认证配置"""
enabled: bool = True
jwt_secret: str = "your-secret-key"
token_expiry_hours: int = 24
allowed_origins: List[str] = None
def __post_init__(self):
if self.allowed_origins is None:
self.allowed_origins = ["*"]
@dataclass
class LoadBalancerConfig:
"""负载均衡配置"""
strategy: str = "round_robin" # round_robin, least_connections, weighted
health_check_interval: int = 30
failure_threshold: int = 3
recovery_threshold: int = 2
class RateLimiter:
"""限流器"""
def __init__(self, config: RateLimitConfig):
self.config = config
self.requests = defaultdict(lambda: {
'minute': deque(),
'hour': deque(),
'burst': deque()
})
def is_allowed(self, client_id: str) -> bool:
"""检查是否允许请求"""
now = time.time()
client_requests = self.requests[client_id]
# 清理过期记录
self._cleanup_expired_requests(client_requests, now)
# 检查突发限制
if len(client_requests['burst']) >= self.config.burst_size:
return False
# 检查分钟限制
if len(client_requests['minute']) >= self.config.requests_per_minute:
return False
# 检查小时限制
if len(client_requests['hour']) >= self.config.requests_per_hour:
return False
# 记录请求
client_requests['burst'].append(now)
client_requests['minute'].append(now)
client_requests['hour'].append(now)
return True
def _cleanup_expired_requests(self, client_requests: Dict[str, deque], now: float):
"""清理过期请求记录"""
# 清理突发记录(1秒窗口)
while (client_requests['burst'] and
now - client_requests['burst'][0] > 1):
client_requests['burst'].popleft()
# 清理分钟记录
while (client_requests['minute'] and
now - client_requests['minute'][0] > 60):
client_requests['minute'].popleft()
# 清理小时记录
while (client_requests['hour'] and
now - client_requests['hour'][0] > 3600):
client_requests['hour'].popleft()
class LoadBalancer:
"""负载均衡器"""
def __init__(self, config: LoadBalancerConfig):
self.config = config
self.service_instances = defaultdict(list)
self.service_health = defaultdict(dict)
self.round_robin_counters = defaultdict(int)
self.connection_counts = defaultdict(int)
self.logger = logging.getLogger("load_balancer")
def register_service_instance(self, service_name: str, instance_id: str, weight: float = 1.0):
"""注册服务实例"""
self.service_instances[service_name].append({
'id': instance_id,
'weight': weight,
'connections': 0
})
self.service_health[service_name][instance_id] = {
'healthy': True,
'failure_count': 0,
'last_check': datetime.now()
}
self.logger.info(f"已注册服务实例: {service_name}/{instance_id}")
def get_service_instance(self, service_name: str) -> Optional[str]:
"""获取服务实例"""
instances = self.service_instances.get(service_name, [])
if not instances:
return None
# 过滤健康的实例
healthy_instances = [
inst for inst in instances
if self.service_health[service_name][inst['id']]['healthy']
]
if not healthy_instances:
self.logger.warning(f"服务 {service_name} 没有健康的实例")
return None
# 根据策略选择实例
if self.config.strategy == "round_robin":
return self._round_robin_select(service_name, healthy_instances)
elif self.config.strategy == "least_connections":
return self._least_connections_select(healthy_instances)
elif self.config.strategy == "weighted":
return self._weighted_select(healthy_instances)
else:
return healthy_instances[0]['id']
def _round_robin_select(self, service_name: str, instances: List[Dict[str, Any]]) -> str:
"""轮询选择"""
counter = self.round_robin_counters[service_name]
selected = instances[counter % len(instances)]
self.round_robin_counters[service_name] = (counter + 1) % len(instances)
return selected['id']
def _least_connections_select(self, instances: List[Dict[str, Any]]) -> str:
"""最少连接选择"""
return min(instances, key=lambda x: x['connections'])['id']
def _weighted_select(self, instances: List[Dict[str, Any]]) -> str:
"""加权选择"""
import random
total_weight = sum(inst['weight'] for inst in instances)
if total_weight == 0:
return instances[0]['id']
rand_val = random.uniform(0, total_weight)
cumulative_weight = 0
for instance in instances:
cumulative_weight += instance['weight']
if rand_val <= cumulative_weight:
return instance['id']
return instances[-1]['id']
def update_instance_health(self, service_name: str, instance_id: str, healthy: bool):
"""更新实例健康状态"""
if service_name not in self.service_health:
return
if instance_id not in self.service_health[service_name]:
return
health_info = self.service_health[service_name][instance_id]
if healthy:
health_info['failure_count'] = 0
if not health_info['healthy']:
self.logger.info(f"服务实例恢复健康: {service_name}/{instance_id}")
health_info['healthy'] = True
else:
health_info['failure_count'] += 1
if health_info['failure_count'] >= self.config.failure_threshold:
if health_info['healthy']:
self.logger.warning(f"服务实例标记为不健康: {service_name}/{instance_id}")
health_info['healthy'] = False
health_info['last_check'] = datetime.now()
class APIGateway:
"""API网关"""
def __init__(self,
service_manager: ServiceManager,
rate_limit_config: RateLimitConfig = None,
auth_config: AuthConfig = None,
load_balancer_config: LoadBalancerConfig = None):
self.service_manager = service_manager
self.rate_limiter = RateLimiter(rate_limit_config or RateLimitConfig())
self.auth_config = auth_config or AuthConfig()
self.load_balancer = LoadBalancer(load_balancer_config or LoadBalancerConfig())
self.logger = logging.getLogger("api_gateway")
# 请求统计
self.request_stats = {
'total_requests': 0,
'successful_requests': 0,
'failed_requests': 0,
'rate_limited_requests': 0,
'auth_failed_requests': 0
}
async def process_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""处理API请求"""
start_time = time.time()
self.request_stats['total_requests'] += 1
try:
# 提取客户端信息
client_id = self._get_client_id(request)
# 限流检查
if not self.rate_limiter.is_allowed(client_id):
self.request_stats['rate_limited_requests'] += 1
return {
"error": "Rate limit exceeded",
"code": 429
}
# 认证检查
if self.auth_config.enabled:
auth_result = self._authenticate_request(request)
if not auth_result['valid']:
self.request_stats['auth_failed_requests'] += 1
return {
"error": "Authentication failed",
"code": 401,
"details": auth_result.get('error')
}
# 路由请求
service_name = request.get('service')
if not service_name:
return {
"error": "Missing service parameter",
"code": 400
}
# 获取服务实例
service = self.service_manager.get_service(service_name)
if not service:
return {
"error": f"Service not found: {service_name}",
"code": 404
}
# 处理请求
result = await service.process_request(request)
# 记录响应时间
response_time = time.time() - start_time
result['response_time'] = response_time
if 'error' not in result:
self.request_stats['successful_requests'] += 1
else:
self.request_stats['failed_requests'] += 1
return result
except Exception as e:
self.logger.error(f"请求处理失败: {e}")
self.request_stats['failed_requests'] += 1
return {
"error": "Internal server error",
"code": 500,
"details": str(e)
}
def _get_client_id(self, request: Dict[str, Any]) -> str:
"""获取客户端ID"""
# 优先使用API密钥
api_key = request.get('api_key')
if api_key:
return hashlib.md5(api_key.encode()).hexdigest()
# 使用IP地址
client_ip = request.get('client_ip', 'unknown')
return hashlib.md5(client_ip.encode()).hexdigest()
def _authenticate_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""认证请求"""
try:
# 检查API密钥
api_key = request.get('api_key')
if api_key:
# 简单的API密钥验证(实际应用中应该查询数据库)
if self._validate_api_key(api_key):
return {'valid': True}
else:
return {'valid': False, 'error': 'Invalid API key'}
# 检查JWT令牌
token = request.get('token')
if token:
try:
payload = jwt.decode(
token,
self.auth_config.jwt_secret,
algorithms=['HS256']
)
# 检查令牌是否过期
if payload.get('exp', 0) < time.time():
return {'valid': False, 'error': 'Token expired'}
return {'valid': True, 'user_id': payload.get('user_id')}
except jwt.InvalidTokenError as e:
return {'valid': False, 'error': f'Invalid token: {e}'}
return {'valid': False, 'error': 'No authentication provided'}
except Exception as e:
return {'valid': False, 'error': f'Authentication error: {e}'}
def _validate_api_key(self, api_key: str) -> bool:
"""验证API密钥"""
# 简化的验证逻辑(实际应用中应该查询数据库)
valid_keys = {
'test-key-123',
'prod-key-456',
'demo-key-789'
}
return api_key in valid_keys
def generate_token(self, user_id: str, additional_claims: Dict[str, Any] = None) -> str:
"""生成JWT令牌"""
payload = {
'user_id': user_id,
'iat': time.time(),
'exp': time.time() + (self.auth_config.token_expiry_hours * 3600)
}
if additional_claims:
payload.update(additional_claims)
return jwt.encode(payload, self.auth_config.jwt_secret, algorithm='HS256')
async def health_check(self) -> Dict[str, Any]:
"""网关健康检查"""
service_statuses = await self.service_manager.health_check_all()
overall_healthy = all(
status.status == "running"
for status in service_statuses.values()
)
return {
'gateway_status': 'healthy' if overall_healthy else 'degraded',
'services': {
name: {
'status': status.status,
'response_time': status.response_time,
'last_check': status.last_check.isoformat()
}
for name, status in service_statuses.items()
},
'request_stats': self.request_stats,
'timestamp': datetime.now().isoformat()
}
def get_metrics(self) -> Dict[str, Any]:
"""获取网关指标"""
total_requests = self.request_stats['total_requests']
return {
'total_requests': total_requests,
'success_rate': (
self.request_stats['successful_requests'] / total_requests
if total_requests > 0 else 0
),
'error_rate': (
self.request_stats['failed_requests'] / total_requests
if total_requests > 0 else 0
),
'rate_limit_rate': (
self.request_stats['rate_limited_requests'] / total_requests
if total_requests > 0 else 0
),
'auth_failure_rate': (
self.request_stats['auth_failed_requests'] / total_requests
if total_requests > 0 else 0
)
}
2. 容器化部署
2.1 Docker配置
# Dockerfile - RAG系统容器化配置
FROM python:3.9-slim
# 设置工作目录
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \
gcc \
g++ \
curl \
wget \
git \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY requirements.txt .
# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY src/ ./src/
COPY config/ ./config/
COPY scripts/ ./scripts/
# 创建必要的目录
RUN mkdir -p /app/data /app/logs /app/models
# 设置环境变量
ENV PYTHONPATH=/app
ENV RAG_CONFIG_PATH=/app/config
ENV RAG_DATA_PATH=/app/data
ENV RAG_LOG_PATH=/app/logs
# 暴露端口
EXPOSE 8000
# 健康检查
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# 启动命令
CMD ["python", "-m", "src.api.main"]
# docker-compose.yml - Docker Compose配置
version: '3.8'
services:
# API网关
api-gateway:
build: .
ports:
- "8000:8000"
environment:
- SERVICE_NAME=api-gateway
- LOG_LEVEL=INFO
- REDIS_URL=redis://redis:6379
- POSTGRES_URL=postgresql://postgres:password@postgres:5432/rag_db
depends_on:
- redis
- postgres
- document-loader
- vector-search
- generation
volumes:
- ./data:/app/data
- ./logs:/app/logs
- ./config:/app/config
networks:
- rag-network
restart: unless-stopped
# 文档加载服务
document-loader:
build: .
command: ["python", "-m", "src.services.document_loader"]
environment:
- SERVICE_NAME=document-loader
- SERVICE_PORT=8001
- LOG_LEVEL=INFO
volumes:
- ./data:/app/data
- ./logs:/app/logs
networks:
- rag-network
restart: unless-stopped
# 向量搜索服务
vector-search:
build: .
command: ["python", "-m", "src.services.vector_search"]
environment:
- SERVICE_NAME=vector-search
- SERVICE_PORT=8002
- LOG_LEVEL=INFO
- CHROMA_HOST=chroma
- CHROMA_PORT=8003
depends_on:
- chroma
volumes:
- ./data:/app/data
- ./logs:/app/logs
networks:
- rag-network
restart: unless-stopped
# 生成服务
generation:
build: .
command: ["python", "-m", "src.services.generation"]
environment:
- SERVICE_NAME=generation
- SERVICE_PORT=8004
- LOG_LEVEL=INFO
- OPENAI_API_KEY=${OPENAI_API_KEY}
volumes:
- ./data:/app/data
- ./logs:/app/logs
- ./models:/app/models
networks:
- rag-network
restart: unless-stopped
deploy:
resources:
limits:
memory: 4G
reservations:
memory: 2G
# Chroma向量数据库
chroma:
image: chromadb/chroma:latest
ports:
- "8003:8000"
environment:
- CHROMA_SERVER_HOST=0.0.0.0
- CHROMA_SERVER_HTTP_PORT=8000
volumes:
- chroma-data:/chroma/chroma
networks:
- rag-network
restart: unless-stopped
# Redis缓存
redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
- redis-data:/data
networks:
- rag-network
restart: unless-stopped
command: redis-server --appendonly yes
# PostgreSQL数据库
postgres:
image: postgres:15-alpine
ports:
- "5432:5432"
environment:
- POSTGRES_DB=rag_db
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=password
volumes:
- postgres-data:/var/lib/postgresql/data
- ./scripts/init.sql:/docker-entrypoint-initdb.d/init.sql
networks:
- rag-network
restart: unless-stopped
# Nginx负载均衡
nginx:
image: nginx:alpine
ports:
- "80:80"
- "443:443"
volumes:
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
- ./nginx/ssl:/etc/nginx/ssl
depends_on:
- api-gateway
networks:
- rag-network
restart: unless-stopped
# Prometheus监控
prometheus:
image: prom/prometheus:latest
ports:
- "9090:9090"
volumes:
- ./monitoring/prometheus.yml:/etc/prometheus/prometheus.yml
- prometheus-data:/prometheus
networks:
- rag-network
restart: unless-stopped
# Grafana可视化
grafana:
image: grafana/grafana:latest
ports:
- "3000:3000"
environment:
- GF_SECURITY_ADMIN_PASSWORD=admin
volumes:
- grafana-data:/var/lib/grafana
- ./monitoring/grafana/dashboards:/etc/grafana/provisioning/dashboards
- ./monitoring/grafana/datasources:/etc/grafana/provisioning/datasources
networks:
- rag-network
restart: unless-stopped
volumes:
chroma-data:
redis-data:
postgres-data:
prometheus-data:
grafana-data:
networks:
rag-network:
driver: bridge
2.2 Kubernetes部署
# k8s/namespace.yaml - 命名空间
apiVersion: v1
kind: Namespace
metadata:
name: rag-system
labels:
name: rag-system
---
# k8s/configmap.yaml - 配置映射
apiVersion: v1
kind: ConfigMap
metadata:
name: rag-config
namespace: rag-system
data:
app.yaml: |
logging:
level: INFO
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
services:
document_loader:
port: 8001
workers: 2
vector_search:
port: 8002
workers: 2
chroma_url: "http://chroma-service:8000"
generation:
port: 8004
workers: 1
model_path: "/app/models"
api_gateway:
port: 8000
rate_limit:
requests_per_minute: 60
requests_per_hour: 1000
auth:
enabled: true
jwt_secret: "your-secret-key"
---
# k8s/secret.yaml - 密钥
apiVersion: v1
kind: Secret
metadata:
name: rag-secrets
namespace: rag-system
type: Opaque
data:
openai-api-key: <base64-encoded-api-key>
postgres-password: <base64-encoded-password>
jwt-secret: <base64-encoded-secret>
---
# k8s/api-gateway-deployment.yaml - API网关部署
apiVersion: apps/v1
kind: Deployment
metadata:
name: api-gateway
namespace: rag-system
labels:
app: api-gateway
spec:
replicas: 3
selector:
matchLabels:
app: api-gateway
template:
metadata:
labels:
app: api-gateway
spec:
containers:
- name: api-gateway
image: rag-system:latest
ports:
- containerPort: 8000
env:
- name: SERVICE_NAME
value: "api-gateway"
- name: LOG_LEVEL
value: "INFO"
- name: REDIS_URL
value: "redis://redis-service:6379"
- name: POSTGRES_URL
value: "postgresql://postgres:$(POSTGRES_PASSWORD)@postgres-service:5432/rag_db"
- name: POSTGRES_PASSWORD
valueFrom:
secretKeyRef:
name: rag-secrets
key: postgres-password
volumeMounts:
- name: config-volume
mountPath: /app/config
- name: data-volume
mountPath: /app/data
resources:
requests:
memory: "512Mi"
cpu: "250m"
limits:
memory: "1Gi"
cpu: "500m"
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 10
readinessProbe:
httpGet:
path: /ready
port: 8000
initialDelaySeconds: 5
periodSeconds: 5
volumes:
- name: config-volume
configMap:
name: rag-config
- name: data-volume
persistentVolumeClaim:
claimName: rag-data-pvc
---
# k8s/api-gateway-service.yaml - API网关服务
apiVersion: v1
kind: Service
metadata:
name: api-gateway-service
namespace: rag-system
labels:
app: api-gateway
spec:
selector:
app: api-gateway
ports:
- protocol: TCP
port: 80
targetPort: 8000
type: LoadBalancer
---
# k8s/api-gateway-hpa.yaml - 水平扩缩容
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: api-gateway-hpa
namespace: rag-system
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: api-gateway
minReplicas: 2
maxReplicas: 10
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70
- type: Resource
resource:
name: memory
target:
type: Utilization
averageUtilization: 80
---
# k8s/persistent-volume.yaml - 持久化存储
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: rag-data-pvc
namespace: rag-system
spec:
accessModes:
- ReadWriteMany
resources:
requests:
storage: 100Gi
storageClassName: fast-ssd
3. 监控与告警
3.1 监控系统
# src/monitoring/metrics_collector.py - 指标收集器
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, field
import time
import logging
import threading
from collections import defaultdict, deque
from datetime import datetime, timedelta
import json
import psutil
import asyncio
@dataclass
class MetricPoint:
"""指标数据点"""
name: str
value: float
timestamp: float
labels: Dict[str, str] = field(default_factory=dict)
class MetricsCollector:
"""指标收集器"""
def __init__(self, collection_interval: int = 10):
self.collection_interval = collection_interval
self.metrics = defaultdict(lambda: deque(maxlen=1000)) # 保留最近1000个数据点
self.custom_metrics = defaultdict(float)
self.is_collecting = False
self.collection_thread = None
self.logger = logging.getLogger("metrics_collector")
def start_collection(self):
"""开始收集指标"""
if self.is_collecting:
return
self.is_collecting = True
self.collection_thread = threading.Thread(target=self._collection_loop)
self.collection_thread.daemon = True
self.collection_thread.start()
self.logger.info("指标收集已启动")
def stop_collection(self):
"""停止收集指标"""
self.is_collecting = False
if self.collection_thread:
self.collection_thread.join(timeout=5)
self.logger.info("指标收集已停止")
def _collection_loop(self):
"""收集循环"""
while self.is_collecting:
try:
self._collect_system_metrics()
self._collect_application_metrics()
time.sleep(self.collection_interval)
except Exception as e:
self.logger.error(f"指标收集失败: {e}")
time.sleep(self.collection_interval)
def _collect_system_metrics(self):
"""收集系统指标"""
timestamp = time.time()
# CPU使用率
cpu_percent = psutil.cpu_percent(interval=1)
self.add_metric("system_cpu_usage_percent", cpu_percent, timestamp)
# 内存使用情况
memory = psutil.virtual_memory()
self.add_metric("system_memory_usage_percent", memory.percent, timestamp)
self.add_metric("system_memory_used_bytes", memory.used, timestamp)
self.add_metric("system_memory_available_bytes", memory.available, timestamp)
# 磁盘使用情况
disk = psutil.disk_usage('/')
self.add_metric("system_disk_usage_percent", (disk.used / disk.total) * 100, timestamp)
self.add_metric("system_disk_used_bytes", disk.used, timestamp)
self.add_metric("system_disk_free_bytes", disk.free, timestamp)
# 网络I/O
network = psutil.net_io_counters()
self.add_metric("system_network_bytes_sent", network.bytes_sent, timestamp)
self.add_metric("system_network_bytes_recv", network.bytes_recv, timestamp)
# 进程信息
process = psutil.Process()
self.add_metric("process_cpu_percent", process.cpu_percent(), timestamp)
self.add_metric("process_memory_rss_bytes", process.memory_info().rss, timestamp)
self.add_metric("process_memory_vms_bytes", process.memory_info().vms, timestamp)
self.add_metric("process_num_threads", process.num_threads(), timestamp)
def _collect_application_metrics(self):
"""收集应用指标"""
timestamp = time.time()
# 添加自定义指标
for metric_name, value in self.custom_metrics.items():
self.add_metric(f"app_{metric_name}", value, timestamp)
def add_metric(self, name: str, value: float, timestamp: float = None, labels: Dict[str, str] = None):
"""添加指标"""
if timestamp is None:
timestamp = time.time()
metric_point = MetricPoint(
name=name,
value=value,
timestamp=timestamp,
labels=labels or {}
)
self.metrics[name].append(metric_point)
def increment_counter(self, name: str, value: float = 1.0, labels: Dict[str, str] = None):
"""递增计数器"""
self.custom_metrics[name] += value
self.add_metric(name, self.custom_metrics[name], labels=labels)
def set_gauge(self, name: str, value: float, labels: Dict[str, str] = None):
"""设置仪表值"""
self.custom_metrics[name] = value
self.add_metric(name, value, labels=labels)
def record_histogram(self, name: str, value: float, labels: Dict[str, str] = None):
"""记录直方图值"""
self.add_metric(f"{name}_value", value, labels=labels)
# 简单的分位数计算
recent_values = [point.value for point in list(self.metrics[f"{name}_value"])[-100:]]
if recent_values:
import numpy as np
self.add_metric(f"{name}_p50", np.percentile(recent_values, 50), labels=labels)
self.add_metric(f"{name}_p95", np.percentile(recent_values, 95), labels=labels)
self.add_metric(f"{name}_p99", np.percentile(recent_values, 99), labels=labels)
def get_metrics(self, name: str = None, start_time: float = None, end_time: float = None) -> List[MetricPoint]:
"""获取指标数据"""
if name:
metrics_data = list(self.metrics.get(name, []))
else:
metrics_data = []
for metric_name, points in self.metrics.items():
metrics_data.extend(points)
# 时间过滤
if start_time or end_time:
filtered_data = []
for point in metrics_data:
if start_time and point.timestamp < start_time:
continue
if end_time and point.timestamp > end_time:
continue
filtered_data.append(point)
return filtered_data
return metrics_data
def get_metric_names(self) -> List[str]:
"""获取所有指标名称"""
return list(self.metrics.keys())
def export_prometheus_format(self) -> str:
"""导出Prometheus格式"""
lines = []
for metric_name, points in self.metrics.items():
if not points:
continue
latest_point = points[-1]
# 构建标签字符串
labels_str = ""
if latest_point.labels:
label_pairs = [f'{k}="{v}"' for k, v in latest_point.labels.items()]
labels_str = "{" + ",".join(label_pairs) + "}"
# 添加指标行
lines.append(f"{metric_name}{labels_str} {latest_point.value} {int(latest_point.timestamp * 1000)}")
return "\n".join(lines)
def export_json_format(self) -> str:
"""导出JSON格式"""
data = {}
for metric_name, points in self.metrics.items():
data[metric_name] = [
{
"value": point.value,
"timestamp": point.timestamp,
"labels": point.labels
}
for point in points
]
return json.dumps(data, indent=2)
class AlertManager:
"""告警管理器"""
def __init__(self, metrics_collector: MetricsCollector):
self.metrics_collector = metrics_collector
self.alert_rules = []
self.active_alerts = {}
self.alert_history = deque(maxlen=1000)
self.notification_handlers = []
self.logger = logging.getLogger("alert_manager")
def add_alert_rule(self,
name: str,
metric_name: str,
condition: str, # ">", "<", ">=", "<=", "=="
threshold: float,
duration: int = 60, # 持续时间(秒)
severity: str = "warning"):
"""添加告警规则"""
rule = {
'name': name,
'metric_name': metric_name,
'condition': condition,
'threshold': threshold,
'duration': duration,
'severity': severity,
'triggered_at': None
}
self.alert_rules.append(rule)
self.logger.info(f"已添加告警规则: {name}")
def add_notification_handler(self, handler: callable):
"""添加通知处理器"""
self.notification_handlers.append(handler)
def check_alerts(self):
"""检查告警"""
current_time = time.time()
for rule in self.alert_rules:
try:
self._check_single_rule(rule, current_time)
except Exception as e:
self.logger.error(f"检查告警规则 {rule['name']} 失败: {e}")
def _check_single_rule(self, rule: Dict[str, Any], current_time: float):
"""检查单个告警规则"""
metric_name = rule['metric_name']
metrics = self.metrics_collector.get_metrics(
metric_name,
start_time=current_time - rule['duration']
)
if not metrics:
return
# 检查是否满足条件
latest_value = metrics[-1].value
condition_met = self._evaluate_condition(
latest_value,
rule['condition'],
rule['threshold']
)
alert_key = rule['name']
if condition_met:
if alert_key not in self.active_alerts:
# 新告警
if rule['triggered_at'] is None:
rule['triggered_at'] = current_time
elif current_time - rule['triggered_at'] >= rule['duration']:
# 持续时间满足,触发告警
alert = {
'name': rule['name'],
'metric_name': metric_name,
'current_value': latest_value,
'threshold': rule['threshold'],
'severity': rule['severity'],
'triggered_at': current_time,
'status': 'firing'
}
self.active_alerts[alert_key] = alert
self.alert_history.append(alert.copy())
self._send_notification(alert)
self.logger.warning(f"告警触发: {rule['name']} - {latest_value} {rule['condition']} {rule['threshold']}")
else:
# 条件不满足,重置触发时间
rule['triggered_at'] = None
# 如果之前有活跃告警,则解除
if alert_key in self.active_alerts:
resolved_alert = self.active_alerts[alert_key].copy()
resolved_alert['status'] = 'resolved'
resolved_alert['resolved_at'] = current_time
del self.active_alerts[alert_key]
self.alert_history.append(resolved_alert)
self._send_notification(resolved_alert)
self.logger.info(f"告警解除: {rule['name']}")
def _evaluate_condition(self, value: float, condition: str, threshold: float) -> bool:
"""评估条件"""
if condition == ">":
return value > threshold
elif condition == "<":
return value < threshold
elif condition == ">=":
return value >= threshold
elif condition == "<=":
return value <= threshold
elif condition == "==":
return abs(value - threshold) < 1e-6
else:
return False
def _send_notification(self, alert: Dict[str, Any]):
"""发送通知"""
for handler in self.notification_handlers:
try:
handler(alert)
except Exception as e:
self.logger.error(f"发送告警通知失败: {e}")
def get_active_alerts(self) -> List[Dict[str, Any]]:
"""获取活跃告警"""
return list(self.active_alerts.values())
def get_alert_history(self, limit: int = 100) -> List[Dict[str, Any]]:
"""获取告警历史"""
return list(self.alert_history)[-limit:]
def email_notification_handler(alert: Dict[str, Any]):
"""邮件通知处理器"""
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
# 邮件配置(实际使用时应从配置文件读取)
smtp_server = "smtp.gmail.com"
smtp_port = 587
sender_email = "alerts@yourcompany.com"
sender_password = "your-password"
recipient_emails = ["admin@yourcompany.com"]
try:
# 创建邮件内容
msg = MIMEMultipart()
msg['From'] = sender_email
msg['To'] = ", ".join(recipient_emails)
if alert['status'] == 'firing':
msg['Subject'] = f"[ALERT] {alert['name']} - {alert['severity'].upper()}"
body = f"""
告警名称: {alert['name']}
指标名称: {alert['metric_name']}
当前值: {alert['current_value']}
阈值: {alert['threshold']}
严重程度: {alert['severity']}
触发时间: {datetime.fromtimestamp(alert['triggered_at']).strftime('%Y-%m-%d %H:%M:%S')}
请及时处理此告警。
"""
else:
msg['Subject'] = f"[RESOLVED] {alert['name']}"
body = f"""
告警名称: {alert['name']}
状态: 已解除
解除时间: {datetime.fromtimestamp(alert['resolved_at']).strftime('%Y-%m-%d %H:%M:%S')}
告警已自动解除。
"""
msg.attach(MIMEText(body, 'plain'))
# 发送邮件
server = smtplib.SMTP(smtp_server, smtp_port)
server.starttls()
server.login(sender_email, sender_password)
text = msg.as_string()
server.sendmail(sender_email, recipient_emails, text)
server.quit()
print(f"告警邮件已发送: {alert['name']}")
except Exception as e:
print(f"发送告警邮件失败: {e}")
def slack_notification_handler(alert: Dict[str, Any]):
"""Slack通知处理器"""
import requests
# Slack Webhook URL(实际使用时应从配置文件读取)
webhook_url = "https://hooks.slack.com/services/YOUR/SLACK/WEBHOOK"
try:
if alert['status'] == 'firing':
color = "danger" if alert['severity'] == "critical" else "warning"
title = f"🚨 Alert: {alert['name']}"
text = f"Metric: {alert['metric_name']}\nCurrent Value: {alert['current_value']}\nThreshold: {alert['threshold']}"
else:
color = "good"
title = f"✅ Resolved: {alert['name']}"
text = f"Alert has been resolved at {datetime.fromtimestamp(alert['resolved_at']).strftime('%Y-%m-%d %H:%M:%S')}"
payload = {
"attachments": [
{
"color": color,
"title": title,
"text": text,
"ts": alert.get('triggered_at', alert.get('resolved_at'))
}
]
}
response = requests.post(webhook_url, json=payload)
response.raise_for_status()
print(f"Slack通知已发送: {alert['name']}")
except Exception as e:
print(f"发送Slack通知失败: {e}")
3.2 日志管理
# src/monitoring/log_manager.py - 日志管理器
import logging
import logging.handlers
from typing import Dict, Any, Optional
from pathlib import Path
import json
from datetime import datetime
import gzip
import os
class StructuredFormatter(logging.Formatter):
"""结构化日志格式化器"""
def format(self, record):
log_entry = {
'timestamp': datetime.fromtimestamp(record.created).isoformat(),
'level': record.levelname,
'logger': record.name,
'message': record.getMessage(),
'module': record.module,
'function': record.funcName,
'line': record.lineno
}
# 添加额外字段
if hasattr(record, 'user_id'):
log_entry['user_id'] = record.user_id
if hasattr(record, 'request_id'):
log_entry['request_id'] = record.request_id
if hasattr(record, 'service_name'):
log_entry['service_name'] = record.service_name
if record.exc_info:
log_entry['exception'] = self.formatException(record.exc_info)
return json.dumps(log_entry, ensure_ascii=False)
class LogManager:
"""日志管理器"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.loggers = {}
self._setup_logging()
def _setup_logging(self):
"""设置日志配置"""
# 创建日志目录
log_dir = Path(self.config.get('log_dir', './logs'))
log_dir.mkdir(parents=True, exist_ok=True)
# 根日志配置
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, self.config.get('level', 'INFO')))
# 清除现有处理器
root_logger.handlers.clear()
# 控制台处理器
if self.config.get('console_enabled', True):
console_handler = logging.StreamHandler()
console_handler.setLevel(getattr(logging, self.config.get('console_level', 'INFO')))
if self.config.get('structured_logging', False):
console_handler.setFormatter(StructuredFormatter())
else:
console_formatter = logging.Formatter(
self.config.get('format',
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
)
console_handler.setFormatter(console_formatter)
root_logger.addHandler(console_handler)
# 文件处理器
if self.config.get('file_enabled', True):
file_handler = logging.handlers.RotatingFileHandler(
log_dir / 'application.log',
maxBytes=self.config.get('max_file_size', 10 * 1024 * 1024), # 10MB
backupCount=self.config.get('backup_count', 5)
)
file_handler.setLevel(getattr(logging, self.config.get('file_level', 'DEBUG')))
if self.config.get('structured_logging', False):
file_handler.setFormatter(StructuredFormatter())
else:
file_formatter = logging.Formatter(
self.config.get('format',
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
)
file_handler.setFormatter(file_formatter)
root_logger.addHandler(file_handler)
# 错误日志处理器
if self.config.get('error_file_enabled', True):
error_handler = logging.handlers.RotatingFileHandler(
log_dir / 'error.log',
maxBytes=self.config.get('max_file_size', 10 * 1024 * 1024),
backupCount=self.config.get('backup_count', 5)
)
error_handler.setLevel(logging.ERROR)
if self.config.get('structured_logging', False):
error_handler.setFormatter(StructuredFormatter())
else:
error_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s - %(pathname)s:%(lineno)d'
)
error_handler.setFormatter(error_formatter)
root_logger.addHandler(error_handler)
def get_logger(self, name: str) -> logging.Logger:
"""获取日志器"""
if name not in self.loggers:
logger = logging.getLogger(name)
self.loggers[name] = logger
return self.loggers[name]
def log_request(self, request_id: str, method: str, path: str,
status_code: int, response_time: float, user_id: str = None):
"""记录请求日志"""
logger = self.get_logger('request')
extra = {
'request_id': request_id,
'method': method,
'path': path,
'status_code': status_code,
'response_time': response_time
}
if user_id:
extra['user_id'] = user_id
logger.info(f"{method} {path} - {status_code} - {response_time:.3f}s", extra=extra)
def log_error(self, error: Exception, context: Dict[str, Any] = None):
"""记录错误日志"""
logger = self.get_logger('error')
extra = context or {}
logger.error(f"Error occurred: {str(error)}", exc_info=True, extra=extra)
def compress_old_logs(self, days_old: int = 7):
"""压缩旧日志文件"""
log_dir = Path(self.config.get('log_dir', './logs'))
for log_file in log_dir.glob('*.log.*'):
if log_file.suffix != '.gz':
# 检查文件年龄
file_age = datetime.now().timestamp() - log_file.stat().st_mtime
if file_age > days_old * 24 * 3600:
# 压缩文件
compressed_file = log_file.with_suffix(log_file.suffix + '.gz')
with open(log_file, 'rb') as f_in:
with gzip.open(compressed_file, 'wb') as f_out:
f_out.writelines(f_in)
# 删除原文件
log_file.unlink()
print(f"已压缩日志文件: {log_file} -> {compressed_file}")
def cleanup_old_logs(self, days_old: int = 30):
"""清理旧日志文件"""
log_dir = Path(self.config.get('log_dir', './logs'))
for log_file in log_dir.glob('*.log.*.gz'):
file_age = datetime.now().timestamp() - log_file.stat().st_mtime
if file_age > days_old * 24 * 3600:
log_file.unlink()
print(f"已删除旧日志文件: {log_file}")
4. 性能优化
4.1 缓存策略
# src/optimization/cache_manager.py - 缓存管理器
from typing import Any, Optional, Dict, List
from abc import ABC, abstractmethod
import time
import json
import hashlib
from dataclasses import dataclass
import redis
import pickle
from datetime import datetime, timedelta
@dataclass
class CacheEntry:
"""缓存条目"""
key: str
value: Any
created_at: float
expires_at: Optional[float] = None
access_count: int = 0
last_accessed: float = None
class BaseCacheBackend(ABC):
"""缓存后端基类"""
@abstractmethod
def get(self, key: str) -> Optional[Any]:
"""获取缓存值"""
pass
@abstractmethod
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""设置缓存值"""
pass
@abstractmethod
def delete(self, key: str) -> bool:
"""删除缓存值"""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""检查键是否存在"""
pass
@abstractmethod
def clear(self) -> bool:
"""清空缓存"""
pass
class MemoryCacheBackend(BaseCacheBackend):
"""内存缓存后端"""
def __init__(self, max_size: int = 1000):
self.max_size = max_size
self.cache: Dict[str, CacheEntry] = {}
def get(self, key: str) -> Optional[Any]:
"""获取缓存值"""
if key not in self.cache:
return None
entry = self.cache[key]
# 检查是否过期
if entry.expires_at and time.time() > entry.expires_at:
del self.cache[key]
return None
# 更新访问统计
entry.access_count += 1
entry.last_accessed = time.time()
return entry.value
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""设置缓存值"""
try:
# 检查缓存大小限制
if len(self.cache) >= self.max_size and key not in self.cache:
self._evict_lru()
expires_at = None
if ttl:
expires_at = time.time() + ttl
entry = CacheEntry(
key=key,
value=value,
created_at=time.time(),
expires_at=expires_at,
last_accessed=time.time()
)
self.cache[key] = entry
return True
except Exception:
return False
def delete(self, key: str) -> bool:
"""删除缓存值"""
if key in self.cache:
del self.cache[key]
return True
return False
def exists(self, key: str) -> bool:
"""检查键是否存在"""
return key in self.cache
def clear(self) -> bool:
"""清空缓存"""
self.cache.clear()
return True
def _evict_lru(self):
"""淘汰最近最少使用的条目"""
if not self.cache:
return
# 找到最近最少访问的条目
lru_key = min(
self.cache.keys(),
key=lambda k: self.cache[k].last_accessed or 0
)
del self.cache[lru_key]
class RedisCacheBackend(BaseCacheBackend):
"""Redis缓存后端"""
def __init__(self, redis_url: str = "redis://localhost:6379",
key_prefix: str = "rag_cache:"):
self.redis_client = redis.from_url(redis_url)
self.key_prefix = key_prefix
def _make_key(self, key: str) -> str:
"""生成完整的键名"""
return f"{self.key_prefix}{key}"
def get(self, key: str) -> Optional[Any]:
"""获取缓存值"""
try:
full_key = self._make_key(key)
data = self.redis_client.get(full_key)
if data is None:
return None
return pickle.loads(data)
except Exception:
return None
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""设置缓存值"""
try:
full_key = self._make_key(key)
data = pickle.dumps(value)
if ttl:
return self.redis_client.setex(full_key, ttl, data)
else:
return self.redis_client.set(full_key, data)
except Exception:
return False
def delete(self, key: str) -> bool:
"""删除缓存值"""
try:
full_key = self._make_key(key)
return bool(self.redis_client.delete(full_key))
except Exception:
return False
def exists(self, key: str) -> bool:
"""检查键是否存在"""
try:
full_key = self._make_key(key)
return bool(self.redis_client.exists(full_key))
except Exception:
return False
def clear(self) -> bool:
"""清空缓存"""
try:
pattern = f"{self.key_prefix}*"
keys = self.redis_client.keys(pattern)
if keys:
return bool(self.redis_client.delete(*keys))
return True
except Exception:
return False
class CacheManager:
"""缓存管理器"""
def __init__(self, backend: BaseCacheBackend):
self.backend = backend
self.stats = {
'hits': 0,
'misses': 0,
'sets': 0,
'deletes': 0
}
def get(self, key: str) -> Optional[Any]:
"""获取缓存值"""
value = self.backend.get(key)
if value is not None:
self.stats['hits'] += 1
else:
self.stats['misses'] += 1
return value
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""设置缓存值"""
result = self.backend.set(key, value, ttl)
if result:
self.stats['sets'] += 1
return result
def delete(self, key: str) -> bool:
"""删除缓存值"""
result = self.backend.delete(key)
if result:
self.stats['deletes'] += 1
return result
def get_or_set(self, key: str, factory_func: callable,
ttl: Optional[int] = None) -> Any:
"""获取或设置缓存值"""
value = self.get(key)
if value is None:
value = factory_func()
self.set(key, value, ttl)
return value
def cache_key(self, *args, **kwargs) -> str:
"""生成缓存键"""
# 将参数转换为字符串并生成哈希
key_data = {
'args': args,
'kwargs': sorted(kwargs.items())
}
key_str = json.dumps(key_data, sort_keys=True, default=str)
return hashlib.md5(key_str.encode()).hexdigest()
def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计"""
total_requests = self.stats['hits'] + self.stats['misses']
hit_rate = self.stats['hits'] / total_requests if total_requests > 0 else 0
return {
'hits': self.stats['hits'],
'misses': self.stats['misses'],
'sets': self.stats['sets'],
'deletes': self.stats['deletes'],
'hit_rate': hit_rate,
'total_requests': total_requests
}
def clear_stats(self):
"""清空统计"""
self.stats = {
'hits': 0,
'misses': 0,
'sets': 0,
'deletes': 0
}
def cache_decorator(cache_manager: CacheManager, ttl: Optional[int] = None,
key_func: Optional[callable] = None):
"""缓存装饰器"""
def decorator(func):
def wrapper(*args, **kwargs):
# 生成缓存键
if key_func:
cache_key = key_func(*args, **kwargs)
else:
cache_key = f"{func.__name__}:{cache_manager.cache_key(*args, **kwargs)}"
# 尝试从缓存获取
result = cache_manager.get(cache_key)
if result is None:
# 缓存未命中,执行函数
result = func(*args, **kwargs)
# 存储到缓存
cache_manager.set(cache_key, result, ttl)
return result
return wrapper
return decorator
4.2 连接池管理
# src/optimization/connection_pool.py - 连接池管理器
from typing import Any, Optional, Dict, List, Callable
from abc import ABC, abstractmethod
import threading
import time
import queue
from dataclasses import dataclass
from contextlib import contextmanager
import logging
@dataclass
class ConnectionConfig:
"""连接配置"""
host: str
port: int
username: Optional[str] = None
password: Optional[str] = None
database: Optional[str] = None
max_connections: int = 10
min_connections: int = 2
connection_timeout: int = 30
idle_timeout: int = 300
max_retries: int = 3
retry_delay: float = 1.0
@dataclass
class ConnectionInfo:
"""连接信息"""
connection: Any
created_at: float
last_used: float
is_active: bool = True
use_count: int = 0
class BaseConnectionPool(ABC):
"""连接池基类"""
def __init__(self, config: ConnectionConfig):
self.config = config
self.connections: queue.Queue = queue.Queue(maxsize=config.max_connections)
self.active_connections: Dict[int, ConnectionInfo] = {}
self.lock = threading.RLock()
self.logger = logging.getLogger(self.__class__.__name__)
self._initialize_pool()
@abstractmethod
def _create_connection(self) -> Any:
"""创建新连接"""
pass
@abstractmethod
def _validate_connection(self, connection: Any) -> bool:
"""验证连接是否有效"""
pass
@abstractmethod
def _close_connection(self, connection: Any):
"""关闭连接"""
pass
def _initialize_pool(self):
"""初始化连接池"""
for _ in range(self.config.min_connections):
try:
connection = self._create_connection()
connection_info = ConnectionInfo(
connection=connection,
created_at=time.time(),
last_used=time.time()
)
self.connections.put(connection_info)
self.logger.info("已创建初始连接")
except Exception as e:
self.logger.error(f"创建初始连接失败: {e}")
@contextmanager
def get_connection(self):
"""获取连接(上下文管理器)"""
connection_info = None
try:
connection_info = self._get_connection()
yield connection_info.connection
finally:
if connection_info:
self._return_connection(connection_info)
def _get_connection(self) -> ConnectionInfo:
"""获取连接"""
with self.lock:
# 尝试从池中获取连接
try:
connection_info = self.connections.get_nowait()
# 验证连接是否有效
if self._validate_connection(connection_info.connection):
connection_info.last_used = time.time()
connection_info.use_count += 1
self.active_connections[id(connection_info)] = connection_info
return connection_info
else:
# 连接无效,关闭并创建新连接
self._close_connection(connection_info.connection)
except queue.Empty:
pass
# 创建新连接
if len(self.active_connections) < self.config.max_connections:
try:
connection = self._create_connection()
connection_info = ConnectionInfo(
connection=connection,
created_at=time.time(),
last_used=time.time(),
use_count=1
)
self.active_connections[id(connection_info)] = connection_info
self.logger.info("已创建新连接")
return connection_info
except Exception as e:
self.logger.error(f"创建连接失败: {e}")
raise
# 等待可用连接
try:
connection_info = self.connections.get(timeout=self.config.connection_timeout)
connection_info.last_used = time.time()
connection_info.use_count += 1
self.active_connections[id(connection_info)] = connection_info
return connection_info
except queue.Empty:
raise TimeoutError("获取连接超时")
def _return_connection(self, connection_info: ConnectionInfo):
"""归还连接"""
with self.lock:
connection_id = id(connection_info)
if connection_id in self.active_connections:
del self.active_connections[connection_id]
# 检查连接是否仍然有效
if (connection_info.is_active and
self._validate_connection(connection_info.connection)):
try:
self.connections.put_nowait(connection_info)
except queue.Full:
# 连接池已满,关闭连接
self._close_connection(connection_info.connection)
self.logger.info("连接池已满,关闭多余连接")
else:
# 连接无效,关闭
self._close_connection(connection_info.connection)
self.logger.info("关闭无效连接")
def cleanup_idle_connections(self):
"""清理空闲连接"""
with self.lock:
current_time = time.time()
connections_to_remove = []
# 检查池中的连接
temp_connections = []
while not self.connections.empty():
try:
connection_info = self.connections.get_nowait()
# 检查是否超过空闲时间
if (current_time - connection_info.last_used > self.config.idle_timeout):
connections_to_remove.append(connection_info)
else:
temp_connections.append(connection_info)
except queue.Empty:
break
# 将有效连接放回池中
for connection_info in temp_connections:
try:
self.connections.put_nowait(connection_info)
except queue.Full:
connections_to_remove.append(connection_info)
# 关闭需要移除的连接
for connection_info in connections_to_remove:
self._close_connection(connection_info.connection)
self.logger.info("已清理空闲连接")
def get_pool_stats(self) -> Dict[str, Any]:
"""获取连接池统计信息"""
with self.lock:
return {
'total_connections': self.connections.qsize() + len(self.active_connections),
'available_connections': self.connections.qsize(),
'active_connections': len(self.active_connections),
'max_connections': self.config.max_connections,
'min_connections': self.config.min_connections
}
def close_all_connections(self):
"""关闭所有连接"""
with self.lock:
# 关闭池中的连接
while not self.connections.empty():
try:
connection_info = self.connections.get_nowait()
self._close_connection(connection_info.connection)
except queue.Empty:
break
# 关闭活跃连接
for connection_info in self.active_connections.values():
connection_info.is_active = False
self._close_connection(connection_info.connection)
self.active_connections.clear()
self.logger.info("已关闭所有连接")
class DatabaseConnectionPool(BaseConnectionPool):
"""数据库连接池"""
def __init__(self, config: ConnectionConfig, db_driver: str = "postgresql"):
self.db_driver = db_driver
super().__init__(config)
def _create_connection(self) -> Any:
"""创建数据库连接"""
if self.db_driver == "postgresql":
import psycopg2
return psycopg2.connect(
host=self.config.host,
port=self.config.port,
user=self.config.username,
password=self.config.password,
database=self.config.database
)
elif self.db_driver == "mysql":
import pymysql
return pymysql.connect(
host=self.config.host,
port=self.config.port,
user=self.config.username,
password=self.config.password,
database=self.config.database
)
else:
raise ValueError(f"不支持的数据库驱动: {self.db_driver}")
def _validate_connection(self, connection: Any) -> bool:
"""验证数据库连接"""
try:
cursor = connection.cursor()
cursor.execute("SELECT 1")
cursor.fetchone()
cursor.close()
return True
except Exception:
return False
def _close_connection(self, connection: Any):
"""关闭数据库连接"""
try:
connection.close()
except Exception as e:
self.logger.error(f"关闭数据库连接失败: {e}")
class RedisConnectionPool(BaseConnectionPool):
"""Redis连接池"""
def _create_connection(self) -> Any:
"""创建Redis连接"""
import redis
return redis.Redis(
host=self.config.host,
port=self.config.port,
password=self.config.password,
db=self.config.database or 0
)
def _validate_connection(self, connection: Any) -> bool:
"""验证Redis连接"""
try:
connection.ping()
return True
except Exception:
return False
def _close_connection(self, connection: Any):
"""关闭Redis连接"""
try:
connection.close()
except Exception as e:
self.logger.error(f"关闭Redis连接失败: {e}")
4.3 使用示例
# examples/deployment_example.py - 部署示例
import asyncio
import time
from datetime import datetime
from typing import Dict, Any
# 导入各个组件
from src.deployment.service_manager import ServiceManager, ServiceConfig
from src.deployment.api_gateway import APIGateway, RateLimitConfig, AuthConfig
from src.monitoring.metrics_collector import MetricsCollector
from src.monitoring.alert_manager import AlertManager, email_notification_handler
from src.monitoring.log_manager import LogManager
from src.optimization.cache_manager import CacheManager, RedisCacheBackend
from src.optimization.connection_pool import DatabaseConnectionPool, ConnectionConfig
def setup_logging():
"""设置日志"""
log_config = {
'level': 'INFO',
'console_enabled': True,
'file_enabled': True,
'structured_logging': True,
'log_dir': './logs'
}
return LogManager(log_config)
def setup_cache():
"""设置缓存"""
# 使用Redis作为缓存后端
redis_backend = RedisCacheBackend(
redis_url="redis://localhost:6379",
key_prefix="rag_prod:"
)
return CacheManager(redis_backend)
def setup_database_pool():
"""设置数据库连接池"""
db_config = ConnectionConfig(
host="localhost",
port=5432,
username="rag_user",
password="rag_password",
database="rag_db",
max_connections=20,
min_connections=5
)
return DatabaseConnectionPool(db_config, "postgresql")
def setup_monitoring():
"""设置监控"""
# 指标收集器
metrics_collector = MetricsCollector()
# 告警管理器
alert_manager = AlertManager(metrics_collector)
# 添加告警规则
alert_manager.add_rule(
name="high_cpu_usage",
metric_name="system.cpu.usage",
condition=">",
threshold=80.0,
duration=300, # 5分钟
severity="warning"
)
alert_manager.add_rule(
name="high_memory_usage",
metric_name="system.memory.usage",
condition=">",
threshold=90.0,
duration=180, # 3分钟
severity="critical"
)
alert_manager.add_rule(
name="high_response_time",
metric_name="api.response_time",
condition=">",
threshold=2.0,
duration=120, # 2分钟
severity="warning"
)
# 添加通知处理器
alert_manager.add_notification_handler(email_notification_handler)
return metrics_collector, alert_manager
def setup_api_gateway():
"""设置API网关"""
# 限流配置
rate_limit_config = RateLimitConfig(
requests_per_minute=1000,
burst_size=100
)
# 认证配置
auth_config = AuthConfig(
jwt_secret="your-secret-key",
token_expiry=3600
)
return APIGateway(rate_limit_config, auth_config)
def setup_services(cache_manager, db_pool, metrics_collector):
"""设置服务"""
service_manager = ServiceManager()
# 文档加载服务配置
doc_loader_config = ServiceConfig(
name="document_loader",
host="localhost",
port=8001,
health_check_url="/health",
dependencies=[]
)
# 向量搜索服务配置
vector_search_config = ServiceConfig(
name="vector_search",
host="localhost",
port=8002,
health_check_url="/health",
dependencies=["document_loader"]
)
# 生成服务配置
generation_config = ServiceConfig(
name="generation",
host="localhost",
port=8003,
health_check_url="/health",
dependencies=["vector_search"]
)
# 注册服务
service_manager.register_service(doc_loader_config)
service_manager.register_service(vector_search_config)
service_manager.register_service(generation_config)
return service_manager
async def run_monitoring_loop(metrics_collector, alert_manager):
"""运行监控循环"""
while True:
try:
# 收集系统指标
metrics_collector.collect_system_metrics()
# 检查告警
alert_manager.check_alerts()
# 等待30秒
await asyncio.sleep(30)
except Exception as e:
print(f"监控循环错误: {e}")
await asyncio.sleep(5)
async def simulate_api_requests(api_gateway, metrics_collector):
"""模拟API请求"""
import random
while True:
try:
# 模拟不同的响应时间
response_time = random.uniform(0.1, 3.0)
# 记录API指标
metrics_collector.record_api_metric(
endpoint="/api/search",
method="POST",
status_code=200,
response_time=response_time
)
# 模拟请求间隔
await asyncio.sleep(random.uniform(0.1, 1.0))
except Exception as e:
print(f"API请求模拟错误: {e}")
await asyncio.sleep(1)
async def main():
"""主函数"""
print("启动RAG系统部署示例...")
# 设置各个组件
log_manager = setup_logging()
cache_manager = setup_cache()
db_pool = setup_database_pool()
metrics_collector, alert_manager = setup_monitoring()
api_gateway = setup_api_gateway()
service_manager = setup_services(cache_manager, db_pool, metrics_collector)
logger = log_manager.get_logger("main")
logger.info("RAG系统组件初始化完成")
try:
# 启动服务
await service_manager.start_all_services()
logger.info("所有服务已启动")
# 启动监控和模拟任务
monitoring_task = asyncio.create_task(
run_monitoring_loop(metrics_collector, alert_manager)
)
simulation_task = asyncio.create_task(
simulate_api_requests(api_gateway, metrics_collector)
)
logger.info("监控和模拟任务已启动")
# 运行一段时间后显示统计信息
await asyncio.sleep(60)
# 显示缓存统计
cache_stats = cache_manager.get_stats()
logger.info(f"缓存统计: {cache_stats}")
# 显示连接池统计
pool_stats = db_pool.get_pool_stats()
logger.info(f"连接池统计: {pool_stats}")
# 显示服务状态
service_status = service_manager.get_all_service_status()
logger.info(f"服务状态: {service_status}")
# 显示活跃告警
active_alerts = alert_manager.get_active_alerts()
if active_alerts:
logger.warning(f"活跃告警: {active_alerts}")
else:
logger.info("无活跃告警")
# 继续运行
await asyncio.gather(monitoring_task, simulation_task)
except KeyboardInterrupt:
logger.info("收到停止信号,正在关闭系统...")
except Exception as e:
logger.error(f"系统运行错误: {e}")
finally:
# 清理资源
await service_manager.stop_all_services()
db_pool.close_all_connections()
logger.info("系统已关闭")
if __name__ == "__main__":
asyncio.run(main())
5. 总结
本章详细介绍了RAG系统的部署与运维技术,包括:
核心要点
系统架构设计
- 微服务架构模式
- API网关统一入口
- 服务发现与注册
- 负载均衡策略
容器化部署
- Docker容器化
- Docker Compose编排
- 环境配置管理
- 数据持久化
监控与告警
- 系统指标收集
- 应用性能监控
- 智能告警机制
- 日志管理系统
性能优化
- 多层缓存策略
- 连接池管理
- 资源优化配置
- 性能调优技巧
最佳实践
部署策略
- 蓝绿部署
- 滚动更新
- 灰度发布
- 回滚机制
监控体系
- 全链路监控
- 实时告警
- 性能分析
- 容量规划
运维自动化
- CI/CD流水线
- 自动化测试
- 配置管理
- 故障自愈
通过本章的学习,你将掌握构建生产级RAG系统的完整部署和运维技能,确保系统的高可用性、高性能和可维护性。