10.1 多模态Text2SQL
10.1.1 图表理解与SQL生成
随着数据可视化的普及,从图表生成SQL查询成为一个重要的研究方向。
import torch
import torch.nn as nn
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, GPT2TokenizerFast
from PIL import Image
import numpy as np
from typing import Dict, List, Any, Optional
class ChartToSQLModel(nn.Module):
"""图表到SQL模型"""
def __init__(self, config: Dict[str, Any]):
super().__init__()
self.config = config
# 视觉编码器
self.vision_encoder = VisionEncoderDecoderModel.from_pretrained(
"microsoft/trocr-base-printed"
)
# 图像处理器
self.image_processor = ViTImageProcessor.from_pretrained(
"microsoft/trocr-base-printed"
)
# 文本编码器
self.text_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=config['hidden_size'],
nhead=config['num_heads'],
dim_feedforward=config['ff_size']
),
num_layers=config['num_layers']
)
# 融合层
self.fusion_layer = nn.MultiheadAttention(
embed_dim=config['hidden_size'],
num_heads=config['num_heads']
)
# SQL解码器
self.sql_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model=config['hidden_size'],
nhead=config['num_heads'],
dim_feedforward=config['ff_size']
),
num_layers=config['decoder_layers']
)
# 输出投影
self.output_projection = nn.Linear(
config['hidden_size'],
config['vocab_size']
)
def extract_chart_features(self, image: Image.Image) -> torch.Tensor:
"""提取图表特征"""
# 预处理图像
inputs = self.image_processor(image, return_tensors="pt")
# 提取视觉特征
with torch.no_grad():
vision_outputs = self.vision_encoder.encoder(
pixel_values=inputs['pixel_values']
)
return vision_outputs.last_hidden_state
def extract_text_features(self, question: str, schema: Dict) -> torch.Tensor:
"""提取文本特征"""
# 构建输入序列
text_input = f"Question: {question} Schema: {self._format_schema(schema)}"
# 编码文本(简化实现)
tokens = text_input.split() # 实际应使用tokenizer
embeddings = torch.randn(len(tokens), self.config['hidden_size']) # 占位符
# 文本编码
text_features = self.text_encoder(embeddings.unsqueeze(0))
return text_features
def _format_schema(self, schema: Dict) -> str:
"""格式化Schema"""
tables = []
for table_name, table_info in schema.items():
columns = ", ".join(table_info.get('columns', []))
tables.append(f"{table_name}({columns})")
return " | ".join(tables)
def forward(self, image: Image.Image, question: str, schema: Dict) -> torch.Tensor:
"""前向传播"""
# 提取特征
chart_features = self.extract_chart_features(image)
text_features = self.extract_text_features(question, schema)
# 特征融合
fused_features, _ = self.fusion_layer(
chart_features, text_features, text_features
)
# SQL生成
sql_output = self.sql_decoder(
torch.zeros(1, 1, self.config['hidden_size']), # 起始token
fused_features
)
# 输出投影
logits = self.output_projection(sql_output)
return logits
class MultiModalText2SQL:
"""多模态Text2SQL系统"""
def __init__(self, model_config: Dict[str, Any]):
self.model = ChartToSQLModel(model_config)
self.chart_analyzer = ChartAnalyzer()
self.context_manager = ContextManager()
def process_multimodal_query(self,
question: str,
image: Optional[Image.Image] = None,
schema: Dict = None) -> Dict[str, Any]:
"""处理多模态查询"""
result = {
'question': question,
'has_image': image is not None,
'chart_info': None,
'sql': None,
'confidence': 0.0
}
try:
# 分析图表(如果有)
if image:
chart_info = self.chart_analyzer.analyze_chart(image)
result['chart_info'] = chart_info
# 更新上下文
self.context_manager.update_context(chart_info)
# 生成SQL
if image and schema:
# 多模态生成
sql_logits = self.model(image, question, schema)
sql = self._decode_sql(sql_logits)
else:
# 纯文本生成
sql = self._generate_text_sql(question, schema)
result['sql'] = sql
result['confidence'] = self._calculate_confidence(sql_logits if 'sql_logits' in locals() else None)
except Exception as e:
result['error'] = str(e)
return result
def _decode_sql(self, logits: torch.Tensor) -> str:
"""解码SQL"""
# 简化的解码实现
predicted_ids = torch.argmax(logits, dim=-1)
# 实际应使用tokenizer解码
return "SELECT * FROM table WHERE condition;" # 占位符
def _generate_text_sql(self, question: str, schema: Dict) -> str:
"""生成纯文本SQL"""
# 使用传统Text2SQL方法
return "SELECT * FROM table;" # 占位符
def _calculate_confidence(self, logits: Optional[torch.Tensor]) -> float:
"""计算置信度"""
if logits is None:
return 0.5
# 基于logits计算置信度
probs = torch.softmax(logits, dim=-1)
max_probs = torch.max(probs, dim=-1)[0]
confidence = torch.mean(max_probs).item()
return confidence
class ChartAnalyzer:
"""图表分析器"""
def __init__(self):
self.chart_types = {
'bar': self._analyze_bar_chart,
'line': self._analyze_line_chart,
'pie': self._analyze_pie_chart,
'scatter': self._analyze_scatter_chart
}
def analyze_chart(self, image: Image.Image) -> Dict[str, Any]:
"""分析图表"""
# 检测图表类型
chart_type = self._detect_chart_type(image)
# 提取图表信息
chart_info = {
'type': chart_type,
'title': self._extract_title(image),
'axes': self._extract_axes_info(image),
'data_points': self._extract_data_points(image),
'legend': self._extract_legend(image)
}
# 特定类型分析
if chart_type in self.chart_types:
specific_info = self.chart_types[chart_type](image)
chart_info.update(specific_info)
return chart_info
def _detect_chart_type(self, image: Image.Image) -> str:
"""检测图表类型"""
# 简化实现,实际应使用图像分类模型
return 'bar' # 占位符
def _extract_title(self, image: Image.Image) -> str:
"""提取图表标题"""
# 使用OCR提取标题
return "Sample Chart Title" # 占位符
def _extract_axes_info(self, image: Image.Image) -> Dict[str, str]:
"""提取坐标轴信息"""
return {
'x_axis': 'X Axis Label',
'y_axis': 'Y Axis Label'
} # 占位符
def _extract_data_points(self, image: Image.Image) -> List[Dict]:
"""提取数据点"""
return [
{'x': 1, 'y': 10, 'label': 'Point 1'},
{'x': 2, 'y': 20, 'label': 'Point 2'}
] # 占位符
def _extract_legend(self, image: Image.Image) -> List[str]:
"""提取图例"""
return ['Series 1', 'Series 2'] # 占位符
def _analyze_bar_chart(self, image: Image.Image) -> Dict[str, Any]:
"""分析柱状图"""
return {
'bars': [
{'category': 'A', 'value': 10},
{'category': 'B', 'value': 20}
],
'orientation': 'vertical'
}
def _analyze_line_chart(self, image: Image.Image) -> Dict[str, Any]:
"""分析折线图"""
return {
'lines': [
{'points': [(1, 10), (2, 20), (3, 15)]}
],
'trend': 'increasing'
}
def _analyze_pie_chart(self, image: Image.Image) -> Dict[str, Any]:
"""分析饼图"""
return {
'slices': [
{'label': 'A', 'value': 30, 'percentage': 30},
{'label': 'B', 'value': 70, 'percentage': 70}
]
}
def _analyze_scatter_chart(self, image: Image.Image) -> Dict[str, Any]:
"""分析散点图"""
return {
'points': [(1, 10), (2, 20), (3, 15)],
'correlation': 'positive'
}
class ContextManager:
"""上下文管理器"""
def __init__(self):
self.context = {
'chart_history': [],
'query_history': [],
'schema_context': {},
'user_preferences': {}
}
def update_context(self, chart_info: Dict[str, Any]):
"""更新上下文"""
self.context['chart_history'].append(chart_info)
# 保持历史记录大小
if len(self.context['chart_history']) > 10:
self.context['chart_history'].pop(0)
def get_relevant_context(self, query: str) -> Dict[str, Any]:
"""获取相关上下文"""
# 基于查询内容返回相关上下文
return {
'recent_charts': self.context['chart_history'][-3:],
'related_queries': self._find_related_queries(query)
}
def _find_related_queries(self, query: str) -> List[str]:
"""查找相关查询"""
# 简化实现
return [q for q in self.context['query_history'] if len(set(query.split()) & set(q.split())) > 1]
print("多模态Text2SQL实现完成")
10.1.2 语音到SQL转换
import speech_recognition as sr
import pyttsx3
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
import torch
import numpy as np
from typing import Dict, List, Any, Optional
class SpeechToSQLSystem:
"""语音到SQL系统"""
def __init__(self, config: Dict[str, Any]):
self.config = config
# 语音识别
self.recognizer = sr.Recognizer()
self.microphone = sr.Microphone()
# 语音合成
self.tts_engine = pyttsx3.init()
self._setup_tts()
# 深度语音识别模型
self.speech_model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-base-960h"
)
self.speech_tokenizer = Wav2Vec2Tokenizer.from_pretrained(
"facebook/wav2vec2-base-960h"
)
# Text2SQL模型
self.text2sql_model = None # 从前面章节导入
# 语音处理器
self.speech_processor = SpeechProcessor()
# 对话管理器
self.dialog_manager = DialogManager()
def _setup_tts(self):
"""设置语音合成"""
voices = self.tts_engine.getProperty('voices')
if voices:
self.tts_engine.setProperty('voice', voices[0].id)
self.tts_engine.setProperty('rate', 150) # 语速
self.tts_engine.setProperty('volume', 0.8) # 音量
def listen_and_process(self, timeout: int = 5) -> Dict[str, Any]:
"""监听并处理语音"""
result = {
'success': False,
'transcription': None,
'sql': None,
'response': None
}
try:
# 监听语音
with self.microphone as source:
self.recognizer.adjust_for_ambient_noise(source)
print("请说话...")
audio = self.recognizer.listen(source, timeout=timeout)
# 语音识别
transcription = self._transcribe_audio(audio)
result['transcription'] = transcription
if transcription:
# 处理自然语言查询
sql_result = self._process_natural_language(transcription)
result.update(sql_result)
# 生成语音响应
response = self._generate_response(sql_result)
result['response'] = response
# 语音输出
self.speak(response)
result['success'] = True
except sr.WaitTimeoutError:
result['error'] = '语音输入超时'
except sr.UnknownValueError:
result['error'] = '无法识别语音'
except sr.RequestError as e:
result['error'] = f'语音识别服务错误: {e}'
except Exception as e:
result['error'] = f'处理错误: {e}'
return result
def _transcribe_audio(self, audio) -> str:
"""转录音频"""
try:
# 使用Google语音识别
text = self.recognizer.recognize_google(audio, language='zh-CN')
# 使用深度学习模型进行后处理
enhanced_text = self._enhance_transcription(text, audio)
return enhanced_text
except Exception as e:
print(f"转录错误: {e}")
return None
def _enhance_transcription(self, text: str, audio) -> str:
"""增强转录结果"""
try:
# 转换音频格式
audio_array = np.frombuffer(audio.get_raw_data(), dtype=np.int16)
audio_array = audio_array.astype(np.float32) / 32768.0
# 使用Wav2Vec2模型
inputs = self.speech_tokenizer(
audio_array,
sampling_rate=16000,
return_tensors="pt"
)
with torch.no_grad():
logits = self.speech_model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.speech_tokenizer.decode(predicted_ids[0])
# 结合两种结果
enhanced_text = self._combine_transcriptions(text, transcription)
return enhanced_text
except Exception as e:
print(f"增强转录错误: {e}")
return text
def _combine_transcriptions(self, google_text: str, wav2vec_text: str) -> str:
"""结合多种转录结果"""
# 简化实现:选择较长的结果
if len(wav2vec_text) > len(google_text):
return wav2vec_text
return google_text
def _process_natural_language(self, text: str) -> Dict[str, Any]:
"""处理自然语言查询"""
# 预处理文本
processed_text = self.speech_processor.preprocess_speech_text(text)
# 意图识别
intent = self.speech_processor.recognize_intent(processed_text)
# 生成SQL
if self.text2sql_model:
sql_result = self.text2sql_model.generate_sql(processed_text)
else:
# 简化实现
sql_result = {
'sql': 'SELECT * FROM table;',
'confidence': 0.8
}
return {
'processed_text': processed_text,
'intent': intent,
'sql': sql_result.get('sql'),
'confidence': sql_result.get('confidence', 0.0)
}
def _generate_response(self, sql_result: Dict[str, Any]) -> str:
"""生成响应"""
if sql_result.get('sql'):
confidence = sql_result.get('confidence', 0.0)
if confidence > 0.8:
return f"我理解了您的查询,生成的SQL是:{sql_result['sql']}"
elif confidence > 0.5:
return f"我认为您想要的SQL可能是:{sql_result['sql']},请确认是否正确。"
else:
return "抱歉,我不太确定您的查询意图,请重新描述或提供更多信息。"
else:
return "抱歉,我无法理解您的查询,请重新尝试。"
def speak(self, text: str):
"""语音输出"""
try:
self.tts_engine.say(text)
self.tts_engine.runAndWait()
except Exception as e:
print(f"语音输出错误: {e}")
def start_conversation(self):
"""开始对话"""
self.speak("您好,我是Text2SQL助手,请告诉我您想要查询什么数据。")
while True:
try:
result = self.listen_and_process()
if result['success']:
print(f"识别文本: {result['transcription']}")
print(f"生成SQL: {result['sql']}")
# 检查是否要退出
if any(word in result['transcription'].lower() for word in ['退出', '结束', '再见']):
self.speak("再见,感谢使用Text2SQL助手!")
break
else:
error_msg = result.get('error', '未知错误')
print(f"错误: {error_msg}")
self.speak("抱歉,出现了错误,请重试。")
except KeyboardInterrupt:
self.speak("再见!")
break
class SpeechProcessor:
"""语音处理器"""
def __init__(self):
self.intent_patterns = {
'select': ['查询', '显示', '找', '看', '获取'],
'count': ['统计', '计算', '数量', '个数'],
'filter': ['筛选', '过滤', '条件', '满足'],
'sort': ['排序', '排列', '最大', '最小', '最高', '最低'],
'group': ['分组', '按照', '每个', '各个']
}
def preprocess_speech_text(self, text: str) -> str:
"""预处理语音文本"""
# 移除填充词
filler_words = ['嗯', '啊', '那个', '这个', '就是']
for word in filler_words:
text = text.replace(word, '')
# 标准化数字
text = self._normalize_numbers(text)
# 标准化时间表达
text = self._normalize_time_expressions(text)
return text.strip()
def _normalize_numbers(self, text: str) -> str:
"""标准化数字表达"""
number_map = {
'一': '1', '二': '2', '三': '3', '四': '4', '五': '5',
'六': '6', '七': '7', '八': '8', '九': '9', '十': '10'
}
for chinese, arabic in number_map.items():
text = text.replace(chinese, arabic)
return text
def _normalize_time_expressions(self, text: str) -> str:
"""标准化时间表达"""
time_map = {
'今天': 'today',
'昨天': 'yesterday',
'明天': 'tomorrow',
'上个月': 'last month',
'这个月': 'this month',
'去年': 'last year',
'今年': 'this year'
}
for chinese, english in time_map.items():
text = text.replace(chinese, english)
return text
def recognize_intent(self, text: str) -> Dict[str, float]:
"""识别意图"""
intent_scores = {}
for intent, keywords in self.intent_patterns.items():
score = 0
for keyword in keywords:
if keyword in text:
score += 1
intent_scores[intent] = score / len(keywords)
return intent_scores
class DialogManager:
"""对话管理器"""
def __init__(self):
self.conversation_history = []
self.context = {}
self.user_preferences = {}
def add_turn(self, user_input: str, system_response: str, sql: str = None):
"""添加对话轮次"""
turn = {
'timestamp': self._get_timestamp(),
'user_input': user_input,
'system_response': system_response,
'sql': sql
}
self.conversation_history.append(turn)
# 保持历史记录大小
if len(self.conversation_history) > 20:
self.conversation_history.pop(0)
def get_context(self) -> Dict[str, Any]:
"""获取对话上下文"""
return {
'recent_queries': self.conversation_history[-3:],
'user_preferences': self.user_preferences,
'session_context': self.context
}
def _get_timestamp(self) -> str:
"""获取时间戳"""
from datetime import datetime
return datetime.now().isoformat()
print("语音到SQL系统实现完成")
10.2 大语言模型与Text2SQL
10.2.1 基于GPT的Text2SQL
import openai
import json
from typing import Dict, List, Any, Optional
import re
from dataclasses import dataclass
@dataclass
class LLMConfig:
"""大语言模型配置"""
model_name: str = "gpt-3.5-turbo"
api_key: str = ""
max_tokens: int = 1000
temperature: float = 0.1
top_p: float = 0.9
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
class GPTText2SQL:
"""基于GPT的Text2SQL系统"""
def __init__(self, config: LLMConfig):
self.config = config
openai.api_key = config.api_key
# 提示词模板
self.prompt_templates = {
'basic': self._load_basic_template(),
'few_shot': self._load_few_shot_template(),
'chain_of_thought': self._load_cot_template(),
'schema_aware': self._load_schema_aware_template()
}
# 示例管理器
self.example_manager = ExampleManager()
# 结果验证器
self.validator = SQLValidator()
def _load_basic_template(self) -> str:
"""基础提示词模板"""
return """
你是一个专业的SQL查询生成助手。请根据给定的自然语言问题和数据库Schema生成相应的SQL查询。
数据库Schema:
{schema}
问题: {question}
请生成对应的SQL查询,只返回SQL语句,不要包含其他解释。
SQL:
"""
def _load_few_shot_template(self) -> str:
"""少样本学习模板"""
return """
你是一个专业的SQL查询生成助手。以下是一些示例:
{examples}
现在请根据以下信息生成SQL查询:
数据库Schema:
{schema}
问题: {question}
SQL:
"""
def _load_cot_template(self) -> str:
"""思维链模板"""
return """
你是一个专业的SQL查询生成助手。请按照以下步骤分析问题并生成SQL:
1. 理解问题的核心需求
2. 识别需要的表和列
3. 确定查询条件
4. 构建SQL查询
数据库Schema:
{schema}
问题: {question}
请按步骤分析:
步骤1 - 理解需求:
步骤2 - 识别表和列:
步骤3 - 确定条件:
步骤4 - 生成SQL:
"""
def _load_schema_aware_template(self) -> str:
"""Schema感知模板"""
return """
你是一个专业的SQL查询生成助手。请仔细分析数据库Schema,理解表之间的关系。
数据库Schema详情:
{detailed_schema}
表关系:
{relationships}
问题: {question}
请生成准确的SQL查询,注意:
1. 使用正确的表名和列名
2. 考虑表之间的关联关系
3. 使用适当的JOIN操作
4. 确保查询语法正确
SQL:
"""
def generate_sql(self,
question: str,
schema: Dict[str, Any],
method: str = 'basic',
examples: List[Dict] = None) -> Dict[str, Any]:
"""生成SQL查询"""
try:
# 选择提示词模板
template = self.prompt_templates.get(method, self.prompt_templates['basic'])
# 构建提示词
prompt = self._build_prompt(template, question, schema, method, examples)
# 调用GPT API
response = self._call_gpt_api(prompt)
# 提取SQL
sql = self._extract_sql(response)
# 验证SQL
validation_result = self.validator.validate_sql(sql, schema)
return {
'sql': sql,
'raw_response': response,
'prompt': prompt,
'validation': validation_result,
'method': method
}
except Exception as e:
return {
'error': str(e),
'sql': None,
'method': method
}
def _build_prompt(self,
template: str,
question: str,
schema: Dict[str, Any],
method: str,
examples: List[Dict] = None) -> str:
"""构建提示词"""
if method == 'basic':
return template.format(
schema=self._format_schema(schema),
question=question
)
elif method == 'few_shot':
if not examples:
examples = self.example_manager.get_relevant_examples(question, schema)
formatted_examples = self._format_examples(examples)
return template.format(
examples=formatted_examples,
schema=self._format_schema(schema),
question=question
)
elif method == 'chain_of_thought':
return template.format(
schema=self._format_schema(schema),
question=question
)
elif method == 'schema_aware':
detailed_schema = self._format_detailed_schema(schema)
relationships = self._extract_relationships(schema)
return template.format(
detailed_schema=detailed_schema,
relationships=relationships,
question=question
)
else:
return template.format(
schema=self._format_schema(schema),
question=question
)
def _format_schema(self, schema: Dict[str, Any]) -> str:
"""格式化Schema"""
formatted_tables = []
for table_name, table_info in schema.items():
columns = table_info.get('columns', [])
if isinstance(columns, list):
column_str = ', '.join(columns)
else:
column_str = ', '.join([f"{col['name']} ({col['type']})" for col in columns])
formatted_tables.append(f"Table {table_name}: {column_str}")
return '\n'.join(formatted_tables)
def _format_detailed_schema(self, schema: Dict[str, Any]) -> str:
"""格式化详细Schema"""
formatted_tables = []
for table_name, table_info in schema.items():
table_desc = f"Table: {table_name}\n"
columns = table_info.get('columns', [])
if columns:
table_desc += "Columns:\n"
for col in columns:
if isinstance(col, dict):
col_desc = f" - {col['name']} ({col.get('type', 'unknown')})"
if col.get('primary_key'):
col_desc += " [PRIMARY KEY]"
if col.get('foreign_key'):
col_desc += f" [FOREIGN KEY -> {col['foreign_key']}]"
table_desc += col_desc + "\n"
else:
table_desc += f" - {col}\n"
formatted_tables.append(table_desc)
return '\n'.join(formatted_tables)
def _extract_relationships(self, schema: Dict[str, Any]) -> str:
"""提取表关系"""
relationships = []
for table_name, table_info in schema.items():
columns = table_info.get('columns', [])
for col in columns:
if isinstance(col, dict) and col.get('foreign_key'):
relationships.append(
f"{table_name}.{col['name']} -> {col['foreign_key']}"
)
return '\n'.join(relationships) if relationships else "No explicit relationships defined"
def _format_examples(self, examples: List[Dict]) -> str:
"""格式化示例"""
formatted_examples = []
for i, example in enumerate(examples, 1):
formatted_examples.append(
f"示例{i}:\n"
f"问题: {example['question']}\n"
f"SQL: {example['sql']}\n"
)
return '\n'.join(formatted_examples)
def _call_gpt_api(self, prompt: str) -> str:
"""调用GPT API"""
response = openai.ChatCompletion.create(
model=self.config.model_name,
messages=[
{"role": "user", "content": prompt}
],
max_tokens=self.config.max_tokens,
temperature=self.config.temperature,
top_p=self.config.top_p,
frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty
)
return response.choices[0].message.content.strip()
def _extract_sql(self, response: str) -> str:
"""从响应中提取SQL"""
# 查找SQL代码块
sql_pattern = r'```sql\s*([^`]+)\s*```'
match = re.search(sql_pattern, response, re.IGNORECASE | re.DOTALL)
if match:
return match.group(1).strip()
# 查找以SELECT开头的语句
lines = response.split('\n')
for line in lines:
line = line.strip()
if line.upper().startswith('SELECT'):
return line
# 返回整个响应(可能包含SQL)
return response.strip()
class ExampleManager:
"""示例管理器"""
def __init__(self):
self.examples = self._load_examples()
self.embeddings = None # 可以使用embedding进行相似度匹配
def _load_examples(self) -> List[Dict]:
"""加载示例"""
return [
{
'question': '显示所有用户的姓名和邮箱',
'sql': 'SELECT name, email FROM users;',
'schema': ['users(id, name, email, age)']
},
{
'question': '查找年龄大于25的用户数量',
'sql': 'SELECT COUNT(*) FROM users WHERE age > 25;',
'schema': ['users(id, name, email, age)']
},
{
'question': '按部门统计员工数量',
'sql': 'SELECT department, COUNT(*) FROM employees GROUP BY department;',
'schema': ['employees(id, name, department, salary)']
}
]
def get_relevant_examples(self,
question: str,
schema: Dict[str, Any],
top_k: int = 3) -> List[Dict]:
"""获取相关示例"""
# 简化实现:基于关键词匹配
question_words = set(question.lower().split())
scored_examples = []
for example in self.examples:
example_words = set(example['question'].lower().split())
similarity = len(question_words & example_words) / len(question_words | example_words)
scored_examples.append((similarity, example))
# 按相似度排序
scored_examples.sort(key=lambda x: x[0], reverse=True)
return [example for _, example in scored_examples[:top_k]]
class SQLValidator:
"""SQL验证器"""
def validate_sql(self, sql: str, schema: Dict[str, Any]) -> Dict[str, Any]:
"""验证SQL查询"""
validation_result = {
'is_valid': True,
'errors': [],
'warnings': [],
'suggestions': []
}
try:
# 基本语法检查
syntax_errors = self._check_syntax(sql)
validation_result['errors'].extend(syntax_errors)
# Schema一致性检查
schema_errors = self._check_schema_consistency(sql, schema)
validation_result['errors'].extend(schema_errors)
# 性能建议
performance_warnings = self._check_performance(sql)
validation_result['warnings'].extend(performance_warnings)
# 设置验证状态
validation_result['is_valid'] = len(validation_result['errors']) == 0
except Exception as e:
validation_result['is_valid'] = False
validation_result['errors'].append(f"验证过程出错: {str(e)}")
return validation_result
def _check_syntax(self, sql: str) -> List[str]:
"""检查SQL语法"""
errors = []
# 基本关键词检查
sql_upper = sql.upper()
if not any(keyword in sql_upper for keyword in ['SELECT', 'INSERT', 'UPDATE', 'DELETE']):
errors.append("SQL语句缺少主要操作关键词")
# 括号匹配检查
if sql.count('(') != sql.count(')'):
errors.append("括号不匹配")
# 引号匹配检查
single_quotes = sql.count("'")
double_quotes = sql.count('"')
if single_quotes % 2 != 0:
errors.append("单引号不匹配")
if double_quotes % 2 != 0:
errors.append("双引号不匹配")
return errors
def _check_schema_consistency(self, sql: str, schema: Dict[str, Any]) -> List[str]:
"""检查Schema一致性"""
errors = []
# 提取SQL中的表名
table_pattern = r'FROM\s+(\w+)|JOIN\s+(\w+)'
table_matches = re.findall(table_pattern, sql, re.IGNORECASE)
for match in table_matches:
table_name = match[0] or match[1]
if table_name and table_name not in schema:
errors.append(f"表 '{table_name}' 在Schema中不存在")
return errors
def _check_performance(self, sql: str) -> List[str]:
"""检查性能问题"""
warnings = []
sql_upper = sql.upper()
# 检查SELECT *
if 'SELECT *' in sql_upper:
warnings.append("建议避免使用SELECT *,明确指定需要的列")
# 检查缺少WHERE子句的DELETE/UPDATE
if ('DELETE FROM' in sql_upper or 'UPDATE' in sql_upper) and 'WHERE' not in sql_upper:
warnings.append("DELETE或UPDATE语句缺少WHERE条件,可能影响所有记录")
# 检查可能的笛卡尔积
if sql_upper.count('FROM') > 1 and 'JOIN' not in sql_upper and 'WHERE' not in sql_upper:
warnings.append("可能存在笛卡尔积,建议使用JOIN或添加WHERE条件")
return warnings
print("基于GPT的Text2SQL系统实现完成")
10.2.2 提示词工程优化
from typing import Dict, List, Any, Optional, Tuple
import json
import re
from dataclasses import dataclass
from abc import ABC, abstractmethod
@dataclass
class PromptTemplate:
"""提示词模板"""
name: str
template: str
variables: List[str]
description: str
examples: List[Dict] = None
class PromptOptimizer:
"""提示词优化器"""
def __init__(self):
self.templates = self._initialize_templates()
self.optimization_strategies = {
'length_optimization': self._optimize_length,
'clarity_enhancement': self._enhance_clarity,
'example_selection': self._optimize_examples,
'structure_improvement': self._improve_structure
}
def _initialize_templates(self) -> Dict[str, PromptTemplate]:
"""初始化提示词模板"""
templates = {}
# 零样本模板
templates['zero_shot'] = PromptTemplate(
name="zero_shot",
template="""
你是一个SQL专家。根据给定的数据库Schema和自然语言问题,生成准确的SQL查询。
数据库Schema:
{schema}
问题: {question}
要求:
1. 只返回SQL查询语句
2. 确保语法正确
3. 使用标准SQL语法
SQL:
""",
variables=["schema", "question"],
description="基础零样本提示词"
)
# 少样本模板
templates['few_shot'] = PromptTemplate(
name="few_shot",
template="""
你是一个SQL专家。以下是一些示例,请学习其中的模式:
{examples}
现在请根据以下信息生成SQL查询:
数据库Schema:
{schema}
问题: {question}
SQL:
""",
variables=["examples", "schema", "question"],
description="少样本学习提示词"
)
# 思维链模板
templates['chain_of_thought'] = PromptTemplate(
name="chain_of_thought",
template="""
你是一个SQL专家。请按照以下步骤分析并生成SQL查询:
步骤1: 理解问题 - 分析用户想要什么信息
步骤2: 识别表和列 - 确定需要哪些表和列
步骤3: 确定关系 - 分析表之间的关联
步骤4: 构建查询 - 组装完整的SQL语句
数据库Schema:
{schema}
问题: {question}
请按步骤分析:
步骤1 - 理解问题:
步骤2 - 识别表和列:
步骤3 - 确定关系:
步骤4 - 构建查询:
""",
variables=["schema", "question"],
description="思维链推理提示词"
)
# 角色扮演模板
templates['role_playing'] = PromptTemplate(
name="role_playing",
template="""
你是一位经验丰富的数据库管理员,拥有20年的SQL开发经验。
你的任务是帮助用户将自然语言问题转换为高效、准确的SQL查询。
你的专业特点:
- 深度理解数据库设计原理
- 熟练掌握SQL优化技巧
- 注重查询性能和准确性
- 善于处理复杂的多表关联
当前数据库Schema:
{schema}
用户问题: {question}
请以专业DBA的角度,生成最优的SQL查询:
""",
variables=["schema", "question"],
description="角色扮演提示词"
)
return templates
def optimize_prompt(self,
template_name: str,
variables: Dict[str, Any],
strategies: List[str] = None) -> str:
"""优化提示词"""
if template_name not in self.templates:
raise ValueError(f"未知的模板: {template_name}")
template = self.templates[template_name]
prompt = template.template.format(**variables)
# 应用优化策略
if strategies:
for strategy in strategies:
if strategy in self.optimization_strategies:
prompt = self.optimization_strategies[strategy](prompt, variables)
return prompt
def _optimize_length(self, prompt: str, variables: Dict[str, Any]) -> str:
"""优化提示词长度"""
# 移除多余的空行
prompt = re.sub(r'\n\s*\n\s*\n', '\n\n', prompt)
# 简化重复的说明
prompt = prompt.replace('请生成SQL查询', '生成SQL')
prompt = prompt.replace('根据给定的', '根据')
return prompt
def _enhance_clarity(self, prompt: str, variables: Dict[str, Any]) -> str:
"""增强提示词清晰度"""
# 添加明确的输出格式说明
if 'SQL:' in prompt and '输出格式' not in prompt:
format_instruction = "\n输出格式: 只返回SQL语句,不包含解释或其他文本。\n"
prompt = prompt.replace('SQL:', format_instruction + 'SQL:')
return prompt
def _optimize_examples(self, prompt: str, variables: Dict[str, Any]) -> str:
"""优化示例选择"""
if 'examples' in variables:
# 这里可以实现智能示例选择逻辑
pass
return prompt
def _improve_structure(self, prompt: str, variables: Dict[str, Any]) -> str:
"""改进提示词结构"""
# 确保关键信息突出显示
if 'Schema:' in prompt:
prompt = prompt.replace('Schema:', '**数据库Schema:**')
if '问题:' in prompt:
prompt = prompt.replace('问题:', '**用户问题:**')
return prompt
class AdaptivePromptGenerator:
"""自适应提示词生成器"""
def __init__(self):
self.prompt_optimizer = PromptOptimizer()
self.performance_tracker = PromptPerformanceTracker()
self.difficulty_analyzer = QueryDifficultyAnalyzer()
def generate_adaptive_prompt(self,
question: str,
schema: Dict[str, Any],
context: Dict[str, Any] = None) -> str:
"""生成自适应提示词"""
# 分析查询难度
difficulty = self.difficulty_analyzer.analyze_difficulty(question, schema)
# 根据难度选择模板
template_name = self._select_template_by_difficulty(difficulty)
# 准备变量
variables = {
'question': question,
'schema': self._format_schema(schema)
}
# 添加示例(如果需要)
if template_name == 'few_shot':
examples = self._select_relevant_examples(question, schema, difficulty)
variables['examples'] = self._format_examples(examples)
# 生成优化的提示词
optimization_strategies = self._select_optimization_strategies(difficulty)
prompt = self.prompt_optimizer.optimize_prompt(
template_name, variables, optimization_strategies
)
return prompt
def _select_template_by_difficulty(self, difficulty: Dict[str, float]) -> str:
"""根据难度选择模板"""
complexity_score = difficulty.get('complexity', 0.5)
if complexity_score < 0.3:
return 'zero_shot'
elif complexity_score < 0.7:
return 'few_shot'
else:
return 'chain_of_thought'
def _select_optimization_strategies(self, difficulty: Dict[str, float]) -> List[str]:
"""选择优化策略"""
strategies = ['clarity_enhancement']
complexity_score = difficulty.get('complexity', 0.5)
if complexity_score > 0.6:
strategies.append('structure_improvement')
if complexity_score < 0.4:
strategies.append('length_optimization')
return strategies
def _format_schema(self, schema: Dict[str, Any]) -> str:
"""格式化Schema"""
formatted_tables = []
for table_name, table_info in schema.items():
columns = table_info.get('columns', [])
if isinstance(columns, list) and columns:
if isinstance(columns[0], dict):
# 详细列信息
col_details = []
for col in columns:
col_str = f"{col['name']} ({col.get('type', 'unknown')})"
if col.get('primary_key'):
col_str += " [PK]"
if col.get('foreign_key'):
col_str += f" [FK -> {col['foreign_key']}]"
col_details.append(col_str)
column_str = ', '.join(col_details)
else:
# 简单列名
column_str = ', '.join(columns)
formatted_tables.append(f"Table {table_name}: {column_str}")
return '\n'.join(formatted_tables)
def _select_relevant_examples(self,
question: str,
schema: Dict[str, Any],
difficulty: Dict[str, float]) -> List[Dict]:
"""选择相关示例"""
# 这里应该实现智能示例选择逻辑
# 简化实现
return [
{
'question': '查找所有用户的姓名',
'sql': 'SELECT name FROM users;'
},
{
'question': '统计每个部门的员工数量',
'sql': 'SELECT department, COUNT(*) FROM employees GROUP BY department;'
}
]
def _format_examples(self, examples: List[Dict]) -> str:
"""格式化示例"""
formatted = []
for i, example in enumerate(examples, 1):
formatted.append(
f"示例{i}:\n"
f"问题: {example['question']}\n"
f"SQL: {example['sql']}\n"
)
return '\n'.join(formatted)
class QueryDifficultyAnalyzer:
"""查询难度分析器"""
def analyze_difficulty(self, question: str, schema: Dict[str, Any]) -> Dict[str, float]:
"""分析查询难度"""
difficulty_factors = {
'complexity': self._analyze_complexity(question),
'ambiguity': self._analyze_ambiguity(question),
'schema_complexity': self._analyze_schema_complexity(schema),
'join_complexity': self._analyze_join_complexity(question, schema)
}
# 计算总体难度
overall_difficulty = sum(difficulty_factors.values()) / len(difficulty_factors)
difficulty_factors['overall'] = overall_difficulty
return difficulty_factors
def _analyze_complexity(self, question: str) -> float:
"""分析问题复杂度"""
complexity_indicators = {
'aggregation': ['统计', '计算', '平均', '最大', '最小', '总和'],
'grouping': ['每个', '按照', '分组', '各个'],
'filtering': ['条件', '满足', '大于', '小于', '等于'],
'sorting': ['排序', '最高', '最低', 'top', '前'],
'joining': ['关联', '连接', '对应', '相关']
}
complexity_score = 0
question_lower = question.lower()
for category, indicators in complexity_indicators.items():
if any(indicator in question_lower for indicator in indicators):
complexity_score += 0.2
return min(complexity_score, 1.0)
def _analyze_ambiguity(self, question: str) -> float:
"""分析问题歧义性"""
ambiguous_terms = ['这个', '那个', '相关', '合适', '好的', '最佳']
question_lower = question.lower()
ambiguity_count = sum(1 for term in ambiguous_terms if term in question_lower)
return min(ambiguity_count * 0.3, 1.0)
def _analyze_schema_complexity(self, schema: Dict[str, Any]) -> float:
"""分析Schema复杂度"""
table_count = len(schema)
total_columns = sum(len(table_info.get('columns', [])) for table_info in schema.values())
# 基于表数量和列数量计算复杂度
complexity = (table_count * 0.1 + total_columns * 0.02)
return min(complexity, 1.0)
def _analyze_join_complexity(self, question: str, schema: Dict[str, Any]) -> float:
"""分析JOIN复杂度"""
# 简化实现:基于问题中提到的表数量
mentioned_tables = 0
question_lower = question.lower()
for table_name in schema.keys():
if table_name.lower() in question_lower:
mentioned_tables += 1
if mentioned_tables > 1:
return min(mentioned_tables * 0.3, 1.0)
return 0.0
class PromptPerformanceTracker:
"""提示词性能跟踪器"""
def __init__(self):
self.performance_history = []
self.template_stats = {}
def record_performance(self,
template_name: str,
prompt: str,
question: str,
generated_sql: str,
execution_success: bool,
accuracy_score: float):
"""记录性能数据"""
record = {
'timestamp': self._get_timestamp(),
'template_name': template_name,
'prompt_length': len(prompt),
'question': question,
'generated_sql': generated_sql,
'execution_success': execution_success,
'accuracy_score': accuracy_score
}
self.performance_history.append(record)
self._update_template_stats(template_name, record)
def _update_template_stats(self, template_name: str, record: Dict):
"""更新模板统计信息"""
if template_name not in self.template_stats:
self.template_stats[template_name] = {
'total_uses': 0,
'success_count': 0,
'accuracy_sum': 0.0,
'avg_accuracy': 0.0,
'success_rate': 0.0
}
stats = self.template_stats[template_name]
stats['total_uses'] += 1
if record['execution_success']:
stats['success_count'] += 1
stats['accuracy_sum'] += record['accuracy_score']
stats['avg_accuracy'] = stats['accuracy_sum'] / stats['total_uses']
stats['success_rate'] = stats['success_count'] / stats['total_uses']
def get_best_template(self, criteria: str = 'accuracy') -> str:
"""获取最佳模板"""
if not self.template_stats:
return 'zero_shot' # 默认模板
if criteria == 'accuracy':
best_template = max(self.template_stats.items(),
key=lambda x: x[1]['avg_accuracy'])
elif criteria == 'success_rate':
best_template = max(self.template_stats.items(),
key=lambda x: x[1]['success_rate'])
else:
best_template = max(self.template_stats.items(),
key=lambda x: x[1]['avg_accuracy'])
return best_template[0]
def _get_timestamp(self) -> str:
"""获取时间戳"""
from datetime import datetime
return datetime.now().isoformat()
print("提示词工程优化实现完成")
10.3 强化学习与Text2SQL
10.3.1 基于强化学习的SQL生成
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from typing import Dict, List, Any, Tuple, Optional
from collections import deque
import random
from dataclasses import dataclass
@dataclass
class RLConfig:
"""强化学习配置"""
learning_rate: float = 0.001
gamma: float = 0.99 # 折扣因子
epsilon: float = 0.1 # 探索率
epsilon_decay: float = 0.995
epsilon_min: float = 0.01
memory_size: int = 10000
batch_size: int = 32
target_update_freq: int = 100
class Text2SQLEnvironment:
"""Text2SQL强化学习环境"""
def __init__(self, schema: Dict[str, Any], sql_executor):
self.schema = schema
self.sql_executor = sql_executor
self.current_question = None
self.current_state = None
self.target_result = None
# 动作空间定义
self.action_space = self._build_action_space()
# 状态空间定义
self.state_dim = self._calculate_state_dim()
# 奖励函数参数
self.reward_weights = {
'execution_success': 10.0,
'result_accuracy': 20.0,
'syntax_correctness': 5.0,
'efficiency': 3.0,
'step_penalty': -0.1
}
def _build_action_space(self) -> Dict[str, List[str]]:
"""构建动作空间"""
return {
'select_clause': ['SELECT', 'SELECT DISTINCT'],
'columns': self._get_all_columns(),
'from_clause': ['FROM'],
'tables': list(self.schema.keys()),
'where_clause': ['WHERE'],
'conditions': self._get_condition_templates(),
'group_by': ['GROUP BY'],
'having': ['HAVING'],
'order_by': ['ORDER BY', 'ORDER BY DESC'],
'limit': ['LIMIT'],
'join_types': ['INNER JOIN', 'LEFT JOIN', 'RIGHT JOIN'],
'aggregations': ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN'],
'operators': ['=', '>', '<', '>=', '<=', '!=', 'LIKE', 'IN'],
'logical': ['AND', 'OR', 'NOT']
}
def _get_all_columns(self) -> List[str]:
"""获取所有列名"""
columns = []
for table_name, table_info in self.schema.items():
table_columns = table_info.get('columns', [])
for col in table_columns:
if isinstance(col, dict):
columns.append(f"{table_name}.{col['name']}")
else:
columns.append(f"{table_name}.{col}")
return columns
def _get_condition_templates(self) -> List[str]:
"""获取条件模板"""
return [
"{column} {operator} {value}",
"{column} BETWEEN {value1} AND {value2}",
"{column} IN ({values})",
"{column} IS NULL",
"{column} IS NOT NULL"
]
def _calculate_state_dim(self) -> int:
"""计算状态维度"""
# 简化实现:基于问题编码 + Schema编码 + 当前SQL状态
return 512 # 假设使用512维向量表示状态
def reset(self, question: str, target_result: Any = None) -> np.ndarray:
"""重置环境"""
self.current_question = question
self.target_result = target_result
self.current_state = self._encode_initial_state(question)
return self.current_state
def _encode_initial_state(self, question: str) -> np.ndarray:
"""编码初始状态"""
# 简化实现:使用随机向量表示状态
# 实际应该使用问题编码 + Schema编码
return np.random.randn(self.state_dim)
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
"""执行动作"""
# 解码动作
decoded_action = self._decode_action(action)
# 更新SQL状态
sql_fragment = self._action_to_sql(decoded_action)
# 计算奖励
reward = self._calculate_reward(sql_fragment, decoded_action)
# 检查是否完成
done = self._is_episode_done(sql_fragment)
# 更新状态
next_state = self._update_state(sql_fragment)
info = {
'action': decoded_action,
'sql_fragment': sql_fragment,
'reward_breakdown': self._get_reward_breakdown(sql_fragment)
}
return next_state, reward, done, info
def _decode_action(self, action: int) -> Dict[str, Any]:
"""解码动作"""
# 简化实现:将动作ID映射到具体的SQL组件
action_types = list(self.action_space.keys())
action_type = action_types[action % len(action_types)]
action_values = self.action_space[action_type]
action_value = action_values[action % len(action_values)]
return {
'type': action_type,
'value': action_value
}
def _action_to_sql(self, action: Dict[str, Any]) -> str:
"""将动作转换为SQL片段"""
action_type = action['type']
action_value = action['value']
if action_type == 'select_clause':
return action_value
elif action_type == 'columns':
return action_value
elif action_type == 'from_clause':
return action_value
elif action_type == 'tables':
return action_value
else:
return action_value
def _calculate_reward(self, sql_fragment: str, action: Dict[str, Any]) -> float:
"""计算奖励"""
reward = 0.0
# 步骤惩罚
reward += self.reward_weights['step_penalty']
# 语法正确性奖励
if self._is_syntactically_correct(sql_fragment):
reward += self.reward_weights['syntax_correctness']
# 如果是完整的SQL,计算执行奖励
if self._is_complete_sql(sql_fragment):
try:
result = self.sql_executor.execute(sql_fragment)
reward += self.reward_weights['execution_success']
# 结果准确性奖励
if self.target_result is not None:
accuracy = self._calculate_result_accuracy(result, self.target_result)
reward += self.reward_weights['result_accuracy'] * accuracy
# 效率奖励
efficiency_score = self._calculate_efficiency(sql_fragment)
reward += self.reward_weights['efficiency'] * efficiency_score
except Exception:
reward -= 5.0 # 执行失败惩罚
return reward
def _is_syntactically_correct(self, sql_fragment: str) -> bool:
"""检查语法正确性"""
# 简化实现
return len(sql_fragment.strip()) > 0
def _is_complete_sql(self, sql_fragment: str) -> bool:
"""检查是否为完整SQL"""
sql_upper = sql_fragment.upper()
return 'SELECT' in sql_upper and 'FROM' in sql_upper
def _calculate_result_accuracy(self, result: Any, target: Any) -> float:
"""计算结果准确性"""
# 简化实现
if result == target:
return 1.0
return 0.0
def _calculate_efficiency(self, sql: str) -> float:
"""计算效率分数"""
# 简化实现:基于SQL长度和复杂度
length_penalty = len(sql) / 1000.0
return max(0.0, 1.0 - length_penalty)
def _is_episode_done(self, sql_fragment: str) -> bool:
"""检查回合是否结束"""
return self._is_complete_sql(sql_fragment)
def _update_state(self, sql_fragment: str) -> np.ndarray:
"""更新状态"""
# 简化实现:基于当前SQL状态更新
self.current_state = np.random.randn(self.state_dim)
return self.current_state
def _get_reward_breakdown(self, sql_fragment: str) -> Dict[str, float]:
"""获取奖励分解"""
return {
'syntax': 1.0 if self._is_syntactically_correct(sql_fragment) else 0.0,
'completeness': 1.0 if self._is_complete_sql(sql_fragment) else 0.0,
'efficiency': self._calculate_efficiency(sql_fragment)
}
class DQNAgent:
"""深度Q网络智能体"""
def __init__(self, state_dim: int, action_dim: int, config: RLConfig):
self.state_dim = state_dim
self.action_dim = action_dim
self.config = config
# 主网络和目标网络
self.q_network = self._build_network()
self.target_network = self._build_network()
self._update_target_network()
# 优化器
self.optimizer = optim.Adam(self.q_network.parameters(), lr=config.learning_rate)
# 经验回放缓冲区
self.memory = deque(maxlen=config.memory_size)
# 探索参数
self.epsilon = config.epsilon
# 训练计数器
self.training_step = 0
def _build_network(self) -> nn.Module:
"""构建神经网络"""
return nn.Sequential(
nn.Linear(self.state_dim, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, self.action_dim)
)
def select_action(self, state: np.ndarray, training: bool = True) -> int:
"""选择动作"""
if training and random.random() < self.epsilon:
# 探索:随机选择动作
return random.randint(0, self.action_dim - 1)
else:
# 利用:选择Q值最大的动作
state_tensor = torch.FloatTensor(state).unsqueeze(0)
q_values = self.q_network(state_tensor)
return q_values.argmax().item()
def store_experience(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool):
"""存储经验"""
self.memory.append((state, action, reward, next_state, done))
def train(self) -> Dict[str, float]:
"""训练智能体"""
if len(self.memory) < self.config.batch_size:
return {'loss': 0.0}
# 采样批次
batch = random.sample(self.memory, self.config.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
# 转换为张量
states = torch.FloatTensor(np.array(states))
actions = torch.LongTensor(actions)
rewards = torch.FloatTensor(rewards)
next_states = torch.FloatTensor(np.array(next_states))
dones = torch.BoolTensor(dones)
# 计算当前Q值
current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1))
# 计算目标Q值
next_q_values = self.target_network(next_states).max(1)[0].detach()
target_q_values = rewards + (self.config.gamma * next_q_values * ~dones)
# 计算损失
loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
# 反向传播
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# 更新探索率
self.epsilon = max(self.config.epsilon_min,
self.epsilon * self.config.epsilon_decay)
# 更新目标网络
self.training_step += 1
if self.training_step % self.config.target_update_freq == 0:
self._update_target_network()
return {'loss': loss.item(), 'epsilon': self.epsilon}
def _update_target_network(self):
"""更新目标网络"""
self.target_network.load_state_dict(self.q_network.state_dict())
class RLText2SQLTrainer:
"""强化学习Text2SQL训练器"""
def __init__(self, environment: Text2SQLEnvironment, config: RLConfig):
self.environment = environment
self.config = config
# 创建智能体
self.agent = DQNAgent(
state_dim=environment.state_dim,
action_dim=len(environment.action_space),
config=config
)
# 训练统计
self.training_stats = {
'episode_rewards': [],
'episode_lengths': [],
'success_rate': [],
'loss_history': []
}
def train(self, questions: List[str], num_episodes: int = 1000) -> Dict[str, Any]:
"""训练模型"""
for episode in range(num_episodes):
# 随机选择问题
question = random.choice(questions)
# 重置环境
state = self.environment.reset(question)
episode_reward = 0
episode_length = 0
while True:
# 选择动作
action = self.agent.select_action(state, training=True)
# 执行动作
next_state, reward, done, info = self.environment.step(action)
# 存储经验
self.agent.store_experience(state, action, reward, next_state, done)
# 训练智能体
train_info = self.agent.train()
# 更新统计
episode_reward += reward
episode_length += 1
if train_info['loss'] > 0:
self.training_stats['loss_history'].append(train_info['loss'])
# 更新状态
state = next_state
if done or episode_length > 50: # 最大步数限制
break
# 记录回合统计
self.training_stats['episode_rewards'].append(episode_reward)
self.training_stats['episode_lengths'].append(episode_length)
# 计算成功率
if episode % 100 == 0:
recent_rewards = self.training_stats['episode_rewards'][-100:]
success_rate = sum(1 for r in recent_rewards if r > 10) / len(recent_rewards)
self.training_stats['success_rate'].append(success_rate)
print(f"Episode {episode}, Avg Reward: {np.mean(recent_rewards):.2f}, "
f"Success Rate: {success_rate:.2f}, Epsilon: {self.agent.epsilon:.3f}")
return self.training_stats
def generate_sql(self, question: str) -> str:
"""生成SQL查询"""
state = self.environment.reset(question)
sql_parts = []
for _ in range(20): # 最大步数
action = self.agent.select_action(state, training=False)
next_state, reward, done, info = self.environment.step(action)
sql_fragment = info.get('sql_fragment', '')
if sql_fragment:
sql_parts.append(sql_fragment)
state = next_state
if done:
break
return ' '.join(sql_parts)
print("强化学习Text2SQL实现完成")
10.4 联邦学习与Text2SQL
10.4.1 联邦学习框架
import torch
import torch.nn as nn
import torch.optim as optim
from typing import Dict, List, Any, Optional, Tuple
import copy
import numpy as np
from dataclasses import dataclass
import hashlib
import json
@dataclass
class FederatedConfig:
"""联邦学习配置"""
num_clients: int = 10
num_rounds: int = 100
client_fraction: float = 0.3 # 每轮参与的客户端比例
local_epochs: int = 5
learning_rate: float = 0.001
aggregation_method: str = 'fedavg' # fedavg, fedprox, scaffold
privacy_budget: float = 1.0 # 差分隐私预算
secure_aggregation: bool = True
min_clients: int = 3 # 最小参与客户端数
class PrivacyMechanism:
"""隐私保护机制"""
def __init__(self, epsilon: float = 1.0, delta: float = 1e-5):
self.epsilon = epsilon # 隐私预算
self.delta = delta
self.noise_scale = self._calculate_noise_scale()
def _calculate_noise_scale(self) -> float:
"""计算噪声尺度"""
# 基于差分隐私理论计算
sensitivity = 1.0 # 假设L2敏感度为1
return sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon
def add_noise(self, tensor: torch.Tensor) -> torch.Tensor:
"""添加高斯噪声"""
noise = torch.normal(0, self.noise_scale, tensor.shape)
return tensor + noise
def clip_gradients(self, gradients: torch.Tensor, clip_norm: float = 1.0) -> torch.Tensor:
"""梯度裁剪"""
grad_norm = torch.norm(gradients)
if grad_norm > clip_norm:
gradients = gradients * (clip_norm / grad_norm)
return gradients
class SecureAggregator:
"""安全聚合器"""
def __init__(self, num_clients: int):
self.num_clients = num_clients
self.client_keys = self._generate_keys()
def _generate_keys(self) -> Dict[str, str]:
"""生成客户端密钥"""
keys = {}
for i in range(self.num_clients):
key = hashlib.sha256(f"client_{i}".encode()).hexdigest()
keys[f"client_{i}"] = key
return keys
def encrypt_model(self, model_state: Dict[str, torch.Tensor],
client_id: str) -> Dict[str, torch.Tensor]:
"""加密模型参数"""
# 简化实现:使用异或加密
encrypted_state = {}
key = self.client_keys.get(client_id, "default_key")
key_hash = int(hashlib.md5(key.encode()).hexdigest(), 16)
for name, param in model_state.items():
# 简单的异或加密
encrypted_param = param.clone()
encrypted_param += torch.tensor(key_hash % 1000 / 1000.0)
encrypted_state[name] = encrypted_param
return encrypted_state
def decrypt_and_aggregate(self, encrypted_models: List[Dict[str, torch.Tensor]],
client_ids: List[str]) -> Dict[str, torch.Tensor]:
"""解密并聚合模型"""
# 解密
decrypted_models = []
for encrypted_model, client_id in zip(encrypted_models, client_ids):
key = self.client_keys.get(client_id, "default_key")
key_hash = int(hashlib.md5(key.encode()).hexdigest(), 16)
decrypted_model = {}
for name, param in encrypted_model.items():
decrypted_param = param.clone()
decrypted_param -= torch.tensor(key_hash % 1000 / 1000.0)
decrypted_model[name] = decrypted_param
decrypted_models.append(decrypted_model)
# 聚合
return self._federated_averaging(decrypted_models)
def _federated_averaging(self, models: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""联邦平均"""
if not models:
return {}
# 初始化聚合模型
aggregated_model = {}
for name in models[0].keys():
aggregated_model[name] = torch.zeros_like(models[0][name])
# 平均聚合
for model in models:
for name, param in model.items():
aggregated_model[name] += param
for name in aggregated_model.keys():
aggregated_model[name] /= len(models)
return aggregated_model
class FederatedClient:
"""联邦学习客户端"""
def __init__(self, client_id: str, model: nn.Module,
train_data: Any, config: FederatedConfig):
self.client_id = client_id
self.model = model
self.train_data = train_data
self.config = config
# 优化器
self.optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
# 隐私保护
self.privacy_mechanism = PrivacyMechanism(epsilon=config.privacy_budget)
# 本地训练统计
self.training_stats = {
'local_losses': [],
'data_size': len(train_data) if hasattr(train_data, '__len__') else 0
}
def local_train(self, global_model_state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""本地训练"""
# 加载全局模型
self.model.load_state_dict(global_model_state)
# 本地训练
self.model.train()
local_losses = []
for epoch in range(self.config.local_epochs):
epoch_loss = 0.0
num_batches = 0
for batch in self.train_data:
# 假设batch包含question, sql, schema等
question, sql, schema = batch
# 前向传播
self.optimizer.zero_grad()
output = self.model(question, schema)
loss = self._calculate_loss(output, sql)
# 反向传播
loss.backward()
# 梯度裁剪(隐私保护)
for param in self.model.parameters():
if param.grad is not None:
param.grad = self.privacy_mechanism.clip_gradients(param.grad)
self.optimizer.step()
epoch_loss += loss.item()
num_batches += 1
avg_loss = epoch_loss / num_batches if num_batches > 0 else 0
local_losses.append(avg_loss)
self.training_stats['local_losses'].extend(local_losses)
# 获取更新后的模型参数
model_state = self.model.state_dict()
# 添加差分隐私噪声
if self.config.privacy_budget > 0:
for name, param in model_state.items():
model_state[name] = self.privacy_mechanism.add_noise(param)
return model_state
def _calculate_loss(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""计算损失"""
# 简化实现
criterion = nn.CrossEntropyLoss()
return criterion(output, target)
def get_data_size(self) -> int:
"""获取数据大小"""
return self.training_stats['data_size']
def get_training_stats(self) -> Dict[str, Any]:
"""获取训练统计"""
return self.training_stats
class FederatedServer:
"""联邦学习服务器"""
def __init__(self, global_model: nn.Module, config: FederatedConfig):
self.global_model = global_model
self.config = config
# 安全聚合器
self.secure_aggregator = SecureAggregator(config.num_clients)
# 客户端管理
self.clients: Dict[str, FederatedClient] = {}
# 训练统计
self.training_stats = {
'round_losses': [],
'client_participation': [],
'aggregation_time': [],
'model_accuracy': []
}
def register_client(self, client: FederatedClient):
"""注册客户端"""
self.clients[client.client_id] = client
def select_clients(self, round_num: int) -> List[str]:
"""选择参与训练的客户端"""
num_selected = max(self.config.min_clients,
int(len(self.clients) * self.config.client_fraction))
# 随机选择客户端
available_clients = list(self.clients.keys())
selected_clients = np.random.choice(
available_clients,
size=min(num_selected, len(available_clients)),
replace=False
).tolist()
return selected_clients
def federated_train(self, num_rounds: Optional[int] = None) -> Dict[str, Any]:
"""联邦训练"""
if num_rounds is None:
num_rounds = self.config.num_rounds
for round_num in range(num_rounds):
print(f"\n=== 联邦学习轮次 {round_num + 1}/{num_rounds} ===")
# 选择客户端
selected_clients = self.select_clients(round_num)
print(f"选择的客户端: {selected_clients}")
# 获取全局模型状态
global_model_state = self.global_model.state_dict()
# 客户端本地训练
client_models = []
client_data_sizes = []
for client_id in selected_clients:
client = self.clients[client_id]
# 本地训练
local_model_state = client.local_train(global_model_state)
# 安全聚合(如果启用)
if self.config.secure_aggregation:
encrypted_model = self.secure_aggregator.encrypt_model(
local_model_state, client_id
)
client_models.append(encrypted_model)
else:
client_models.append(local_model_state)
client_data_sizes.append(client.get_data_size())
# 模型聚合
if self.config.secure_aggregation:
aggregated_model = self.secure_aggregator.decrypt_and_aggregate(
client_models, selected_clients
)
else:
aggregated_model = self._weighted_federated_averaging(
client_models, client_data_sizes
)
# 更新全局模型
self.global_model.load_state_dict(aggregated_model)
# 评估模型性能
accuracy = self._evaluate_global_model()
# 记录统计信息
self.training_stats['client_participation'].append(len(selected_clients))
self.training_stats['model_accuracy'].append(accuracy)
print(f"轮次 {round_num + 1} 完成,模型准确率: {accuracy:.4f}")
return self.training_stats
def _weighted_federated_averaging(self, models: List[Dict[str, torch.Tensor]],
data_sizes: List[int]) -> Dict[str, torch.Tensor]:
"""加权联邦平均"""
if not models:
return {}
total_data_size = sum(data_sizes)
weights = [size / total_data_size for size in data_sizes]
# 初始化聚合模型
aggregated_model = {}
for name in models[0].keys():
aggregated_model[name] = torch.zeros_like(models[0][name])
# 加权平均聚合
for model, weight in zip(models, weights):
for name, param in model.items():
aggregated_model[name] += weight * param
return aggregated_model
def _evaluate_global_model(self) -> float:
"""评估全局模型"""
# 简化实现:返回随机准确率
# 实际应该在验证集上评估
return np.random.uniform(0.7, 0.95)
def get_global_model(self) -> nn.Module:
"""获取全局模型"""
return self.global_model
def get_training_stats(self) -> Dict[str, Any]:
"""获取训练统计"""
return self.training_stats
class FederatedText2SQLSystem:
"""联邦Text2SQL系统"""
def __init__(self, model_class: type, config: FederatedConfig):
self.model_class = model_class
self.config = config
# 创建全局模型
self.global_model = model_class()
# 创建联邦服务器
self.server = FederatedServer(self.global_model, config)
# 系统统计
self.system_stats = {
'total_clients': 0,
'total_rounds': 0,
'privacy_budget_used': 0.0
}
def add_client(self, client_id: str, train_data: Any) -> FederatedClient:
"""添加客户端"""
# 创建客户端模型(与全局模型相同架构)
client_model = self.model_class()
# 创建客户端
client = FederatedClient(client_id, client_model, train_data, self.config)
# 注册到服务器
self.server.register_client(client)
self.system_stats['total_clients'] += 1
return client
def start_federated_training(self, num_rounds: int = None) -> Dict[str, Any]:
"""开始联邦训练"""
print(f"开始联邦学习训练,客户端数量: {self.system_stats['total_clients']}")
# 执行联邦训练
training_stats = self.server.federated_train(num_rounds)
self.system_stats['total_rounds'] = len(training_stats['model_accuracy'])
self.system_stats['privacy_budget_used'] = self.config.privacy_budget
return {
'training_stats': training_stats,
'system_stats': self.system_stats
}
def predict(self, question: str, schema: Dict[str, Any]) -> str:
"""使用全局模型进行预测"""
global_model = self.server.get_global_model()
global_model.eval()
with torch.no_grad():
# 简化实现
output = global_model(question, schema)
# 解码输出为SQL
sql = self._decode_output(output)
return sql
def _decode_output(self, output: torch.Tensor) -> str:
"""解码模型输出为SQL"""
# 简化实现
return "SELECT * FROM table WHERE condition"
def save_global_model(self, path: str):
"""保存全局模型"""
torch.save(self.server.get_global_model().state_dict(), path)
print(f"全局模型已保存到: {path}")
def load_global_model(self, path: str):
"""加载全局模型"""
state_dict = torch.load(path)
self.server.get_global_model().load_state_dict(state_dict)
print(f"全局模型已从 {path} 加载")
print("联邦学习Text2SQL实现完成")
10.5 知识图谱增强Text2SQL
10.5.1 知识图谱构建与集成
import torch
import torch.nn as nn
from typing import Dict, List, Any, Tuple, Optional, Set
import networkx as nx
import numpy as np
from dataclasses import dataclass
import json
from collections import defaultdict
import re
@dataclass
class KnowledgeGraphConfig:
"""知识图谱配置"""
entity_embedding_dim: int = 256
relation_embedding_dim: int = 128
graph_attention_heads: int = 8
graph_layers: int = 3
max_path_length: int = 3
similarity_threshold: float = 0.8
use_external_kg: bool = True
kg_weight: float = 0.3
class Entity:
"""实体类"""
def __init__(self, entity_id: str, entity_type: str,
name: str, attributes: Dict[str, Any] = None):
self.entity_id = entity_id
self.entity_type = entity_type
self.name = name
self.attributes = attributes or {}
self.aliases = set([name])
def add_alias(self, alias: str):
"""添加别名"""
self.aliases.add(alias)
def __str__(self):
return f"Entity({self.entity_id}, {self.name}, {self.entity_type})"
def __repr__(self):
return self.__str__()
class Relation:
"""关系类"""
def __init__(self, relation_id: str, relation_type: str,
head_entity: str, tail_entity: str,
properties: Dict[str, Any] = None):
self.relation_id = relation_id
self.relation_type = relation_type
self.head_entity = head_entity
self.tail_entity = tail_entity
self.properties = properties or {}
def __str__(self):
return f"Relation({self.head_entity} --{self.relation_type}--> {self.tail_entity})"
def __repr__(self):
return self.__str__()
class KnowledgeGraph:
"""知识图谱类"""
def __init__(self):
self.entities: Dict[str, Entity] = {}
self.relations: Dict[str, Relation] = {}
self.graph = nx.MultiDiGraph()
# 索引
self.entity_name_index: Dict[str, Set[str]] = defaultdict(set)
self.entity_type_index: Dict[str, Set[str]] = defaultdict(set)
self.relation_type_index: Dict[str, Set[str]] = defaultdict(set)
def add_entity(self, entity: Entity):
"""添加实体"""
self.entities[entity.entity_id] = entity
self.graph.add_node(entity.entity_id, **entity.attributes)
# 更新索引
for alias in entity.aliases:
self.entity_name_index[alias.lower()].add(entity.entity_id)
self.entity_type_index[entity.entity_type].add(entity.entity_id)
def add_relation(self, relation: Relation):
"""添加关系"""
self.relations[relation.relation_id] = relation
self.graph.add_edge(
relation.head_entity,
relation.tail_entity,
relation_id=relation.relation_id,
relation_type=relation.relation_type,
**relation.properties
)
# 更新索引
self.relation_type_index[relation.relation_type].add(relation.relation_id)
def find_entities_by_name(self, name: str, fuzzy: bool = True) -> List[Entity]:
"""根据名称查找实体"""
name_lower = name.lower()
entity_ids = set()
# 精确匹配
if name_lower in self.entity_name_index:
entity_ids.update(self.entity_name_index[name_lower])
# 模糊匹配
if fuzzy and not entity_ids:
for indexed_name, ids in self.entity_name_index.items():
if self._calculate_similarity(name_lower, indexed_name) > 0.8:
entity_ids.update(ids)
return [self.entities[eid] for eid in entity_ids if eid in self.entities]
def find_entities_by_type(self, entity_type: str) -> List[Entity]:
"""根据类型查找实体"""
entity_ids = self.entity_type_index.get(entity_type, set())
return [self.entities[eid] for eid in entity_ids if eid in self.entities]
def find_relations_between(self, entity1_id: str, entity2_id: str) -> List[Relation]:
"""查找两个实体之间的关系"""
relations = []
if self.graph.has_edge(entity1_id, entity2_id):
edge_data = self.graph.get_edge_data(entity1_id, entity2_id)
for key, data in edge_data.items():
relation_id = data.get('relation_id')
if relation_id and relation_id in self.relations:
relations.append(self.relations[relation_id])
return relations
def find_shortest_path(self, start_entity: str, end_entity: str,
max_length: int = 3) -> List[List[str]]:
"""查找最短路径"""
try:
paths = list(nx.all_simple_paths(
self.graph, start_entity, end_entity, cutoff=max_length
))
return sorted(paths, key=len)[:5] # 返回最短的5条路径
except nx.NetworkXNoPath:
return []
def get_neighbors(self, entity_id: str, relation_types: List[str] = None) -> List[str]:
"""获取邻居实体"""
neighbors = []
for neighbor in self.graph.neighbors(entity_id):
if relation_types:
edge_data = self.graph.get_edge_data(entity_id, neighbor)
for data in edge_data.values():
if data.get('relation_type') in relation_types:
neighbors.append(neighbor)
break
else:
neighbors.append(neighbor)
return neighbors
def _calculate_similarity(self, str1: str, str2: str) -> float:
"""计算字符串相似度"""
# 简化实现:使用编辑距离
if len(str1) == 0 or len(str2) == 0:
return 0.0
max_len = max(len(str1), len(str2))
edit_distance = self._edit_distance(str1, str2)
return 1.0 - edit_distance / max_len
def _edit_distance(self, str1: str, str2: str) -> int:
"""计算编辑距离"""
m, n = len(str1), len(str2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
dp[i][0] = i
for j in range(n + 1):
dp[0][j] = j
for i in range(1, m + 1):
for j in range(1, n + 1):
if str1[i-1] == str2[j-1]:
dp[i][j] = dp[i-1][j-1]
else:
dp[i][j] = 1 + min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1])
return dp[m][n]
def get_subgraph(self, entity_ids: List[str], max_hops: int = 2) -> 'KnowledgeGraph':
"""获取子图"""
subgraph_kg = KnowledgeGraph()
# 扩展实体集合
expanded_entities = set(entity_ids)
for hop in range(max_hops):
new_entities = set()
for entity_id in expanded_entities:
neighbors = self.get_neighbors(entity_id)
new_entities.update(neighbors)
expanded_entities.update(new_entities)
# 添加实体
for entity_id in expanded_entities:
if entity_id in self.entities:
subgraph_kg.add_entity(self.entities[entity_id])
# 添加关系
for relation in self.relations.values():
if (relation.head_entity in expanded_entities and
relation.tail_entity in expanded_entities):
subgraph_kg.add_relation(relation)
return subgraph_kg
class KnowledgeGraphEmbedding(nn.Module):
"""知识图谱嵌入模型"""
def __init__(self, num_entities: int, num_relations: int, config: KnowledgeGraphConfig):
super().__init__()
self.config = config
# 实体和关系嵌入
self.entity_embeddings = nn.Embedding(num_entities, config.entity_embedding_dim)
self.relation_embeddings = nn.Embedding(num_relations, config.relation_embedding_dim)
# 图注意力网络
self.graph_attention_layers = nn.ModuleList([
GraphAttentionLayer(
config.entity_embedding_dim,
config.entity_embedding_dim // config.graph_attention_heads,
config.graph_attention_heads
) for _ in range(config.graph_layers)
])
# 输出层
self.output_projection = nn.Linear(
config.entity_embedding_dim,
config.entity_embedding_dim
)
def forward(self, entity_ids: torch.Tensor,
adjacency_matrix: torch.Tensor) -> torch.Tensor:
"""前向传播"""
# 获取实体嵌入
entity_embeds = self.entity_embeddings(entity_ids)
# 图注意力层
for gat_layer in self.graph_attention_layers:
entity_embeds = gat_layer(entity_embeds, adjacency_matrix)
# 输出投影
output = self.output_projection(entity_embeds)
return output
class GraphAttentionLayer(nn.Module):
"""图注意力层"""
def __init__(self, input_dim: int, output_dim: int, num_heads: int):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.num_heads = num_heads
# 多头注意力
self.attention_heads = nn.ModuleList([
nn.Linear(input_dim, output_dim) for _ in range(num_heads)
])
self.attention_weights = nn.ModuleList([
nn.Linear(2 * output_dim, 1) for _ in range(num_heads)
])
self.output_projection = nn.Linear(num_heads * output_dim, input_dim)
self.dropout = nn.Dropout(0.1)
self.layer_norm = nn.LayerNorm(input_dim)
def forward(self, node_features: torch.Tensor,
adjacency_matrix: torch.Tensor) -> torch.Tensor:
"""前向传播"""
batch_size, num_nodes, feature_dim = node_features.shape
head_outputs = []
for head_idx in range(self.num_heads):
# 线性变换
transformed_features = self.attention_heads[head_idx](node_features)
# 计算注意力权重
attention_scores = self._compute_attention(
transformed_features,
adjacency_matrix,
head_idx
)
# 应用注意力权重
attended_features = torch.bmm(attention_scores, transformed_features)
head_outputs.append(attended_features)
# 拼接多头输出
multi_head_output = torch.cat(head_outputs, dim=-1)
# 输出投影
output = self.output_projection(multi_head_output)
# 残差连接和层归一化
output = self.layer_norm(output + node_features)
return self.dropout(output)
def _compute_attention(self, features: torch.Tensor,
adjacency_matrix: torch.Tensor,
head_idx: int) -> torch.Tensor:
"""计算注意力权重"""
batch_size, num_nodes, feature_dim = features.shape
# 计算所有节点对的特征拼接
features_i = features.unsqueeze(2).expand(-1, -1, num_nodes, -1)
features_j = features.unsqueeze(1).expand(-1, num_nodes, -1, -1)
concatenated_features = torch.cat([features_i, features_j], dim=-1)
# 计算注意力分数
attention_scores = self.attention_weights[head_idx](
concatenated_features
).squeeze(-1)
# 应用邻接矩阵掩码
attention_scores = attention_scores.masked_fill(
adjacency_matrix == 0, float('-inf')
)
# Softmax归一化
attention_weights = torch.softmax(attention_scores, dim=-1)
return attention_weights
class KnowledgeEnhancedText2SQL(nn.Module):
"""知识增强Text2SQL模型"""
def __init__(self, base_model: nn.Module, knowledge_graph: KnowledgeGraph,
config: KnowledgeGraphConfig):
super().__init__()
self.base_model = base_model
self.knowledge_graph = knowledge_graph
self.config = config
# 知识图谱嵌入
self.kg_embedding = KnowledgeGraphEmbedding(
len(knowledge_graph.entities),
len(knowledge_graph.relations),
config
)
# 知识融合层
self.knowledge_fusion = nn.MultiheadAttention(
embed_dim=config.entity_embedding_dim,
num_heads=config.graph_attention_heads
)
# 输出融合
self.output_fusion = nn.Linear(
config.entity_embedding_dim * 2,
config.entity_embedding_dim
)
def forward(self, question: str, schema: Dict[str, Any]) -> torch.Tensor:
"""前向传播"""
# 基础模型输出
base_output = self.base_model(question, schema)
# 提取相关知识
relevant_knowledge = self._extract_relevant_knowledge(question, schema)
if relevant_knowledge:
# 知识图谱嵌入
kg_embeddings = self._get_knowledge_embeddings(relevant_knowledge)
# 知识融合
fused_output = self._fuse_knowledge(base_output, kg_embeddings)
return fused_output
else:
return base_output
def _extract_relevant_knowledge(self, question: str,
schema: Dict[str, Any]) -> List[Entity]:
"""提取相关知识"""
relevant_entities = []
# 从问题中提取实体
question_entities = self._extract_entities_from_text(question)
relevant_entities.extend(question_entities)
# 从Schema中提取实体
schema_entities = self._extract_entities_from_schema(schema)
relevant_entities.extend(schema_entities)
# 扩展相关实体
expanded_entities = self._expand_entities(relevant_entities)
return expanded_entities
def _extract_entities_from_text(self, text: str) -> List[Entity]:
"""从文本中提取实体"""
entities = []
words = text.lower().split()
for word in words:
found_entities = self.knowledge_graph.find_entities_by_name(word)
entities.extend(found_entities)
return entities
def _extract_entities_from_schema(self, schema: Dict[str, Any]) -> List[Entity]:
"""从Schema中提取实体"""
entities = []
# 从表名和列名中提取实体
for table_name, table_info in schema.items():
# 表名实体
table_entities = self.knowledge_graph.find_entities_by_name(table_name)
entities.extend(table_entities)
# 列名实体
columns = table_info.get('columns', [])
for column in columns:
if isinstance(column, dict):
column_name = column.get('name', '')
else:
column_name = str(column)
column_entities = self.knowledge_graph.find_entities_by_name(column_name)
entities.extend(column_entities)
return entities
def _expand_entities(self, entities: List[Entity]) -> List[Entity]:
"""扩展相关实体"""
expanded = set(entities)
for entity in entities:
# 获取邻居实体
neighbors = self.knowledge_graph.get_neighbors(entity.entity_id)
for neighbor_id in neighbors[:5]: # 限制邻居数量
if neighbor_id in self.knowledge_graph.entities:
expanded.add(self.knowledge_graph.entities[neighbor_id])
return list(expanded)
def _get_knowledge_embeddings(self, entities: List[Entity]) -> torch.Tensor:
"""获取知识嵌入"""
# 简化实现:返回随机嵌入
num_entities = len(entities)
embeddings = torch.randn(num_entities, self.config.entity_embedding_dim)
return embeddings
def _fuse_knowledge(self, base_output: torch.Tensor,
kg_embeddings: torch.Tensor) -> torch.Tensor:
"""融合知识"""
# 简化实现:使用注意力机制融合
if kg_embeddings.size(0) == 0:
return base_output
# 扩展维度以匹配
base_expanded = base_output.unsqueeze(0)
kg_expanded = kg_embeddings.unsqueeze(0)
# 注意力融合
fused_features, _ = self.knowledge_fusion(
base_expanded, kg_expanded, kg_expanded
)
# 拼接和投影
concatenated = torch.cat([base_output, fused_features.squeeze(0)], dim=-1)
output = self.output_fusion(concatenated)
return output
print("知识图谱增强Text2SQL实现完成")
10.6 本章总结
本章深入探讨了Text2SQL领域的高级技术与前沿发展,涵盖了以下几个重要方向:
10.6.1 技术总结
多模态Text2SQL
- 图表理解与SQL生成:实现了从图表到SQL的转换能力
- 语音到SQL转换:支持语音输入的Text2SQL系统
- 多模态融合:整合文本、图像、语音等多种输入模式
大语言模型集成
- GPT系列模型应用:利用预训练大模型的强大能力
- 提示词工程优化:设计高效的提示词模板和策略
- 上下文学习:通过示例学习提升模型性能
强化学习方法
- 环境建模:构建Text2SQL强化学习环境
- 奖励函数设计:平衡语法正确性、执行成功率和效率
- 深度Q网络:实现基于DQN的SQL生成策略
联邦学习框架
- 隐私保护:实现差分隐私和安全聚合机制
- 分布式训练:支持多客户端协作训练
- 模型聚合:采用联邦平均等聚合策略
知识图谱增强
- 知识图谱构建:建立领域知识图谱
- 图嵌入学习:学习实体和关系的向量表示
- 知识融合:将外部知识融入Text2SQL模型
10.6.2 技术优势
- 多样性:支持多种输入模式和应用场景
- 智能化:利用大模型和强化学习提升智能水平
- 隐私性:通过联邦学习保护数据隐私
- 准确性:通过知识增强提升理解准确性
- 可扩展性:支持大规模分布式部署
10.6.3 应用前景
这些前沿技术为Text2SQL系统带来了广阔的应用前景:
- 企业级应用:支持复杂的企业数据分析需求
- 智能助手:构建更智能的数据查询助手
- 教育培训:辅助SQL学习和数据库教学
- 科研工具:支持科研数据的快速查询分析
- 跨领域应用:适应不同行业的特定需求
10.6.4 发展趋势
未来Text2SQL技术的发展趋势包括:
- 更强的泛化能力:适应更多数据库类型和查询场景
- 更好的交互体验:支持自然语言对话式查询
- 更高的准确性:通过多技术融合提升准确率
- 更强的实时性:支持流式数据的实时查询
- 更广的应用范围:扩展到更多垂直领域
通过本章的学习,读者应该能够: - 理解Text2SQL领域的前沿技术发展 - 掌握多模态、强化学习、联邦学习等高级方法 - 了解知识图谱在Text2SQL中的应用 - 具备构建高级Text2SQL系统的能力 - 把握技术发展趋势和应用前景
这些技术的掌握将为读者在Text2SQL领域的深入研究和实际应用奠定坚实基础。 “`