9.1 项目概述与架构设计
9.1.1 项目目标
本章将指导你构建一个完整的Text2SQL系统,该系统具备以下功能:
- 自然语言查询理解:解析用户的自然语言问题
- SQL生成:将自然语言转换为可执行的SQL查询
- 查询执行:在数据库中执行生成的SQL
- 结果展示:以用户友好的方式展示查询结果
- 交互式界面:提供Web界面供用户使用
- 性能监控:监控系统性能和查询质量
9.1.2 系统架构
import os
import json
import logging
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from abc import ABC, abstractmethod
from pathlib import Path
@dataclass
class SystemConfig:
"""系统配置"""
# 模型配置
model_name: str = "text2sql-t5-base"
model_path: str = "./models/text2sql"
device: str = "cuda"
max_length: int = 512
# 数据库配置
db_type: str = "sqlite"
db_path: str = "./data/database.db"
schema_path: str = "./data/schema.json"
# API配置
api_host: str = "0.0.0.0"
api_port: int = 8000
debug: bool = False
# 日志配置
log_level: str = "INFO"
log_file: str = "./logs/text2sql.log"
# 缓存配置
enable_cache: bool = True
cache_size: int = 1000
cache_ttl: int = 3600
class Text2SQLSystem:
"""Text2SQL系统主类"""
def __init__(self, config: SystemConfig):
self.config = config
self.logger = self._setup_logging()
# 初始化组件
self.query_processor = None
self.model_manager = None
self.database_manager = None
self.cache_manager = None
self.result_formatter = None
self._initialize_components()
def _setup_logging(self) -> logging.Logger:
"""设置日志"""
# 创建日志目录
log_dir = Path(self.config.log_file).parent
log_dir.mkdir(parents=True, exist_ok=True)
# 配置日志
logging.basicConfig(
level=getattr(logging, self.config.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(self.config.log_file),
logging.StreamHandler()
]
)
return logging.getLogger(__name__)
def _initialize_components(self):
"""初始化系统组件"""
self.logger.info("初始化Text2SQL系统组件...")
try:
# 初始化查询处理器
self.query_processor = QueryProcessor(self.config)
# 初始化模型管理器
self.model_manager = ModelManager(self.config)
# 初始化数据库管理器
self.database_manager = DatabaseManager(self.config)
# 初始化缓存管理器
if self.config.enable_cache:
self.cache_manager = CacheManager(self.config)
# 初始化结果格式化器
self.result_formatter = ResultFormatter(self.config)
self.logger.info("系统组件初始化完成")
except Exception as e:
self.logger.error(f"系统初始化失败: {e}")
raise
def process_query(self, question: str, db_id: str = None) -> Dict[str, Any]:
"""处理查询请求"""
self.logger.info(f"处理查询: {question}")
try:
# 1. 预处理查询
processed_query = self.query_processor.preprocess(question)
# 2. 检查缓存
cache_key = f"{processed_query}_{db_id}"
if self.cache_manager:
cached_result = self.cache_manager.get(cache_key)
if cached_result:
self.logger.info("返回缓存结果")
return cached_result
# 3. 生成SQL
sql_result = self.model_manager.generate_sql(
processed_query, db_id
)
# 4. 执行SQL
execution_result = self.database_manager.execute_sql(
sql_result['sql'], db_id
)
# 5. 格式化结果
formatted_result = self.result_formatter.format_result(
question, sql_result, execution_result
)
# 6. 缓存结果
if self.cache_manager:
self.cache_manager.set(cache_key, formatted_result)
self.logger.info("查询处理完成")
return formatted_result
except Exception as e:
self.logger.error(f"查询处理失败: {e}")
return {
'success': False,
'error': str(e),
'question': question
}
def get_system_status(self) -> Dict[str, Any]:
"""获取系统状态"""
return {
'model_status': self.model_manager.get_status(),
'database_status': self.database_manager.get_status(),
'cache_status': self.cache_manager.get_status() if self.cache_manager else None,
'system_config': {
'model_name': self.config.model_name,
'db_type': self.config.db_type,
'cache_enabled': self.config.enable_cache
}
}
print("Text2SQL系统架构定义完成")
9.2 核心组件实现
9.2.1 查询处理器
import re
import string
from typing import Dict, List, Tuple
import spacy
from transformers import AutoTokenizer
class QueryProcessor:
"""查询处理器"""
def __init__(self, config: SystemConfig):
self.config = config
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
# 加载NLP模型
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
self.nlp = None
print("警告: spaCy英文模型未安装,某些功能可能受限")
# 预定义的查询模式
self.query_patterns = {
'aggregation': r'\b(count|sum|avg|average|max|maximum|min|minimum|total)\b',
'comparison': r'\b(greater|less|more|fewer|higher|lower|above|below|between)\b',
'temporal': r'\b(year|month|day|date|time|before|after|during|since)\b',
'sorting': r'\b(sort|order|rank|top|bottom|first|last|highest|lowest)\b'
}
def preprocess(self, question: str) -> str:
"""预处理查询"""
# 1. 基础清理
processed = self._basic_cleaning(question)
# 2. 实体识别和标准化
processed = self._normalize_entities(processed)
# 3. 查询意图分析
intent = self._analyze_intent(processed)
# 4. 添加上下文信息
processed = self._add_context(processed, intent)
return processed
def _basic_cleaning(self, text: str) -> str:
"""基础文本清理"""
# 移除多余空白
text = re.sub(r'\s+', ' ', text.strip())
# 标准化标点符号
text = text.replace(''', "'").replace('"', '"')
text = text.replace("'", "'")
# 移除不必要的标点
text = text.strip(string.punctuation)
return text
def _normalize_entities(self, text: str) -> str:
"""实体标准化"""
if not self.nlp:
return text
doc = self.nlp(text)
# 标准化数字
text = re.sub(r'\b(one|two|three|four|five|six|seven|eight|nine|ten)\b',
lambda m: str({'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5,
'six': 6, 'seven': 7, 'eight': 8, 'nine': 9, 'ten': 10}[m.group()]),
text, flags=re.IGNORECASE)
# 标准化日期表达式
text = re.sub(r'\b(last year|previous year)\b', 'year before', text, flags=re.IGNORECASE)
text = re.sub(r'\b(this year|current year)\b', 'current year', text, flags=re.IGNORECASE)
return text
def _analyze_intent(self, text: str) -> Dict[str, bool]:
"""分析查询意图"""
intent = {}
for pattern_name, pattern in self.query_patterns.items():
intent[pattern_name] = bool(re.search(pattern, text, re.IGNORECASE))
return intent
def _add_context(self, text: str, intent: Dict[str, bool]) -> str:
"""添加上下文信息"""
# 根据意图添加提示词
if intent.get('aggregation'):
text = f"[AGGREGATION] {text}"
if intent.get('comparison'):
text = f"[COMPARISON] {text}"
if intent.get('temporal'):
text = f"[TEMPORAL] {text}"
if intent.get('sorting'):
text = f"[SORTING] {text}"
return text
def extract_keywords(self, text: str) -> List[str]:
"""提取关键词"""
if not self.nlp:
# 简单的关键词提取
words = text.lower().split()
stopwords = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'}
return [word for word in words if word not in stopwords and len(word) > 2]
doc = self.nlp(text)
keywords = []
for token in doc:
if (not token.is_stop and
not token.is_punct and
not token.is_space and
len(token.text) > 2):
keywords.append(token.lemma_.lower())
return keywords
class ModelManager:
"""模型管理器"""
def __init__(self, config: SystemConfig):
self.config = config
self.model = None
self.tokenizer = None
self.device = config.device
self._load_model()
def _load_model(self):
"""加载模型"""
try:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
# 加载tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
# 加载模型
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.config.model_name)
# 移动到指定设备
if torch.cuda.is_available() and self.device == 'cuda':
self.model = self.model.cuda()
self.model.eval()
print(f"模型 {self.config.model_name} 加载成功")
except Exception as e:
print(f"模型加载失败: {e}")
raise
def generate_sql(self, question: str, db_id: str = None) -> Dict[str, Any]:
"""生成SQL"""
try:
import torch
# 准备输入
if db_id:
input_text = f"question: {question} database: {db_id}"
else:
input_text = f"question: {question}"
# 编码输入
inputs = self.tokenizer(
input_text,
max_length=self.config.max_length,
padding=True,
truncation=True,
return_tensors="pt"
)
# 移动到设备
if self.device == 'cuda' and torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
# 生成SQL
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=self.config.max_length,
num_beams=4,
early_stopping=True,
do_sample=False
)
# 解码输出
generated_sql = self.tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
# 后处理SQL
processed_sql = self._postprocess_sql(generated_sql)
return {
'sql': processed_sql,
'raw_sql': generated_sql,
'confidence': self._calculate_confidence(outputs),
'input_text': input_text
}
except Exception as e:
return {
'sql': None,
'error': str(e),
'input_text': input_text if 'input_text' in locals() else question
}
def _postprocess_sql(self, sql: str) -> str:
"""后处理SQL"""
# 移除多余空白
sql = re.sub(r'\s+', ' ', sql.strip())
# 确保SQL以分号结尾
if not sql.endswith(';'):
sql += ';'
# 标准化关键字大小写
keywords = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'HAVING', 'ORDER BY',
'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN', 'OUTER JOIN',
'AND', 'OR', 'NOT', 'IN', 'LIKE', 'BETWEEN', 'IS', 'NULL',
'COUNT', 'SUM', 'AVG', 'MAX', 'MIN', 'DISTINCT']
for keyword in keywords:
pattern = r'\b' + re.escape(keyword.lower()) + r'\b'
sql = re.sub(pattern, keyword, sql, flags=re.IGNORECASE)
return sql
def _calculate_confidence(self, outputs) -> float:
"""计算置信度"""
# 简单的置信度计算,实际应用中可以使用更复杂的方法
return 0.8 # 占位符
def get_status(self) -> Dict[str, Any]:
"""获取模型状态"""
return {
'model_loaded': self.model is not None,
'model_name': self.config.model_name,
'device': self.device,
'tokenizer_vocab_size': len(self.tokenizer) if self.tokenizer else 0
}
class DatabaseManager:
"""数据库管理器"""
def __init__(self, config: SystemConfig):
self.config = config
self.connections = {}
self.schemas = {}
self._load_schemas()
def _load_schemas(self):
"""加载数据库Schema"""
try:
if os.path.exists(self.config.schema_path):
with open(self.config.schema_path, 'r', encoding='utf-8') as f:
self.schemas = json.load(f)
print(f"加载了 {len(self.schemas)} 个数据库Schema")
except Exception as e:
print(f"Schema加载失败: {e}")
def get_connection(self, db_id: str = None):
"""获取数据库连接"""
if db_id is None:
db_id = 'default'
if db_id not in self.connections:
if self.config.db_type == 'sqlite':
import sqlite3
db_path = self.config.db_path
if db_id != 'default':
db_path = f"./data/{db_id}.db"
self.connections[db_id] = sqlite3.connect(db_path)
else:
raise ValueError(f"不支持的数据库类型: {self.config.db_type}")
return self.connections[db_id]
def execute_sql(self, sql: str, db_id: str = None) -> Dict[str, Any]:
"""执行SQL查询"""
try:
conn = self.get_connection(db_id)
cursor = conn.cursor()
# 执行查询
cursor.execute(sql)
# 获取结果
if sql.strip().upper().startswith('SELECT'):
columns = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
return {
'success': True,
'columns': columns,
'rows': rows,
'row_count': len(rows)
}
else:
# 非查询语句
conn.commit()
return {
'success': True,
'affected_rows': cursor.rowcount
}
except Exception as e:
return {
'success': False,
'error': str(e),
'sql': sql
}
def get_schema(self, db_id: str) -> Dict[str, Any]:
"""获取数据库Schema"""
return self.schemas.get(db_id, {})
def get_status(self) -> Dict[str, Any]:
"""获取数据库状态"""
return {
'db_type': self.config.db_type,
'active_connections': len(self.connections),
'available_schemas': list(self.schemas.keys())
}
class CacheManager:
"""缓存管理器"""
def __init__(self, config: SystemConfig):
self.config = config
self.cache = {}
self.access_times = {}
self.max_size = config.cache_size
self.ttl = config.cache_ttl
def get(self, key: str) -> Optional[Any]:
"""获取缓存"""
import time
if key in self.cache:
# 检查TTL
if time.time() - self.access_times[key] < self.ttl:
self.access_times[key] = time.time()
return self.cache[key]
else:
# 过期,删除
del self.cache[key]
del self.access_times[key]
return None
def set(self, key: str, value: Any):
"""设置缓存"""
import time
# 检查缓存大小
if len(self.cache) >= self.max_size:
self._evict_oldest()
self.cache[key] = value
self.access_times[key] = time.time()
def _evict_oldest(self):
"""淘汰最旧的缓存项"""
if self.access_times:
oldest_key = min(self.access_times.keys(),
key=lambda k: self.access_times[k])
del self.cache[oldest_key]
del self.access_times[oldest_key]
def get_status(self) -> Dict[str, Any]:
"""获取缓存状态"""
return {
'cache_size': len(self.cache),
'max_size': self.max_size,
'hit_rate': self._calculate_hit_rate()
}
def _calculate_hit_rate(self) -> float:
"""计算缓存命中率"""
# 简化实现,实际应用中需要记录命中统计
return 0.75 # 占位符
class ResultFormatter:
"""结果格式化器"""
def __init__(self, config: SystemConfig):
self.config = config
def format_result(self, question: str, sql_result: Dict,
execution_result: Dict) -> Dict[str, Any]:
"""格式化查询结果"""
formatted = {
'question': question,
'timestamp': self._get_timestamp(),
'sql': sql_result.get('sql'),
'success': execution_result.get('success', False)
}
if execution_result.get('success'):
# 成功执行
formatted.update({
'columns': execution_result.get('columns', []),
'rows': execution_result.get('rows', []),
'row_count': execution_result.get('row_count', 0),
'formatted_table': self._format_table(
execution_result.get('columns', []),
execution_result.get('rows', [])
)
})
else:
# 执行失败
formatted.update({
'error': execution_result.get('error'),
'error_type': self._classify_error(execution_result.get('error', ''))
})
# 添加SQL生成信息
formatted['sql_generation'] = {
'confidence': sql_result.get('confidence'),
'raw_sql': sql_result.get('raw_sql')
}
return formatted
def _get_timestamp(self) -> str:
"""获取时间戳"""
from datetime import datetime
return datetime.now().isoformat()
def _format_table(self, columns: List[str], rows: List[Tuple]) -> str:
"""格式化表格"""
if not columns or not rows:
return "No data found."
# 计算列宽
col_widths = [len(col) for col in columns]
for row in rows:
for i, cell in enumerate(row):
col_widths[i] = max(col_widths[i], len(str(cell)))
# 构建表格
table_lines = []
# 表头
header = " | ".join(col.ljust(width) for col, width in zip(columns, col_widths))
table_lines.append(header)
table_lines.append("-" * len(header))
# 数据行
for row in rows:
row_str = " | ".join(str(cell).ljust(width) for cell, width in zip(row, col_widths))
table_lines.append(row_str)
return "\n".join(table_lines)
def _classify_error(self, error_msg: str) -> str:
"""分类错误类型"""
error_msg = error_msg.lower()
if 'syntax error' in error_msg:
return 'syntax_error'
elif 'no such table' in error_msg:
return 'table_not_found'
elif 'no such column' in error_msg:
return 'column_not_found'
elif 'ambiguous column' in error_msg:
return 'ambiguous_column'
else:
return 'unknown_error'
print("核心组件实现完成")
## 9.3 系统集成与接口设计
### 9.3.1 Web接口实现
```python
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
import logging
from typing import Dict, Any
class Text2SQLWebAPI:
"""Text2SQL Web API"""
def __init__(self, system: Text2SQLSystem):
self.system = system
self.app = Flask(__name__)
CORS(self.app)
# 配置日志
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
self._setup_routes()
def _setup_routes(self):
"""设置路由"""
@self.app.route('/', methods=['GET'])
def index():
"""主页"""
return render_template('index.html')
@self.app.route('/api/query', methods=['POST'])
def query():
"""处理查询请求"""
try:
data = request.get_json()
if not data or 'question' not in data:
return jsonify({
'success': False,
'error': 'Missing question parameter'
}), 400
question = data['question']
db_id = data.get('db_id')
# 处理查询
result = self.system.process_query(question, db_id)
return jsonify({
'success': True,
'result': result
})
except Exception as e:
self.logger.error(f"查询处理错误: {e}")
return jsonify({
'success': False,
'error': str(e)
}), 500
@self.app.route('/api/status', methods=['GET'])
def status():
"""获取系统状态"""
try:
status = self.system.get_status()
return jsonify({
'success': True,
'status': status
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
@self.app.route('/api/schemas', methods=['GET'])
def get_schemas():
"""获取可用的数据库Schema"""
try:
schemas = self.system.db_manager.get_status()['available_schemas']
return jsonify({
'success': True,
'schemas': schemas
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
@self.app.route('/api/schema/<db_id>', methods=['GET'])
def get_schema(db_id):
"""获取特定数据库的Schema"""
try:
schema = self.system.db_manager.get_schema(db_id)
return jsonify({
'success': True,
'schema': schema
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
def run(self, host='0.0.0.0', port=5000, debug=False):
"""启动Web服务"""
self.logger.info(f"启动Text2SQL Web API服务: http://{host}:{port}")
self.app.run(host=host, port=port, debug=debug)
class CommandLineInterface:
"""命令行接口"""
def __init__(self, system: Text2SQLSystem):
self.system = system
def run(self):
"""运行命令行界面"""
print("=" * 50)
print("Text2SQL 系统命令行界面")
print("=" * 50)
print("输入 'help' 查看帮助信息")
print("输入 'quit' 退出系统")
print()
while True:
try:
user_input = input("请输入自然语言查询: ").strip()
if user_input.lower() == 'quit':
print("再见!")
break
elif user_input.lower() == 'help':
self._show_help()
continue
elif user_input.lower() == 'status':
self._show_status()
continue
elif user_input.lower().startswith('use '):
db_id = user_input[4:].strip()
self._set_database(db_id)
continue
elif not user_input:
continue
# 处理查询
result = self.system.process_query(user_input)
self._display_result(result)
except KeyboardInterrupt:
print("\n再见!")
break
except Exception as e:
print(f"错误: {e}")
def _show_help(self):
"""显示帮助信息"""
help_text = """
可用命令:
help - 显示此帮助信息
status - 显示系统状态
use <db> - 切换数据库
quit - 退出系统
示例查询:
"显示所有用户"
"查找年龄大于25的用户"
"统计每个部门的员工数量"
"""
print(help_text)
def _show_status(self):
"""显示系统状态"""
status = self.system.get_status()
print("\n系统状态:")
for key, value in status.items():
print(f" {key}: {value}")
print()
def _set_database(self, db_id: str):
"""设置数据库"""
try:
schema = self.system.db_manager.get_schema(db_id)
if schema:
print(f"已切换到数据库: {db_id}")
else:
print(f"数据库 {db_id} 不存在")
except Exception as e:
print(f"切换数据库失败: {e}")
def _display_result(self, result: Dict[str, Any]):
"""显示查询结果"""
print("\n" + "="*50)
print(f"问题: {result['question']}")
print(f"生成的SQL: {result['sql']}")
if result['success']:
print(f"执行成功,返回 {result['row_count']} 行数据")
if result.get('formatted_table'):
print("\n查询结果:")
print(result['formatted_table'])
else:
print(f"执行失败: {result.get('error', '未知错误')}")
print("="*50 + "\n")
## 9.4 完整使用示例
### 9.4.1 基本使用示例
```python
def basic_usage_example():
"""基本使用示例"""
# 1. 创建配置
config = SystemConfig(
model_name="t5-small", # 使用小模型进行演示
db_path="./data/example.db",
schema_path="./data/schemas.json"
)
# 2. 初始化系统
system = Text2SQLSystem(config)
# 3. 处理查询
questions = [
"显示所有用户的姓名和邮箱",
"查找年龄大于25岁的用户",
"统计每个部门的员工数量",
"找出薪资最高的前5名员工"
]
for question in questions:
print(f"\n处理问题: {question}")
result = system.process_query(question)
if result['success']:
print(f"生成SQL: {result['sql']}")
print(f"返回 {result['row_count']} 行数据")
else:
print(f"处理失败: {result.get('error')}")
def web_api_example():
"""Web API使用示例"""
# 创建系统
config = SystemConfig()
system = Text2SQLSystem(config)
# 创建Web API
api = Text2SQLWebAPI(system)
# 启动服务
api.run(host='localhost', port=5000, debug=True)
def cli_example():
"""命令行界面示例"""
# 创建系统
config = SystemConfig()
system = Text2SQLSystem(config)
# 创建命令行界面
cli = CommandLineInterface(system)
# 运行
cli.run()
### 9.4.2 高级使用示例
```python
class AdvancedText2SQLDemo:
"""高级Text2SQL演示"""
def __init__(self):
self.config = SystemConfig(
model_name="facebook/bart-large",
device="cuda" if torch.cuda.is_available() else "cpu",
cache_size=1000,
cache_ttl=3600
)
self.system = Text2SQLSystem(self.config)
def batch_processing_demo(self):
"""批处理演示"""
questions = [
"显示所有产品的名称和价格",
"查找库存少于10的产品",
"统计每个类别的产品数量",
"找出销量最好的产品",
"显示最近一个月的订单"
]
results = []
for question in questions:
result = self.system.process_query(question)
results.append(result)
# 生成报告
self._generate_report(results)
def _generate_report(self, results: List[Dict]):
"""生成处理报告"""
print("\n" + "="*60)
print("批处理结果报告")
print("="*60)
success_count = sum(1 for r in results if r['success'])
total_count = len(results)
print(f"总查询数: {total_count}")
print(f"成功数: {success_count}")
print(f"成功率: {success_count/total_count*100:.1f}%")
print("\n详细结果:")
for i, result in enumerate(results, 1):
status = "✓" if result['success'] else "✗"
print(f"{i}. {status} {result['question']}")
if result['success']:
print(f" SQL: {result['sql']}")
print(f" 结果: {result['row_count']} 行")
else:
print(f" 错误: {result.get('error')}")
def performance_test(self):
"""性能测试"""
import time
test_questions = [
"显示所有用户",
"查找活跃用户",
"统计用户数量"
] * 10 # 重复测试
start_time = time.time()
for question in test_questions:
self.system.process_query(question)
end_time = time.time()
total_time = end_time - start_time
avg_time = total_time / len(test_questions)
print(f"\n性能测试结果:")
print(f"总查询数: {len(test_questions)}")
print(f"总耗时: {total_time:.2f}秒")
print(f"平均耗时: {avg_time:.3f}秒/查询")
print(f"QPS: {len(test_questions)/total_time:.1f}")
def error_handling_demo(self):
"""错误处理演示"""
error_questions = [
"显示不存在的表",
"查询语法错误的问题",
"使用不存在的列名",
"复杂的嵌套查询"
]
print("\n错误处理演示:")
for question in error_questions:
result = self.system.process_query(question)
print(f"\n问题: {question}")
if result['success']:
print("意外成功")
else:
error_type = result.get('error_type', 'unknown')
print(f"错误类型: {error_type}")
print(f"错误信息: {result.get('error')}")
## 9.5 部署与运维
### 9.5.1 Docker部署
```dockerfile
# Dockerfile
FROM python:3.9-slim
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \
gcc \
g++ \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY requirements.txt .
# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY . .
# 暴露端口
EXPOSE 5000
# 启动命令
CMD ["python", "app.py"]
# docker-compose.yml
version: '3.8'
services:
text2sql:
build: .
ports:
- "5000:5000"
volumes:
- ./data:/app/data
- ./models:/app/models
environment:
- MODEL_NAME=t5-base
- DB_PATH=/app/data/database.db
- SCHEMA_PATH=/app/data/schemas.json
restart: unless-stopped
nginx:
image: nginx:alpine
ports:
- "80:80"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf
depends_on:
- text2sql
restart: unless-stopped
9.5.2 监控与日志
”`python class SystemMonitor: “”“系统监控”“”
def __init__(self, system: Text2SQLSystem):
self.system = system
self.metrics = {
'total_queries': 0,
'successful_queries': 0,
'failed_queries': 0,
'avg_response_time': 0.0,
'cache_hit_rate': 0.0
}
def log_query(self, question: str, result: Dict, response_time: float):
"""记录查询日志"""
self.metrics['total_queries'] += 1
if result['success']:
self.metrics['successful_queries'] += 1
else:
self.metrics['failed_queries'] += 1
# 更新平均响应时间
total_time = self.metrics['avg_response_time'] * (self.metrics['total_queries'] - 1)
self.metrics['avg_response_time'] = (total_time + response_time) / self.metrics['total_queries']
# 记录详细日志
logging.info(f"Query: {question}, Success: {result['success']}, Time: {response_time:.3f}s")
def get_metrics(self) -> Dict[str, Any]:
"""获取监控指标"""
success_rate = 0.0
if self.metrics['total_queries'] > 0:
success_rate = self.metrics['successful_queries'] / self.metrics['total_queries']
return {
**self.metrics,
'success_rate': success_rate,
'system_status': self.system.get_status()
}
9.6 本章总结
在本章中,我们构建了一个完整的Text2SQL系统,包括:
主要成果
系统架构设计
- 模块化的组件设计
- 清晰的接口定义
- 可扩展的架构模式
核心组件实现
- 查询处理器:自然语言预处理和意图分析
- 模型管理器:模型加载和SQL生成
- 数据库管理器:数据库连接和查询执行
- 缓存管理器:查询结果缓存
- 结果格式化器:结果展示和错误处理
接口设计
- Web API接口:支持HTTP请求
- 命令行接口:交互式查询
- 批处理接口:大量查询处理
部署方案
- Docker容器化部署
- 负载均衡配置
- 监控和日志系统
技术特点
- 高性能:缓存机制和批处理优化
- 高可用:错误处理和容错机制
- 可扩展:模块化设计和插件架构
- 易维护:完善的日志和监控
应用场景
- 企业数据查询系统
- 商业智能平台
- 数据分析工具
- 教育培训系统
通过本章的学习,你已经掌握了构建完整Text2SQL系统的方法,可以根据实际需求进行定制和扩展。”}}]}