1. Text2SQL模型概述
Text2SQL模型是将自然语言查询转换为SQL语句的核心组件。随着深度学习技术的发展,Text2SQL模型经历了从基于规则的方法到端到端神经网络模型的演进。本章将详细介绍各种Text2SQL模型架构及其特点。
1.1 模型发展历程
第一代:基于规则的方法(2000s-2010s)
- 特点:使用手工编写的语法规则和模板
- 优势:可解释性强,对特定领域效果好
- 劣势:扩展性差,难以处理复杂查询
第二代:统计机器学习方法(2010s)
- 特点:使用传统机器学习算法
- 代表:基于CRF的序列标注模型
- 优势:能够从数据中学习
- 劣势:特征工程复杂,性能有限
第三代:深度学习方法(2015-2020)
- 特点:使用RNN、LSTM等神经网络
- 代表:Seq2Seq模型、注意力机制
- 优势:端到端学习,性能大幅提升
- 劣势:需要大量训练数据
第四代:预训练模型时代(2020-至今)
- 特点:基于Transformer和预训练模型
- 代表:BERT、T5、GPT系列
- 优势:强大的语言理解能力
- 劣势:计算资源需求大
1.2 模型分类
按生成方式分类
- 生成式模型:直接生成SQL语句
- 分类式模型:将SQL生成分解为多个分类任务
- 混合式模型:结合生成和分类的优势
按架构类型分类
- 序列到序列模型:将自然语言序列转换为SQL序列
- 图神经网络模型:利用数据库schema的图结构
- 语法制导模型:基于SQL语法结构生成
2. 序列到序列模型
2.1 基础Seq2Seq架构
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicSeq2SeqModel(nn.Module):
"""基础序列到序列模型"""
def __init__(self, input_vocab_size, output_vocab_size,
embedding_dim, hidden_dim, num_layers=2):
super(BasicSeq2SeqModel, self).__init__()
# 编码器
self.encoder_embedding = nn.Embedding(input_vocab_size, embedding_dim)
self.encoder_lstm = nn.LSTM(embedding_dim, hidden_dim,
num_layers, batch_first=True)
# 解码器
self.decoder_embedding = nn.Embedding(output_vocab_size, embedding_dim)
self.decoder_lstm = nn.LSTM(embedding_dim, hidden_dim,
num_layers, batch_first=True)
# 输出层
self.output_projection = nn.Linear(hidden_dim, output_vocab_size)
self.dropout = nn.Dropout(0.1)
def encode(self, input_seq):
"""编码输入序列"""
embedded = self.encoder_embedding(input_seq)
output, (hidden, cell) = self.encoder_lstm(embedded)
return output, (hidden, cell)
def decode(self, target_seq, encoder_hidden):
"""解码目标序列"""
embedded = self.decoder_embedding(target_seq)
output, _ = self.decoder_lstm(embedded, encoder_hidden)
output = self.dropout(output)
logits = self.output_projection(output)
return logits
def forward(self, input_seq, target_seq=None):
"""前向传播"""
# 编码
encoder_output, encoder_hidden = self.encode(input_seq)
if target_seq is not None:
# 训练模式:使用teacher forcing
logits = self.decode(target_seq, encoder_hidden)
return logits
else:
# 推理模式:自回归生成
return self.generate(encoder_hidden)
def generate(self, encoder_hidden, max_length=100):
"""生成SQL序列"""
batch_size = encoder_hidden[0].size(1)
device = encoder_hidden[0].device
# 初始化解码器输入
decoder_input = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
decoder_hidden = encoder_hidden
outputs = []
for _ in range(max_length):
# 解码一步
embedded = self.decoder_embedding(decoder_input)
output, decoder_hidden = self.decoder_lstm(embedded, decoder_hidden)
logits = self.output_projection(output)
# 选择下一个token
next_token = torch.argmax(logits, dim=-1)
outputs.append(next_token)
# 更新解码器输入
decoder_input = next_token
# 检查是否结束
if torch.all(next_token == 0): # 假设0是EOS token
break
return torch.cat(outputs, dim=1)
# 模型实例化示例
input_vocab_size = 10000
output_vocab_size = 5000
embedding_dim = 256
hidden_dim = 512
model = BasicSeq2SeqModel(input_vocab_size, output_vocab_size,
embedding_dim, hidden_dim)
print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
2.2 注意力机制
class AttentionSeq2SeqModel(nn.Module):
"""带注意力机制的序列到序列模型"""
def __init__(self, input_vocab_size, output_vocab_size,
embedding_dim, hidden_dim, num_layers=2):
super(AttentionSeq2SeqModel, self).__init__()
# 编码器
self.encoder_embedding = nn.Embedding(input_vocab_size, embedding_dim)
self.encoder_lstm = nn.LSTM(embedding_dim, hidden_dim,
num_layers, batch_first=True,
bidirectional=True)
# 解码器
self.decoder_embedding = nn.Embedding(output_vocab_size, embedding_dim)
self.decoder_lstm = nn.LSTM(embedding_dim + hidden_dim * 2, hidden_dim,
num_layers, batch_first=True)
# 注意力机制
self.attention = nn.Linear(hidden_dim * 3, hidden_dim)
self.attention_combine = nn.Linear(hidden_dim * 2, hidden_dim)
# 输出层
self.output_projection = nn.Linear(hidden_dim, output_vocab_size)
self.dropout = nn.Dropout(0.1)
def attention_mechanism(self, decoder_hidden, encoder_outputs):
"""计算注意力权重"""
batch_size, seq_len, hidden_size = encoder_outputs.size()
# 扩展解码器隐藏状态
decoder_hidden_expanded = decoder_hidden.unsqueeze(1).expand(
batch_size, seq_len, -1)
# 计算注意力分数
combined = torch.cat([decoder_hidden_expanded, encoder_outputs], dim=2)
attention_scores = self.attention(combined)
attention_weights = F.softmax(attention_scores.sum(dim=2), dim=1)
# 计算上下文向量
context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
context = context.squeeze(1)
return context, attention_weights
def encode(self, input_seq):
"""编码输入序列"""
embedded = self.encoder_embedding(input_seq)
output, (hidden, cell) = self.encoder_lstm(embedded)
return output, (hidden, cell)
def decode_step(self, input_token, decoder_hidden, encoder_outputs):
"""解码一步"""
embedded = self.decoder_embedding(input_token)
# 计算注意力
context, attention_weights = self.attention_mechanism(
decoder_hidden[0][-1], encoder_outputs)
# 结合输入和上下文
combined_input = torch.cat([embedded.squeeze(1), context], dim=1)
combined_input = combined_input.unsqueeze(1)
# LSTM解码
output, decoder_hidden = self.decoder_lstm(combined_input, decoder_hidden)
# 输出投影
output = self.dropout(output)
logits = self.output_projection(output)
return logits, decoder_hidden, attention_weights
# 使用示例
attention_model = AttentionSeq2SeqModel(input_vocab_size, output_vocab_size,
embedding_dim, hidden_dim)
print(f"注意力模型参数数量: {sum(p.numel() for p in attention_model.parameters())}")
3. Transformer架构
3.1 基础Transformer模型
import math
class PositionalEncoding(nn.Module):
"""位置编码"""
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
class TransformerText2SQL(nn.Module):
"""基于Transformer的Text2SQL模型"""
def __init__(self, input_vocab_size, output_vocab_size, d_model=512,
nhead=8, num_encoder_layers=6, num_decoder_layers=6):
super(TransformerText2SQL, self).__init__()
self.d_model = d_model
# 嵌入层
self.input_embedding = nn.Embedding(input_vocab_size, d_model)
self.output_embedding = nn.Embedding(output_vocab_size, d_model)
# 位置编码
self.pos_encoder = PositionalEncoding(d_model)
# Transformer
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
batch_first=True
)
# 输出层
self.output_projection = nn.Linear(d_model, output_vocab_size)
self.dropout = nn.Dropout(0.1)
def forward(self, src, tgt=None, src_mask=None, tgt_mask=None):
"""前向传播"""
# 源序列嵌入
src_emb = self.input_embedding(src) * math.sqrt(self.d_model)
src_emb = self.pos_encoder(src_emb)
if tgt is not None:
# 训练模式
tgt_emb = self.output_embedding(tgt) * math.sqrt(self.d_model)
tgt_emb = self.pos_encoder(tgt_emb)
output = self.transformer(src_emb, tgt_emb,
src_key_padding_mask=src_mask,
tgt_key_padding_mask=tgt_mask)
output = self.dropout(output)
logits = self.output_projection(output)
return logits
else:
# 推理模式
return self.generate(src_emb, src_mask)
def generate(self, src_emb, src_mask=None, max_length=100):
"""生成SQL序列"""
batch_size = src_emb.size(0)
device = src_emb.device
# 初始化目标序列
tgt = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
for _ in range(max_length):
tgt_emb = self.output_embedding(tgt) * math.sqrt(self.d_model)
tgt_emb = self.pos_encoder(tgt_emb)
output = self.transformer(src_emb, tgt_emb,
src_key_padding_mask=src_mask)
logits = self.output_projection(output[:, -1:, :])
next_token = torch.argmax(logits, dim=-1)
tgt = torch.cat([tgt, next_token], dim=1)
# 检查是否结束
if torch.all(next_token == 0): # EOS token
break
return tgt[:, 1:] # 移除起始token
# 模型实例化
transformer_model = TransformerText2SQL(
input_vocab_size=10000,
output_vocab_size=5000,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6
)
print(f"Transformer模型参数数量: {sum(p.numel() for p in transformer_model.parameters())}")
3.2 预训练模型微调
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
class T5Text2SQL(nn.Module):
"""基于T5的Text2SQL模型"""
def __init__(self, model_name='t5-base'):
super(T5Text2SQL, self).__init__()
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
self.model = T5ForConditionalGeneration.from_pretrained(model_name)
# 添加特殊token
special_tokens = ['<table>', '</table>', '<column>', '</column>']
self.tokenizer.add_tokens(special_tokens)
self.model.resize_token_embeddings(len(self.tokenizer))
def forward(self, input_text, target_text=None):
"""前向传播"""
# 编码输入
inputs = self.tokenizer(
input_text,
padding=True,
truncation=True,
return_tensors='pt',
max_length=512
)
if target_text is not None:
# 训练模式
targets = self.tokenizer(
target_text,
padding=True,
truncation=True,
return_tensors='pt',
max_length=256
)
outputs = self.model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
labels=targets['input_ids']
)
return outputs.loss, outputs.logits
else:
# 推理模式
outputs = self.model.generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_length=256,
num_beams=4,
early_stopping=True
)
return outputs
def predict(self, input_text):
"""预测SQL语句"""
outputs = self.forward(input_text)
sql_queries = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return sql_queries
# 使用示例
t5_model = T5Text2SQL('t5-base')
# 示例输入
input_text = [
"translate to SQL: Show me all employees with salary greater than 50000 | tables: employees (id, name, salary, department_id)"
]
# 预测
sql_queries = t5_model.predict(input_text)
print(f"生成的SQL: {sql_queries[0]}")
4. 语法制导模型
4.1 抽象语法树(AST)生成
class SQLASTNode:
"""SQL抽象语法树节点"""
def __init__(self, node_type, value=None, children=None):
self.node_type = node_type
self.value = value
self.children = children or []
def add_child(self, child):
self.children.append(child)
def to_sql(self):
"""将AST转换为SQL字符串"""
if self.node_type == 'SELECT':
return self._select_to_sql()
elif self.node_type == 'FROM':
return f"FROM {self.value}"
elif self.node_type == 'WHERE':
return f"WHERE {self._condition_to_sql()}"
elif self.node_type == 'COLUMN':
return self.value
elif self.node_type == 'CONDITION':
return self._condition_to_sql()
else:
return str(self.value)
def _select_to_sql(self):
"""生成SELECT语句"""
parts = ['SELECT']
# 处理列
columns = []
for child in self.children:
if child.node_type == 'COLUMN':
columns.append(child.to_sql())
elif child.node_type == 'FROM':
parts.append(', '.join(columns) if columns else '*')
parts.append(child.to_sql())
elif child.node_type == 'WHERE':
parts.append(child.to_sql())
return ' '.join(parts)
def _condition_to_sql(self):
"""生成条件语句"""
if len(self.children) == 3:
left, op, right = self.children
return f"{left.to_sql()} {op.value} {right.to_sql()}"
return self.value
class GrammarGuidedDecoder(nn.Module):
"""语法制导解码器"""
def __init__(self, hidden_dim, vocab_size):
super(GrammarGuidedDecoder, self).__init__()
self.hidden_dim = hidden_dim
# 不同语法规则的解码器
self.select_decoder = nn.Linear(hidden_dim, vocab_size)
self.from_decoder = nn.Linear(hidden_dim, vocab_size)
self.where_decoder = nn.Linear(hidden_dim, vocab_size)
self.column_decoder = nn.Linear(hidden_dim, vocab_size)
# 语法规则分类器
self.rule_classifier = nn.Linear(hidden_dim, 4) # SELECT, FROM, WHERE, COLUMN
def forward(self, hidden_state, current_rule=None):
"""根据当前语法规则解码"""
if current_rule is None:
# 预测下一个语法规则
rule_logits = self.rule_classifier(hidden_state)
return rule_logits
# 根据规则选择对应的解码器
if current_rule == 'SELECT':
return self.select_decoder(hidden_state)
elif current_rule == 'FROM':
return self.from_decoder(hidden_state)
elif current_rule == 'WHERE':
return self.where_decoder(hidden_state)
elif current_rule == 'COLUMN':
return self.column_decoder(hidden_state)
else:
raise ValueError(f"Unknown rule: {current_rule}")
class GrammarGuidedText2SQL(nn.Module):
"""语法制导的Text2SQL模型"""
def __init__(self, input_vocab_size, output_vocab_size,
embedding_dim, hidden_dim):
super(GrammarGuidedText2SQL, self).__init__()
# 编码器
self.encoder_embedding = nn.Embedding(input_vocab_size, embedding_dim)
self.encoder_lstm = nn.LSTM(embedding_dim, hidden_dim,
batch_first=True, bidirectional=True)
# 语法制导解码器
self.decoder = GrammarGuidedDecoder(hidden_dim * 2, output_vocab_size)
# SQL语法规则
self.grammar_rules = {
'ROOT': ['SELECT'],
'SELECT': ['COLUMN', 'FROM'],
'FROM': ['TABLE', 'WHERE?'],
'WHERE': ['CONDITION'],
'CONDITION': ['COLUMN', 'OPERATOR', 'VALUE']
}
def encode(self, input_seq):
"""编码输入序列"""
embedded = self.encoder_embedding(input_seq)
output, (hidden, cell) = self.encoder_lstm(embedded)
return output, (hidden, cell)
def decode_with_grammar(self, encoder_output, max_depth=10):
"""使用语法规则解码"""
# 初始化AST根节点
root = SQLASTNode('ROOT')
# 使用栈来管理解码过程
decode_stack = [(root, 'SELECT', 0)]
while decode_stack and max_depth > 0:
current_node, rule, depth = decode_stack.pop()
if depth >= max_depth:
break
# 获取编码器输出的平均值作为上下文
context = encoder_output.mean(dim=1)
# 解码当前规则
logits = self.decoder(context, rule)
predicted_token = torch.argmax(logits, dim=-1)
# 创建新节点
new_node = SQLASTNode(rule, predicted_token.item())
current_node.add_child(new_node)
# 根据语法规则添加子节点到栈中
if rule in self.grammar_rules:
for child_rule in reversed(self.grammar_rules[rule]):
if not child_rule.endswith('?'): # 非可选规则
decode_stack.append((new_node, child_rule, depth + 1))
max_depth -= 1
return root
def forward(self, input_seq):
"""前向传播"""
encoder_output, _ = self.encode(input_seq)
ast = self.decode_with_grammar(encoder_output)
return ast
# 使用示例
grammar_model = GrammarGuidedText2SQL(
input_vocab_size=10000,
output_vocab_size=5000,
embedding_dim=256,
hidden_dim=512
)
print(f"语法制导模型参数数量: {sum(p.numel() for p in grammar_model.parameters())}")
5. 图神经网络模型
5.1 Schema图构建
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data
class SchemaGraphBuilder:
"""数据库Schema图构建器"""
def __init__(self):
self.node_types = ['table', 'column', 'value']
self.edge_types = ['table_to_column', 'column_to_table',
'foreign_key', 'column_to_value']
def build_graph(self, schema, query_entities):
"""构建Schema图"""
nodes = []
edges = []
node_features = []
node_id = 0
node_mapping = {}
# 添加表节点
for table_name in schema['tables']:
nodes.append({
'id': node_id,
'type': 'table',
'name': table_name
})
node_mapping[f"table_{table_name}"] = node_id
node_features.append([1, 0, 0]) # one-hot for table
node_id += 1
# 添加列节点
for table_name, columns in schema['columns'].items():
table_id = node_mapping[f"table_{table_name}"]
for column_name in columns:
nodes.append({
'id': node_id,
'type': 'column',
'name': column_name,
'table': table_name
})
node_mapping[f"column_{table_name}_{column_name}"] = node_id
node_features.append([0, 1, 0]) # one-hot for column
# 添加表到列的边
edges.append([table_id, node_id])
edges.append([node_id, table_id])
node_id += 1
# 添加查询中的值节点
for entity in query_entities.get('values', []):
nodes.append({
'id': node_id,
'type': 'value',
'name': entity
})
node_features.append([0, 0, 1]) # one-hot for value
node_id += 1
# 转换为PyTorch Geometric格式
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
x = torch.tensor(node_features, dtype=torch.float)
graph_data = Data(x=x, edge_index=edge_index)
graph_data.node_mapping = node_mapping
graph_data.nodes_info = nodes
return graph_data
class GraphNeuralNetwork(nn.Module):
"""图神经网络"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2):
super(GraphNeuralNetwork, self).__init__()
self.num_layers = num_layers
self.convs = nn.ModuleList()
# 第一层
self.convs.append(GCNConv(input_dim, hidden_dim))
# 中间层
for _ in range(num_layers - 2):
self.convs.append(GCNConv(hidden_dim, hidden_dim))
# 最后一层
if num_layers > 1:
self.convs.append(GCNConv(hidden_dim, output_dim))
self.dropout = nn.Dropout(0.1)
self.activation = nn.ReLU()
def forward(self, x, edge_index):
"""前向传播"""
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i < len(self.convs) - 1:
x = self.activation(x)
x = self.dropout(x)
return x
class GraphText2SQL(nn.Module):
"""基于图神经网络的Text2SQL模型"""
def __init__(self, text_vocab_size, sql_vocab_size,
embedding_dim, hidden_dim, graph_dim):
super(GraphText2SQL, self).__init__()
# 文本编码器
self.text_embedding = nn.Embedding(text_vocab_size, embedding_dim)
self.text_encoder = nn.LSTM(embedding_dim, hidden_dim,
batch_first=True, bidirectional=True)
# 图神经网络
self.graph_nn = GraphNeuralNetwork(
input_dim=3, # 节点类型的one-hot维度
hidden_dim=graph_dim,
output_dim=graph_dim
)
# 融合层
self.fusion = nn.Linear(hidden_dim * 2 + graph_dim, hidden_dim)
# SQL解码器
self.sql_decoder = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
self.sql_embedding = nn.Embedding(sql_vocab_size, embedding_dim)
self.output_projection = nn.Linear(hidden_dim, sql_vocab_size)
self.schema_builder = SchemaGraphBuilder()
def encode_text(self, text_seq):
"""编码文本序列"""
embedded = self.text_embedding(text_seq)
output, (hidden, cell) = self.text_encoder(embedded)
return output, (hidden, cell)
def encode_graph(self, schema, query_entities):
"""编码Schema图"""
graph_data = self.schema_builder.build_graph(schema, query_entities)
graph_embeddings = self.graph_nn(graph_data.x, graph_data.edge_index)
# 聚合图表示
graph_repr = torch.mean(graph_embeddings, dim=0, keepdim=True)
return graph_repr, graph_data
def forward(self, text_seq, schema, query_entities, target_sql=None):
"""前向传播"""
# 编码文本
text_output, text_hidden = self.encode_text(text_seq)
text_repr = text_output.mean(dim=1) # 平均池化
# 编码图
graph_repr, graph_data = self.encode_graph(schema, query_entities)
graph_repr = graph_repr.expand(text_repr.size(0), -1)
# 融合文本和图表示
combined_repr = torch.cat([text_repr, graph_repr], dim=1)
fused_repr = self.fusion(combined_repr)
if target_sql is not None:
# 训练模式
sql_embedded = self.sql_embedding(target_sql)
decoder_output, _ = self.sql_decoder(sql_embedded)
# 结合融合表示
decoder_output = decoder_output + fused_repr.unsqueeze(1)
logits = self.output_projection(decoder_output)
return logits
else:
# 推理模式
return self.generate_sql(fused_repr)
def generate_sql(self, fused_repr, max_length=100):
"""生成SQL序列"""
batch_size = fused_repr.size(0)
device = fused_repr.device
# 初始化解码器输入
decoder_input = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
decoder_hidden = None
outputs = []
for _ in range(max_length):
# 嵌入当前输入
embedded = self.sql_embedding(decoder_input)
# 解码一步
output, decoder_hidden = self.sql_decoder(embedded, decoder_hidden)
# 结合融合表示
output = output + fused_repr.unsqueeze(1)
logits = self.output_projection(output)
# 选择下一个token
next_token = torch.argmax(logits, dim=-1)
outputs.append(next_token)
# 更新输入
decoder_input = next_token
# 检查结束条件
if torch.all(next_token == 0): # EOS token
break
return torch.cat(outputs, dim=1)
# 使用示例
graph_model = GraphText2SQL(
text_vocab_size=10000,
sql_vocab_size=5000,
embedding_dim=256,
hidden_dim=512,
graph_dim=128
)
print(f"图神经网络模型参数数量: {sum(p.numel() for p in graph_model.parameters())}")
6. 多任务学习架构
6.1 联合训练模型
class MultiTaskText2SQL(nn.Module):
"""多任务Text2SQL模型"""
def __init__(self, vocab_size, embedding_dim, hidden_dim,
num_tables, num_columns, num_operators):
super(MultiTaskText2SQL, self).__init__()
# 共享编码器
self.shared_encoder = nn.LSTM(
embedding_dim, hidden_dim,
batch_first=True, bidirectional=True
)
self.embedding = nn.Embedding(vocab_size, embedding_dim)
# 任务特定的头部
self.table_classifier = nn.Linear(hidden_dim * 2, num_tables)
self.column_classifier = nn.Linear(hidden_dim * 2, num_columns)
self.operator_classifier = nn.Linear(hidden_dim * 2, num_operators)
self.aggregation_classifier = nn.Linear(hidden_dim * 2, 6) # COUNT, SUM, AVG, MAX, MIN, NONE
# SQL生成器
self.sql_generator = nn.LSTM(
embedding_dim, hidden_dim, batch_first=True
)
self.sql_projection = nn.Linear(hidden_dim, vocab_size)
self.dropout = nn.Dropout(0.1)
def encode(self, input_seq):
"""共享编码器"""
embedded = self.embedding(input_seq)
output, (hidden, cell) = self.shared_encoder(embedded)
return output, (hidden, cell)
def forward(self, input_seq, task='all'):
"""前向传播"""
# 共享编码
encoder_output, encoder_hidden = self.encode(input_seq)
# 池化得到句子表示
sentence_repr = encoder_output.mean(dim=1)
sentence_repr = self.dropout(sentence_repr)
outputs = {}
if task in ['all', 'table']:
outputs['table'] = self.table_classifier(sentence_repr)
if task in ['all', 'column']:
outputs['column'] = self.column_classifier(sentence_repr)
if task in ['all', 'operator']:
outputs['operator'] = self.operator_classifier(sentence_repr)
if task in ['all', 'aggregation']:
outputs['aggregation'] = self.aggregation_classifier(sentence_repr)
if task in ['all', 'sql']:
# SQL生成任务
sql_output, _ = self.sql_generator(
encoder_output, encoder_hidden
)
outputs['sql'] = self.sql_projection(sql_output)
return outputs
def compute_loss(self, outputs, targets, task_weights=None):
"""计算多任务损失"""
if task_weights is None:
task_weights = {
'table': 1.0,
'column': 1.0,
'operator': 1.0,
'aggregation': 1.0,
'sql': 2.0 # SQL生成任务权重更高
}
total_loss = 0
loss_components = {}
for task, output in outputs.items():
if task in targets:
if task == 'sql':
# 序列生成损失
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
loss = loss_fn(
output.view(-1, output.size(-1)),
targets[task].view(-1)
)
else:
# 分类损失
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(output, targets[task])
weighted_loss = loss * task_weights.get(task, 1.0)
total_loss += weighted_loss
loss_components[task] = loss.item()
return total_loss, loss_components
# 使用示例
multi_task_model = MultiTaskText2SQL(
vocab_size=10000,
embedding_dim=256,
hidden_dim=512,
num_tables=100,
num_columns=500,
num_operators=10
)
print(f"多任务模型参数数量: {sum(p.numel() for p in multi_task_model.parameters())}")
7. 模型评估与比较
7.1 评估指标
import sqlparse
from typing import List, Dict
class Text2SQLEvaluator:
"""Text2SQL模型评估器"""
def __init__(self):
self.metrics = {
'exact_match': 0,
'execution_accuracy': 0,
'component_accuracy': {
'select': 0,
'from': 0,
'where': 0,
'group_by': 0,
'order_by': 0
}
}
def normalize_sql(self, sql: str) -> str:
"""标准化SQL语句"""
# 解析SQL
parsed = sqlparse.parse(sql)[0]
# 格式化
formatted = sqlparse.format(
str(parsed),
reindent=True,
keyword_case='upper',
identifier_case='lower',
strip_comments=True
)
return formatted.strip()
def exact_match_accuracy(self, predictions: List[str],
references: List[str]) -> float:
"""精确匹配准确率"""
correct = 0
total = len(predictions)
for pred, ref in zip(predictions, references):
try:
pred_normalized = self.normalize_sql(pred)
ref_normalized = self.normalize_sql(ref)
if pred_normalized == ref_normalized:
correct += 1
except Exception:
# SQL解析失败
continue
return correct / total if total > 0 else 0
def component_accuracy(self, predictions: List[str],
references: List[str]) -> Dict[str, float]:
"""组件准确率"""
components = ['select', 'from', 'where', 'group_by', 'order_by']
accuracies = {comp: 0 for comp in components}
for pred, ref in zip(predictions, references):
try:
pred_components = self.extract_components(pred)
ref_components = self.extract_components(ref)
for comp in components:
if pred_components[comp] == ref_components[comp]:
accuracies[comp] += 1
except Exception:
continue
total = len(predictions)
return {comp: acc / total for comp, acc in accuracies.items()}
def extract_components(self, sql: str) -> Dict[str, str]:
"""提取SQL组件"""
components = {
'select': '',
'from': '',
'where': '',
'group_by': '',
'order_by': ''
}
try:
parsed = sqlparse.parse(sql)[0]
tokens = list(parsed.flatten())
current_component = None
component_tokens = []
for token in tokens:
if token.ttype is sqlparse.tokens.Keyword:
keyword = token.value.upper()
if current_component and component_tokens:
components[current_component] = ' '.join(component_tokens).strip()
component_tokens = []
if keyword == 'SELECT':
current_component = 'select'
elif keyword == 'FROM':
current_component = 'from'
elif keyword == 'WHERE':
current_component = 'where'
elif keyword in ['GROUP', 'ORDER']:
current_component = keyword.lower() + '_by'
else:
current_component = None
elif current_component and token.ttype not in [
sqlparse.tokens.Whitespace, sqlparse.tokens.Newline
]:
component_tokens.append(token.value)
# 处理最后一个组件
if current_component and component_tokens:
components[current_component] = ' '.join(component_tokens).strip()
except Exception:
pass
return components
def evaluate(self, predictions: List[str], references: List[str]) -> Dict:
"""综合评估"""
results = {
'exact_match': self.exact_match_accuracy(predictions, references),
'component_accuracy': self.component_accuracy(predictions, references),
'total_samples': len(predictions)
}
return results
# 使用示例
evaluator = Text2SQLEvaluator()
# 示例数据
predictions = [
"SELECT name FROM employees WHERE salary > 50000",
"SELECT COUNT(*) FROM departments",
"SELECT AVG(salary) FROM employees GROUP BY department_id"
]
references = [
"SELECT name FROM employees WHERE salary > 50000",
"SELECT count(*) FROM departments",
"SELECT avg(salary) FROM employees GROUP BY department_id"
]
# 评估
results = evaluator.evaluate(predictions, references)
print("评估结果:")
print(f"精确匹配准确率: {results['exact_match']:.4f}")
print("组件准确率:")
for comp, acc in results['component_accuracy'].items():
print(f" {comp}: {acc:.4f}")
7.2 模型比较框架
class ModelComparison:
"""模型比较框架"""
def __init__(self):
self.models = {}
self.evaluator = Text2SQLEvaluator()
def add_model(self, name: str, model, tokenizer=None):
"""添加模型"""
self.models[name] = {
'model': model,
'tokenizer': tokenizer,
'results': None
}
def evaluate_model(self, model_name: str, test_data: List[Dict]):
"""评估单个模型"""
if model_name not in self.models:
raise ValueError(f"Model {model_name} not found")
model_info = self.models[model_name]
model = model_info['model']
tokenizer = model_info['tokenizer']
predictions = []
references = []
model.eval()
with torch.no_grad():
for sample in test_data:
query = sample['question']
reference_sql = sample['sql']
# 模型预测
if hasattr(model, 'predict'):
predicted_sql = model.predict([query])[0]
else:
# 通用预测逻辑
if tokenizer:
inputs = tokenizer(query, return_tensors='pt')
outputs = model.generate(**inputs)
predicted_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
else:
# 假设模型有自己的预测方法
predicted_sql = "SELECT * FROM table" # 占位符
predictions.append(predicted_sql)
references.append(reference_sql)
# 评估
results = self.evaluator.evaluate(predictions, references)
self.models[model_name]['results'] = results
return results
def compare_all(self, test_data: List[Dict]):
"""比较所有模型"""
comparison_results = {}
for model_name in self.models:
print(f"评估模型: {model_name}")
results = self.evaluate_model(model_name, test_data)
comparison_results[model_name] = results
return comparison_results
def print_comparison(self):
"""打印比较结果"""
print("\n=== 模型比较结果 ===")
print(f"{'模型名称':<20} {'精确匹配':<10} {'SELECT':<10} {'FROM':<10} {'WHERE':<10}")
print("-" * 60)
for model_name, model_info in self.models.items():
if model_info['results']:
results = model_info['results']
em = results['exact_match']
select_acc = results['component_accuracy']['select']
from_acc = results['component_accuracy']['from']
where_acc = results['component_accuracy']['where']
print(f"{model_name:<20} {em:<10.4f} {select_acc:<10.4f} {from_acc:<10.4f} {where_acc:<10.4f}")
# 使用示例
comparison = ModelComparison()
# 添加模型(示例)
# comparison.add_model('Seq2Seq', seq2seq_model)
# comparison.add_model('Transformer', transformer_model)
# comparison.add_model('T5', t5_model, t5_tokenizer)
# 测试数据示例
test_data = [
{
'question': 'Show me all employees with salary greater than 50000',
'sql': 'SELECT * FROM employees WHERE salary > 50000'
},
{
'question': 'How many departments are there',
'sql': 'SELECT COUNT(*) FROM departments'
}
]
# 比较模型
# results = comparison.compare_all(test_data)
# comparison.print_comparison()
总结
本章详细介绍了Text2SQL的各种模型架构:
- 序列到序列模型:基础的编码器-解码器架构,包括注意力机制的改进
- Transformer架构:现代的自注意力机制模型,包括预训练模型的应用
- 语法制导模型:利用SQL语法结构指导生成过程
- 图神经网络模型:利用数据库schema的图结构信息
- 多任务学习:联合训练多个相关任务提升性能
- 评估框架:全面的模型评估和比较方法
每种架构都有其特点和适用场景: - Seq2Seq适合简单查询,易于实现 - Transformer性能强大,是当前主流 - 语法制导保证SQL语法正确性 - 图神经网络能更好利用schema信息 - 多任务学习提升模型泛化能力
在实际应用中,可以根据具体需求选择合适的架构,或者结合多种方法的优势。在下一章中,我们将学习如何进行数据预处理和特征工程。