1. Text2SQL模型概述

Text2SQL模型是将自然语言查询转换为SQL语句的核心组件。随着深度学习技术的发展,Text2SQL模型经历了从基于规则的方法到端到端神经网络模型的演进。本章将详细介绍各种Text2SQL模型架构及其特点。

1.1 模型发展历程

第一代:基于规则的方法(2000s-2010s)

  • 特点:使用手工编写的语法规则和模板
  • 优势:可解释性强,对特定领域效果好
  • 劣势:扩展性差,难以处理复杂查询

第二代:统计机器学习方法(2010s)

  • 特点:使用传统机器学习算法
  • 代表:基于CRF的序列标注模型
  • 优势:能够从数据中学习
  • 劣势:特征工程复杂,性能有限

第三代:深度学习方法(2015-2020)

  • 特点:使用RNN、LSTM等神经网络
  • 代表:Seq2Seq模型、注意力机制
  • 优势:端到端学习,性能大幅提升
  • 劣势:需要大量训练数据

第四代:预训练模型时代(2020-至今)

  • 特点:基于Transformer和预训练模型
  • 代表:BERT、T5、GPT系列
  • 优势:强大的语言理解能力
  • 劣势:计算资源需求大

1.2 模型分类

按生成方式分类

  • 生成式模型:直接生成SQL语句
  • 分类式模型:将SQL生成分解为多个分类任务
  • 混合式模型:结合生成和分类的优势

按架构类型分类

  • 序列到序列模型:将自然语言序列转换为SQL序列
  • 图神经网络模型:利用数据库schema的图结构
  • 语法制导模型:基于SQL语法结构生成

2. 序列到序列模型

2.1 基础Seq2Seq架构

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicSeq2SeqModel(nn.Module):
    """基础序列到序列模型"""
    
    def __init__(self, input_vocab_size, output_vocab_size, 
                 embedding_dim, hidden_dim, num_layers=2):
        super(BasicSeq2SeqModel, self).__init__()
        
        # 编码器
        self.encoder_embedding = nn.Embedding(input_vocab_size, embedding_dim)
        self.encoder_lstm = nn.LSTM(embedding_dim, hidden_dim, 
                                   num_layers, batch_first=True)
        
        # 解码器
        self.decoder_embedding = nn.Embedding(output_vocab_size, embedding_dim)
        self.decoder_lstm = nn.LSTM(embedding_dim, hidden_dim, 
                                   num_layers, batch_first=True)
        
        # 输出层
        self.output_projection = nn.Linear(hidden_dim, output_vocab_size)
        self.dropout = nn.Dropout(0.1)
        
    def encode(self, input_seq):
        """编码输入序列"""
        embedded = self.encoder_embedding(input_seq)
        output, (hidden, cell) = self.encoder_lstm(embedded)
        return output, (hidden, cell)
    
    def decode(self, target_seq, encoder_hidden):
        """解码目标序列"""
        embedded = self.decoder_embedding(target_seq)
        output, _ = self.decoder_lstm(embedded, encoder_hidden)
        output = self.dropout(output)
        logits = self.output_projection(output)
        return logits
    
    def forward(self, input_seq, target_seq=None):
        """前向传播"""
        # 编码
        encoder_output, encoder_hidden = self.encode(input_seq)
        
        if target_seq is not None:
            # 训练模式:使用teacher forcing
            logits = self.decode(target_seq, encoder_hidden)
            return logits
        else:
            # 推理模式:自回归生成
            return self.generate(encoder_hidden)
    
    def generate(self, encoder_hidden, max_length=100):
        """生成SQL序列"""
        batch_size = encoder_hidden[0].size(1)
        device = encoder_hidden[0].device
        
        # 初始化解码器输入
        decoder_input = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
        decoder_hidden = encoder_hidden
        
        outputs = []
        
        for _ in range(max_length):
            # 解码一步
            embedded = self.decoder_embedding(decoder_input)
            output, decoder_hidden = self.decoder_lstm(embedded, decoder_hidden)
            logits = self.output_projection(output)
            
            # 选择下一个token
            next_token = torch.argmax(logits, dim=-1)
            outputs.append(next_token)
            
            # 更新解码器输入
            decoder_input = next_token
            
            # 检查是否结束
            if torch.all(next_token == 0):  # 假设0是EOS token
                break
        
        return torch.cat(outputs, dim=1)

# 模型实例化示例
input_vocab_size = 10000
output_vocab_size = 5000
embedding_dim = 256
hidden_dim = 512

model = BasicSeq2SeqModel(input_vocab_size, output_vocab_size, 
                         embedding_dim, hidden_dim)
print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")

2.2 注意力机制

class AttentionSeq2SeqModel(nn.Module):
    """带注意力机制的序列到序列模型"""
    
    def __init__(self, input_vocab_size, output_vocab_size, 
                 embedding_dim, hidden_dim, num_layers=2):
        super(AttentionSeq2SeqModel, self).__init__()
        
        # 编码器
        self.encoder_embedding = nn.Embedding(input_vocab_size, embedding_dim)
        self.encoder_lstm = nn.LSTM(embedding_dim, hidden_dim, 
                                   num_layers, batch_first=True, 
                                   bidirectional=True)
        
        # 解码器
        self.decoder_embedding = nn.Embedding(output_vocab_size, embedding_dim)
        self.decoder_lstm = nn.LSTM(embedding_dim + hidden_dim * 2, hidden_dim, 
                                   num_layers, batch_first=True)
        
        # 注意力机制
        self.attention = nn.Linear(hidden_dim * 3, hidden_dim)
        self.attention_combine = nn.Linear(hidden_dim * 2, hidden_dim)
        
        # 输出层
        self.output_projection = nn.Linear(hidden_dim, output_vocab_size)
        self.dropout = nn.Dropout(0.1)
        
    def attention_mechanism(self, decoder_hidden, encoder_outputs):
        """计算注意力权重"""
        batch_size, seq_len, hidden_size = encoder_outputs.size()
        
        # 扩展解码器隐藏状态
        decoder_hidden_expanded = decoder_hidden.unsqueeze(1).expand(
            batch_size, seq_len, -1)
        
        # 计算注意力分数
        combined = torch.cat([decoder_hidden_expanded, encoder_outputs], dim=2)
        attention_scores = self.attention(combined)
        attention_weights = F.softmax(attention_scores.sum(dim=2), dim=1)
        
        # 计算上下文向量
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        context = context.squeeze(1)
        
        return context, attention_weights
    
    def encode(self, input_seq):
        """编码输入序列"""
        embedded = self.encoder_embedding(input_seq)
        output, (hidden, cell) = self.encoder_lstm(embedded)
        return output, (hidden, cell)
    
    def decode_step(self, input_token, decoder_hidden, encoder_outputs):
        """解码一步"""
        embedded = self.decoder_embedding(input_token)
        
        # 计算注意力
        context, attention_weights = self.attention_mechanism(
            decoder_hidden[0][-1], encoder_outputs)
        
        # 结合输入和上下文
        combined_input = torch.cat([embedded.squeeze(1), context], dim=1)
        combined_input = combined_input.unsqueeze(1)
        
        # LSTM解码
        output, decoder_hidden = self.decoder_lstm(combined_input, decoder_hidden)
        
        # 输出投影
        output = self.dropout(output)
        logits = self.output_projection(output)
        
        return logits, decoder_hidden, attention_weights

# 使用示例
attention_model = AttentionSeq2SeqModel(input_vocab_size, output_vocab_size, 
                                       embedding_dim, hidden_dim)
print(f"注意力模型参数数量: {sum(p.numel() for p in attention_model.parameters())}")

3. Transformer架构

3.1 基础Transformer模型

import math

class PositionalEncoding(nn.Module):
    """位置编码"""
    
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class TransformerText2SQL(nn.Module):
    """基于Transformer的Text2SQL模型"""
    
    def __init__(self, input_vocab_size, output_vocab_size, d_model=512, 
                 nhead=8, num_encoder_layers=6, num_decoder_layers=6):
        super(TransformerText2SQL, self).__init__()
        
        self.d_model = d_model
        
        # 嵌入层
        self.input_embedding = nn.Embedding(input_vocab_size, d_model)
        self.output_embedding = nn.Embedding(output_vocab_size, d_model)
        
        # 位置编码
        self.pos_encoder = PositionalEncoding(d_model)
        
        # Transformer
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            batch_first=True
        )
        
        # 输出层
        self.output_projection = nn.Linear(d_model, output_vocab_size)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, src, tgt=None, src_mask=None, tgt_mask=None):
        """前向传播"""
        # 源序列嵌入
        src_emb = self.input_embedding(src) * math.sqrt(self.d_model)
        src_emb = self.pos_encoder(src_emb)
        
        if tgt is not None:
            # 训练模式
            tgt_emb = self.output_embedding(tgt) * math.sqrt(self.d_model)
            tgt_emb = self.pos_encoder(tgt_emb)
            
            output = self.transformer(src_emb, tgt_emb, 
                                    src_key_padding_mask=src_mask,
                                    tgt_key_padding_mask=tgt_mask)
            
            output = self.dropout(output)
            logits = self.output_projection(output)
            return logits
        else:
            # 推理模式
            return self.generate(src_emb, src_mask)
    
    def generate(self, src_emb, src_mask=None, max_length=100):
        """生成SQL序列"""
        batch_size = src_emb.size(0)
        device = src_emb.device
        
        # 初始化目标序列
        tgt = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
        
        for _ in range(max_length):
            tgt_emb = self.output_embedding(tgt) * math.sqrt(self.d_model)
            tgt_emb = self.pos_encoder(tgt_emb)
            
            output = self.transformer(src_emb, tgt_emb,
                                    src_key_padding_mask=src_mask)
            
            logits = self.output_projection(output[:, -1:, :])
            next_token = torch.argmax(logits, dim=-1)
            
            tgt = torch.cat([tgt, next_token], dim=1)
            
            # 检查是否结束
            if torch.all(next_token == 0):  # EOS token
                break
        
        return tgt[:, 1:]  # 移除起始token

# 模型实例化
transformer_model = TransformerText2SQL(
    input_vocab_size=10000,
    output_vocab_size=5000,
    d_model=512,
    nhead=8,
    num_encoder_layers=6,
    num_decoder_layers=6
)

print(f"Transformer模型参数数量: {sum(p.numel() for p in transformer_model.parameters())}")

3.2 预训练模型微调

from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch

class T5Text2SQL(nn.Module):
    """基于T5的Text2SQL模型"""
    
    def __init__(self, model_name='t5-base'):
        super(T5Text2SQL, self).__init__()
        
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name)
        
        # 添加特殊token
        special_tokens = ['<table>', '</table>', '<column>', '</column>']
        self.tokenizer.add_tokens(special_tokens)
        self.model.resize_token_embeddings(len(self.tokenizer))
    
    def forward(self, input_text, target_text=None):
        """前向传播"""
        # 编码输入
        inputs = self.tokenizer(
            input_text,
            padding=True,
            truncation=True,
            return_tensors='pt',
            max_length=512
        )
        
        if target_text is not None:
            # 训练模式
            targets = self.tokenizer(
                target_text,
                padding=True,
                truncation=True,
                return_tensors='pt',
                max_length=256
            )
            
            outputs = self.model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                labels=targets['input_ids']
            )
            
            return outputs.loss, outputs.logits
        else:
            # 推理模式
            outputs = self.model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_length=256,
                num_beams=4,
                early_stopping=True
            )
            
            return outputs
    
    def predict(self, input_text):
        """预测SQL语句"""
        outputs = self.forward(input_text)
        sql_queries = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        return sql_queries

# 使用示例
t5_model = T5Text2SQL('t5-base')

# 示例输入
input_text = [
    "translate to SQL: Show me all employees with salary greater than 50000 | tables: employees (id, name, salary, department_id)"
]

# 预测
sql_queries = t5_model.predict(input_text)
print(f"生成的SQL: {sql_queries[0]}")

4. 语法制导模型

4.1 抽象语法树(AST)生成

class SQLASTNode:
    """SQL抽象语法树节点"""
    
    def __init__(self, node_type, value=None, children=None):
        self.node_type = node_type
        self.value = value
        self.children = children or []
    
    def add_child(self, child):
        self.children.append(child)
    
    def to_sql(self):
        """将AST转换为SQL字符串"""
        if self.node_type == 'SELECT':
            return self._select_to_sql()
        elif self.node_type == 'FROM':
            return f"FROM {self.value}"
        elif self.node_type == 'WHERE':
            return f"WHERE {self._condition_to_sql()}"
        elif self.node_type == 'COLUMN':
            return self.value
        elif self.node_type == 'CONDITION':
            return self._condition_to_sql()
        else:
            return str(self.value)
    
    def _select_to_sql(self):
        """生成SELECT语句"""
        parts = ['SELECT']
        
        # 处理列
        columns = []
        for child in self.children:
            if child.node_type == 'COLUMN':
                columns.append(child.to_sql())
            elif child.node_type == 'FROM':
                parts.append(', '.join(columns) if columns else '*')
                parts.append(child.to_sql())
            elif child.node_type == 'WHERE':
                parts.append(child.to_sql())
        
        return ' '.join(parts)
    
    def _condition_to_sql(self):
        """生成条件语句"""
        if len(self.children) == 3:
            left, op, right = self.children
            return f"{left.to_sql()} {op.value} {right.to_sql()}"
        return self.value

class GrammarGuidedDecoder(nn.Module):
    """语法制导解码器"""
    
    def __init__(self, hidden_dim, vocab_size):
        super(GrammarGuidedDecoder, self).__init__()
        
        self.hidden_dim = hidden_dim
        
        # 不同语法规则的解码器
        self.select_decoder = nn.Linear(hidden_dim, vocab_size)
        self.from_decoder = nn.Linear(hidden_dim, vocab_size)
        self.where_decoder = nn.Linear(hidden_dim, vocab_size)
        self.column_decoder = nn.Linear(hidden_dim, vocab_size)
        
        # 语法规则分类器
        self.rule_classifier = nn.Linear(hidden_dim, 4)  # SELECT, FROM, WHERE, COLUMN
        
    def forward(self, hidden_state, current_rule=None):
        """根据当前语法规则解码"""
        if current_rule is None:
            # 预测下一个语法规则
            rule_logits = self.rule_classifier(hidden_state)
            return rule_logits
        
        # 根据规则选择对应的解码器
        if current_rule == 'SELECT':
            return self.select_decoder(hidden_state)
        elif current_rule == 'FROM':
            return self.from_decoder(hidden_state)
        elif current_rule == 'WHERE':
            return self.where_decoder(hidden_state)
        elif current_rule == 'COLUMN':
            return self.column_decoder(hidden_state)
        else:
            raise ValueError(f"Unknown rule: {current_rule}")

class GrammarGuidedText2SQL(nn.Module):
    """语法制导的Text2SQL模型"""
    
    def __init__(self, input_vocab_size, output_vocab_size, 
                 embedding_dim, hidden_dim):
        super(GrammarGuidedText2SQL, self).__init__()
        
        # 编码器
        self.encoder_embedding = nn.Embedding(input_vocab_size, embedding_dim)
        self.encoder_lstm = nn.LSTM(embedding_dim, hidden_dim, 
                                   batch_first=True, bidirectional=True)
        
        # 语法制导解码器
        self.decoder = GrammarGuidedDecoder(hidden_dim * 2, output_vocab_size)
        
        # SQL语法规则
        self.grammar_rules = {
            'ROOT': ['SELECT'],
            'SELECT': ['COLUMN', 'FROM'],
            'FROM': ['TABLE', 'WHERE?'],
            'WHERE': ['CONDITION'],
            'CONDITION': ['COLUMN', 'OPERATOR', 'VALUE']
        }
    
    def encode(self, input_seq):
        """编码输入序列"""
        embedded = self.encoder_embedding(input_seq)
        output, (hidden, cell) = self.encoder_lstm(embedded)
        return output, (hidden, cell)
    
    def decode_with_grammar(self, encoder_output, max_depth=10):
        """使用语法规则解码"""
        # 初始化AST根节点
        root = SQLASTNode('ROOT')
        
        # 使用栈来管理解码过程
        decode_stack = [(root, 'SELECT', 0)]
        
        while decode_stack and max_depth > 0:
            current_node, rule, depth = decode_stack.pop()
            
            if depth >= max_depth:
                break
            
            # 获取编码器输出的平均值作为上下文
            context = encoder_output.mean(dim=1)
            
            # 解码当前规则
            logits = self.decoder(context, rule)
            predicted_token = torch.argmax(logits, dim=-1)
            
            # 创建新节点
            new_node = SQLASTNode(rule, predicted_token.item())
            current_node.add_child(new_node)
            
            # 根据语法规则添加子节点到栈中
            if rule in self.grammar_rules:
                for child_rule in reversed(self.grammar_rules[rule]):
                    if not child_rule.endswith('?'):  # 非可选规则
                        decode_stack.append((new_node, child_rule, depth + 1))
            
            max_depth -= 1
        
        return root
    
    def forward(self, input_seq):
        """前向传播"""
        encoder_output, _ = self.encode(input_seq)
        ast = self.decode_with_grammar(encoder_output)
        return ast

# 使用示例
grammar_model = GrammarGuidedText2SQL(
    input_vocab_size=10000,
    output_vocab_size=5000,
    embedding_dim=256,
    hidden_dim=512
)

print(f"语法制导模型参数数量: {sum(p.numel() for p in grammar_model.parameters())}")

5. 图神经网络模型

5.1 Schema图构建

import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data

class SchemaGraphBuilder:
    """数据库Schema图构建器"""
    
    def __init__(self):
        self.node_types = ['table', 'column', 'value']
        self.edge_types = ['table_to_column', 'column_to_table', 
                          'foreign_key', 'column_to_value']
    
    def build_graph(self, schema, query_entities):
        """构建Schema图"""
        nodes = []
        edges = []
        node_features = []
        
        node_id = 0
        node_mapping = {}
        
        # 添加表节点
        for table_name in schema['tables']:
            nodes.append({
                'id': node_id,
                'type': 'table',
                'name': table_name
            })
            node_mapping[f"table_{table_name}"] = node_id
            node_features.append([1, 0, 0])  # one-hot for table
            node_id += 1
        
        # 添加列节点
        for table_name, columns in schema['columns'].items():
            table_id = node_mapping[f"table_{table_name}"]
            
            for column_name in columns:
                nodes.append({
                    'id': node_id,
                    'type': 'column',
                    'name': column_name,
                    'table': table_name
                })
                node_mapping[f"column_{table_name}_{column_name}"] = node_id
                node_features.append([0, 1, 0])  # one-hot for column
                
                # 添加表到列的边
                edges.append([table_id, node_id])
                edges.append([node_id, table_id])
                
                node_id += 1
        
        # 添加查询中的值节点
        for entity in query_entities.get('values', []):
            nodes.append({
                'id': node_id,
                'type': 'value',
                'name': entity
            })
            node_features.append([0, 0, 1])  # one-hot for value
            node_id += 1
        
        # 转换为PyTorch Geometric格式
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        x = torch.tensor(node_features, dtype=torch.float)
        
        graph_data = Data(x=x, edge_index=edge_index)
        graph_data.node_mapping = node_mapping
        graph_data.nodes_info = nodes
        
        return graph_data

class GraphNeuralNetwork(nn.Module):
    """图神经网络"""
    
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2):
        super(GraphNeuralNetwork, self).__init__()
        
        self.num_layers = num_layers
        self.convs = nn.ModuleList()
        
        # 第一层
        self.convs.append(GCNConv(input_dim, hidden_dim))
        
        # 中间层
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
        
        # 最后一层
        if num_layers > 1:
            self.convs.append(GCNConv(hidden_dim, output_dim))
        
        self.dropout = nn.Dropout(0.1)
        self.activation = nn.ReLU()
    
    def forward(self, x, edge_index):
        """前向传播"""
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            
            if i < len(self.convs) - 1:
                x = self.activation(x)
                x = self.dropout(x)
        
        return x

class GraphText2SQL(nn.Module):
    """基于图神经网络的Text2SQL模型"""
    
    def __init__(self, text_vocab_size, sql_vocab_size, 
                 embedding_dim, hidden_dim, graph_dim):
        super(GraphText2SQL, self).__init__()
        
        # 文本编码器
        self.text_embedding = nn.Embedding(text_vocab_size, embedding_dim)
        self.text_encoder = nn.LSTM(embedding_dim, hidden_dim, 
                                   batch_first=True, bidirectional=True)
        
        # 图神经网络
        self.graph_nn = GraphNeuralNetwork(
            input_dim=3,  # 节点类型的one-hot维度
            hidden_dim=graph_dim,
            output_dim=graph_dim
        )
        
        # 融合层
        self.fusion = nn.Linear(hidden_dim * 2 + graph_dim, hidden_dim)
        
        # SQL解码器
        self.sql_decoder = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.sql_embedding = nn.Embedding(sql_vocab_size, embedding_dim)
        self.output_projection = nn.Linear(hidden_dim, sql_vocab_size)
        
        self.schema_builder = SchemaGraphBuilder()
    
    def encode_text(self, text_seq):
        """编码文本序列"""
        embedded = self.text_embedding(text_seq)
        output, (hidden, cell) = self.text_encoder(embedded)
        return output, (hidden, cell)
    
    def encode_graph(self, schema, query_entities):
        """编码Schema图"""
        graph_data = self.schema_builder.build_graph(schema, query_entities)
        graph_embeddings = self.graph_nn(graph_data.x, graph_data.edge_index)
        
        # 聚合图表示
        graph_repr = torch.mean(graph_embeddings, dim=0, keepdim=True)
        return graph_repr, graph_data
    
    def forward(self, text_seq, schema, query_entities, target_sql=None):
        """前向传播"""
        # 编码文本
        text_output, text_hidden = self.encode_text(text_seq)
        text_repr = text_output.mean(dim=1)  # 平均池化
        
        # 编码图
        graph_repr, graph_data = self.encode_graph(schema, query_entities)
        graph_repr = graph_repr.expand(text_repr.size(0), -1)
        
        # 融合文本和图表示
        combined_repr = torch.cat([text_repr, graph_repr], dim=1)
        fused_repr = self.fusion(combined_repr)
        
        if target_sql is not None:
            # 训练模式
            sql_embedded = self.sql_embedding(target_sql)
            decoder_output, _ = self.sql_decoder(sql_embedded)
            
            # 结合融合表示
            decoder_output = decoder_output + fused_repr.unsqueeze(1)
            logits = self.output_projection(decoder_output)
            return logits
        else:
            # 推理模式
            return self.generate_sql(fused_repr)
    
    def generate_sql(self, fused_repr, max_length=100):
        """生成SQL序列"""
        batch_size = fused_repr.size(0)
        device = fused_repr.device
        
        # 初始化解码器输入
        decoder_input = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
        decoder_hidden = None
        
        outputs = []
        
        for _ in range(max_length):
            # 嵌入当前输入
            embedded = self.sql_embedding(decoder_input)
            
            # 解码一步
            output, decoder_hidden = self.sql_decoder(embedded, decoder_hidden)
            
            # 结合融合表示
            output = output + fused_repr.unsqueeze(1)
            logits = self.output_projection(output)
            
            # 选择下一个token
            next_token = torch.argmax(logits, dim=-1)
            outputs.append(next_token)
            
            # 更新输入
            decoder_input = next_token
            
            # 检查结束条件
            if torch.all(next_token == 0):  # EOS token
                break
        
        return torch.cat(outputs, dim=1)

# 使用示例
graph_model = GraphText2SQL(
    text_vocab_size=10000,
    sql_vocab_size=5000,
    embedding_dim=256,
    hidden_dim=512,
    graph_dim=128
)

print(f"图神经网络模型参数数量: {sum(p.numel() for p in graph_model.parameters())}")

6. 多任务学习架构

6.1 联合训练模型

class MultiTaskText2SQL(nn.Module):
    """多任务Text2SQL模型"""
    
    def __init__(self, vocab_size, embedding_dim, hidden_dim, 
                 num_tables, num_columns, num_operators):
        super(MultiTaskText2SQL, self).__init__()
        
        # 共享编码器
        self.shared_encoder = nn.LSTM(
            embedding_dim, hidden_dim, 
            batch_first=True, bidirectional=True
        )
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # 任务特定的头部
        self.table_classifier = nn.Linear(hidden_dim * 2, num_tables)
        self.column_classifier = nn.Linear(hidden_dim * 2, num_columns)
        self.operator_classifier = nn.Linear(hidden_dim * 2, num_operators)
        self.aggregation_classifier = nn.Linear(hidden_dim * 2, 6)  # COUNT, SUM, AVG, MAX, MIN, NONE
        
        # SQL生成器
        self.sql_generator = nn.LSTM(
            embedding_dim, hidden_dim, batch_first=True
        )
        self.sql_projection = nn.Linear(hidden_dim, vocab_size)
        
        self.dropout = nn.Dropout(0.1)
    
    def encode(self, input_seq):
        """共享编码器"""
        embedded = self.embedding(input_seq)
        output, (hidden, cell) = self.shared_encoder(embedded)
        return output, (hidden, cell)
    
    def forward(self, input_seq, task='all'):
        """前向传播"""
        # 共享编码
        encoder_output, encoder_hidden = self.encode(input_seq)
        
        # 池化得到句子表示
        sentence_repr = encoder_output.mean(dim=1)
        sentence_repr = self.dropout(sentence_repr)
        
        outputs = {}
        
        if task in ['all', 'table']:
            outputs['table'] = self.table_classifier(sentence_repr)
        
        if task in ['all', 'column']:
            outputs['column'] = self.column_classifier(sentence_repr)
        
        if task in ['all', 'operator']:
            outputs['operator'] = self.operator_classifier(sentence_repr)
        
        if task in ['all', 'aggregation']:
            outputs['aggregation'] = self.aggregation_classifier(sentence_repr)
        
        if task in ['all', 'sql']:
            # SQL生成任务
            sql_output, _ = self.sql_generator(
                encoder_output, encoder_hidden
            )
            outputs['sql'] = self.sql_projection(sql_output)
        
        return outputs
    
    def compute_loss(self, outputs, targets, task_weights=None):
        """计算多任务损失"""
        if task_weights is None:
            task_weights = {
                'table': 1.0,
                'column': 1.0,
                'operator': 1.0,
                'aggregation': 1.0,
                'sql': 2.0  # SQL生成任务权重更高
            }
        
        total_loss = 0
        loss_components = {}
        
        for task, output in outputs.items():
            if task in targets:
                if task == 'sql':
                    # 序列生成损失
                    loss_fn = nn.CrossEntropyLoss(ignore_index=0)
                    loss = loss_fn(
                        output.view(-1, output.size(-1)),
                        targets[task].view(-1)
                    )
                else:
                    # 分类损失
                    loss_fn = nn.CrossEntropyLoss()
                    loss = loss_fn(output, targets[task])
                
                weighted_loss = loss * task_weights.get(task, 1.0)
                total_loss += weighted_loss
                loss_components[task] = loss.item()
        
        return total_loss, loss_components

# 使用示例
multi_task_model = MultiTaskText2SQL(
    vocab_size=10000,
    embedding_dim=256,
    hidden_dim=512,
    num_tables=100,
    num_columns=500,
    num_operators=10
)

print(f"多任务模型参数数量: {sum(p.numel() for p in multi_task_model.parameters())}")

7. 模型评估与比较

7.1 评估指标

import sqlparse
from typing import List, Dict

class Text2SQLEvaluator:
    """Text2SQL模型评估器"""
    
    def __init__(self):
        self.metrics = {
            'exact_match': 0,
            'execution_accuracy': 0,
            'component_accuracy': {
                'select': 0,
                'from': 0,
                'where': 0,
                'group_by': 0,
                'order_by': 0
            }
        }
    
    def normalize_sql(self, sql: str) -> str:
        """标准化SQL语句"""
        # 解析SQL
        parsed = sqlparse.parse(sql)[0]
        
        # 格式化
        formatted = sqlparse.format(
            str(parsed),
            reindent=True,
            keyword_case='upper',
            identifier_case='lower',
            strip_comments=True
        )
        
        return formatted.strip()
    
    def exact_match_accuracy(self, predictions: List[str], 
                           references: List[str]) -> float:
        """精确匹配准确率"""
        correct = 0
        total = len(predictions)
        
        for pred, ref in zip(predictions, references):
            try:
                pred_normalized = self.normalize_sql(pred)
                ref_normalized = self.normalize_sql(ref)
                
                if pred_normalized == ref_normalized:
                    correct += 1
            except Exception:
                # SQL解析失败
                continue
        
        return correct / total if total > 0 else 0
    
    def component_accuracy(self, predictions: List[str], 
                         references: List[str]) -> Dict[str, float]:
        """组件准确率"""
        components = ['select', 'from', 'where', 'group_by', 'order_by']
        accuracies = {comp: 0 for comp in components}
        
        for pred, ref in zip(predictions, references):
            try:
                pred_components = self.extract_components(pred)
                ref_components = self.extract_components(ref)
                
                for comp in components:
                    if pred_components[comp] == ref_components[comp]:
                        accuracies[comp] += 1
            except Exception:
                continue
        
        total = len(predictions)
        return {comp: acc / total for comp, acc in accuracies.items()}
    
    def extract_components(self, sql: str) -> Dict[str, str]:
        """提取SQL组件"""
        components = {
            'select': '',
            'from': '',
            'where': '',
            'group_by': '',
            'order_by': ''
        }
        
        try:
            parsed = sqlparse.parse(sql)[0]
            tokens = list(parsed.flatten())
            
            current_component = None
            component_tokens = []
            
            for token in tokens:
                if token.ttype is sqlparse.tokens.Keyword:
                    keyword = token.value.upper()
                    
                    if current_component and component_tokens:
                        components[current_component] = ' '.join(component_tokens).strip()
                        component_tokens = []
                    
                    if keyword == 'SELECT':
                        current_component = 'select'
                    elif keyword == 'FROM':
                        current_component = 'from'
                    elif keyword == 'WHERE':
                        current_component = 'where'
                    elif keyword in ['GROUP', 'ORDER']:
                        current_component = keyword.lower() + '_by'
                    else:
                        current_component = None
                
                elif current_component and token.ttype not in [
                    sqlparse.tokens.Whitespace, sqlparse.tokens.Newline
                ]:
                    component_tokens.append(token.value)
            
            # 处理最后一个组件
            if current_component and component_tokens:
                components[current_component] = ' '.join(component_tokens).strip()
        
        except Exception:
            pass
        
        return components
    
    def evaluate(self, predictions: List[str], references: List[str]) -> Dict:
        """综合评估"""
        results = {
            'exact_match': self.exact_match_accuracy(predictions, references),
            'component_accuracy': self.component_accuracy(predictions, references),
            'total_samples': len(predictions)
        }
        
        return results

# 使用示例
evaluator = Text2SQLEvaluator()

# 示例数据
predictions = [
    "SELECT name FROM employees WHERE salary > 50000",
    "SELECT COUNT(*) FROM departments",
    "SELECT AVG(salary) FROM employees GROUP BY department_id"
]

references = [
    "SELECT name FROM employees WHERE salary > 50000",
    "SELECT count(*) FROM departments",
    "SELECT avg(salary) FROM employees GROUP BY department_id"
]

# 评估
results = evaluator.evaluate(predictions, references)
print("评估结果:")
print(f"精确匹配准确率: {results['exact_match']:.4f}")
print("组件准确率:")
for comp, acc in results['component_accuracy'].items():
    print(f"  {comp}: {acc:.4f}")

7.2 模型比较框架

class ModelComparison:
    """模型比较框架"""
    
    def __init__(self):
        self.models = {}
        self.evaluator = Text2SQLEvaluator()
    
    def add_model(self, name: str, model, tokenizer=None):
        """添加模型"""
        self.models[name] = {
            'model': model,
            'tokenizer': tokenizer,
            'results': None
        }
    
    def evaluate_model(self, model_name: str, test_data: List[Dict]):
        """评估单个模型"""
        if model_name not in self.models:
            raise ValueError(f"Model {model_name} not found")
        
        model_info = self.models[model_name]
        model = model_info['model']
        tokenizer = model_info['tokenizer']
        
        predictions = []
        references = []
        
        model.eval()
        with torch.no_grad():
            for sample in test_data:
                query = sample['question']
                reference_sql = sample['sql']
                
                # 模型预测
                if hasattr(model, 'predict'):
                    predicted_sql = model.predict([query])[0]
                else:
                    # 通用预测逻辑
                    if tokenizer:
                        inputs = tokenizer(query, return_tensors='pt')
                        outputs = model.generate(**inputs)
                        predicted_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
                    else:
                        # 假设模型有自己的预测方法
                        predicted_sql = "SELECT * FROM table"  # 占位符
                
                predictions.append(predicted_sql)
                references.append(reference_sql)
        
        # 评估
        results = self.evaluator.evaluate(predictions, references)
        self.models[model_name]['results'] = results
        
        return results
    
    def compare_all(self, test_data: List[Dict]):
        """比较所有模型"""
        comparison_results = {}
        
        for model_name in self.models:
            print(f"评估模型: {model_name}")
            results = self.evaluate_model(model_name, test_data)
            comparison_results[model_name] = results
        
        return comparison_results
    
    def print_comparison(self):
        """打印比较结果"""
        print("\n=== 模型比较结果 ===")
        print(f"{'模型名称':<20} {'精确匹配':<10} {'SELECT':<10} {'FROM':<10} {'WHERE':<10}")
        print("-" * 60)
        
        for model_name, model_info in self.models.items():
            if model_info['results']:
                results = model_info['results']
                em = results['exact_match']
                select_acc = results['component_accuracy']['select']
                from_acc = results['component_accuracy']['from']
                where_acc = results['component_accuracy']['where']
                
                print(f"{model_name:<20} {em:<10.4f} {select_acc:<10.4f} {from_acc:<10.4f} {where_acc:<10.4f}")

# 使用示例
comparison = ModelComparison()

# 添加模型(示例)
# comparison.add_model('Seq2Seq', seq2seq_model)
# comparison.add_model('Transformer', transformer_model)
# comparison.add_model('T5', t5_model, t5_tokenizer)

# 测试数据示例
test_data = [
    {
        'question': 'Show me all employees with salary greater than 50000',
        'sql': 'SELECT * FROM employees WHERE salary > 50000'
    },
    {
        'question': 'How many departments are there',
        'sql': 'SELECT COUNT(*) FROM departments'
    }
]

# 比较模型
# results = comparison.compare_all(test_data)
# comparison.print_comparison()

总结

本章详细介绍了Text2SQL的各种模型架构:

  1. 序列到序列模型:基础的编码器-解码器架构,包括注意力机制的改进
  2. Transformer架构:现代的自注意力机制模型,包括预训练模型的应用
  3. 语法制导模型:利用SQL语法结构指导生成过程
  4. 图神经网络模型:利用数据库schema的图结构信息
  5. 多任务学习:联合训练多个相关任务提升性能
  6. 评估框架:全面的模型评估和比较方法

每种架构都有其特点和适用场景: - Seq2Seq适合简单查询,易于实现 - Transformer性能强大,是当前主流 - 语法制导保证SQL语法正确性 - 图神经网络能更好利用schema信息 - 多任务学习提升模型泛化能力

在实际应用中,可以根据具体需求选择合适的架构,或者结合多种方法的优势。在下一章中,我们将学习如何进行数据预处理和特征工程。