8.1 评估指标体系

8.1.1 Text2SQL评估指标概述

Text2SQL模型的评估需要从多个维度进行,包括SQL语法正确性、语义准确性、执行结果正确性等。本节将详细介绍各种评估指标的定义、计算方法和适用场景。

import json
import sqlite3
import sqlparse
from typing import Dict, List, Tuple, Optional, Any
from collections import defaultdict
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import difflib

class Text2SQLEvaluator:
    """Text2SQL评估器"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.db_connections = {}
        
        # 评估配置
        self.case_sensitive = config.get('case_sensitive', False)
        self.order_sensitive = config.get('order_sensitive', False)
        self.value_sensitive = config.get('value_sensitive', True)
        
        # 初始化评估指标
        self.reset_metrics()
    
    def reset_metrics(self):
        """重置评估指标"""
        self.metrics = {
            'exact_match': [],
            'execution_accuracy': [],
            'component_match': defaultdict(list),
            'syntax_errors': [],
            'semantic_errors': [],
            'execution_errors': []
        }
    
    def evaluate_batch(self, predictions: List[str], 
                      ground_truths: List[str],
                      databases: List[str] = None) -> Dict[str, float]:
        """批量评估"""
        self.reset_metrics()
        
        for i, (pred, gt) in enumerate(zip(predictions, ground_truths)):
            db_path = databases[i] if databases else None
            
            # 精确匹配评估
            exact_match = self.exact_match_score(pred, gt)
            self.metrics['exact_match'].append(exact_match)
            
            # 执行准确性评估
            if db_path:
                exec_acc = self.execution_accuracy_score(pred, gt, db_path)
                self.metrics['execution_accuracy'].append(exec_acc)
            
            # 组件匹配评估
            component_scores = self.component_match_score(pred, gt)
            for component, score in component_scores.items():
                self.metrics['component_match'][component].append(score)
            
            # 错误分析
            self.analyze_errors(pred, gt, db_path)
        
        return self.compute_final_metrics()
    
    def exact_match_score(self, prediction: str, ground_truth: str) -> float:
        """精确匹配分数"""
        # 标准化SQL语句
        pred_normalized = self.normalize_sql(prediction)
        gt_normalized = self.normalize_sql(ground_truth)
        
        return float(pred_normalized == gt_normalized)
    
    def normalize_sql(self, sql: str) -> str:
        """标准化SQL语句"""
        try:
            # 解析SQL
            parsed = sqlparse.parse(sql)[0]
            
            # 格式化SQL
            formatted = sqlparse.format(
                str(parsed),
                reindent=True,
                keyword_case='upper',
                identifier_case='lower',
                strip_comments=True
            )
            
            # 移除多余空格
            normalized = ' '.join(formatted.split())
            
            if not self.case_sensitive:
                normalized = normalized.lower()
            
            return normalized
            
        except Exception as e:
            # 如果解析失败,返回原始字符串的清理版本
            cleaned = ' '.join(sql.split())
            return cleaned.lower() if not self.case_sensitive else cleaned
    
    def execution_accuracy_score(self, prediction: str, 
                               ground_truth: str, db_path: str) -> float:
        """执行准确性分数"""
        try:
            # 执行预测SQL
            pred_result = self.execute_sql(prediction, db_path)
            if pred_result is None:
                return 0.0
            
            # 执行真实SQL
            gt_result = self.execute_sql(ground_truth, db_path)
            if gt_result is None:
                return 0.0
            
            # 比较结果
            return self.compare_results(pred_result, gt_result)
            
        except Exception as e:
            print(f"执行准确性评估错误: {e}")
            return 0.0
    
    def execute_sql(self, sql: str, db_path: str) -> Optional[List[Tuple]]:
        """执行SQL查询"""
        try:
            # 获取数据库连接
            if db_path not in self.db_connections:
                self.db_connections[db_path] = sqlite3.connect(db_path)
            
            conn = self.db_connections[db_path]
            cursor = conn.cursor()
            
            # 执行查询
            cursor.execute(sql)
            result = cursor.fetchall()
            
            return result
            
        except Exception as e:
            print(f"SQL执行错误: {e}")
            return None
    
    def compare_results(self, result1: List[Tuple], 
                       result2: List[Tuple]) -> float:
        """比较查询结果"""
        # 转换为集合进行比较(如果不关心顺序)
        if not self.order_sensitive:
            set1 = set(result1)
            set2 = set(result2)
            return float(set1 == set2)
        else:
            return float(result1 == result2)
    
    def component_match_score(self, prediction: str, 
                            ground_truth: str) -> Dict[str, float]:
        """组件匹配分数"""
        try:
            pred_components = self.parse_sql_components(prediction)
            gt_components = self.parse_sql_components(ground_truth)
            
            scores = {}
            
            # 评估各个组件
            for component in ['select', 'from', 'where', 'group_by', 
                            'having', 'order_by', 'limit']:
                pred_comp = pred_components.get(component, [])
                gt_comp = gt_components.get(component, [])
                
                if not gt_comp and not pred_comp:
                    scores[component] = 1.0
                elif not gt_comp or not pred_comp:
                    scores[component] = 0.0
                else:
                    scores[component] = self.compare_components(pred_comp, gt_comp)
            
            return scores
            
        except Exception as e:
            print(f"组件匹配评估错误: {e}")
            return {comp: 0.0 for comp in ['select', 'from', 'where', 
                                         'group_by', 'having', 'order_by', 'limit']}
    
    def parse_sql_components(self, sql: str) -> Dict[str, List[str]]:
        """解析SQL组件"""
        components = {
            'select': [],
            'from': [],
            'where': [],
            'group_by': [],
            'having': [],
            'order_by': [],
            'limit': []
        }
        
        try:
            parsed = sqlparse.parse(sql)[0]
            
            # 简化的组件提取(实际实现需要更复杂的解析逻辑)
            sql_upper = sql.upper()
            
            # 提取SELECT子句
            if 'SELECT' in sql_upper:
                select_start = sql_upper.find('SELECT') + 6
                from_pos = sql_upper.find('FROM')
                if from_pos != -1:
                    select_clause = sql[select_start:from_pos].strip()
                    components['select'] = [item.strip() for item in select_clause.split(',')]
            
            # 提取FROM子句
            if 'FROM' in sql_upper:
                from_start = sql_upper.find('FROM') + 4
                where_pos = sql_upper.find('WHERE')
                group_pos = sql_upper.find('GROUP BY')
                order_pos = sql_upper.find('ORDER BY')
                
                end_pos = min([pos for pos in [where_pos, group_pos, order_pos, len(sql)] if pos != -1])
                from_clause = sql[from_start:end_pos].strip()
                components['from'] = [item.strip() for item in from_clause.split(',')]
            
            # 其他组件的提取逻辑...
            
        except Exception as e:
            print(f"SQL组件解析错误: {e}")
        
        return components
    
    def compare_components(self, comp1: List[str], comp2: List[str]) -> float:
        """比较SQL组件"""
        # 标准化组件
        norm_comp1 = [self.normalize_component(c) for c in comp1]
        norm_comp2 = [self.normalize_component(c) for c in comp2]
        
        # 计算Jaccard相似度
        set1 = set(norm_comp1)
        set2 = set(norm_comp2)
        
        intersection = len(set1.intersection(set2))
        union = len(set1.union(set2))
        
        return intersection / union if union > 0 else 0.0
    
    def normalize_component(self, component: str) -> str:
        """标准化组件"""
        normalized = component.strip().lower()
        # 移除多余空格
        normalized = ' '.join(normalized.split())
        return normalized
    
    def analyze_errors(self, prediction: str, ground_truth: str, 
                      db_path: str = None):
        """错误分析"""
        # 语法错误检查
        syntax_error = self.check_syntax_error(prediction)
        self.metrics['syntax_errors'].append(syntax_error)
        
        # 语义错误检查
        semantic_error = self.check_semantic_error(prediction, ground_truth)
        self.metrics['semantic_errors'].append(semantic_error)
        
        # 执行错误检查
        if db_path:
            execution_error = self.check_execution_error(prediction, db_path)
            self.metrics['execution_errors'].append(execution_error)
    
    def check_syntax_error(self, sql: str) -> bool:
        """检查语法错误"""
        try:
            sqlparse.parse(sql)
            return False
        except Exception:
            return True
    
    def check_semantic_error(self, prediction: str, ground_truth: str) -> bool:
        """检查语义错误"""
        # 简化的语义错误检查
        pred_keywords = set(self.extract_keywords(prediction))
        gt_keywords = set(self.extract_keywords(ground_truth))
        
        # 如果关键词差异过大,认为存在语义错误
        jaccard = len(pred_keywords.intersection(gt_keywords)) / len(pred_keywords.union(gt_keywords))
        return jaccard < 0.5
    
    def extract_keywords(self, sql: str) -> List[str]:
        """提取SQL关键词"""
        keywords = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'HAVING', 
                   'ORDER BY', 'LIMIT', 'JOIN', 'INNER', 'LEFT', 'RIGHT']
        
        sql_upper = sql.upper()
        found_keywords = [kw for kw in keywords if kw in sql_upper]
        return found_keywords
    
    def check_execution_error(self, sql: str, db_path: str) -> bool:
        """检查执行错误"""
        result = self.execute_sql(sql, db_path)
        return result is None
    
    def compute_final_metrics(self) -> Dict[str, float]:
        """计算最终指标"""
        final_metrics = {}
        
        # 精确匹配准确率
        if self.metrics['exact_match']:
            final_metrics['exact_match_accuracy'] = np.mean(self.metrics['exact_match'])
        
        # 执行准确率
        if self.metrics['execution_accuracy']:
            final_metrics['execution_accuracy'] = np.mean(self.metrics['execution_accuracy'])
        
        # 组件匹配准确率
        for component, scores in self.metrics['component_match'].items():
            if scores:
                final_metrics[f'{component}_accuracy'] = np.mean(scores)
        
        # 错误率
        if self.metrics['syntax_errors']:
            final_metrics['syntax_error_rate'] = np.mean(self.metrics['syntax_errors'])
        
        if self.metrics['semantic_errors']:
            final_metrics['semantic_error_rate'] = np.mean(self.metrics['semantic_errors'])
        
        if self.metrics['execution_errors']:
            final_metrics['execution_error_rate'] = np.mean(self.metrics['execution_errors'])
        
        return final_metrics
    
    def generate_detailed_report(self, predictions: List[str], 
                               ground_truths: List[str],
                               questions: List[str] = None) -> Dict:
        """生成详细评估报告"""
        report = {
            'summary': self.compute_final_metrics(),
            'detailed_results': [],
            'error_analysis': self.analyze_error_patterns(),
            'recommendations': self.generate_recommendations()
        }
        
        # 详细结果
        for i, (pred, gt) in enumerate(zip(predictions, ground_truths)):
            question = questions[i] if questions else f"Question {i+1}"
            
            result = {
                'question': question,
                'prediction': pred,
                'ground_truth': gt,
                'exact_match': self.exact_match_score(pred, gt),
                'component_scores': self.component_match_score(pred, gt),
                'errors': {
                    'syntax': self.check_syntax_error(pred),
                    'semantic': self.check_semantic_error(pred, gt)
                }
            }
            
            report['detailed_results'].append(result)
        
        return report
    
    def analyze_error_patterns(self) -> Dict:
        """分析错误模式"""
        patterns = {
            'common_syntax_errors': [],
            'common_semantic_errors': [],
            'component_error_distribution': {},
            'error_frequency': {}
        }
        
        # 组件错误分布
        for component, scores in self.metrics['component_match'].items():
            if scores:
                error_rate = 1 - np.mean(scores)
                patterns['component_error_distribution'][component] = error_rate
        
        return patterns
    
    def generate_recommendations(self) -> List[str]:
        """生成改进建议"""
        recommendations = []
        
        # 基于错误率生成建议
        if self.metrics['syntax_errors'] and np.mean(self.metrics['syntax_errors']) > 0.1:
            recommendations.append("建议加强SQL语法训练,提高语法正确性")
        
        if self.metrics['semantic_errors'] and np.mean(self.metrics['semantic_errors']) > 0.2:
            recommendations.append("建议改进语义理解模块,提高语义准确性")
        
        # 基于组件错误分布生成建议
        for component, scores in self.metrics['component_match'].items():
            if scores and np.mean(scores) < 0.7:
                recommendations.append(f"建议重点优化{component}组件的生成")
        
        return recommendations

# 使用示例
evaluator_config = {
    'case_sensitive': False,
    'order_sensitive': False,
    'value_sensitive': True
}

evaluator = Text2SQLEvaluator(evaluator_config)
print("Text2SQL评估器初始化完成")

8.1.2 Spider评估指标

Spider是Text2SQL领域最权威的评估基准之一,本节实现Spider标准的评估指标。

class SpiderEvaluator:
    """Spider标准评估器"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.kmaps = {}  # 关键词映射
    
    def load_spider_data(self, data_path: str, table_path: str):
        """加载Spider数据"""
        with open(data_path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)
        
        with open(table_path, 'r', encoding='utf-8') as f:
            self.tables = json.load(f)
        
        # 构建数据库schema映射
        self.db_schemas = {}
        for table in self.tables:
            db_id = table['db_id']
            self.db_schemas[db_id] = table
    
    def evaluate_spider(self, predictions: List[str], 
                       gold_sqls: List[str],
                       db_ids: List[str]) -> Dict[str, float]:
        """Spider标准评估"""
        exact_match_scores = []
        execution_scores = []
        
        for pred, gold, db_id in zip(predictions, gold_sqls, db_ids):
            # 精确匹配评估
            em_score = self.spider_exact_match(pred, gold, db_id)
            exact_match_scores.append(em_score)
            
            # 执行匹配评估
            exec_score = self.spider_execution_match(pred, gold, db_id)
            execution_scores.append(exec_score)
        
        return {
            'exact_match': np.mean(exact_match_scores),
            'execution_accuracy': np.mean(execution_scores)
        }
    
    def spider_exact_match(self, prediction: str, gold: str, db_id: str) -> float:
        """Spider精确匹配"""
        try:
            # 获取数据库schema
            schema = self.db_schemas.get(db_id, {})
            
            # 标准化SQL
            pred_normalized = self.spider_normalize_sql(prediction, schema)
            gold_normalized = self.spider_normalize_sql(gold, schema)
            
            return float(pred_normalized == gold_normalized)
            
        except Exception as e:
            print(f"Spider精确匹配错误: {e}")
            return 0.0
    
    def spider_normalize_sql(self, sql: str, schema: Dict) -> str:
        """Spider SQL标准化"""
        # 实现Spider标准的SQL标准化逻辑
        # 这里简化实现,实际需要更复杂的处理
        
        # 转换为小写
        normalized = sql.lower().strip()
        
        # 移除多余空格
        normalized = ' '.join(normalized.split())
        
        # 标准化表名和列名
        if schema:
            normalized = self.normalize_identifiers(normalized, schema)
        
        return normalized
    
    def normalize_identifiers(self, sql: str, schema: Dict) -> str:
        """标准化标识符"""
        # 获取表名和列名映射
        table_names = [table['table_name'].lower() for table in schema.get('table_names', [])]
        column_names = [col[1].lower() for col in schema.get('column_names', [])]
        
        # 替换标识符(简化实现)
        normalized = sql
        for table_name in table_names:
            # 处理表名标准化
            pass
        
        for column_name in column_names:
            # 处理列名标准化
            pass
        
        return normalized
    
    def spider_execution_match(self, prediction: str, gold: str, db_id: str) -> float:
        """Spider执行匹配"""
        try:
            # 构建数据库路径
            db_path = f"./spider/database/{db_id}/{db_id}.sqlite"
            
            # 执行SQL
            pred_result = self.execute_sql_safe(prediction, db_path)
            gold_result = self.execute_sql_safe(gold, db_path)
            
            if pred_result is None or gold_result is None:
                return 0.0
            
            # 比较结果
            return self.compare_execution_results(pred_result, gold_result)
            
        except Exception as e:
            print(f"Spider执行匹配错误: {e}")
            return 0.0
    
    def execute_sql_safe(self, sql: str, db_path: str) -> Optional[List]:
        """安全执行SQL"""
        try:
            conn = sqlite3.connect(db_path)
            cursor = conn.cursor()
            
            # 设置超时
            cursor.execute("PRAGMA timeout = 5000")
            
            # 执行查询
            cursor.execute(sql)
            result = cursor.fetchall()
            
            conn.close()
            return result
            
        except Exception as e:
            print(f"SQL执行错误: {e}")
            return None
    
    def compare_execution_results(self, result1: List, result2: List) -> float:
        """比较执行结果"""
        # 转换为集合比较(忽略顺序)
        try:
            set1 = set(tuple(row) if isinstance(row, (list, tuple)) else (row,) for row in result1)
            set2 = set(tuple(row) if isinstance(row, (list, tuple)) else (row,) for row in result2)
            
            return float(set1 == set2)
            
        except Exception:
            # 如果转换失败,直接比较
            return float(result1 == result2)
    
    def evaluate_by_difficulty(self, predictions: List[str], 
                             gold_sqls: List[str],
                             db_ids: List[str],
                             difficulties: List[str]) -> Dict[str, Dict[str, float]]:
        """按难度级别评估"""
        difficulty_results = defaultdict(lambda: {'exact_match': [], 'execution': []})
        
        for pred, gold, db_id, difficulty in zip(predictions, gold_sqls, db_ids, difficulties):
            em_score = self.spider_exact_match(pred, gold, db_id)
            exec_score = self.spider_execution_match(pred, gold, db_id)
            
            difficulty_results[difficulty]['exact_match'].append(em_score)
            difficulty_results[difficulty]['execution'].append(exec_score)
        
        # 计算平均分数
        final_results = {}
        for difficulty, scores in difficulty_results.items():
            final_results[difficulty] = {
                'exact_match': np.mean(scores['exact_match']),
                'execution_accuracy': np.mean(scores['execution'])
            }
        
        return final_results

# 使用示例
spider_evaluator = SpiderEvaluator({})
print("Spider评估器初始化完成")

8.2 测试框架设计

8.2.1 单元测试框架

import unittest
from unittest.mock import Mock, patch
import tempfile
import os

class Text2SQLTestCase(unittest.TestCase):
    """Text2SQL测试基类"""
    
    def setUp(self):
        """测试设置"""
        self.test_config = {
            'model_name': 'test_model',
            'max_length': 512,
            'batch_size': 2
        }
        
        # 创建临时目录
        self.temp_dir = tempfile.mkdtemp()
        
        # 模拟数据
        self.sample_questions = [
            "Show all students",
            "Count the number of courses",
            "Find students with GPA > 3.5"
        ]
        
        self.sample_sqls = [
            "SELECT * FROM students",
            "SELECT COUNT(*) FROM courses",
            "SELECT * FROM students WHERE gpa > 3.5"
        ]
        
        self.sample_schema = {
            'table_names': ['students', 'courses'],
            'column_names': [
                ['students', 'id'],
                ['students', 'name'],
                ['students', 'gpa'],
                ['courses', 'id'],
                ['courses', 'name']
            ]
        }
    
    def tearDown(self):
        """测试清理"""
        # 清理临时文件
        import shutil
        shutil.rmtree(self.temp_dir, ignore_errors=True)
    
    def create_mock_model(self):
        """创建模拟模型"""
        mock_model = Mock()
        mock_model.predict.return_value = self.sample_sqls
        return mock_model
    
    def create_test_database(self) -> str:
        """创建测试数据库"""
        db_path = os.path.join(self.temp_dir, 'test.db')
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        
        # 创建测试表
        cursor.execute("""
            CREATE TABLE students (
                id INTEGER PRIMARY KEY,
                name TEXT,
                gpa REAL
            )
        """)
        
        cursor.execute("""
            CREATE TABLE courses (
                id INTEGER PRIMARY KEY,
                name TEXT
            )
        """)
        
        # 插入测试数据
        cursor.execute("INSERT INTO students VALUES (1, 'Alice', 3.8)")
        cursor.execute("INSERT INTO students VALUES (2, 'Bob', 3.2)")
        cursor.execute("INSERT INTO courses VALUES (1, 'Math')")
        cursor.execute("INSERT INTO courses VALUES (2, 'Physics')")
        
        conn.commit()
        conn.close()
        
        return db_path

class TestText2SQLEvaluator(Text2SQLTestCase):
    """Text2SQL评估器测试"""
    
    def setUp(self):
        super().setUp()
        self.evaluator = Text2SQLEvaluator({
            'case_sensitive': False,
            'order_sensitive': False
        })
    
    def test_exact_match_score(self):
        """测试精确匹配分数"""
        # 完全匹配
        score1 = self.evaluator.exact_match_score(
            "SELECT * FROM students",
            "SELECT * FROM students"
        )
        self.assertEqual(score1, 1.0)
        
        # 不匹配
        score2 = self.evaluator.exact_match_score(
            "SELECT * FROM students",
            "SELECT * FROM courses"
        )
        self.assertEqual(score2, 0.0)
        
        # 格式不同但语义相同
        score3 = self.evaluator.exact_match_score(
            "select * from students",
            "SELECT * FROM students"
        )
        self.assertEqual(score3, 1.0)
    
    def test_execution_accuracy_score(self):
        """测试执行准确性分数"""
        db_path = self.create_test_database()
        
        # 相同结果
        score1 = self.evaluator.execution_accuracy_score(
            "SELECT * FROM students",
            "SELECT * FROM students",
            db_path
        )
        self.assertEqual(score1, 1.0)
        
        # 不同结果
        score2 = self.evaluator.execution_accuracy_score(
            "SELECT * FROM students",
            "SELECT * FROM courses",
            db_path
        )
        self.assertEqual(score2, 0.0)
    
    def test_component_match_score(self):
        """测试组件匹配分数"""
        scores = self.evaluator.component_match_score(
            "SELECT name FROM students WHERE gpa > 3.5",
            "SELECT name FROM students WHERE gpa > 3.0"
        )
        
        # SELECT和FROM应该完全匹配
        self.assertEqual(scores['select'], 1.0)
        self.assertEqual(scores['from'], 1.0)
        
        # WHERE条件不同,应该部分匹配
        self.assertGreater(scores['where'], 0.0)
        self.assertLess(scores['where'], 1.0)
    
    def test_batch_evaluation(self):
        """测试批量评估"""
        predictions = [
            "SELECT * FROM students",
            "SELECT COUNT(*) FROM courses",
            "SELECT * FROM students WHERE gpa > 3.5"
        ]
        
        ground_truths = [
            "SELECT * FROM students",
            "SELECT COUNT(*) FROM courses",
            "SELECT * FROM students WHERE gpa > 3.0"
        ]
        
        metrics = self.evaluator.evaluate_batch(predictions, ground_truths)
        
        # 检查指标存在
        self.assertIn('exact_match_accuracy', metrics)
        self.assertIn('select_accuracy', metrics)
        self.assertIn('from_accuracy', metrics)
        
        # 检查指标范围
        for metric_name, value in metrics.items():
            self.assertGreaterEqual(value, 0.0)
            self.assertLessEqual(value, 1.0)

class TestModelComponents(Text2SQLTestCase):
    """模型组件测试"""
    
    def test_tokenizer(self):
        """测试分词器"""
        # 这里应该测试实际的分词器组件
        pass
    
    def test_encoder(self):
        """测试编码器"""
        # 这里应该测试实际的编码器组件
        pass
    
    def test_decoder(self):
        """测试解码器"""
        # 这里应该测试实际的解码器组件
        pass

class TestDataProcessing(Text2SQLTestCase):
    """数据处理测试"""
    
    def test_data_loading(self):
        """测试数据加载"""
        # 创建测试数据文件
        test_data = {
            'questions': self.sample_questions,
            'sqls': self.sample_sqls
        }
        
        test_file = os.path.join(self.temp_dir, 'test_data.json')
        with open(test_file, 'w') as f:
            json.dump(test_data, f)
        
        # 测试数据加载逻辑
        # 这里应该调用实际的数据加载函数
        pass
    
    def test_preprocessing(self):
        """测试预处理"""
        # 测试文本预处理逻辑
        pass
    
    def test_schema_processing(self):
        """测试schema处理"""
        # 测试schema处理逻辑
        pass

# 测试套件
def create_test_suite():
    """创建测试套件"""
    suite = unittest.TestSuite()
    
    # 添加测试用例
    suite.addTest(unittest.makeSuite(TestText2SQLEvaluator))
    suite.addTest(unittest.makeSuite(TestModelComponents))
    suite.addTest(unittest.makeSuite(TestDataProcessing))
    
    return suite

# 运行测试
if __name__ == '__main__':
    # 创建测试套件
    suite = create_test_suite()
    
    # 运行测试
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(suite)
    
    # 输出测试结果
    print(f"\n测试运行完成:")
    print(f"运行测试数: {result.testsRun}")
    print(f"失败数: {len(result.failures)}")
    print(f"错误数: {len(result.errors)}")
    
    if result.failures:
        print("\n失败的测试:")
        for test, traceback in result.failures:
            print(f"- {test}: {traceback}")
    
    if result.errors:
        print("\n错误的测试:")
        for test, traceback in result.errors:
            print(f"- {test}: {traceback}")

8.2.2 集成测试框架

class IntegrationTestFramework:
    """集成测试框架"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.test_results = []
        
        # 测试环境配置
        self.test_env = {
            'data_dir': config.get('test_data_dir', './test_data'),
            'model_dir': config.get('test_model_dir', './test_models'),
            'output_dir': config.get('test_output_dir', './test_outputs')
        }
        
        # 创建测试目录
        for dir_path in self.test_env.values():
            os.makedirs(dir_path, exist_ok=True)
    
    def setup_test_environment(self):
        """设置测试环境"""
        print("设置集成测试环境...")
        
        # 准备测试数据
        self.prepare_test_data()
        
        # 初始化测试模型
        self.initialize_test_models()
        
        # 配置测试数据库
        self.setup_test_databases()
        
        print("测试环境设置完成")
    
    def prepare_test_data(self):
        """准备测试数据"""
        # 创建小规模测试数据集
        test_data = {
            'train': {
                'questions': [
                    "Show all students",
                    "Count the number of courses",
                    "Find students with high GPA"
                ],
                'sqls': [
                    "SELECT * FROM students",
                    "SELECT COUNT(*) FROM courses",
                    "SELECT * FROM students WHERE gpa > 3.5"
                ],
                'db_ids': ['university', 'university', 'university']
            },
            'dev': {
                'questions': [
                    "List all course names",
                    "Find the average GPA"
                ],
                'sqls': [
                    "SELECT name FROM courses",
                    "SELECT AVG(gpa) FROM students"
                ],
                'db_ids': ['university', 'university']
            }
        }
        
        # 保存测试数据
        for split, data in test_data.items():
            file_path = os.path.join(self.test_env['data_dir'], f'{split}.json')
            with open(file_path, 'w') as f:
                json.dump(data, f, indent=2)
    
    def initialize_test_models(self):
        """初始化测试模型"""
        # 这里应该初始化实际的模型
        # 为了演示,我们创建一个模拟模型
        self.test_model = self.create_mock_model()
    
    def create_mock_model(self):
        """创建模拟模型"""
        class MockText2SQLModel:
            def __init__(self):
                self.is_trained = False
            
            def train(self, train_data):
                print("模拟训练过程...")
                self.is_trained = True
                return {'loss': 0.5, 'accuracy': 0.8}
            
            def predict(self, questions, schemas=None):
                # 返回模拟预测结果
                predictions = []
                for question in questions:
                    if 'count' in question.lower():
                        predictions.append("SELECT COUNT(*) FROM table")
                    elif 'all' in question.lower():
                        predictions.append("SELECT * FROM table")
                    else:
                        predictions.append("SELECT column FROM table WHERE condition")
                return predictions
            
            def evaluate(self, test_data):
                return {'exact_match': 0.7, 'execution_accuracy': 0.6}
        
        return MockText2SQLModel()
    
    def setup_test_databases(self):
        """设置测试数据库"""
        # 创建测试数据库
        db_path = os.path.join(self.test_env['data_dir'], 'university.db')
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        
        # 创建表结构
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS students (
                id INTEGER PRIMARY KEY,
                name TEXT,
                gpa REAL,
                major TEXT
            )
        """)
        
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS courses (
                id INTEGER PRIMARY KEY,
                name TEXT,
                credits INTEGER
            )
        """)
        
        # 插入测试数据
        students_data = [
            (1, 'Alice', 3.8, 'Computer Science'),
            (2, 'Bob', 3.2, 'Mathematics'),
            (3, 'Charlie', 3.9, 'Physics')
        ]
        
        courses_data = [
            (1, 'Database Systems', 3),
            (2, 'Machine Learning', 4),
            (3, 'Linear Algebra', 3)
        ]
        
        cursor.executemany("INSERT OR REPLACE INTO students VALUES (?, ?, ?, ?)", students_data)
        cursor.executemany("INSERT OR REPLACE INTO courses VALUES (?, ?, ?)", courses_data)
        
        conn.commit()
        conn.close()
    
    def run_end_to_end_test(self) -> Dict:
        """运行端到端测试"""
        print("开始端到端测试...")
        
        test_result = {
            'test_name': 'end_to_end',
            'start_time': time.time(),
            'stages': {},
            'overall_success': True
        }
        
        try:
            # 阶段1: 数据加载测试
            print("阶段1: 数据加载测试")
            data_loading_result = self.test_data_loading()
            test_result['stages']['data_loading'] = data_loading_result
            
            # 阶段2: 模型训练测试
            print("阶段2: 模型训练测试")
            training_result = self.test_model_training()
            test_result['stages']['training'] = training_result
            
            # 阶段3: 模型推理测试
            print("阶段3: 模型推理测试")
            inference_result = self.test_model_inference()
            test_result['stages']['inference'] = inference_result
            
            # 阶段4: 评估测试
            print("阶段4: 评估测试")
            evaluation_result = self.test_model_evaluation()
            test_result['stages']['evaluation'] = evaluation_result
            
        except Exception as e:
            print(f"端到端测试失败: {e}")
            test_result['overall_success'] = False
            test_result['error'] = str(e)
        
        test_result['end_time'] = time.time()
        test_result['duration'] = test_result['end_time'] - test_result['start_time']
        
        self.test_results.append(test_result)
        return test_result
    
    def test_data_loading(self) -> Dict:
        """测试数据加载"""
        try:
            # 加载训练数据
            train_file = os.path.join(self.test_env['data_dir'], 'train.json')
            with open(train_file, 'r') as f:
                train_data = json.load(f)
            
            # 验证数据格式
            assert 'questions' in train_data
            assert 'sqls' in train_data
            assert len(train_data['questions']) == len(train_data['sqls'])
            
            return {
                'success': True,
                'data_size': len(train_data['questions']),
                'message': '数据加载成功'
            }
            
        except Exception as e:
            return {
                'success': False,
                'error': str(e),
                'message': '数据加载失败'
            }
    
    def test_model_training(self) -> Dict:
        """测试模型训练"""
        try:
            # 加载训练数据
            train_file = os.path.join(self.test_env['data_dir'], 'train.json')
            with open(train_file, 'r') as f:
                train_data = json.load(f)
            
            # 训练模型
            training_metrics = self.test_model.train(train_data)
            
            return {
                'success': True,
                'metrics': training_metrics,
                'message': '模型训练成功'
            }
            
        except Exception as e:
            return {
                'success': False,
                'error': str(e),
                'message': '模型训练失败'
            }
    
    def test_model_inference(self) -> Dict:
        """测试模型推理"""
        try:
            # 准备测试问题
            test_questions = [
                "Show all students",
                "Count the courses"
            ]
            
            # 进行推理
            predictions = self.test_model.predict(test_questions)
            
            # 验证预测结果
            assert len(predictions) == len(test_questions)
            assert all(isinstance(pred, str) for pred in predictions)
            
            return {
                'success': True,
                'predictions': predictions,
                'message': '模型推理成功'
            }
            
        except Exception as e:
            return {
                'success': False,
                'error': str(e),
                'message': '模型推理失败'
            }
    
    def test_model_evaluation(self) -> Dict:
        """测试模型评估"""
        try:
            # 加载开发集数据
            dev_file = os.path.join(self.test_env['data_dir'], 'dev.json')
            with open(dev_file, 'r') as f:
                dev_data = json.load(f)
            
            # 评估模型
            evaluation_metrics = self.test_model.evaluate(dev_data)
            
            # 验证评估指标
            assert 'exact_match' in evaluation_metrics
            assert 'execution_accuracy' in evaluation_metrics
            
            return {
                'success': True,
                'metrics': evaluation_metrics,
                'message': '模型评估成功'
            }
            
        except Exception as e:
            return {
                'success': False,
                'error': str(e),
                'message': '模型评估失败'
            }
    
    def run_performance_test(self) -> Dict:
        """运行性能测试"""
        print("开始性能测试...")
        
        performance_result = {
            'test_name': 'performance',
            'start_time': time.time(),
            'metrics': {}
        }
        
        try:
            # 测试推理速度
            inference_speed = self.test_inference_speed()
            performance_result['metrics']['inference_speed'] = inference_speed
            
            # 测试内存使用
            memory_usage = self.test_memory_usage()
            performance_result['metrics']['memory_usage'] = memory_usage
            
            # 测试并发处理
            concurrent_performance = self.test_concurrent_processing()
            performance_result['metrics']['concurrent_performance'] = concurrent_performance
            
        except Exception as e:
            performance_result['error'] = str(e)
        
        performance_result['end_time'] = time.time()
        performance_result['duration'] = performance_result['end_time'] - performance_result['start_time']
        
        self.test_results.append(performance_result)
        return performance_result
    
    def test_inference_speed(self) -> Dict:
        """测试推理速度"""
        questions = ["Show all students"] * 100
        
        start_time = time.time()
        predictions = self.test_model.predict(questions)
        end_time = time.time()
        
        total_time = end_time - start_time
        qps = len(questions) / total_time
        
        return {
            'total_questions': len(questions),
            'total_time': total_time,
            'questions_per_second': qps
        }
    
    def test_memory_usage(self) -> Dict:
        """测试内存使用"""
        import psutil
        
        process = psutil.Process()
        
        # 获取初始内存使用
        initial_memory = process.memory_info().rss / 1024 / 1024  # MB
        
        # 执行推理
        questions = ["Show all students"] * 50
        predictions = self.test_model.predict(questions)
        
        # 获取推理后内存使用
        final_memory = process.memory_info().rss / 1024 / 1024  # MB
        
        return {
            'initial_memory_mb': initial_memory,
            'final_memory_mb': final_memory,
            'memory_increase_mb': final_memory - initial_memory
        }
    
    def test_concurrent_processing(self) -> Dict:
        """测试并发处理"""
        import threading
        import queue
        
        def worker(question_queue, result_queue):
            while True:
                try:
                    question = question_queue.get(timeout=1)
                    prediction = self.test_model.predict([question])[0]
                    result_queue.put(prediction)
                    question_queue.task_done()
                except queue.Empty:
                    break
        
        # 准备测试数据
        questions = ["Show all students", "Count courses"] * 10
        question_queue = queue.Queue()
        result_queue = queue.Queue()
        
        for question in questions:
            question_queue.put(question)
        
        # 启动多个线程
        num_threads = 4
        threads = []
        
        start_time = time.time()
        
        for _ in range(num_threads):
            thread = threading.Thread(target=worker, args=(question_queue, result_queue))
            thread.start()
            threads.append(thread)
        
        # 等待所有任务完成
        question_queue.join()
        
        # 等待所有线程结束
        for thread in threads:
            thread.join()
        
        end_time = time.time()
        
        # 收集结果
        results = []
        while not result_queue.empty():
            results.append(result_queue.get())
        
        return {
            'num_threads': num_threads,
            'total_questions': len(questions),
            'total_time': end_time - start_time,
            'concurrent_qps': len(questions) / (end_time - start_time),
            'success_rate': len(results) / len(questions)
        }
    
    def generate_test_report(self) -> str:
        """生成测试报告"""
        report = ["# Text2SQL 集成测试报告\n"]
        
        # 测试概览
        report.append("## 测试概览\n")
        report.append(f"总测试数: {len(self.test_results)}\n")
        
        successful_tests = sum(1 for result in self.test_results 
                             if result.get('overall_success', True))
        report.append(f"成功测试数: {successful_tests}\n")
        report.append(f"失败测试数: {len(self.test_results) - successful_tests}\n\n")
        
        # 详细测试结果
        for i, result in enumerate(self.test_results, 1):
            report.append(f"## 测试 {i}: {result['test_name']}\n")
            report.append(f"开始时间: {time.ctime(result['start_time'])}\n")
            report.append(f"持续时间: {result['duration']:.2f}秒\n")
            
            if 'stages' in result:
                report.append("### 测试阶段\n")
                for stage_name, stage_result in result['stages'].items():
                    status = "✓" if stage_result['success'] else "✗"
                    report.append(f"- {stage_name}: {status} {stage_result['message']}\n")
            
            if 'metrics' in result:
                report.append("### 性能指标\n")
                for metric_name, metric_value in result['metrics'].items():
                    report.append(f"- {metric_name}: {metric_value}\n")
            
            if 'error' in result:
                report.append(f"### 错误信息\n{result['error']}\n")
            
            report.append("\n")
        
        return ''.join(report)
    
    def cleanup_test_environment(self):
        """清理测试环境"""
        print("清理测试环境...")
        
        # 关闭数据库连接
        for conn in getattr(self, 'db_connections', {}).values():
            if conn:
                conn.close()
        
        # 清理临时文件(可选)
        # shutil.rmtree(self.test_env['output_dir'], ignore_errors=True)
        
        print("测试环境清理完成")

# 使用示例
if __name__ == "__main__":
    # 配置集成测试
    test_config = {
        'test_data_dir': './integration_test_data',
        'test_model_dir': './integration_test_models',
        'test_output_dir': './integration_test_outputs'
    }
    
    # 创建测试框架
    test_framework = IntegrationTestFramework(test_config)
    
    try:
        # 设置测试环境
        test_framework.setup_test_environment()
        
        # 运行端到端测试
        e2e_result = test_framework.run_end_to_end_test()
        print(f"端到端测试结果: {'成功' if e2e_result['overall_success'] else '失败'}")
        
        # 运行性能测试
        perf_result = test_framework.run_performance_test()
        print(f"性能测试完成,耗时: {perf_result['duration']:.2f}秒")
        
        # 生成测试报告
        report = test_framework.generate_test_report()
        
        # 保存报告
        report_path = os.path.join(test_config['test_output_dir'], 'test_report.md')
        with open(report_path, 'w', encoding='utf-8') as f:
            f.write(report)
        
        print(f"测试报告已保存到: {report_path}")
        
    finally:
        # 清理测试环境
        test_framework.cleanup_test_environment()