9.1 项目概述与架构设计

9.1.1 项目目标

本章将指导你构建一个完整的Text2SQL系统,该系统具备以下功能:

  • 自然语言查询理解:解析用户的自然语言问题
  • SQL生成:将自然语言转换为可执行的SQL查询
  • 查询执行:在数据库中执行生成的SQL
  • 结果展示:以用户友好的方式展示查询结果
  • 交互式界面:提供Web界面供用户使用
  • 性能监控:监控系统性能和查询质量

9.1.2 系统架构

import os
import json
import logging
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from abc import ABC, abstractmethod
from pathlib import Path

@dataclass
class SystemConfig:
    """系统配置"""
    # 模型配置
    model_name: str = "text2sql-t5-base"
    model_path: str = "./models/text2sql"
    device: str = "cuda"
    max_length: int = 512
    
    # 数据库配置
    db_type: str = "sqlite"
    db_path: str = "./data/database.db"
    schema_path: str = "./data/schema.json"
    
    # API配置
    api_host: str = "0.0.0.0"
    api_port: int = 8000
    debug: bool = False
    
    # 日志配置
    log_level: str = "INFO"
    log_file: str = "./logs/text2sql.log"
    
    # 缓存配置
    enable_cache: bool = True
    cache_size: int = 1000
    cache_ttl: int = 3600

class Text2SQLSystem:
    """Text2SQL系统主类"""
    
    def __init__(self, config: SystemConfig):
        self.config = config
        self.logger = self._setup_logging()
        
        # 初始化组件
        self.query_processor = None
        self.model_manager = None
        self.database_manager = None
        self.cache_manager = None
        self.result_formatter = None
        
        self._initialize_components()
    
    def _setup_logging(self) -> logging.Logger:
        """设置日志"""
        # 创建日志目录
        log_dir = Path(self.config.log_file).parent
        log_dir.mkdir(parents=True, exist_ok=True)
        
        # 配置日志
        logging.basicConfig(
            level=getattr(logging, self.config.log_level),
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(self.config.log_file),
                logging.StreamHandler()
            ]
        )
        
        return logging.getLogger(__name__)
    
    def _initialize_components(self):
        """初始化系统组件"""
        self.logger.info("初始化Text2SQL系统组件...")
        
        try:
            # 初始化查询处理器
            self.query_processor = QueryProcessor(self.config)
            
            # 初始化模型管理器
            self.model_manager = ModelManager(self.config)
            
            # 初始化数据库管理器
            self.database_manager = DatabaseManager(self.config)
            
            # 初始化缓存管理器
            if self.config.enable_cache:
                self.cache_manager = CacheManager(self.config)
            
            # 初始化结果格式化器
            self.result_formatter = ResultFormatter(self.config)
            
            self.logger.info("系统组件初始化完成")
            
        except Exception as e:
            self.logger.error(f"系统初始化失败: {e}")
            raise
    
    def process_query(self, question: str, db_id: str = None) -> Dict[str, Any]:
        """处理查询请求"""
        self.logger.info(f"处理查询: {question}")
        
        try:
            # 1. 预处理查询
            processed_query = self.query_processor.preprocess(question)
            
            # 2. 检查缓存
            cache_key = f"{processed_query}_{db_id}"
            if self.cache_manager:
                cached_result = self.cache_manager.get(cache_key)
                if cached_result:
                    self.logger.info("返回缓存结果")
                    return cached_result
            
            # 3. 生成SQL
            sql_result = self.model_manager.generate_sql(
                processed_query, db_id
            )
            
            # 4. 执行SQL
            execution_result = self.database_manager.execute_sql(
                sql_result['sql'], db_id
            )
            
            # 5. 格式化结果
            formatted_result = self.result_formatter.format_result(
                question, sql_result, execution_result
            )
            
            # 6. 缓存结果
            if self.cache_manager:
                self.cache_manager.set(cache_key, formatted_result)
            
            self.logger.info("查询处理完成")
            return formatted_result
            
        except Exception as e:
            self.logger.error(f"查询处理失败: {e}")
            return {
                'success': False,
                'error': str(e),
                'question': question
            }
    
    def get_system_status(self) -> Dict[str, Any]:
        """获取系统状态"""
        return {
            'model_status': self.model_manager.get_status(),
            'database_status': self.database_manager.get_status(),
            'cache_status': self.cache_manager.get_status() if self.cache_manager else None,
            'system_config': {
                'model_name': self.config.model_name,
                'db_type': self.config.db_type,
                'cache_enabled': self.config.enable_cache
            }
        }

print("Text2SQL系统架构定义完成")

9.2 核心组件实现

9.2.1 查询处理器

import re
import string
from typing import Dict, List, Tuple
import spacy
from transformers import AutoTokenizer

class QueryProcessor:
    """查询处理器"""
    
    def __init__(self, config: SystemConfig):
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        
        # 加载NLP模型
        try:
            self.nlp = spacy.load("en_core_web_sm")
        except OSError:
            self.nlp = None
            print("警告: spaCy英文模型未安装,某些功能可能受限")
        
        # 预定义的查询模式
        self.query_patterns = {
            'aggregation': r'\b(count|sum|avg|average|max|maximum|min|minimum|total)\b',
            'comparison': r'\b(greater|less|more|fewer|higher|lower|above|below|between)\b',
            'temporal': r'\b(year|month|day|date|time|before|after|during|since)\b',
            'sorting': r'\b(sort|order|rank|top|bottom|first|last|highest|lowest)\b'
        }
    
    def preprocess(self, question: str) -> str:
        """预处理查询"""
        # 1. 基础清理
        processed = self._basic_cleaning(question)
        
        # 2. 实体识别和标准化
        processed = self._normalize_entities(processed)
        
        # 3. 查询意图分析
        intent = self._analyze_intent(processed)
        
        # 4. 添加上下文信息
        processed = self._add_context(processed, intent)
        
        return processed
    
    def _basic_cleaning(self, text: str) -> str:
        """基础文本清理"""
        # 移除多余空白
        text = re.sub(r'\s+', ' ', text.strip())
        
        # 标准化标点符号
        text = text.replace(''', "'").replace('"', '"')
        text = text.replace("'", "'")
        
        # 移除不必要的标点
        text = text.strip(string.punctuation)
        
        return text
    
    def _normalize_entities(self, text: str) -> str:
        """实体标准化"""
        if not self.nlp:
            return text
        
        doc = self.nlp(text)
        
        # 标准化数字
        text = re.sub(r'\b(one|two|three|four|five|six|seven|eight|nine|ten)\b', 
                     lambda m: str({'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5,
                                   'six': 6, 'seven': 7, 'eight': 8, 'nine': 9, 'ten': 10}[m.group()]), 
                     text, flags=re.IGNORECASE)
        
        # 标准化日期表达式
        text = re.sub(r'\b(last year|previous year)\b', 'year before', text, flags=re.IGNORECASE)
        text = re.sub(r'\b(this year|current year)\b', 'current year', text, flags=re.IGNORECASE)
        
        return text
    
    def _analyze_intent(self, text: str) -> Dict[str, bool]:
        """分析查询意图"""
        intent = {}
        
        for pattern_name, pattern in self.query_patterns.items():
            intent[pattern_name] = bool(re.search(pattern, text, re.IGNORECASE))
        
        return intent
    
    def _add_context(self, text: str, intent: Dict[str, bool]) -> str:
        """添加上下文信息"""
        # 根据意图添加提示词
        if intent.get('aggregation'):
            text = f"[AGGREGATION] {text}"
        
        if intent.get('comparison'):
            text = f"[COMPARISON] {text}"
        
        if intent.get('temporal'):
            text = f"[TEMPORAL] {text}"
        
        if intent.get('sorting'):
            text = f"[SORTING] {text}"
        
        return text
    
    def extract_keywords(self, text: str) -> List[str]:
        """提取关键词"""
        if not self.nlp:
            # 简单的关键词提取
            words = text.lower().split()
            stopwords = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'}
            return [word for word in words if word not in stopwords and len(word) > 2]
        
        doc = self.nlp(text)
        keywords = []
        
        for token in doc:
            if (not token.is_stop and 
                not token.is_punct and 
                not token.is_space and 
                len(token.text) > 2):
                keywords.append(token.lemma_.lower())
        
        return keywords

class ModelManager:
    """模型管理器"""
    
    def __init__(self, config: SystemConfig):
        self.config = config
        self.model = None
        self.tokenizer = None
        self.device = config.device
        
        self._load_model()
    
    def _load_model(self):
        """加载模型"""
        try:
            from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
            import torch
            
            # 加载tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
            
            # 加载模型
            self.model = AutoModelForSeq2SeqLM.from_pretrained(self.config.model_name)
            
            # 移动到指定设备
            if torch.cuda.is_available() and self.device == 'cuda':
                self.model = self.model.cuda()
            
            self.model.eval()
            
            print(f"模型 {self.config.model_name} 加载成功")
            
        except Exception as e:
            print(f"模型加载失败: {e}")
            raise
    
    def generate_sql(self, question: str, db_id: str = None) -> Dict[str, Any]:
        """生成SQL"""
        try:
            import torch
            
            # 准备输入
            if db_id:
                input_text = f"question: {question} database: {db_id}"
            else:
                input_text = f"question: {question}"
            
            # 编码输入
            inputs = self.tokenizer(
                input_text,
                max_length=self.config.max_length,
                padding=True,
                truncation=True,
                return_tensors="pt"
            )
            
            # 移动到设备
            if self.device == 'cuda' and torch.cuda.is_available():
                inputs = {k: v.cuda() for k, v in inputs.items()}
            
            # 生成SQL
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_length=self.config.max_length,
                    num_beams=4,
                    early_stopping=True,
                    do_sample=False
                )
            
            # 解码输出
            generated_sql = self.tokenizer.decode(
                outputs[0], 
                skip_special_tokens=True
            )
            
            # 后处理SQL
            processed_sql = self._postprocess_sql(generated_sql)
            
            return {
                'sql': processed_sql,
                'raw_sql': generated_sql,
                'confidence': self._calculate_confidence(outputs),
                'input_text': input_text
            }
            
        except Exception as e:
            return {
                'sql': None,
                'error': str(e),
                'input_text': input_text if 'input_text' in locals() else question
            }
    
    def _postprocess_sql(self, sql: str) -> str:
        """后处理SQL"""
        # 移除多余空白
        sql = re.sub(r'\s+', ' ', sql.strip())
        
        # 确保SQL以分号结尾
        if not sql.endswith(';'):
            sql += ';'
        
        # 标准化关键字大小写
        keywords = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'HAVING', 'ORDER BY', 
                   'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN', 'OUTER JOIN',
                   'AND', 'OR', 'NOT', 'IN', 'LIKE', 'BETWEEN', 'IS', 'NULL',
                   'COUNT', 'SUM', 'AVG', 'MAX', 'MIN', 'DISTINCT']
        
        for keyword in keywords:
            pattern = r'\b' + re.escape(keyword.lower()) + r'\b'
            sql = re.sub(pattern, keyword, sql, flags=re.IGNORECASE)
        
        return sql
    
    def _calculate_confidence(self, outputs) -> float:
        """计算置信度"""
        # 简单的置信度计算,实际应用中可以使用更复杂的方法
        return 0.8  # 占位符
    
    def get_status(self) -> Dict[str, Any]:
        """获取模型状态"""
        return {
            'model_loaded': self.model is not None,
            'model_name': self.config.model_name,
            'device': self.device,
            'tokenizer_vocab_size': len(self.tokenizer) if self.tokenizer else 0
        }

class DatabaseManager:
    """数据库管理器"""
    
    def __init__(self, config: SystemConfig):
        self.config = config
        self.connections = {}
        self.schemas = {}
        
        self._load_schemas()
    
    def _load_schemas(self):
        """加载数据库Schema"""
        try:
            if os.path.exists(self.config.schema_path):
                with open(self.config.schema_path, 'r', encoding='utf-8') as f:
                    self.schemas = json.load(f)
                print(f"加载了 {len(self.schemas)} 个数据库Schema")
        except Exception as e:
            print(f"Schema加载失败: {e}")
    
    def get_connection(self, db_id: str = None):
        """获取数据库连接"""
        if db_id is None:
            db_id = 'default'
        
        if db_id not in self.connections:
            if self.config.db_type == 'sqlite':
                import sqlite3
                db_path = self.config.db_path
                if db_id != 'default':
                    db_path = f"./data/{db_id}.db"
                
                self.connections[db_id] = sqlite3.connect(db_path)
            else:
                raise ValueError(f"不支持的数据库类型: {self.config.db_type}")
        
        return self.connections[db_id]
    
    def execute_sql(self, sql: str, db_id: str = None) -> Dict[str, Any]:
        """执行SQL查询"""
        try:
            conn = self.get_connection(db_id)
            cursor = conn.cursor()
            
            # 执行查询
            cursor.execute(sql)
            
            # 获取结果
            if sql.strip().upper().startswith('SELECT'):
                columns = [desc[0] for desc in cursor.description]
                rows = cursor.fetchall()
                
                return {
                    'success': True,
                    'columns': columns,
                    'rows': rows,
                    'row_count': len(rows)
                }
            else:
                # 非查询语句
                conn.commit()
                return {
                    'success': True,
                    'affected_rows': cursor.rowcount
                }
                
        except Exception as e:
            return {
                'success': False,
                'error': str(e),
                'sql': sql
            }
    
    def get_schema(self, db_id: str) -> Dict[str, Any]:
        """获取数据库Schema"""
        return self.schemas.get(db_id, {})
    
    def get_status(self) -> Dict[str, Any]:
        """获取数据库状态"""
        return {
            'db_type': self.config.db_type,
            'active_connections': len(self.connections),
            'available_schemas': list(self.schemas.keys())
        }

class CacheManager:
    """缓存管理器"""
    
    def __init__(self, config: SystemConfig):
        self.config = config
        self.cache = {}
        self.access_times = {}
        self.max_size = config.cache_size
        self.ttl = config.cache_ttl
    
    def get(self, key: str) -> Optional[Any]:
        """获取缓存"""
        import time
        
        if key in self.cache:
            # 检查TTL
            if time.time() - self.access_times[key] < self.ttl:
                self.access_times[key] = time.time()
                return self.cache[key]
            else:
                # 过期,删除
                del self.cache[key]
                del self.access_times[key]
        
        return None
    
    def set(self, key: str, value: Any):
        """设置缓存"""
        import time
        
        # 检查缓存大小
        if len(self.cache) >= self.max_size:
            self._evict_oldest()
        
        self.cache[key] = value
        self.access_times[key] = time.time()
    
    def _evict_oldest(self):
        """淘汰最旧的缓存项"""
        if self.access_times:
            oldest_key = min(self.access_times.keys(), 
                           key=lambda k: self.access_times[k])
            del self.cache[oldest_key]
            del self.access_times[oldest_key]
    
    def get_status(self) -> Dict[str, Any]:
        """获取缓存状态"""
        return {
            'cache_size': len(self.cache),
            'max_size': self.max_size,
            'hit_rate': self._calculate_hit_rate()
        }
    
    def _calculate_hit_rate(self) -> float:
        """计算缓存命中率"""
        # 简化实现,实际应用中需要记录命中统计
        return 0.75  # 占位符

class ResultFormatter:
    """结果格式化器"""
    
    def __init__(self, config: SystemConfig):
        self.config = config
    
    def format_result(self, question: str, sql_result: Dict, 
                     execution_result: Dict) -> Dict[str, Any]:
        """格式化查询结果"""
        formatted = {
            'question': question,
            'timestamp': self._get_timestamp(),
            'sql': sql_result.get('sql'),
            'success': execution_result.get('success', False)
        }
        
        if execution_result.get('success'):
            # 成功执行
            formatted.update({
                'columns': execution_result.get('columns', []),
                'rows': execution_result.get('rows', []),
                'row_count': execution_result.get('row_count', 0),
                'formatted_table': self._format_table(
                    execution_result.get('columns', []),
                    execution_result.get('rows', [])
                )
            })
        else:
            # 执行失败
            formatted.update({
                'error': execution_result.get('error'),
                'error_type': self._classify_error(execution_result.get('error', ''))
            })
        
        # 添加SQL生成信息
        formatted['sql_generation'] = {
            'confidence': sql_result.get('confidence'),
            'raw_sql': sql_result.get('raw_sql')
        }
        
        return formatted
    
    def _get_timestamp(self) -> str:
        """获取时间戳"""
        from datetime import datetime
        return datetime.now().isoformat()
    
    def _format_table(self, columns: List[str], rows: List[Tuple]) -> str:
        """格式化表格"""
        if not columns or not rows:
            return "No data found."
        
        # 计算列宽
        col_widths = [len(col) for col in columns]
        for row in rows:
            for i, cell in enumerate(row):
                col_widths[i] = max(col_widths[i], len(str(cell)))
        
        # 构建表格
        table_lines = []
        
        # 表头
        header = " | ".join(col.ljust(width) for col, width in zip(columns, col_widths))
        table_lines.append(header)
        table_lines.append("-" * len(header))
        
        # 数据行
        for row in rows:
            row_str = " | ".join(str(cell).ljust(width) for cell, width in zip(row, col_widths))
            table_lines.append(row_str)
        
        return "\n".join(table_lines)
    
    def _classify_error(self, error_msg: str) -> str:
        """分类错误类型"""
        error_msg = error_msg.lower()
        
        if 'syntax error' in error_msg:
            return 'syntax_error'
        elif 'no such table' in error_msg:
            return 'table_not_found'
        elif 'no such column' in error_msg:
            return 'column_not_found'
        elif 'ambiguous column' in error_msg:
            return 'ambiguous_column'
        else:
            return 'unknown_error'

print("核心组件实现完成")

## 9.3 系统集成与接口设计

### 9.3.1 Web接口实现

```python
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
import logging
from typing import Dict, Any

class Text2SQLWebAPI:
    """Text2SQL Web API"""
    
    def __init__(self, system: Text2SQLSystem):
        self.system = system
        self.app = Flask(__name__)
        CORS(self.app)
        
        # 配置日志
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)
        
        self._setup_routes()
    
    def _setup_routes(self):
        """设置路由"""
        
        @self.app.route('/', methods=['GET'])
        def index():
            """主页"""
            return render_template('index.html')
        
        @self.app.route('/api/query', methods=['POST'])
        def query():
            """处理查询请求"""
            try:
                data = request.get_json()
                
                if not data or 'question' not in data:
                    return jsonify({
                        'success': False,
                        'error': 'Missing question parameter'
                    }), 400
                
                question = data['question']
                db_id = data.get('db_id')
                
                # 处理查询
                result = self.system.process_query(question, db_id)
                
                return jsonify({
                    'success': True,
                    'result': result
                })
                
            except Exception as e:
                self.logger.error(f"查询处理错误: {e}")
                return jsonify({
                    'success': False,
                    'error': str(e)
                }), 500
        
        @self.app.route('/api/status', methods=['GET'])
        def status():
            """获取系统状态"""
            try:
                status = self.system.get_status()
                return jsonify({
                    'success': True,
                    'status': status
                })
            except Exception as e:
                return jsonify({
                    'success': False,
                    'error': str(e)
                }), 500
        
        @self.app.route('/api/schemas', methods=['GET'])
        def get_schemas():
            """获取可用的数据库Schema"""
            try:
                schemas = self.system.db_manager.get_status()['available_schemas']
                return jsonify({
                    'success': True,
                    'schemas': schemas
                })
            except Exception as e:
                return jsonify({
                    'success': False,
                    'error': str(e)
                }), 500
        
        @self.app.route('/api/schema/<db_id>', methods=['GET'])
        def get_schema(db_id):
            """获取特定数据库的Schema"""
            try:
                schema = self.system.db_manager.get_schema(db_id)
                return jsonify({
                    'success': True,
                    'schema': schema
                })
            except Exception as e:
                return jsonify({
                    'success': False,
                    'error': str(e)
                }), 500
    
    def run(self, host='0.0.0.0', port=5000, debug=False):
        """启动Web服务"""
        self.logger.info(f"启动Text2SQL Web API服务: http://{host}:{port}")
        self.app.run(host=host, port=port, debug=debug)

class CommandLineInterface:
    """命令行接口"""
    
    def __init__(self, system: Text2SQLSystem):
        self.system = system
    
    def run(self):
        """运行命令行界面"""
        print("=" * 50)
        print("Text2SQL 系统命令行界面")
        print("=" * 50)
        print("输入 'help' 查看帮助信息")
        print("输入 'quit' 退出系统")
        print()
        
        while True:
            try:
                user_input = input("请输入自然语言查询: ").strip()
                
                if user_input.lower() == 'quit':
                    print("再见!")
                    break
                elif user_input.lower() == 'help':
                    self._show_help()
                    continue
                elif user_input.lower() == 'status':
                    self._show_status()
                    continue
                elif user_input.lower().startswith('use '):
                    db_id = user_input[4:].strip()
                    self._set_database(db_id)
                    continue
                elif not user_input:
                    continue
                
                # 处理查询
                result = self.system.process_query(user_input)
                self._display_result(result)
                
            except KeyboardInterrupt:
                print("\n再见!")
                break
            except Exception as e:
                print(f"错误: {e}")
    
    def _show_help(self):
        """显示帮助信息"""
        help_text = """
可用命令:
  help     - 显示此帮助信息
  status   - 显示系统状态
  use <db> - 切换数据库
  quit     - 退出系统

示例查询:
  "显示所有用户"
  "查找年龄大于25的用户"
  "统计每个部门的员工数量"
        """
        print(help_text)
    
    def _show_status(self):
        """显示系统状态"""
        status = self.system.get_status()
        print("\n系统状态:")
        for key, value in status.items():
            print(f"  {key}: {value}")
        print()
    
    def _set_database(self, db_id: str):
        """设置数据库"""
        try:
            schema = self.system.db_manager.get_schema(db_id)
            if schema:
                print(f"已切换到数据库: {db_id}")
            else:
                print(f"数据库 {db_id} 不存在")
        except Exception as e:
            print(f"切换数据库失败: {e}")
    
    def _display_result(self, result: Dict[str, Any]):
        """显示查询结果"""
        print("\n" + "="*50)
        print(f"问题: {result['question']}")
        print(f"生成的SQL: {result['sql']}")
        
        if result['success']:
            print(f"执行成功,返回 {result['row_count']} 行数据")
            if result.get('formatted_table'):
                print("\n查询结果:")
                print(result['formatted_table'])
        else:
            print(f"执行失败: {result.get('error', '未知错误')}")
        
        print("="*50 + "\n")

## 9.4 完整使用示例

### 9.4.1 基本使用示例

```python
def basic_usage_example():
    """基本使用示例"""
    
    # 1. 创建配置
    config = SystemConfig(
        model_name="t5-small",  # 使用小模型进行演示
        db_path="./data/example.db",
        schema_path="./data/schemas.json"
    )
    
    # 2. 初始化系统
    system = Text2SQLSystem(config)
    
    # 3. 处理查询
    questions = [
        "显示所有用户的姓名和邮箱",
        "查找年龄大于25岁的用户",
        "统计每个部门的员工数量",
        "找出薪资最高的前5名员工"
    ]
    
    for question in questions:
        print(f"\n处理问题: {question}")
        result = system.process_query(question)
        
        if result['success']:
            print(f"生成SQL: {result['sql']}")
            print(f"返回 {result['row_count']} 行数据")
        else:
            print(f"处理失败: {result.get('error')}")

def web_api_example():
    """Web API使用示例"""
    
    # 创建系统
    config = SystemConfig()
    system = Text2SQLSystem(config)
    
    # 创建Web API
    api = Text2SQLWebAPI(system)
    
    # 启动服务
    api.run(host='localhost', port=5000, debug=True)

def cli_example():
    """命令行界面示例"""
    
    # 创建系统
    config = SystemConfig()
    system = Text2SQLSystem(config)
    
    # 创建命令行界面
    cli = CommandLineInterface(system)
    
    # 运行
    cli.run()

### 9.4.2 高级使用示例

```python
class AdvancedText2SQLDemo:
    """高级Text2SQL演示"""
    
    def __init__(self):
        self.config = SystemConfig(
            model_name="facebook/bart-large",
            device="cuda" if torch.cuda.is_available() else "cpu",
            cache_size=1000,
            cache_ttl=3600
        )
        self.system = Text2SQLSystem(self.config)
    
    def batch_processing_demo(self):
        """批处理演示"""
        questions = [
            "显示所有产品的名称和价格",
            "查找库存少于10的产品",
            "统计每个类别的产品数量",
            "找出销量最好的产品",
            "显示最近一个月的订单"
        ]
        
        results = []
        for question in questions:
            result = self.system.process_query(question)
            results.append(result)
        
        # 生成报告
        self._generate_report(results)
    
    def _generate_report(self, results: List[Dict]):
        """生成处理报告"""
        print("\n" + "="*60)
        print("批处理结果报告")
        print("="*60)
        
        success_count = sum(1 for r in results if r['success'])
        total_count = len(results)
        
        print(f"总查询数: {total_count}")
        print(f"成功数: {success_count}")
        print(f"成功率: {success_count/total_count*100:.1f}%")
        
        print("\n详细结果:")
        for i, result in enumerate(results, 1):
            status = "✓" if result['success'] else "✗"
            print(f"{i}. {status} {result['question']}")
            if result['success']:
                print(f"   SQL: {result['sql']}")
                print(f"   结果: {result['row_count']} 行")
            else:
                print(f"   错误: {result.get('error')}")
    
    def performance_test(self):
        """性能测试"""
        import time
        
        test_questions = [
            "显示所有用户",
            "查找活跃用户",
            "统计用户数量"
        ] * 10  # 重复测试
        
        start_time = time.time()
        
        for question in test_questions:
            self.system.process_query(question)
        
        end_time = time.time()
        
        total_time = end_time - start_time
        avg_time = total_time / len(test_questions)
        
        print(f"\n性能测试结果:")
        print(f"总查询数: {len(test_questions)}")
        print(f"总耗时: {total_time:.2f}秒")
        print(f"平均耗时: {avg_time:.3f}秒/查询")
        print(f"QPS: {len(test_questions)/total_time:.1f}")
    
    def error_handling_demo(self):
        """错误处理演示"""
        error_questions = [
            "显示不存在的表",
            "查询语法错误的问题",
            "使用不存在的列名",
            "复杂的嵌套查询"
        ]
        
        print("\n错误处理演示:")
        for question in error_questions:
            result = self.system.process_query(question)
            
            print(f"\n问题: {question}")
            if result['success']:
                print("意外成功")
            else:
                error_type = result.get('error_type', 'unknown')
                print(f"错误类型: {error_type}")
                print(f"错误信息: {result.get('error')}")

## 9.5 部署与运维

### 9.5.1 Docker部署

```dockerfile
# Dockerfile
FROM python:3.9-slim

WORKDIR /app

# 安装系统依赖
RUN apt-get update && apt-get install -y \
    gcc \
    g++ \
    && rm -rf /var/lib/apt/lists/*

# 复制依赖文件
COPY requirements.txt .

# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt

# 复制应用代码
COPY . .

# 暴露端口
EXPOSE 5000

# 启动命令
CMD ["python", "app.py"]
# docker-compose.yml
version: '3.8'

services:
  text2sql:
    build: .
    ports:
      - "5000:5000"
    volumes:
      - ./data:/app/data
      - ./models:/app/models
    environment:
      - MODEL_NAME=t5-base
      - DB_PATH=/app/data/database.db
      - SCHEMA_PATH=/app/data/schemas.json
    restart: unless-stopped
  
  nginx:
    image: nginx:alpine
    ports:
      - "80:80"
    volumes:
      - ./nginx.conf:/etc/nginx/nginx.conf
    depends_on:
      - text2sql
    restart: unless-stopped

9.5.2 监控与日志

”`python class SystemMonitor: “”“系统监控”“”

def __init__(self, system: Text2SQLSystem):
    self.system = system
    self.metrics = {
        'total_queries': 0,
        'successful_queries': 0,
        'failed_queries': 0,
        'avg_response_time': 0.0,
        'cache_hit_rate': 0.0
    }

def log_query(self, question: str, result: Dict, response_time: float):
    """记录查询日志"""
    self.metrics['total_queries'] += 1

    if result['success']:
        self.metrics['successful_queries'] += 1
    else:
        self.metrics['failed_queries'] += 1

    # 更新平均响应时间
    total_time = self.metrics['avg_response_time'] * (self.metrics['total_queries'] - 1)
    self.metrics['avg_response_time'] = (total_time + response_time) / self.metrics['total_queries']

    # 记录详细日志
    logging.info(f"Query: {question}, Success: {result['success']}, Time: {response_time:.3f}s")

def get_metrics(self) -> Dict[str, Any]:
    """获取监控指标"""
    success_rate = 0.0
    if self.metrics['total_queries'] > 0:
        success_rate = self.metrics['successful_queries'] / self.metrics['total_queries']

    return {
        **self.metrics,
        'success_rate': success_rate,
        'system_status': self.system.get_status()
    }

9.6 本章总结

在本章中,我们构建了一个完整的Text2SQL系统,包括:

主要成果

  1. 系统架构设计

    • 模块化的组件设计
    • 清晰的接口定义
    • 可扩展的架构模式
  2. 核心组件实现

    • 查询处理器:自然语言预处理和意图分析
    • 模型管理器:模型加载和SQL生成
    • 数据库管理器:数据库连接和查询执行
    • 缓存管理器:查询结果缓存
    • 结果格式化器:结果展示和错误处理
  3. 接口设计

    • Web API接口:支持HTTP请求
    • 命令行接口:交互式查询
    • 批处理接口:大量查询处理
  4. 部署方案

    • Docker容器化部署
    • 负载均衡配置
    • 监控和日志系统

技术特点

  • 高性能:缓存机制和批处理优化
  • 高可用:错误处理和容错机制
  • 可扩展:模块化设计和插件架构
  • 易维护:完善的日志和监控

应用场景

  • 企业数据查询系统
  • 商业智能平台
  • 数据分析工具
  • 教育培训系统

通过本章的学习,你已经掌握了构建完整Text2SQL系统的方法,可以根据实际需求进行定制和扩展。”}}]}