1. 数据预处理概述

数据预处理是Text2SQL系统中至关重要的一步,它直接影响模型的训练效果和最终性能。本章将详细介绍Text2SQL任务中的数据预处理技术和特征工程方法。

1.1 Text2SQL数据特点

数据组成

  • 自然语言查询:用户的问题或需求描述
  • 数据库Schema:表结构、列信息、数据类型等
  • SQL语句:对应的标准SQL查询
  • 执行结果:SQL执行后的结果(可选)

数据挑战

  • 语言多样性:同一查询意图可能有多种表达方式
  • Schema复杂性:数据库结构可能非常复杂
  • SQL复杂性:从简单查询到复杂的多表连接
  • 数据不平衡:不同类型查询的分布不均

1.2 预处理流程

import re
import json
import pandas as pd
from typing import List, Dict, Tuple
from collections import Counter
import sqlparse
from sqlparse.sql import IdentifierList, Identifier
from sqlparse.tokens import Keyword, DML

class Text2SQLPreprocessor:
    """Text2SQL数据预处理器"""
    
    def __init__(self):
        self.vocab = {}
        self.schema_vocab = {}
        self.sql_vocab = {}
        self.special_tokens = {
            '<PAD>': 0,
            '<UNK>': 1,
            '<SOS>': 2,
            '<EOS>': 3,
            '<TABLE>': 4,
            '<COLUMN>': 5,
            '<VALUE>': 6
        }
        
    def preprocess_dataset(self, dataset_path: str) -> Dict:
        """预处理整个数据集"""
        # 加载原始数据
        raw_data = self.load_dataset(dataset_path)
        
        # 预处理各个组件
        processed_data = {
            'questions': [],
            'schemas': [],
            'sqls': [],
            'metadata': {}
        }
        
        for sample in raw_data:
            # 预处理问题
            processed_question = self.preprocess_question(sample['question'])
            processed_data['questions'].append(processed_question)
            
            # 预处理Schema
            processed_schema = self.preprocess_schema(sample['db_schema'])
            processed_data['schemas'].append(processed_schema)
            
            # 预处理SQL
            processed_sql = self.preprocess_sql(sample['sql'])
            processed_data['sqls'].append(processed_sql)
        
        # 构建词汇表
        self.build_vocabularies(processed_data)
        
        # 转换为数值表示
        numerical_data = self.convert_to_numerical(processed_data)
        
        return numerical_data
    
    def load_dataset(self, dataset_path: str) -> List[Dict]:
        """加载数据集"""
        with open(dataset_path, 'r', encoding='utf-8') as f:
            if dataset_path.endswith('.json'):
                return json.load(f)
            elif dataset_path.endswith('.jsonl'):
                return [json.loads(line) for line in f]
        
        raise ValueError("Unsupported file format")

# 使用示例
preprocessor = Text2SQLPreprocessor()
print("Text2SQL预处理器初始化完成")

2. 自然语言查询预处理

2.1 文本清洗

import string
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import PorterStemmer, WordNetLemmatizer

class QuestionPreprocessor:
    """问题预处理器"""
    
    def __init__(self, language='english'):
        self.language = language
        self.stemmer = PorterStemmer()
        self.lemmatizer = WordNetLemmatizer()
        
        # 下载必要的NLTK数据
        try:
            self.stop_words = set(stopwords.words(language))
        except:
            nltk.download('stopwords')
            nltk.download('punkt')
            nltk.download('wordnet')
            self.stop_words = set(stopwords.words(language))
    
    def clean_text(self, text: str) -> str:
        """基础文本清洗"""
        # 转换为小写
        text = text.lower()
        
        # 移除多余空格
        text = re.sub(r'\s+', ' ', text)
        
        # 移除特殊字符(保留一些有意义的符号)
        text = re.sub(r'[^\w\s\-\'".,?!><=]', '', text)
        
        # 处理数字
        text = self.normalize_numbers(text)
        
        return text.strip()
    
    def normalize_numbers(self, text: str) -> str:
        """数字标准化"""
        # 将数字词转换为数字
        number_words = {
            'zero': '0', 'one': '1', 'two': '2', 'three': '3', 'four': '4',
            'five': '5', 'six': '6', 'seven': '7', 'eight': '8', 'nine': '9',
            'ten': '10', 'eleven': '11', 'twelve': '12', 'thirteen': '13',
            'fourteen': '14', 'fifteen': '15', 'sixteen': '16', 'seventeen': '17',
            'eighteen': '18', 'nineteen': '19', 'twenty': '20', 'thirty': '30',
            'forty': '40', 'fifty': '50', 'sixty': '60', 'seventy': '70',
            'eighty': '80', 'ninety': '90', 'hundred': '100', 'thousand': '1000'
        }
        
        for word, num in number_words.items():
            text = re.sub(r'\b' + word + r'\b', num, text)
        
        # 标准化数字格式
        text = re.sub(r'\b\d+\b', '<NUM>', text)
        
        return text
    
    def tokenize(self, text: str) -> List[str]:
        """分词"""
        tokens = word_tokenize(text)
        return tokens
    
    def remove_stopwords(self, tokens: List[str]) -> List[str]:
        """移除停用词"""
        return [token for token in tokens if token.lower() not in self.stop_words]
    
    def stem_tokens(self, tokens: List[str]) -> List[str]:
        """词干提取"""
        return [self.stemmer.stem(token) for token in tokens]
    
    def lemmatize_tokens(self, tokens: List[str]) -> List[str]:
        """词形还原"""
        return [self.lemmatizer.lemmatize(token) for token in tokens]
    
    def preprocess_question(self, question: str, 
                          remove_stopwords: bool = False,
                          use_stemming: bool = False,
                          use_lemmatization: bool = True) -> Dict:
        """完整的问题预处理"""
        # 原始问题
        original = question
        
        # 清洗
        cleaned = self.clean_text(question)
        
        # 分词
        tokens = self.tokenize(cleaned)
        
        # 可选的预处理步骤
        if remove_stopwords:
            tokens = self.remove_stopwords(tokens)
        
        if use_stemming:
            tokens = self.stem_tokens(tokens)
        elif use_lemmatization:
            tokens = self.lemmatize_tokens(tokens)
        
        # 实体识别
        entities = self.extract_entities(tokens)
        
        return {
            'original': original,
            'cleaned': cleaned,
            'tokens': tokens,
            'entities': entities,
            'length': len(tokens)
        }
    
    def extract_entities(self, tokens: List[str]) -> Dict:
        """提取实体"""
        entities = {
            'numbers': [],
            'dates': [],
            'names': [],
            'keywords': []
        }
        
        # SQL关键词
        sql_keywords = {
            'select', 'from', 'where', 'group', 'order', 'by', 'having',
            'count', 'sum', 'avg', 'max', 'min', 'distinct', 'limit',
            'and', 'or', 'not', 'in', 'like', 'between', 'is', 'null'
        }
        
        for token in tokens:
            # 数字
            if token == '<NUM>' or re.match(r'^\d+$', token):
                entities['numbers'].append(token)
            
            # 日期模式
            elif re.match(r'^\d{4}-\d{2}-\d{2}$', token):
                entities['dates'].append(token)
            
            # SQL关键词
            elif token.lower() in sql_keywords:
                entities['keywords'].append(token.lower())
            
            # 可能的名称(首字母大写)
            elif token[0].isupper() and len(token) > 1:
                entities['names'].append(token)
        
        return entities

# 使用示例
question_processor = QuestionPreprocessor()

# 示例问题
questions = [
    "Show me all employees with salary greater than fifty thousand",
    "What is the average age of students in Computer Science department?",
    "List the top 10 products by sales in 2023"
]

for question in questions:
    processed = question_processor.preprocess_question(question)
    print(f"原始: {processed['original']}")
    print(f"清洗: {processed['cleaned']}")
    print(f"分词: {processed['tokens']}")
    print(f"实体: {processed['entities']}")
    print("-" * 50)

2.2 实体链接与识别

class EntityLinker:
    """实体链接器"""
    
    def __init__(self, schema_info: Dict):
        self.schema_info = schema_info
        self.table_names = set(schema_info.get('tables', []))
        self.column_names = set()
        self.value_patterns = {}
        
        # 构建列名集合
        for table, columns in schema_info.get('columns', {}).items():
            self.column_names.update(columns)
        
        # 构建值模式
        self.build_value_patterns()
    
    def build_value_patterns(self):
        """构建值模式"""
        # 常见的值模式
        self.value_patterns = {
            'year': r'\b(19|20)\d{2}\b',
            'month': r'\b(january|february|march|april|may|june|july|august|september|october|november|december|jan|feb|mar|apr|may|jun|jul|aug|sep|oct|nov|dec)\b',
            'number': r'\b\d+(\.\d+)?\b',
            'percentage': r'\b\d+(\.\d+)?%\b',
            'currency': r'\$\d+(\.\d+)?\b'
        }
    
    def link_entities(self, tokens: List[str], schema: Dict) -> Dict:
        """链接实体到Schema"""
        linked_entities = {
            'tables': [],
            'columns': [],
            'values': [],
            'operators': []
        }
        
        text = ' '.join(tokens).lower()
        
        # 链接表名
        for table in schema.get('tables', []):
            if self.fuzzy_match(table.lower(), text):
                linked_entities['tables'].append({
                    'name': table,
                    'confidence': self.calculate_confidence(table, text)
                })
        
        # 链接列名
        for table, columns in schema.get('columns', {}).items():
            for column in columns:
                if self.fuzzy_match(column.lower(), text):
                    linked_entities['columns'].append({
                        'name': column,
                        'table': table,
                        'confidence': self.calculate_confidence(column, text)
                    })
        
        # 识别操作符
        operators = self.extract_operators(text)
        linked_entities['operators'] = operators
        
        # 识别值
        values = self.extract_values(text)
        linked_entities['values'] = values
        
        return linked_entities
    
    def fuzzy_match(self, entity: str, text: str, threshold: float = 0.8) -> bool:
        """模糊匹配"""
        # 简单的包含匹配
        if entity in text:
            return True
        
        # 部分匹配
        entity_words = entity.split('_')
        for word in entity_words:
            if len(word) > 2 and word in text:
                return True
        
        return False
    
    def calculate_confidence(self, entity: str, text: str) -> float:
        """计算置信度"""
        if entity in text:
            return 1.0
        
        # 基于编辑距离的置信度
        from difflib import SequenceMatcher
        matcher = SequenceMatcher(None, entity, text)
        return matcher.ratio()
    
    def extract_operators(self, text: str) -> List[Dict]:
        """提取操作符"""
        operators = []
        
        operator_patterns = {
            'greater_than': r'\b(greater than|more than|above|over|>)\b',
            'less_than': r'\b(less than|below|under|<)\b',
            'equal': r'\b(equal to|equals|is|=)\b',
            'not_equal': r'\b(not equal|not|!=|<>)\b',
            'between': r'\bbetween\b',
            'in': r'\bin\b',
            'like': r'\b(like|contains|includes)\b'
        }
        
        for op_type, pattern in operator_patterns.items():
            matches = re.finditer(pattern, text, re.IGNORECASE)
            for match in matches:
                operators.append({
                    'type': op_type,
                    'text': match.group(),
                    'position': match.span()
                })
        
        return operators
    
    def extract_values(self, text: str) -> List[Dict]:
        """提取值"""
        values = []
        
        for value_type, pattern in self.value_patterns.items():
            matches = re.finditer(pattern, text, re.IGNORECASE)
            for match in matches:
                values.append({
                    'type': value_type,
                    'value': match.group(),
                    'position': match.span()
                })
        
        return values

# 使用示例
schema_example = {
    'tables': ['employees', 'departments', 'projects'],
    'columns': {
        'employees': ['id', 'name', 'salary', 'age', 'department_id'],
        'departments': ['id', 'name', 'budget'],
        'projects': ['id', 'title', 'budget', 'start_date']
    }
}

entity_linker = EntityLinker(schema_example)

# 测试实体链接
test_question = "show employees with salary greater than 50000"
tokens = test_question.split()
linked = entity_linker.link_entities(tokens, schema_example)

print("实体链接结果:")
for entity_type, entities in linked.items():
    if entities:
        print(f"{entity_type}: {entities}")

3. Schema预处理

3.1 Schema标准化

class SchemaPreprocessor:
    """Schema预处理器"""
    
    def __init__(self):
        self.type_mapping = {
            'varchar': 'text',
            'char': 'text',
            'text': 'text',
            'int': 'number',
            'integer': 'number',
            'bigint': 'number',
            'float': 'number',
            'double': 'number',
            'decimal': 'number',
            'date': 'date',
            'datetime': 'datetime',
            'timestamp': 'datetime',
            'boolean': 'boolean',
            'bool': 'boolean'
        }
    
    def normalize_schema(self, raw_schema: Dict) -> Dict:
        """标准化Schema"""
        normalized = {
            'tables': {},
            'relationships': [],
            'constraints': {},
            'metadata': {}
        }
        
        # 处理表信息
        for table_info in raw_schema.get('tables', []):
            table_name = table_info['name'].lower()
            
            normalized['tables'][table_name] = {
                'columns': {},
                'primary_keys': [],
                'foreign_keys': []
            }
            
            # 处理列信息
            for column_info in table_info.get('columns', []):
                column_name = column_info['name'].lower()
                column_type = self.normalize_type(column_info.get('type', 'text'))
                
                normalized['tables'][table_name]['columns'][column_name] = {
                    'type': column_type,
                    'nullable': column_info.get('nullable', True),
                    'default': column_info.get('default'),
                    'description': column_info.get('description', '')
                }
                
                # 处理主键
                if column_info.get('primary_key', False):
                    normalized['tables'][table_name]['primary_keys'].append(column_name)
                
                # 处理外键
                if 'foreign_key' in column_info:
                    fk_info = column_info['foreign_key']
                    normalized['tables'][table_name]['foreign_keys'].append({
                        'column': column_name,
                        'references_table': fk_info['table'].lower(),
                        'references_column': fk_info['column'].lower()
                    })
        
        # 构建关系图
        normalized['relationships'] = self.build_relationships(normalized['tables'])
        
        return normalized
    
    def normalize_type(self, data_type: str) -> str:
        """标准化数据类型"""
        data_type = data_type.lower().strip()
        
        # 移除长度限制
        data_type = re.sub(r'\(\d+\)', '', data_type)
        
        return self.type_mapping.get(data_type, 'text')
    
    def build_relationships(self, tables: Dict) -> List[Dict]:
        """构建表关系"""
        relationships = []
        
        for table_name, table_info in tables.items():
            for fk in table_info['foreign_keys']:
                relationships.append({
                    'from_table': table_name,
                    'from_column': fk['column'],
                    'to_table': fk['references_table'],
                    'to_column': fk['references_column'],
                    'type': 'foreign_key'
                })
        
        return relationships
    
    def generate_schema_graph(self, normalized_schema: Dict) -> Dict:
        """生成Schema图表示"""
        nodes = []
        edges = []
        
        node_id = 0
        node_mapping = {}
        
        # 添加表节点
        for table_name in normalized_schema['tables']:
            nodes.append({
                'id': node_id,
                'type': 'table',
                'name': table_name,
                'label': table_name
            })
            node_mapping[f"table_{table_name}"] = node_id
            node_id += 1
        
        # 添加列节点
        for table_name, table_info in normalized_schema['tables'].items():
            table_node_id = node_mapping[f"table_{table_name}"]
            
            for column_name, column_info in table_info['columns'].items():
                nodes.append({
                    'id': node_id,
                    'type': 'column',
                    'name': column_name,
                    'table': table_name,
                    'data_type': column_info['type'],
                    'label': f"{table_name}.{column_name}"
                })
                
                # 添加表到列的边
                edges.append({
                    'from': table_node_id,
                    'to': node_id,
                    'type': 'has_column'
                })
                
                node_mapping[f"column_{table_name}_{column_name}"] = node_id
                node_id += 1
        
        # 添加外键关系边
        for relationship in normalized_schema['relationships']:
            from_table = relationship['from_table']
            from_column = relationship['from_column']
            to_table = relationship['to_table']
            to_column = relationship['to_column']
            
            from_node_id = node_mapping.get(f"column_{from_table}_{from_column}")
            to_node_id = node_mapping.get(f"column_{to_table}_{to_column}")
            
            if from_node_id is not None and to_node_id is not None:
                edges.append({
                    'from': from_node_id,
                    'to': to_node_id,
                    'type': 'foreign_key'
                })
        
        return {
            'nodes': nodes,
            'edges': edges,
            'node_mapping': node_mapping
        }
    
    def extract_schema_features(self, normalized_schema: Dict) -> Dict:
        """提取Schema特征"""
        features = {
            'num_tables': len(normalized_schema['tables']),
            'num_columns': 0,
            'num_relationships': len(normalized_schema['relationships']),
            'column_types': Counter(),
            'table_sizes': {},
            'complexity_score': 0
        }
        
        for table_name, table_info in normalized_schema['tables'].items():
            num_columns = len(table_info['columns'])
            features['num_columns'] += num_columns
            features['table_sizes'][table_name] = num_columns
            
            # 统计列类型
            for column_info in table_info['columns'].values():
                features['column_types'][column_info['type']] += 1
        
        # 计算复杂度分数
        features['complexity_score'] = (
            features['num_tables'] * 1.0 +
            features['num_columns'] * 0.5 +
            features['num_relationships'] * 2.0
        )
        
        return features

# 使用示例
schema_processor = SchemaPreprocessor()

# 示例原始Schema
raw_schema = {
    'tables': [
        {
            'name': 'Employees',
            'columns': [
                {'name': 'ID', 'type': 'INT', 'primary_key': True},
                {'name': 'Name', 'type': 'VARCHAR(100)'},
                {'name': 'Salary', 'type': 'DECIMAL(10,2)'},
                {'name': 'DepartmentID', 'type': 'INT', 
                 'foreign_key': {'table': 'Departments', 'column': 'ID'}}
            ]
        },
        {
            'name': 'Departments',
            'columns': [
                {'name': 'ID', 'type': 'INT', 'primary_key': True},
                {'name': 'Name', 'type': 'VARCHAR(50)'},
                {'name': 'Budget', 'type': 'DECIMAL(15,2)'}
            ]
        }
    ]
}

# 标准化Schema
normalized = schema_processor.normalize_schema(raw_schema)
print("标准化Schema:")
print(json.dumps(normalized, indent=2))

# 生成Schema图
schema_graph = schema_processor.generate_schema_graph(normalized)
print(f"\nSchema图节点数: {len(schema_graph['nodes'])}")
print(f"Schema图边数: {len(schema_graph['edges'])}")

# 提取特征
features = schema_processor.extract_schema_features(normalized)
print(f"\nSchema特征: {features}")

3.2 Schema编码

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

class SchemaEncoder:
    """Schema编码器"""
    
    def __init__(self, vocab_size: int, embedding_dim: int = 128):
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.vocab = {'<PAD>': 0, '<UNK>': 1}
        self.reverse_vocab = {0: '<PAD>', 1: '<UNK>'}
        self.vocab_counter = 2
        
        # 特殊token
        self.special_tokens = {
            '<TABLE>': self.add_token('<TABLE>'),
            '<COLUMN>': self.add_token('<COLUMN>'),
            '<TYPE>': self.add_token('<TYPE>'),
            '<PK>': self.add_token('<PK>'),
            '<FK>': self.add_token('<FK>')
        }
    
    def add_token(self, token: str) -> int:
        """添加token到词汇表"""
        if token not in self.vocab:
            self.vocab[token] = self.vocab_counter
            self.reverse_vocab[self.vocab_counter] = token
            self.vocab_counter += 1
        return self.vocab[token]
    
    def encode_schema(self, normalized_schema: Dict) -> Dict:
        """编码Schema为数值表示"""
        encoded_schema = {
            'table_sequences': [],
            'column_sequences': [],
            'type_sequences': [],
            'relationship_matrix': None,
            'metadata': {}
        }
        
        table_names = list(normalized_schema['tables'].keys())
        table_to_id = {name: i for i, name in enumerate(table_names)}
        
        # 编码表序列
        for table_name in table_names:
            table_seq = [self.special_tokens['<TABLE>']]
            
            # 添加表名tokens
            for token in table_name.split('_'):
                table_seq.append(self.add_token(token))
            
            encoded_schema['table_sequences'].append(table_seq)
        
        # 编码列序列
        all_columns = []
        for table_name, table_info in normalized_schema['tables'].items():
            for column_name, column_info in table_info['columns'].items():
                column_seq = [self.special_tokens['<COLUMN>']]
                
                # 添加表名
                for token in table_name.split('_'):
                    column_seq.append(self.add_token(token))
                
                # 添加列名
                for token in column_name.split('_'):
                    column_seq.append(self.add_token(token))
                
                # 添加类型
                column_seq.append(self.special_tokens['<TYPE>'])
                column_seq.append(self.add_token(column_info['type']))
                
                # 添加约束信息
                if column_name in table_info['primary_keys']:
                    column_seq.append(self.special_tokens['<PK>'])
                
                for fk in table_info['foreign_keys']:
                    if fk['column'] == column_name:
                        column_seq.append(self.special_tokens['<FK>'])
                        # 添加引用的表和列
                        for token in fk['references_table'].split('_'):
                            column_seq.append(self.add_token(token))
                        for token in fk['references_column'].split('_'):
                            column_seq.append(self.add_token(token))
                
                encoded_schema['column_sequences'].append(column_seq)
                all_columns.append((table_name, column_name))
        
        # 构建关系矩阵
        num_tables = len(table_names)
        relationship_matrix = torch.zeros(num_tables, num_tables)
        
        for relationship in normalized_schema['relationships']:
            from_table_id = table_to_id[relationship['from_table']]
            to_table_id = table_to_id[relationship['to_table']]
            relationship_matrix[from_table_id][to_table_id] = 1
            relationship_matrix[to_table_id][from_table_id] = 1  # 双向
        
        encoded_schema['relationship_matrix'] = relationship_matrix
        encoded_schema['metadata'] = {
            'table_names': table_names,
            'table_to_id': table_to_id,
            'all_columns': all_columns,
            'vocab_size': len(self.vocab)
        }
        
        return encoded_schema
    
    def create_schema_embeddings(self, encoded_schema: Dict) -> torch.Tensor:
        """创建Schema嵌入"""
        # 简单的平均嵌入
        embedding_layer = nn.Embedding(len(self.vocab), self.embedding_dim)
        
        # 表嵌入
        table_embeddings = []
        for table_seq in encoded_schema['table_sequences']:
            table_tensor = torch.tensor(table_seq)
            table_emb = embedding_layer(table_tensor).mean(dim=0)
            table_embeddings.append(table_emb)
        
        # 列嵌入
        column_embeddings = []
        for column_seq in encoded_schema['column_sequences']:
            column_tensor = torch.tensor(column_seq)
            column_emb = embedding_layer(column_tensor).mean(dim=0)
            column_embeddings.append(column_emb)
        
        # 组合嵌入
        if table_embeddings:
            table_emb_tensor = torch.stack(table_embeddings)
        else:
            table_emb_tensor = torch.zeros(1, self.embedding_dim)
        
        if column_embeddings:
            column_emb_tensor = torch.stack(column_embeddings)
        else:
            column_emb_tensor = torch.zeros(1, self.embedding_dim)
        
        # 整体Schema嵌入
        schema_embedding = torch.cat([
            table_emb_tensor.mean(dim=0),
            column_emb_tensor.mean(dim=0)
        ])
        
        return schema_embedding

# 使用示例
schema_encoder = SchemaEncoder(vocab_size=10000)

# 编码Schema
encoded = schema_encoder.encode_schema(normalized)
print(f"编码后的表序列数量: {len(encoded['table_sequences'])}")
print(f"编码后的列序列数量: {len(encoded['column_sequences'])}")
print(f"关系矩阵形状: {encoded['relationship_matrix'].shape}")

# 创建嵌入
schema_emb = schema_encoder.create_schema_embeddings(encoded)
print(f"Schema嵌入维度: {schema_emb.shape}")

4. SQL预处理

4.1 SQL解析与标准化

import sqlparse
from sqlparse.sql import Statement, IdentifierList, Identifier, Function
from sqlparse.tokens import Keyword, DML, Name

class SQLPreprocessor:
    """SQL预处理器"""
    
    def __init__(self):
        self.sql_keywords = {
            'SELECT', 'FROM', 'WHERE', 'GROUP', 'BY', 'ORDER', 'HAVING',
            'LIMIT', 'OFFSET', 'UNION', 'INTERSECT', 'EXCEPT', 'WITH',
            'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'DROP', 'ALTER',
            'AND', 'OR', 'NOT', 'IN', 'EXISTS', 'BETWEEN', 'LIKE',
            'IS', 'NULL', 'DISTINCT', 'ALL', 'ANY', 'SOME'
        }
        
        self.aggregation_functions = {
            'COUNT', 'SUM', 'AVG', 'MAX', 'MIN', 'STDDEV', 'VARIANCE'
        }
        
        self.operators = {
            '=', '!=', '<>', '<', '>', '<=', '>=', '+', '-', '*', '/', '%'
        }
    
    def normalize_sql(self, sql: str) -> str:
        """标准化SQL语句"""
        # 解析SQL
        try:
            parsed = sqlparse.parse(sql)[0]
        except:
            return sql
        
        # 格式化
        formatted = sqlparse.format(
            str(parsed),
            reindent=True,
            keyword_case='upper',
            identifier_case='lower',
            strip_comments=True,
            use_space_around_operators=True
        )
        
        # 移除多余空格
        formatted = re.sub(r'\s+', ' ', formatted)
        
        return formatted.strip()
    
    def parse_sql_structure(self, sql: str) -> Dict:
        """解析SQL结构"""
        structure = {
            'type': 'SELECT',  # 默认为SELECT
            'select': [],
            'from': [],
            'where': [],
            'group_by': [],
            'having': [],
            'order_by': [],
            'limit': None,
            'joins': [],
            'subqueries': []
        }
        
        try:
            parsed = sqlparse.parse(sql)[0]
            tokens = list(parsed.flatten())
            
            current_clause = None
            clause_tokens = []
            
            for token in tokens:
                if token.ttype is Keyword:
                    keyword = token.value.upper()
                    
                    # 保存前一个子句
                    if current_clause and clause_tokens:
                        structure[current_clause] = self.process_clause_tokens(
                            current_clause, clause_tokens
                        )
                        clause_tokens = []
                    
                    # 设置当前子句
                    if keyword in ['SELECT']:
                        current_clause = 'select'
                        structure['type'] = 'SELECT'
                    elif keyword in ['FROM']:
                        current_clause = 'from'
                    elif keyword in ['WHERE']:
                        current_clause = 'where'
                    elif keyword in ['GROUP']:
                        current_clause = 'group_by'
                    elif keyword in ['HAVING']:
                        current_clause = 'having'
                    elif keyword in ['ORDER']:
                        current_clause = 'order_by'
                    elif keyword in ['LIMIT']:
                        current_clause = 'limit'
                    else:
                        current_clause = None
                
                elif current_clause and token.ttype not in [
                    sqlparse.tokens.Whitespace, sqlparse.tokens.Newline
                ]:
                    clause_tokens.append(token)
            
            # 处理最后一个子句
            if current_clause and clause_tokens:
                structure[current_clause] = self.process_clause_tokens(
                    current_clause, clause_tokens
                )
        
        except Exception as e:
            print(f"SQL解析错误: {e}")
        
        return structure
    
    def process_clause_tokens(self, clause_type: str, tokens: List) -> List:
        """处理子句tokens"""
        if clause_type == 'select':
            return self.process_select_clause(tokens)
        elif clause_type == 'from':
            return self.process_from_clause(tokens)
        elif clause_type == 'where':
            return self.process_where_clause(tokens)
        elif clause_type in ['group_by', 'order_by']:
            return self.process_list_clause(tokens)
        elif clause_type == 'limit':
            return self.process_limit_clause(tokens)
        else:
            return [token.value for token in tokens if token.value.strip()]
    
    def process_select_clause(self, tokens: List) -> List:
        """处理SELECT子句"""
        select_items = []
        current_item = []
        
        for token in tokens:
            if token.value == ',':
                if current_item:
                    select_items.append(' '.join(t.value for t in current_item))
                    current_item = []
            else:
                current_item.append(token)
        
        # 添加最后一项
        if current_item:
            select_items.append(' '.join(t.value for t in current_item))
        
        return select_items
    
    def process_from_clause(self, tokens: List) -> List:
        """处理FROM子句"""
        from_items = []
        current_item = []
        
        for token in tokens:
            if token.value.upper() in ['JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER']:
                if current_item:
                    from_items.append(' '.join(t.value for t in current_item))
                    current_item = []
                # 处理JOIN
                # 这里可以添加更复杂的JOIN处理逻辑
            elif token.value == ',':
                if current_item:
                    from_items.append(' '.join(t.value for t in current_item))
                    current_item = []
            else:
                current_item.append(token)
        
        if current_item:
            from_items.append(' '.join(t.value for t in current_item))
        
        return from_items
    
    def process_where_clause(self, tokens: List) -> List:
        """处理WHERE子句"""
        conditions = []
        current_condition = []
        
        for token in tokens:
            if token.value.upper() in ['AND', 'OR']:
                if current_condition:
                    conditions.append({
                        'condition': ' '.join(t.value for t in current_condition),
                        'operator': token.value.upper()
                    })
                    current_condition = []
            else:
                current_condition.append(token)
        
        if current_condition:
            conditions.append({
                'condition': ' '.join(t.value for t in current_condition),
                'operator': None
            })
        
        return conditions
    
    def process_list_clause(self, tokens: List) -> List:
        """处理列表子句(GROUP BY, ORDER BY)"""
        items = []
        current_item = []
        
        for token in tokens:
            if token.value == ',':
                if current_item:
                    items.append(' '.join(t.value for t in current_item))
                    current_item = []
            else:
                current_item.append(token)
        
        if current_item:
            items.append(' '.join(t.value for t in current_item))
        
        return items
    
    def process_limit_clause(self, tokens: List) -> str:
        """处理LIMIT子句"""
        return ' '.join(token.value for token in tokens if token.value.strip())
    
    def extract_sql_features(self, sql_structure: Dict) -> Dict:
        """提取SQL特征"""
        features = {
            'query_type': sql_structure['type'],
            'num_select_items': len(sql_structure['select']),
            'num_tables': len(sql_structure['from']),
            'has_where': len(sql_structure['where']) > 0,
            'has_group_by': len(sql_structure['group_by']) > 0,
            'has_having': len(sql_structure['having']) > 0,
            'has_order_by': len(sql_structure['order_by']) > 0,
            'has_limit': sql_structure['limit'] is not None,
            'num_conditions': len(sql_structure['where']),
            'complexity_score': 0
        }
        
        # 计算复杂度分数
        complexity = 0
        complexity += features['num_select_items'] * 0.5
        complexity += features['num_tables'] * 1.0
        complexity += features['num_conditions'] * 1.5
        complexity += 2.0 if features['has_group_by'] else 0
        complexity += 1.0 if features['has_having'] else 0
        complexity += 0.5 if features['has_order_by'] else 0
        
        features['complexity_score'] = complexity
        
        return features
    
    def tokenize_sql(self, sql: str) -> List[str]:
        """SQL分词"""
        # 标准化
        normalized = self.normalize_sql(sql)
        
        # 简单分词
        tokens = []
        current_token = ''
        
        for char in normalized:
            if char.isspace():
                if current_token:
                    tokens.append(current_token)
                    current_token = ''
            elif char in '(),;':
                if current_token:
                    tokens.append(current_token)
                    current_token = ''
                tokens.append(char)
            else:
                current_token += char
        
        if current_token:
            tokens.append(current_token)
        
        return tokens

# 使用示例
sql_processor = SQLPreprocessor()

# 测试SQL语句
test_sqls = [
    "SELECT name, salary FROM employees WHERE salary > 50000",
    "SELECT COUNT(*) FROM departments GROUP BY budget HAVING COUNT(*) > 5",
    "SELECT e.name, d.name FROM employees e JOIN departments d ON e.dept_id = d.id ORDER BY e.salary DESC LIMIT 10"
]

for sql in test_sqls:
    print(f"原始SQL: {sql}")
    
    # 标准化
    normalized = sql_processor.normalize_sql(sql)
    print(f"标准化: {normalized}")
    
    # 解析结构
    structure = sql_processor.parse_sql_structure(sql)
    print(f"结构: {structure}")
    
    # 提取特征
    features = sql_processor.extract_sql_features(structure)
    print(f"特征: {features}")
    
    # 分词
    tokens = sql_processor.tokenize_sql(sql)
    print(f"分词: {tokens}")
    
    print("-" * 80)

4.2 SQL序列化

class SQLSequentializer:
    """SQL序列化器"""
    
    def __init__(self):
        self.vocab = {
            '<PAD>': 0, '<UNK>': 1, '<SOS>': 2, '<EOS>': 3,
            '<SEP>': 4, '<NUM>': 5, '<STR>': 6
        }
        self.reverse_vocab = {v: k for k, v in self.vocab.items()}
        self.vocab_counter = len(self.vocab)
        
        # SQL关键词
        sql_keywords = [
            'SELECT', 'FROM', 'WHERE', 'GROUP', 'BY', 'ORDER', 'HAVING',
            'LIMIT', 'OFFSET', 'UNION', 'INTERSECT', 'EXCEPT',
            'AND', 'OR', 'NOT', 'IN', 'EXISTS', 'BETWEEN', 'LIKE',
            'IS', 'NULL', 'DISTINCT', 'ALL', 'ANY', 'SOME',
            'COUNT', 'SUM', 'AVG', 'MAX', 'MIN',
            'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON',
            'AS', 'ASC', 'DESC'
        ]
        
        for keyword in sql_keywords:
            self.add_token(keyword)
        
        # 操作符
        operators = ['=', '!=', '<>', '<', '>', '<=', '>=', '+', '-', '*', '/', '%']
        for op in operators:
            self.add_token(op)
    
    def add_token(self, token: str) -> int:
        """添加token到词汇表"""
        if token not in self.vocab:
            self.vocab[token] = self.vocab_counter
            self.reverse_vocab[self.vocab_counter] = token
            self.vocab_counter += 1
        return self.vocab[token]
    
    def preprocess_token(self, token: str) -> str:
        """预处理单个token"""
        # 数字替换
        if re.match(r'^\d+(\.\d+)?$', token):
            return '<NUM>'
        
        # 字符串替换
        if (token.startswith("'") and token.endswith("'")) or \
           (token.startswith('"') and token.endswith('"')):
            return '<STR>'
        
        return token.upper()
    
    def sql_to_sequence(self, sql: str) -> List[int]:
        """将SQL转换为序列"""
        # 分词
        sql_processor = SQLPreprocessor()
        tokens = sql_processor.tokenize_sql(sql)
        
        # 预处理tokens
        processed_tokens = [self.preprocess_token(token) for token in tokens]
        
        # 转换为ID序列
        sequence = [self.vocab['<SOS>']]
        
        for token in processed_tokens:
            if token in self.vocab:
                sequence.append(self.vocab[token])
            else:
                # 添加新token或使用UNK
                if len(self.vocab) < 50000:  # 词汇表大小限制
                    sequence.append(self.add_token(token))
                else:
                    sequence.append(self.vocab['<UNK>'])
        
        sequence.append(self.vocab['<EOS>'])
        
        return sequence
    
    def sequence_to_sql(self, sequence: List[int]) -> str:
        """将序列转换为SQL"""
        tokens = []
        
        for token_id in sequence:
            if token_id in self.reverse_vocab:
                token = self.reverse_vocab[token_id]
                if token not in ['<SOS>', '<EOS>', '<PAD>']:
                    tokens.append(token)
        
        return ' '.join(tokens)
    
    def create_training_pairs(self, questions: List[str], 
                            sqls: List[str]) -> List[Tuple[List[int], List[int]]]:
        """创建训练对"""
        pairs = []
        
        # 为问题构建词汇表
        question_processor = QuestionPreprocessor()
        
        for question, sql in zip(questions, sqls):
            # 处理问题
            processed_question = question_processor.preprocess_question(question)
            question_tokens = processed_question['tokens']
            
            # 构建问题序列
            question_sequence = [self.vocab['<SOS>']]
            for token in question_tokens:
                token_id = self.add_token(token.upper())
                question_sequence.append(token_id)
            question_sequence.append(self.vocab['<EOS>'])
            
            # 构建SQL序列
            sql_sequence = self.sql_to_sequence(sql)
            
            pairs.append((question_sequence, sql_sequence))
        
        return pairs
    
    def pad_sequences(self, sequences: List[List[int]], 
                     max_length: int = None) -> torch.Tensor:
        """填充序列"""
        if max_length is None:
            max_length = max(len(seq) for seq in sequences)
        
        padded = []
        for seq in sequences:
            if len(seq) > max_length:
                padded.append(seq[:max_length])
            else:
                padded.append(seq + [self.vocab['<PAD>']] * (max_length - len(seq)))
        
        return torch.tensor(padded, dtype=torch.long)
    
    def create_attention_mask(self, sequences: torch.Tensor) -> torch.Tensor:
        """创建注意力掩码"""
        return (sequences != self.vocab['<PAD>']).float()

# 使用示例
sequentializer = SQLSequentializer()

# 测试数据
test_questions = [
    "Show all employees with salary greater than 50000",
    "What is the average salary by department",
    "List top 10 highest paid employees"
]

test_sqls = [
    "SELECT * FROM employees WHERE salary > 50000",
    "SELECT department, AVG(salary) FROM employees GROUP BY department",
    "SELECT * FROM employees ORDER BY salary DESC LIMIT 10"
]

# 创建训练对
training_pairs = sequentializer.create_training_pairs(test_questions, test_sqls)

print(f"训练对数量: {len(training_pairs)}")
print(f"词汇表大小: {len(sequentializer.vocab)}")

# 显示第一个训练对
question_seq, sql_seq = training_pairs[0]
print(f"问题序列: {question_seq}")
print(f"SQL序列: {sql_seq}")

# 转换回文本
recovered_sql = sequentializer.sequence_to_sql(sql_seq)
print(f"恢复的SQL: {recovered_sql}")

# 批量处理
question_sequences = [pair[0] for pair in training_pairs]
sql_sequences = [pair[1] for pair in training_pairs]

padded_questions = sequentializer.pad_sequences(question_sequences)
padded_sqls = sequentializer.pad_sequences(sql_sequences)

print(f"填充后的问题序列形状: {padded_questions.shape}")
print(f"填充后的SQL序列形状: {padded_sqls.shape}")

# 创建注意力掩码
question_mask = sequentializer.create_attention_mask(padded_questions)
sql_mask = sequentializer.create_attention_mask(padded_sqls)

print(f"问题注意力掩码形状: {question_mask.shape}")
print(f"SQL注意力掩码形状: {sql_mask.shape}")

5. 特征工程

5.1 多模态特征融合

class MultiModalFeatureExtractor:
    """多模态特征提取器"""
    
    def __init__(self, embedding_dim: int = 256):
        self.embedding_dim = embedding_dim
        self.question_processor = QuestionPreprocessor()
        self.schema_processor = SchemaPreprocessor()
        self.sql_processor = SQLPreprocessor()
    
    def extract_question_features(self, question: str) -> Dict:
        """提取问题特征"""
        processed = self.question_processor.preprocess_question(question)
        
        features = {
            'length': processed['length'],
            'num_entities': sum(len(entities) for entities in processed['entities'].values()),
            'has_numbers': len(processed['entities']['numbers']) > 0,
            'has_keywords': len(processed['entities']['keywords']) > 0,
            'question_type': self.classify_question_type(processed['tokens']),
            'complexity': self.calculate_question_complexity(processed)
        }
        
        return features
    
    def classify_question_type(self, tokens: List[str]) -> str:
        """分类问题类型"""
        question_words = {'what', 'how', 'which', 'who', 'when', 'where', 'why'}
        action_words = {'show', 'list', 'find', 'get', 'count', 'calculate'}
        
        tokens_lower = [token.lower() for token in tokens]
        
        if any(word in tokens_lower for word in question_words):
            if 'how many' in ' '.join(tokens_lower) or 'count' in tokens_lower:
                return 'count'
            elif 'what' in tokens_lower and ('average' in tokens_lower or 'avg' in tokens_lower):
                return 'aggregation'
            elif 'what' in tokens_lower or 'which' in tokens_lower:
                return 'selection'
            else:
                return 'other'
        elif any(word in tokens_lower for word in action_words):
            if 'show' in tokens_lower or 'list' in tokens_lower:
                return 'listing'
            elif 'count' in tokens_lower:
                return 'count'
            else:
                return 'action'
        else:
            return 'unknown'
    
    def calculate_question_complexity(self, processed_question: Dict) -> float:
        """计算问题复杂度"""
        complexity = 0.0
        
        # 基于长度
        complexity += processed_question['length'] * 0.1
        
        # 基于实体数量
        entities = processed_question['entities']
        complexity += sum(len(entity_list) for entity_list in entities.values()) * 0.5
        
        # 基于关键词
        complexity += len(entities['keywords']) * 1.0
        
        return complexity
    
    def extract_schema_features(self, schema: Dict) -> Dict:
        """提取Schema特征"""
        normalized = self.schema_processor.normalize_schema(schema)
        features = self.schema_processor.extract_schema_features(normalized)
        
        # 添加额外特征
        features['avg_columns_per_table'] = (
            features['num_columns'] / features['num_tables'] 
            if features['num_tables'] > 0 else 0
        )
        
        features['relationship_density'] = (
            features['num_relationships'] / (features['num_tables'] * (features['num_tables'] - 1) / 2)
            if features['num_tables'] > 1 else 0
        )
        
        return features
    
    def extract_sql_features(self, sql: str) -> Dict:
        """提取SQL特征"""
        structure = self.sql_processor.parse_sql_structure(sql)
        features = self.sql_processor.extract_sql_features(structure)
        
        # 添加额外特征
        features['has_aggregation'] = any(
            func in sql.upper() for func in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']
        )
        
        features['has_join'] = 'JOIN' in sql.upper()
        features['has_subquery'] = '(' in sql and 'SELECT' in sql[sql.find('('):]
        
        return features
    
    def extract_alignment_features(self, question: str, schema: Dict, sql: str) -> Dict:
        """提取对齐特征"""
        # 处理问题
        processed_question = self.question_processor.preprocess_question(question)
        question_tokens = set(token.lower() for token in processed_question['tokens'])
        
        # 提取Schema实体
        schema_entities = set()
        for table_name in schema.get('tables', []):
            schema_entities.add(table_name.lower())
            for column_name in schema.get('columns', {}).get(table_name, []):
                schema_entities.add(column_name.lower())
        
        # 提取SQL实体
        sql_structure = self.sql_processor.parse_sql_structure(sql)
        sql_entities = set()
        for select_item in sql_structure['select']:
            sql_entities.update(select_item.lower().split())
        for from_item in sql_structure['from']:
            sql_entities.update(from_item.lower().split())
        
        # 计算对齐特征
        question_schema_overlap = len(question_tokens & schema_entities)
        question_sql_overlap = len(question_tokens & sql_entities)
        schema_sql_overlap = len(schema_entities & sql_entities)
        
        features = {
            'question_schema_overlap': question_schema_overlap,
            'question_sql_overlap': question_sql_overlap,
            'schema_sql_overlap': schema_sql_overlap,
            'question_schema_ratio': (
                question_schema_overlap / len(question_tokens) 
                if len(question_tokens) > 0 else 0
            ),
            'question_sql_ratio': (
                question_sql_overlap / len(question_tokens) 
                if len(question_tokens) > 0 else 0
            )
        }
        
        return features
    
    def create_feature_vector(self, question: str, schema: Dict, sql: str = None) -> torch.Tensor:
        """创建特征向量"""
        # 提取各类特征
        question_features = self.extract_question_features(question)
        schema_features = self.extract_schema_features(schema)
        
        feature_vector = []
        
        # 问题特征
        feature_vector.extend([
            question_features['length'],
            question_features['num_entities'],
            float(question_features['has_numbers']),
            float(question_features['has_keywords']),
            question_features['complexity']
        ])
        
        # Schema特征
        feature_vector.extend([
            schema_features['num_tables'],
            schema_features['num_columns'],
            schema_features['num_relationships'],
            schema_features['avg_columns_per_table'],
            schema_features['relationship_density'],
            schema_features['complexity_score']
        ])
        
        # 如果有SQL,添加SQL特征和对齐特征
        if sql:
            sql_features = self.extract_sql_features(sql)
            alignment_features = self.extract_alignment_features(question, schema, sql)
            
            feature_vector.extend([
                sql_features['num_select_items'],
                sql_features['num_tables'],
                sql_features['num_conditions'],
                float(sql_features['has_group_by']),
                float(sql_features['has_order_by']),
                float(sql_features['has_aggregation']),
                float(sql_features['has_join']),
                float(sql_features['has_subquery']),
                sql_features['complexity_score']
            ])
            
            feature_vector.extend([
                alignment_features['question_schema_overlap'],
                alignment_features['question_sql_overlap'],
                alignment_features['schema_sql_overlap'],
                alignment_features['question_schema_ratio'],
                alignment_features['question_sql_ratio']
            ])
        
        return torch.tensor(feature_vector, dtype=torch.float32)

# 使用示例
feature_extractor = MultiModalFeatureExtractor()

# 测试数据
test_question = "Show me all employees with salary greater than 50000"
test_schema = {
    'tables': [
        {
            'name': 'employees',
            'columns': [
                {'name': 'id', 'type': 'INT', 'primary_key': True},
                {'name': 'name', 'type': 'VARCHAR(100)'},
                {'name': 'salary', 'type': 'DECIMAL(10,2)'},
                {'name': 'department_id', 'type': 'INT'}
            ]
        }
    ]
}
test_sql = "SELECT * FROM employees WHERE salary > 50000"

# 提取特征
question_features = feature_extractor.extract_question_features(test_question)
schema_features = feature_extractor.extract_schema_features(test_schema)
sql_features = feature_extractor.extract_sql_features(test_sql)
alignment_features = feature_extractor.extract_alignment_features(
    test_question, test_schema, test_sql
)

print("问题特征:", question_features)
print("Schema特征:", schema_features)
print("SQL特征:", sql_features)
print("对齐特征:", alignment_features)

# 创建特征向量
feature_vector = feature_extractor.create_feature_vector(
    test_question, test_schema, test_sql
)
print(f"特征向量维度: {feature_vector.shape}")
print(f"特征向量: {feature_vector}")

6.4 数据增强技术

6.4.1 问题改写

import random
from typing import List, Dict

class QuestionAugmenter:
    """问题增强器"""
    
    def __init__(self):
        # 同义词词典
        self.synonyms = {
            'show': ['display', 'list', 'get', 'find', 'retrieve'],
            'all': ['every', 'each', 'total'],
            'employee': ['worker', 'staff', 'personnel'],
            'salary': ['wage', 'pay', 'income', 'earnings'],
            'greater': ['higher', 'more', 'above', 'over'],
            'department': ['division', 'section', 'unit']
        }
        
        # 句式模板
        self.templates = [
            "Show me {entity} where {condition}",
            "List all {entity} with {condition}",
            "Find {entity} that have {condition}",
            "Get {entity} where {condition}",
            "Display {entity} with {condition}"
        ]
    
    def synonym_replacement(self, question: str, num_replacements: int = 1) -> str:
        """同义词替换"""
        words = question.lower().split()
        new_words = words.copy()
        
        # 随机选择要替换的词
        replaceable_words = [i for i, word in enumerate(words) if word in self.synonyms]
        
        if not replaceable_words:
            return question
        
        for _ in range(min(num_replacements, len(replaceable_words))):
            word_idx = random.choice(replaceable_words)
            word = words[word_idx]
            synonym = random.choice(self.synonyms[word])
            new_words[word_idx] = synonym
            replaceable_words.remove(word_idx)
        
        return ' '.join(new_words)
    
    def template_based_generation(self, entities: List[str], conditions: List[str]) -> List[str]:
        """基于模板的生成"""
        generated_questions = []
        
        for template in self.templates:
            for entity in entities:
                for condition in conditions:
                    question = template.format(entity=entity, condition=condition)
                    generated_questions.append(question)
        
        return generated_questions
    
    def paraphrase_generation(self, question: str) -> List[str]:
        """释义生成"""
        paraphrases = []
        
        # 简单的释义规则
        if "show me" in question.lower():
            paraphrases.append(question.replace("show me", "list"))
            paraphrases.append(question.replace("show me", "display"))
        
        if "all" in question.lower():
            paraphrases.append(question.replace("all", "every"))
        
        if "greater than" in question.lower():
            paraphrases.append(question.replace("greater than", "more than"))
            paraphrases.append(question.replace("greater than", "above"))
        
        return paraphrases

# 使用示例
augmenter = QuestionAugmenter()
original_question = "Show me all employees with salary greater than 50000"

# 同义词替换
synonym_variants = [
    augmenter.synonym_replacement(original_question, 1)
    for _ in range(3)
]
print("同义词替换:")
for variant in synonym_variants:
    print(f"  - {variant}")

# 释义生成
paraphrases = augmenter.paraphrase_generation(original_question)
print("\n释义生成:")
for paraphrase in paraphrases:
    print(f"  - {paraphrase}")

6.4.2 SQL增强

class SQLAugmenter:
    """SQL增强器"""
    
    def __init__(self):
        self.equivalent_operators = {
            '>': ['>', '>='],
            '<': ['<', '<='],
            '=': ['=', 'IN'],
            'AND': ['AND', '&&'],
            'OR': ['OR', '||']
        }
    
    def operator_substitution(self, sql: str) -> List[str]:
        """操作符替换"""
        variants = []
        
        for original, alternatives in self.equivalent_operators.items():
            if original in sql:
                for alt in alternatives:
                    if alt != original:
                        variant = sql.replace(original, alt)
                        variants.append(variant)
        
        return variants
    
    def column_alias_addition(self, sql: str) -> str:
        """添加列别名"""
        # 简单的别名添加
        if "SELECT *" in sql:
            return sql  # 不处理SELECT *
        
        # 为列添加别名
        if "SELECT" in sql and "FROM" in sql:
            select_part = sql.split("FROM")[0]
            from_part = "FROM" + sql.split("FROM")[1]
            
            # 简单处理:为第一个列添加别名
            if "SELECT " in select_part:
                columns = select_part.replace("SELECT ", "").split(",")
                if len(columns) > 0:
                    columns[0] = columns[0].strip() + " AS result"
                    new_select = "SELECT " + ", ".join(columns)
                    return new_select + " " + from_part
        
        return sql
    
    def subquery_transformation(self, sql: str) -> str:
        """子查询转换"""
        # 将简单查询转换为子查询形式
        if "WHERE" in sql and "SELECT" in sql:
            return f"SELECT * FROM ({sql}) AS subquery"
        return sql

# 使用示例
sql_augmenter = SQLAugmenter()
original_sql = "SELECT name, salary FROM employees WHERE salary > 50000"

# 操作符替换
operator_variants = sql_augmenter.operator_substitution(original_sql)
print("操作符替换:")
for variant in operator_variants:
    print(f"  - {variant}")

# 添加别名
alias_sql = sql_augmenter.column_alias_addition(original_sql)
print(f"\n添加别名: {alias_sql}")

# 子查询转换
subquery_sql = sql_augmenter.subquery_transformation(original_sql)
print(f"子查询转换: {subquery_sql}")

6.5 批处理与优化

6.5.1 批处理数据加载器

import torch
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Dict

class Text2SQLDataset(Dataset):
    """Text2SQL数据集"""
    
    def __init__(self, data: List[Dict], tokenizer, max_length: int = 512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.preprocessor = Text2SQLPreprocessor()
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # 预处理问题
        question = self.preprocessor.preprocess_question(item['question'])
        
        # 预处理Schema
        schema_text = self.preprocessor.schema_to_text(item['schema'])
        
        # 组合输入文本
        input_text = f"Question: {question} Schema: {schema_text}"
        
        # 分词
        encoding = self.tokenizer(
            input_text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # 处理SQL标签
        sql = item.get('sql', '')
        if sql:
            sql_encoding = self.tokenizer(
                sql,
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
            labels = sql_encoding['input_ids'].squeeze()
        else:
            labels = torch.tensor([-100] * self.max_length)  # 忽略标签
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': labels
        }

class DataCollator:
    """数据整理器"""
    
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
    
    def __call__(self, batch):
        # 动态填充
        max_length = max([len(item['input_ids']) for item in batch])
        
        input_ids = []
        attention_masks = []
        labels = []
        
        for item in batch:
            # 填充到批次最大长度
            input_id = item['input_ids']
            attention_mask = item['attention_mask']
            label = item['labels']
            
            # 填充
            pad_length = max_length - len(input_id)
            if pad_length > 0:
                input_id = torch.cat([
                    input_id, 
                    torch.tensor([self.tokenizer.pad_token_id] * pad_length)
                ])
                attention_mask = torch.cat([
                    attention_mask,
                    torch.tensor([0] * pad_length)
                ])
                label = torch.cat([
                    label,
                    torch.tensor([-100] * pad_length)
                ])
            
            input_ids.append(input_id)
            attention_masks.append(attention_mask)
            labels.append(label)
        
        return {
            'input_ids': torch.stack(input_ids),
            'attention_mask': torch.stack(attention_masks),
            'labels': torch.stack(labels)
        }

# 使用示例
from transformers import AutoTokenizer

# 初始化分词器
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 示例数据
sample_data = [
    {
        'question': 'Show all employees',
        'schema': {
            'tables': [{
                'name': 'employees',
                'columns': [{'name': 'id'}, {'name': 'name'}]
            }]
        },
        'sql': 'SELECT * FROM employees'
    }
]

# 创建数据集和数据加载器
dataset = Text2SQLDataset(sample_data, tokenizer)
collator = DataCollator(tokenizer)
dataloader = DataLoader(
    dataset, 
    batch_size=2, 
    shuffle=True, 
    collate_fn=collator
)

# 测试批处理
for batch in dataloader:
    print(f"输入形状: {batch['input_ids'].shape}")
    print(f"注意力掩码形状: {batch['attention_mask'].shape}")
    print(f"标签形状: {batch['labels'].shape}")
    break

6.5.2 内存优化

import gc
import psutil
from typing import Iterator, List

class MemoryOptimizedProcessor:
    """内存优化处理器"""
    
    def __init__(self, chunk_size: int = 1000):
        self.chunk_size = chunk_size
    
    def get_memory_usage(self) -> float:
        """获取当前内存使用率"""
        process = psutil.Process()
        return process.memory_percent()
    
    def process_in_chunks(self, data: List, processor_func) -> Iterator:
        """分块处理数据"""
        for i in range(0, len(data), self.chunk_size):
            chunk = data[i:i + self.chunk_size]
            
            # 处理当前块
            processed_chunk = processor_func(chunk)
            
            # 返回处理结果
            yield processed_chunk
            
            # 清理内存
            del chunk
            gc.collect()
            
            # 监控内存使用
            memory_usage = self.get_memory_usage()
            if memory_usage > 80:  # 如果内存使用超过80%
                print(f"警告: 内存使用率达到 {memory_usage:.1f}%")
    
    def lazy_load_dataset(self, file_path: str) -> Iterator[Dict]:
        """懒加载数据集"""
        import json
        
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    yield json.loads(line.strip())
                except json.JSONDecodeError:
                    continue
    
    def cache_preprocessed_data(self, data: List[Dict], cache_file: str):
        """缓存预处理数据"""
        import pickle
        
        with open(cache_file, 'wb') as f:
            pickle.dump(data, f)
    
    def load_cached_data(self, cache_file: str) -> List[Dict]:
        """加载缓存数据"""
        import pickle
        import os
        
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as f:
                return pickle.load(f)
        return None

# 使用示例
processor = MemoryOptimizedProcessor(chunk_size=500)

# 模拟大数据集处理
def preprocess_chunk(chunk):
    """处理数据块"""
    preprocessor = Text2SQLPreprocessor()
    processed = []
    
    for item in chunk:
        processed_item = {
            'question': preprocessor.preprocess_question(item['question']),
            'schema': preprocessor.preprocess_schema(item['schema']),
            'sql': preprocessor.preprocess_sql(item.get('sql', ''))
        }
        processed.append(processed_item)
    
    return processed

# 分块处理示例
large_dataset = [sample_data[0]] * 10000  # 模拟大数据集

print(f"开始处理,初始内存使用: {processor.get_memory_usage():.1f}%")

processed_chunks = []
for chunk_result in processor.process_in_chunks(large_dataset, preprocess_chunk):
    processed_chunks.extend(chunk_result)
    print(f"处理了 {len(processed_chunks)} 条数据,内存使用: {processor.get_memory_usage():.1f}%")

print(f"处理完成,最终内存使用: {processor.get_memory_usage():.1f}%")

6.6 实践项目:构建完整的预处理管道

class Text2SQLPipeline:
    """完整的Text2SQL预处理管道"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.preprocessor = Text2SQLPreprocessor()
        self.feature_extractor = MultiModalFeatureExtractor()
        self.question_augmenter = QuestionAugmenter()
        self.sql_augmenter = SQLAugmenter()
        self.memory_optimizer = MemoryOptimizedProcessor(
            chunk_size=config.get('chunk_size', 1000)
        )
    
    def process_dataset(self, input_file: str, output_file: str, 
                       enable_augmentation: bool = True):
        """处理完整数据集"""
        print("开始处理数据集...")
        
        # 检查缓存
        cache_file = output_file + '.cache'
        cached_data = self.memory_optimizer.load_cached_data(cache_file)
        
        if cached_data:
            print("发现缓存数据,直接加载...")
            return cached_data
        
        processed_data = []
        
        # 懒加载和分块处理
        data_generator = self.memory_optimizer.lazy_load_dataset(input_file)
        current_chunk = []
        
        for item in data_generator:
            current_chunk.append(item)
            
            if len(current_chunk) >= self.config.get('chunk_size', 1000):
                # 处理当前块
                processed_chunk = self._process_chunk(
                    current_chunk, enable_augmentation
                )
                processed_data.extend(processed_chunk)
                
                # 清空当前块
                current_chunk = []
                
                print(f"已处理 {len(processed_data)} 条数据")
        
        # 处理最后一块
        if current_chunk:
            processed_chunk = self._process_chunk(
                current_chunk, enable_augmentation
            )
            processed_data.extend(processed_chunk)
        
        # 保存结果
        self._save_processed_data(processed_data, output_file)
        
        # 缓存数据
        self.memory_optimizer.cache_preprocessed_data(processed_data, cache_file)
        
        print(f"数据集处理完成,共 {len(processed_data)} 条数据")
        return processed_data
    
    def _process_chunk(self, chunk: List[Dict], enable_augmentation: bool) -> List[Dict]:
        """处理数据块"""
        processed_chunk = []
        
        for item in chunk:
            # 基础预处理
            processed_item = self._process_single_item(item)
            processed_chunk.append(processed_item)
            
            # 数据增强
            if enable_augmentation:
                augmented_items = self._augment_item(item)
                processed_chunk.extend(augmented_items)
        
        return processed_chunk
    
    def _process_single_item(self, item: Dict) -> Dict:
        """处理单个数据项"""
        # 预处理
        question = self.preprocessor.preprocess_question(item['question'])
        schema = self.preprocessor.preprocess_schema(item['schema'])
        sql = self.preprocessor.preprocess_sql(item.get('sql', ''))
        
        # 特征提取
        features = self.feature_extractor.create_feature_vector(
            question, schema, sql if sql else None
        )
        
        return {
            'original_question': item['question'],
            'processed_question': question,
            'schema': schema,
            'sql': sql,
            'features': features.tolist(),
            'metadata': {
                'question_length': len(question.split()),
                'sql_complexity': self._calculate_sql_complexity(sql)
            }
        }
    
    def _augment_item(self, item: Dict) -> List[Dict]:
        """增强单个数据项"""
        augmented_items = []
        
        # 问题增强
        question_variants = self.question_augmenter.synonym_replacement(
            item['question'], num_replacements=1
        )
        
        paraphrases = self.question_augmenter.paraphrase_generation(
            item['question']
        )
        
        # SQL增强
        sql_variants = []
        if 'sql' in item:
            sql_variants = self.sql_augmenter.operator_substitution(item['sql'])
        
        # 创建增强样本
        for variant_question in [question_variants] + paraphrases:
            augmented_item = item.copy()
            augmented_item['question'] = variant_question
            augmented_item['augmented'] = True
            augmented_items.append(self._process_single_item(augmented_item))
        
        return augmented_items[:2]  # 限制增强样本数量
    
    def _calculate_sql_complexity(self, sql: str) -> int:
        """计算SQL复杂度"""
        if not sql:
            return 0
        
        complexity = 0
        sql_lower = sql.lower()
        
        # 基础查询
        if 'select' in sql_lower:
            complexity += 1
        
        # 条件查询
        if 'where' in sql_lower:
            complexity += 1
        
        # 连接查询
        if 'join' in sql_lower:
            complexity += 2
        
        # 聚合查询
        if any(agg in sql_lower for agg in ['count', 'sum', 'avg', 'max', 'min']):
            complexity += 2
        
        # 分组查询
        if 'group by' in sql_lower:
            complexity += 2
        
        # 子查询
        if sql_lower.count('select') > 1:
            complexity += 3
        
        return complexity
    
    def _save_processed_data(self, data: List[Dict], output_file: str):
        """保存处理后的数据"""
        import json
        
        with open(output_file, 'w', encoding='utf-8') as f:
            for item in data:
                json.dump(item, f, ensure_ascii=False)
                f.write('\n')
    
    def get_statistics(self, data: List[Dict]) -> Dict:
        """获取数据统计信息"""
        stats = {
            'total_samples': len(data),
            'avg_question_length': 0,
            'avg_sql_complexity': 0,
            'complexity_distribution': {},
            'augmented_samples': 0
        }
        
        question_lengths = []
        sql_complexities = []
        complexity_counts = {}
        
        for item in data:
            # 问题长度
            q_len = item['metadata']['question_length']
            question_lengths.append(q_len)
            
            # SQL复杂度
            sql_comp = item['metadata']['sql_complexity']
            sql_complexities.append(sql_comp)
            
            # 复杂度分布
            complexity_counts[sql_comp] = complexity_counts.get(sql_comp, 0) + 1
            
            # 增强样本计数
            if item.get('augmented', False):
                stats['augmented_samples'] += 1
        
        stats['avg_question_length'] = sum(question_lengths) / len(question_lengths)
        stats['avg_sql_complexity'] = sum(sql_complexities) / len(sql_complexities)
        stats['complexity_distribution'] = complexity_counts
        
        return stats

# 使用示例
config = {
    'chunk_size': 500,
    'enable_caching': True,
    'augmentation_ratio': 0.3
}

pipeline = Text2SQLPipeline(config)

# 创建示例数据文件
sample_data_file = 'sample_text2sql.jsonl'
with open(sample_data_file, 'w', encoding='utf-8') as f:
    for i in range(100):
        sample = {
            'question': f'Show employees in department {i}',
            'schema': {
                'tables': [{
                    'name': 'employees',
                    'columns': [{'name': 'id'}, {'name': 'name'}, {'name': 'dept_id'}]
                }]
            },
            'sql': f'SELECT * FROM employees WHERE dept_id = {i}'
        }
        json.dump(sample, f, ensure_ascii=False)
        f.write('\n')

# 处理数据集
processed_data = pipeline.process_dataset(
    sample_data_file, 
    'processed_text2sql.jsonl',
    enable_augmentation=True
)

# 获取统计信息
stats = pipeline.get_statistics(processed_data)
print("\n数据集统计信息:")
for key, value in stats.items():
    print(f"{key}: {value}")

6.7 总结

本章详细介绍了Text2SQL任务中的数据预处理与特征工程技术,主要内容包括:

核心技术点

  1. 多模态预处理

    • 自然语言查询的标准化处理
    • Schema结构的编码与表示
    • SQL语句的解析与标准化
  2. 特征工程

    • 问题特征提取(长度、实体、关键词等)
    • Schema特征提取(表数量、列数量、关系复杂度等)
    • SQL特征提取(查询类型、复杂度等)
    • 跨模态对齐特征(实体重叠、语义匹配等)
  3. 数据增强

    • 问题改写(同义词替换、释义生成)
    • SQL变换(操作符替换、结构转换)
    • 模板生成(基于规则的数据扩充)
  4. 性能优化

    • 批处理数据加载
    • 内存优化策略
    • 缓存机制
    • 懒加载技术

实践要点

  1. 数据质量:确保预处理后的数据保持语义一致性
  2. 特征选择:根据具体任务选择合适的特征组合
  3. 增强策略:平衡数据增强的数量和质量
  4. 性能监控:实时监控内存使用和处理速度

下一步学习

  • 第7章将介绍模型训练与优化技术
  • 学习如何使用预处理后的数据训练高效的Text2SQL模型
  • 掌握模型调优和性能优化方法

通过本章的学习,你应该能够构建一个完整的Text2SQL数据预处理管道,为后续的模型训练打下坚实基础。