14.1 分库分表概述
14.1.1 什么是分库分表
分库分表是一种数据库水平扩展技术,通过将数据分散到多个数据库实例或表中来解决单一数据库的性能瓶颈和存储限制。
分库(Database Sharding): - 将数据分散到多个数据库实例中 - 每个数据库实例运行在不同的服务器上 - 提高并发处理能力和存储容量
分表(Table Partitioning): - 将单个大表分割成多个小表 - 可以在同一数据库实例中进行 - 减少单表数据量,提高查询性能
14.1.2 为什么需要分库分表
性能瓶颈
- 单表数据量过大(通常超过1000万行)
- 查询响应时间过长
- 并发访问压力大
- 索引维护成本高
存储限制
- 单机存储容量不足
- 备份和恢复时间过长
- 硬件升级成本高
可用性要求
- 避免单点故障
- 提高系统容错能力
- 支持水平扩展
14.1.3 分库分表策略分析器
## 14.3 垂直分片策略
### 14.3.1 垂直分库
垂直分库是将不同的表分布到不同的数据库中,通常按业务模块划分。
**优点**:
- 业务隔离性好
- 便于团队协作
- 减少单库压力
- 支持异构数据库
**缺点**:
- 跨库事务复杂
- 数据一致性难保证
- 跨库查询性能差
```sql
-- 垂直分库示例:电商系统按业务模块分库
-- 用户数据库 (user_db)
CREATE DATABASE user_db;
USE user_db;
CREATE TABLE users (
user_id BIGINT PRIMARY KEY AUTO_INCREMENT,
username VARCHAR(50) UNIQUE NOT NULL,
email VARCHAR(100) UNIQUE NOT NULL,
password_hash VARCHAR(255) NOT NULL,
phone VARCHAR(20),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_email (email),
INDEX idx_phone (phone)
) ENGINE=InnoDB;
CREATE TABLE user_profiles (
user_id BIGINT PRIMARY KEY,
real_name VARCHAR(100),
gender ENUM('M', 'F', 'U') DEFAULT 'U',
birthday DATE,
avatar_url VARCHAR(255),
bio TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
) ENGINE=InnoDB;
-- 商品数据库 (product_db)
CREATE DATABASE product_db;
USE product_db;
CREATE TABLE categories (
category_id INT PRIMARY KEY AUTO_INCREMENT,
name VARCHAR(100) NOT NULL,
parent_id INT,
level INT DEFAULT 1,
sort_order INT DEFAULT 0,
is_active BOOLEAN DEFAULT TRUE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_parent_id (parent_id),
INDEX idx_level (level)
) ENGINE=InnoDB;
CREATE TABLE products (
product_id BIGINT PRIMARY KEY AUTO_INCREMENT,
category_id INT NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT,
price DECIMAL(10,2) NOT NULL,
stock_quantity INT DEFAULT 0,
sku VARCHAR(100) UNIQUE,
status ENUM('active', 'inactive', 'deleted') DEFAULT 'active',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_category_id (category_id),
INDEX idx_sku (sku),
INDEX idx_status (status),
INDEX idx_price (price)
) ENGINE=InnoDB;
-- 订单数据库 (order_db)
CREATE DATABASE order_db;
USE order_db;
CREATE TABLE orders (
order_id BIGINT PRIMARY KEY AUTO_INCREMENT,
user_id BIGINT NOT NULL,
order_no VARCHAR(32) UNIQUE NOT NULL,
total_amount DECIMAL(10,2) NOT NULL,
status ENUM('pending', 'paid', 'shipped', 'delivered', 'cancelled') DEFAULT 'pending',
payment_method VARCHAR(50),
shipping_address TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_user_id (user_id),
INDEX idx_order_no (order_no),
INDEX idx_status (status),
INDEX idx_created_at (created_at)
) ENGINE=InnoDB;
CREATE TABLE order_items (
item_id BIGINT PRIMARY KEY AUTO_INCREMENT,
order_id BIGINT NOT NULL,
product_id BIGINT NOT NULL,
quantity INT NOT NULL,
unit_price DECIMAL(10,2) NOT NULL,
total_price DECIMAL(10,2) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (order_id) REFERENCES orders(order_id) ON DELETE CASCADE,
INDEX idx_order_id (order_id),
INDEX idx_product_id (product_id)
) ENGINE=InnoDB;
14.3.2 垂直分表
垂直分表是将一个表的不同列分布到不同的表中,通常按访问频率划分。
-- 垂直分表示例:用户表按访问频率分表
-- 用户基本信息表(高频访问)
CREATE TABLE user_basic (
user_id BIGINT PRIMARY KEY,
username VARCHAR(50) NOT NULL,
email VARCHAR(100) NOT NULL,
status ENUM('active', 'inactive', 'banned') DEFAULT 'active',
last_login_at TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_username (username),
INDEX idx_email (email),
INDEX idx_status (status)
) ENGINE=InnoDB;
-- 用户详细信息表(低频访问)
CREATE TABLE user_detail (
user_id BIGINT PRIMARY KEY,
real_name VARCHAR(100),
id_card VARCHAR(18),
address TEXT,
company VARCHAR(200),
education VARCHAR(50),
income_range VARCHAR(20),
interests TEXT,
preferences JSON,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES user_basic(user_id) ON DELETE CASCADE
) ENGINE=InnoDB;
-- 用户统计信息表(分析用)
CREATE TABLE user_stats (
user_id BIGINT PRIMARY KEY,
login_count INT DEFAULT 0,
order_count INT DEFAULT 0,
total_spent DECIMAL(12,2) DEFAULT 0.00,
avg_order_amount DECIMAL(10,2) DEFAULT 0.00,
last_order_at TIMESTAMP NULL,
favorite_category_id INT,
risk_score DECIMAL(3,2) DEFAULT 0.00,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES user_basic(user_id) ON DELETE CASCADE,
INDEX idx_login_count (login_count),
INDEX idx_order_count (order_count),
INDEX idx_total_spent (total_spent)
) ENGINE=InnoDB;
-- 垂直分表查询视图
CREATE VIEW user_complete AS
SELECT
b.user_id,
b.username,
b.email,
b.status,
b.last_login_at,
d.real_name,
d.address,
d.company,
s.login_count,
s.order_count,
s.total_spent,
b.created_at
FROM user_basic b
LEFT JOIN user_detail d ON b.user_id = d.user_id
LEFT JOIN user_stats s ON b.user_id = s.user_id;
14.4 分库分表实施方案
14.4.1 自动化分库分表部署脚本
#!/bin/bash
# MySQL分库分表自动化部署脚本
# 支持水平分片和垂直分片的自动部署
set -e
# 配置参数
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
CONFIG_FILE="${SCRIPT_DIR}/sharding_config.json"
LOG_FILE="${SCRIPT_DIR}/sharding_deployment.log"
BACKUP_DIR="${SCRIPT_DIR}/backup"
SQL_DIR="${SCRIPT_DIR}/sql"
# 日志函数
log() {
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1" | tee -a "$LOG_FILE"
}
error() {
echo "[$(date '+%Y-%m-%d %H:%M:%S')] ERROR: $1" | tee -a "$LOG_FILE"
exit 1
}
# 检查依赖
check_dependencies() {
log "检查依赖项..."
command -v mysql >/dev/null 2>&1 || error "MySQL客户端未安装"
command -v jq >/dev/null 2>&1 || error "jq未安装"
if [ ! -f "$CONFIG_FILE" ]; then
error "配置文件不存在: $CONFIG_FILE"
fi
log "依赖检查完成"
}
# 创建目录结构
setup_directories() {
log "创建目录结构..."
mkdir -p "$BACKUP_DIR"
mkdir -p "$SQL_DIR"
mkdir -p "${SCRIPT_DIR}/scripts"
mkdir -p "${SCRIPT_DIR}/monitoring"
log "目录结构创建完成"
}
# 读取配置
read_config() {
log "读取配置文件..."
# 验证JSON格式
if ! jq empty "$CONFIG_FILE" 2>/dev/null; then
error "配置文件JSON格式错误"
fi
# 读取基本配置
SHARDING_TYPE=$(jq -r '.sharding_type' "$CONFIG_FILE")
SOURCE_DB=$(jq -r '.source.database' "$CONFIG_FILE")
SOURCE_HOST=$(jq -r '.source.host' "$CONFIG_FILE")
SOURCE_PORT=$(jq -r '.source.port' "$CONFIG_FILE")
SOURCE_USER=$(jq -r '.source.user' "$CONFIG_FILE")
SOURCE_PASS=$(jq -r '.source.password' "$CONFIG_FILE")
log "配置读取完成: 分片类型=$SHARDING_TYPE, 源数据库=$SOURCE_DB"
}
# 备份原始数据
backup_original_data() {
log "备份原始数据..."
local backup_file="${BACKUP_DIR}/original_${SOURCE_DB}_$(date +%Y%m%d_%H%M%S).sql"
mysqldump -h"$SOURCE_HOST" -P"$SOURCE_PORT" -u"$SOURCE_USER" -p"$SOURCE_PASS" \
--single-transaction --routines --triggers --events \
"$SOURCE_DB" > "$backup_file"
if [ $? -eq 0 ]; then
log "数据备份完成: $backup_file"
else
error "数据备份失败"
fi
}
# 创建分片数据库
create_shard_databases() {
log "创建分片数据库..."
local shard_count=$(jq -r '.shards | length' "$CONFIG_FILE")
for ((i=0; i<shard_count; i++)); do
local shard_config=$(jq -r ".shards[$i]" "$CONFIG_FILE")
local shard_name=$(echo "$shard_config" | jq -r '.database')
local shard_host=$(echo "$shard_config" | jq -r '.host')
local shard_port=$(echo "$shard_config" | jq -r '.port')
local shard_user=$(echo "$shard_config" | jq -r '.user')
local shard_pass=$(echo "$shard_config" | jq -r '.password')
log "创建分片数据库: $shard_name @ $shard_host:$shard_port"
# 创建数据库
mysql -h"$shard_host" -P"$shard_port" -u"$shard_user" -p"$shard_pass" \
-e "CREATE DATABASE IF NOT EXISTS \`$shard_name\` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;"
if [ $? -eq 0 ]; then
log "分片数据库创建成功: $shard_name"
else
error "分片数据库创建失败: $shard_name"
fi
done
}
# 生成分片表结构
generate_shard_schema() {
log "生成分片表结构..."
local tables=$(jq -r '.tables[] | .name' "$CONFIG_FILE")
for table in $tables; do
log "处理表: $table"
# 获取原表结构
local create_sql=$(mysql -h"$SOURCE_HOST" -P"$SOURCE_PORT" -u"$SOURCE_USER" -p"$SOURCE_PASS" \
-e "SHOW CREATE TABLE \`$SOURCE_DB\`.\`$table\`;" | tail -n +2 | cut -f2)
# 获取表的分片配置
local table_config=$(jq -r ".tables[] | select(.name == \"$table\")" "$CONFIG_FILE")
local shard_strategy=$(echo "$table_config" | jq -r '.strategy')
local shard_key=$(echo "$table_config" | jq -r '.shard_key')
# 为每个分片生成表结构
local shard_count=$(jq -r '.shards | length' "$CONFIG_FILE")
for ((i=0; i<shard_count; i++)); do
local shard_config=$(jq -r ".shards[$i]" "$CONFIG_FILE")
local shard_name=$(echo "$shard_config" | jq -r '.database')
# 生成分片表名
local shard_table_name
if [ "$SHARDING_TYPE" = "horizontal" ]; then
shard_table_name="${table}_${i}"
else
shard_table_name="$table"
fi
# 修改表结构SQL
local shard_create_sql=$(echo "$create_sql" | sed "s/CREATE TABLE \`$table\`/CREATE TABLE \`$shard_table_name\`/")
# 添加分片约束(如果是范围分片)
if [ "$shard_strategy" = "range" ]; then
local range_config=$(echo "$table_config" | jq -r ".ranges[$i]")
local min_val=$(echo "$range_config" | jq -r '.min')
local max_val=$(echo "$range_config" | jq -r '.max')
if [ "$min_val" != "null" ] && [ "$max_val" != "null" ]; then
shard_create_sql=$(echo "$shard_create_sql" | sed "s/) ENGINE=/,\n CONSTRAINT chk_${shard_key}_range CHECK ($shard_key BETWEEN $min_val AND $max_val)\n) ENGINE=/")
fi
fi
# 保存SQL到文件
local sql_file="${SQL_DIR}/${shard_name}_${shard_table_name}.sql"
echo "USE \`$shard_name\`;" > "$sql_file"
echo "$shard_create_sql;" >> "$sql_file"
log "生成分片表结构: $shard_name.$shard_table_name"
done
done
}
# 执行分片表创建
execute_shard_creation() {
log "执行分片表创建..."
local shard_count=$(jq -r '.shards | length' "$CONFIG_FILE")
for ((i=0; i<shard_count; i++)); do
local shard_config=$(jq -r ".shards[$i]" "$CONFIG_FILE")
local shard_name=$(echo "$shard_config" | jq -r '.database')
local shard_host=$(echo "$shard_config" | jq -r '.host')
local shard_port=$(echo "$shard_config" | jq -r '.port')
local shard_user=$(echo "$shard_config" | jq -r '.user')
local shard_pass=$(echo "$shard_config" | jq -r '.password')
# 执行该分片的所有SQL文件
for sql_file in "${SQL_DIR}/${shard_name}_"*.sql; do
if [ -f "$sql_file" ]; then
log "执行SQL文件: $sql_file"
mysql -h"$shard_host" -P"$shard_port" -u"$shard_user" -p"$shard_pass" < "$sql_file"
if [ $? -eq 0 ]; then
log "SQL执行成功: $sql_file"
else
error "SQL执行失败: $sql_file"
fi
fi
done
done
}
# 生成数据迁移脚本
generate_migration_scripts() {
log "生成数据迁移脚本..."
local migration_script="${SCRIPT_DIR}/scripts/migrate_data.py"
cat > "$migration_script" << 'EOF'
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MySQL分库分表数据迁移脚本
"""
import json
import sys
import pymysql
from typing import Dict, List, Any
import hashlib
from datetime import datetime
class ShardingMigrator:
def __init__(self, config_file: str):
with open(config_file, 'r') as f:
self.config = json.load(f)
self.source_conn = None
self.shard_conns = {}
def connect_source(self):
"""连接源数据库"""
source_config = self.config['source']
self.source_conn = pymysql.connect(
host=source_config['host'],
port=source_config['port'],
user=source_config['user'],
password=source_config['password'],
database=source_config['database'],
charset='utf8mb4'
)
def connect_shards(self):
"""连接所有分片数据库"""
for i, shard_config in enumerate(self.config['shards']):
conn = pymysql.connect(
host=shard_config['host'],
port=shard_config['port'],
user=shard_config['user'],
password=shard_config['password'],
database=shard_config['database'],
charset='utf8mb4'
)
self.shard_conns[i] = conn
def get_shard_index(self, table_config: Dict, row: Dict) -> int:
"""根据分片策略计算分片索引"""
strategy = table_config['strategy']
shard_key = table_config['shard_key']
shard_value = row[shard_key]
if strategy == 'hash':
# 哈希分片
hash_value = int(hashlib.md5(str(shard_value).encode()).hexdigest(), 16)
return hash_value % len(self.config['shards'])
elif strategy == 'range':
# 范围分片
for i, range_config in enumerate(table_config['ranges']):
if range_config['min'] <= shard_value <= range_config['max']:
return i
return 0 # 默认分片
else:
return 0 # 默认分片
def migrate_table(self, table_name: str):
"""迁移单个表"""
print(f"开始迁移表: {table_name}")
# 获取表配置
table_config = None
for table in self.config['tables']:
if table['name'] == table_name:
table_config = table
break
if not table_config:
print(f"未找到表配置: {table_name}")
return
# 查询源表数据
with self.source_conn.cursor(pymysql.cursors.DictCursor) as cursor:
cursor.execute(f"SELECT * FROM `{table_name}`")
batch_size = 1000
batch_count = 0
while True:
rows = cursor.fetchmany(batch_size)
if not rows:
break
batch_count += 1
print(f"处理批次 {batch_count}, 记录数: {len(rows)}")
# 按分片分组数据
shard_data = {}
for row in rows:
shard_index = self.get_shard_index(table_config, row)
if shard_index not in shard_data:
shard_data[shard_index] = []
shard_data[shard_index].append(row)
# 插入到对应分片
for shard_index, shard_rows in shard_data.items():
self.insert_to_shard(table_config, shard_index, shard_rows)
print(f"表 {table_name} 迁移完成")
def insert_to_shard(self, table_config: Dict, shard_index: int, rows: List[Dict]):
"""插入数据到分片"""
if not rows:
return
table_name = table_config['name']
shard_table_name = f"{table_name}_{shard_index}" if self.config['sharding_type'] == 'horizontal' else table_name
conn = self.shard_conns[shard_index]
# 构建插入SQL
columns = list(rows[0].keys())
placeholders = ', '.join(['%s'] * len(columns))
sql = f"INSERT INTO `{shard_table_name}` (`{'`, `'.join(columns)}`) VALUES ({placeholders})"
# 批量插入
values = []
for row in rows:
values.append([row[col] for col in columns])
with conn.cursor() as cursor:
cursor.executemany(sql, values)
conn.commit()
def migrate_all(self):
"""迁移所有表"""
self.connect_source()
self.connect_shards()
try:
for table_config in self.config['tables']:
self.migrate_table(table_config['name'])
finally:
if self.source_conn:
self.source_conn.close()
for conn in self.shard_conns.values():
conn.close()
if __name__ == '__main__':
if len(sys.argv) != 2:
print("用法: python migrate_data.py <config_file>")
sys.exit(1)
migrator = ShardingMigrator(sys.argv[1])
migrator.migrate_all()
EOF
chmod +x "$migration_script"
log "数据迁移脚本生成完成: $migration_script"
}
# 生成分片路由中间件
generate_routing_middleware() {
log "生成分片路由中间件..."
local middleware_script="${SCRIPT_DIR}/scripts/sharding_middleware.py"
cat > "$middleware_script" << 'EOF'
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MySQL分库分表路由中间件
"""
import json
import re
import hashlib
from typing import Dict, List, Tuple, Optional
import pymysql
from pymysql.cursors import DictCursor
class ShardingRouter:
def __init__(self, config_file: str):
with open(config_file, 'r') as f:
self.config = json.load(f)
self.connections = {}
self._init_connections()
def _init_connections(self):
"""初始化数据库连接池"""
for i, shard_config in enumerate(self.config['shards']):
self.connections[i] = pymysql.connect(
host=shard_config['host'],
port=shard_config['port'],
user=shard_config['user'],
password=shard_config['password'],
database=shard_config['database'],
charset='utf8mb4',
autocommit=True
)
def parse_sql(self, sql: str) -> Dict:
"""解析SQL语句"""
sql = sql.strip()
# 提取操作类型
operation = sql.split()[0].upper()
# 提取表名
table_pattern = r'(?:FROM|INTO|UPDATE|TABLE)\s+`?([a-zA-Z_][a-zA-Z0-9_]*)`?'
table_match = re.search(table_pattern, sql, re.IGNORECASE)
table_name = table_match.group(1) if table_match else None
# 提取WHERE条件中的分片键值
shard_key_values = {}
for table_config in self.config['tables']:
if table_config['name'] == table_name:
shard_key = table_config['shard_key']
# 查找分片键的值
patterns = [
rf'{shard_key}\s*=\s*([\'\"]?)([^\s\'\";]+)\1',
rf'{shard_key}\s+IN\s*\(([^)]+)\)',
rf'{shard_key}\s+BETWEEN\s+([^\s]+)\s+AND\s+([^\s]+)'
]
for pattern in patterns:
match = re.search(pattern, sql, re.IGNORECASE)
if match:
if 'BETWEEN' in pattern:
shard_key_values[shard_key] = [match.group(1), match.group(2)]
elif 'IN' in pattern:
values = [v.strip().strip('\'"') for v in match.group(1).split(',')]
shard_key_values[shard_key] = values
else:
shard_key_values[shard_key] = [match.group(2)]
break
return {
'operation': operation,
'table': table_name,
'shard_key_values': shard_key_values
}
def get_shard_indices(self, table_name: str, shard_key_values: Dict) -> List[int]:
"""获取需要查询的分片索引"""
# 获取表配置
table_config = None
for table in self.config['tables']:
if table['name'] == table_name:
table_config = table
break
if not table_config:
return list(range(len(self.config['shards']))) # 查询所有分片
shard_key = table_config['shard_key']
strategy = table_config['strategy']
if shard_key not in shard_key_values:
return list(range(len(self.config['shards']))) # 查询所有分片
values = shard_key_values[shard_key]
shard_indices = set()
for value in values:
if strategy == 'hash':
hash_value = int(hashlib.md5(str(value).encode()).hexdigest(), 16)
shard_index = hash_value % len(self.config['shards'])
shard_indices.add(shard_index)
elif strategy == 'range':
for i, range_config in enumerate(table_config['ranges']):
if range_config['min'] <= int(value) <= range_config['max']:
shard_indices.add(i)
break
return list(shard_indices) if shard_indices else [0]
def execute_query(self, sql: str) -> List[Dict]:
"""执行查询"""
parsed = self.parse_sql(sql)
if parsed['operation'] in ['SELECT']:
return self._execute_select(sql, parsed)
elif parsed['operation'] in ['INSERT', 'UPDATE', 'DELETE']:
return self._execute_modify(sql, parsed)
else:
# DDL语句在所有分片执行
return self._execute_on_all_shards(sql)
def _execute_select(self, sql: str, parsed: Dict) -> List[Dict]:
"""执行SELECT查询"""
table_name = parsed['table']
shard_indices = self.get_shard_indices(table_name, parsed['shard_key_values'])
results = []
for shard_index in shard_indices:
# 修改SQL中的表名
shard_table_name = f"{table_name}_{shard_index}" if self.config['sharding_type'] == 'horizontal' else table_name
shard_sql = sql.replace(f'`{table_name}`', f'`{shard_table_name}`')
shard_sql = sql.replace(f' {table_name} ', f' {shard_table_name} ')
conn = self.connections[shard_index]
with conn.cursor(DictCursor) as cursor:
cursor.execute(shard_sql)
shard_results = cursor.fetchall()
results.extend(shard_results)
return results
def _execute_modify(self, sql: str, parsed: Dict) -> List[Dict]:
"""执行修改操作"""
table_name = parsed['table']
shard_indices = self.get_shard_indices(table_name, parsed['shard_key_values'])
affected_rows = 0
for shard_index in shard_indices:
# 修改SQL中的表名
shard_table_name = f"{table_name}_{shard_index}" if self.config['sharding_type'] == 'horizontal' else table_name
shard_sql = sql.replace(f'`{table_name}`', f'`{shard_table_name}`')
shard_sql = sql.replace(f' {table_name} ', f' {shard_table_name} ')
conn = self.connections[shard_index]
with conn.cursor() as cursor:
cursor.execute(shard_sql)
affected_rows += cursor.rowcount
return [{'affected_rows': affected_rows}]
def _execute_on_all_shards(self, sql: str) -> List[Dict]:
"""在所有分片执行SQL"""
results = []
for shard_index, conn in self.connections.items():
with conn.cursor(DictCursor) as cursor:
cursor.execute(sql)
if cursor.description:
results.extend(cursor.fetchall())
return results
def close_connections(self):
"""关闭所有连接"""
for conn in self.connections.values():
conn.close()
# Flask Web API示例
from flask import Flask, request, jsonify
app = Flask(__name__)
router = ShardingRouter('sharding_config.json')
@app.route('/query', methods=['POST'])
def execute_query():
try:
sql = request.json.get('sql')
if not sql:
return jsonify({'error': 'SQL语句不能为空'}), 400
results = router.execute_query(sql)
return jsonify({'results': results})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({'status': 'ok'})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8080, debug=True)
EOF
chmod +x "$middleware_script"
log "分片路由中间件生成完成: $middleware_script"
}
# 生成监控脚本
generate_monitoring_scripts() {
log "生成监控脚本..."
local monitor_script="${SCRIPT_DIR}/monitoring/shard_monitor.py"
cat > "$monitor_script" << 'EOF'
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MySQL分库分表监控脚本
"""
import json
import time
import pymysql
from datetime import datetime
from typing import Dict, List
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
class ShardingMonitor:
def __init__(self, config_file: str):
with open(config_file, 'r') as f:
self.config = json.load(f)
self.connections = {}
self.alerts = []
def connect_shards(self):
"""连接所有分片"""
for i, shard_config in enumerate(self.config['shards']):
try:
conn = pymysql.connect(
host=shard_config['host'],
port=shard_config['port'],
user=shard_config['user'],
password=shard_config['password'],
database=shard_config['database'],
charset='utf8mb4'
)
self.connections[i] = conn
except Exception as e:
self.alerts.append(f"分片 {i} 连接失败: {e}")
def check_shard_health(self) -> Dict:
"""检查分片健康状态"""
health_status = {
'timestamp': datetime.now().isoformat(),
'total_shards': len(self.config['shards']),
'healthy_shards': 0,
'shard_details': []
}
for i, conn in self.connections.items():
shard_detail = {
'shard_index': i,
'status': 'unknown',
'response_time_ms': 0,
'connections': 0,
'queries_per_second': 0,
'data_size_mb': 0
}
try:
start_time = time.time()
with conn.cursor(pymysql.cursors.DictCursor) as cursor:
# 检查连接状态
cursor.execute("SELECT 1")
# 获取连接数
cursor.execute("SHOW STATUS LIKE 'Threads_connected'")
connections = cursor.fetchone()['Value']
# 获取QPS
cursor.execute("SHOW STATUS LIKE 'Queries'")
queries = cursor.fetchone()['Value']
# 获取数据库大小
cursor.execute(f"""
SELECT ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) as size_mb
FROM information_schema.tables
WHERE table_schema = '{self.config['shards'][i]['database']}'
""")
size_result = cursor.fetchone()
data_size = size_result['size_mb'] if size_result['size_mb'] else 0
response_time = (time.time() - start_time) * 1000
shard_detail.update({
'status': 'healthy',
'response_time_ms': round(response_time, 2),
'connections': int(connections),
'data_size_mb': float(data_size)
})
health_status['healthy_shards'] += 1
except Exception as e:
shard_detail['status'] = f'error: {e}'
self.alerts.append(f"分片 {i} 健康检查失败: {e}")
health_status['shard_details'].append(shard_detail)
return health_status
def check_data_balance(self) -> Dict:
"""检查数据平衡性"""
balance_status = {
'timestamp': datetime.now().isoformat(),
'tables': []
}
for table_config in self.config['tables']:
table_name = table_config['name']
table_balance = {
'table_name': table_name,
'shard_counts': [],
'balance_ratio': 0.0,
'recommendation': ''
}
shard_counts = []
for i, conn in self.connections.items():
try:
shard_table_name = f"{table_name}_{i}" if self.config['sharding_type'] == 'horizontal' else table_name
with conn.cursor() as cursor:
cursor.execute(f"SELECT COUNT(*) as count FROM `{shard_table_name}`")
count = cursor.fetchone()[0]
shard_counts.append(count)
except Exception as e:
shard_counts.append(0)
self.alerts.append(f"表 {table_name} 分片 {i} 计数失败: {e}")
table_balance['shard_counts'] = shard_counts
# 计算平衡比例
if shard_counts and max(shard_counts) > 0:
balance_ratio = min(shard_counts) / max(shard_counts)
table_balance['balance_ratio'] = round(balance_ratio, 3)
if balance_ratio < 0.7:
table_balance['recommendation'] = '数据分布不均,建议重新平衡'
self.alerts.append(f"表 {table_name} 数据分布不均,平衡比例: {balance_ratio:.3f}")
else:
table_balance['recommendation'] = '数据分布良好'
balance_status['tables'].append(table_balance)
return balance_status
def check_performance_metrics(self) -> Dict:
"""检查性能指标"""
performance_status = {
'timestamp': datetime.now().isoformat(),
'metrics': []
}
for i, conn in self.connections.items():
shard_metrics = {
'shard_index': i,
'slow_queries': 0,
'avg_query_time': 0.0,
'lock_waits': 0,
'deadlocks': 0
}
try:
with conn.cursor(pymysql.cursors.DictCursor) as cursor:
# 慢查询数量
cursor.execute("SHOW STATUS LIKE 'Slow_queries'")
slow_queries = cursor.fetchone()['Value']
# 锁等待
cursor.execute("SHOW STATUS LIKE 'Table_locks_waited'")
lock_waits = cursor.fetchone()['Value']
# 死锁
cursor.execute("SHOW STATUS LIKE 'Innodb_deadlocks'")
deadlocks = cursor.fetchone()['Value']
shard_metrics.update({
'slow_queries': int(slow_queries),
'lock_waits': int(lock_waits),
'deadlocks': int(deadlocks)
})
# 检查告警阈值
if int(slow_queries) > 100:
self.alerts.append(f"分片 {i} 慢查询过多: {slow_queries}")
if int(deadlocks) > 10:
self.alerts.append(f"分片 {i} 死锁过多: {deadlocks}")
except Exception as e:
self.alerts.append(f"分片 {i} 性能指标获取失败: {e}")
performance_status['metrics'].append(shard_metrics)
return performance_status
def send_alerts(self):
"""发送告警"""
if not self.alerts:
return
# 这里可以集成邮件、短信、钉钉等告警方式
alert_message = "\n".join(self.alerts)
print(f"\n=== 分库分表告警 ===\n{alert_message}\n")
# 清空告警列表
self.alerts = []
def generate_report(self) -> Dict:
"""生成监控报告"""
self.connect_shards()
report = {
'timestamp': datetime.now().isoformat(),
'health_status': self.check_shard_health(),
'data_balance': self.check_data_balance(),
'performance_metrics': self.check_performance_metrics(),
'alerts': self.alerts.copy()
}
# 发送告警
self.send_alerts()
# 关闭连接
for conn in self.connections.values():
conn.close()
return report
def save_report(self, report: Dict, filename: str = None):
"""保存监控报告"""
if not filename:
filename = f"shard_monitor_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(filename, 'w', encoding='utf-8') as f:
json.dump(report, f, ensure_ascii=False, indent=2)
return filename
if __name__ == '__main__':
import sys
if len(sys.argv) != 2:
print("用法: python shard_monitor.py <config_file>")
sys.exit(1)
monitor = ShardingMonitor(sys.argv[1])
report = monitor.generate_report()
# 打印摘要
print(f"\n=== 分库分表监控报告 ===")
print(f"时间: {report['timestamp']}")
print(f"总分片数: {report['health_status']['total_shards']}")
print(f"健康分片数: {report['health_status']['healthy_shards']}")
# 保存报告
report_file = monitor.save_report(report)
print(f"详细报告已保存到: {report_file}")
EOF
chmod +x "$monitor_script"
log "监控脚本生成完成: $monitor_script"
}
# 生成配置文件模板
generate_config_template() {
log "生成配置文件模板..."
local template_file="${SCRIPT_DIR}/sharding_config_template.json"
cat > "$template_file" << 'EOF'
{
"sharding_type": "horizontal",
"source": {
"host": "localhost",
"port": 3306,
"user": "root",
"password": "password",
"database": "original_db"
},
"shards": [
{
"host": "localhost",
"port": 3306,
"user": "root",
"password": "password",
"database": "shard_db_0"
},
{
"host": "localhost",
"port": 3306,
"user": "root",
"password": "password",
"database": "shard_db_1"
}
],
"tables": [
{
"name": "users",
"strategy": "hash",
"shard_key": "user_id"
},
{
"name": "orders",
"strategy": "range",
"shard_key": "order_id",
"ranges": [
{"min": 1, "max": 1000000},
{"min": 1000001, "max": 2000000}
]
}
]
}
EOF
log "配置文件模板生成完成: $template_file"
}
# 验证部署结果
validate_deployment() {
log "验证部署结果..."
local validation_passed=true
# 检查分片数据库连接
local shard_count=$(jq -r '.shards | length' "$CONFIG_FILE")
for ((i=0; i<shard_count; i++)); do
local shard_config=$(jq -r ".shards[$i]" "$CONFIG_FILE")
local shard_name=$(echo "$shard_config" | jq -r '.database')
local shard_host=$(echo "$shard_config" | jq -r '.host')
local shard_port=$(echo "$shard_config" | jq -r '.port')
local shard_user=$(echo "$shard_config" | jq -r '.user')
local shard_pass=$(echo "$shard_config" | jq -r '.password')
log "验证分片数据库: $shard_name"
# 测试连接
if mysql -h"$shard_host" -P"$shard_port" -u"$shard_user" -p"$shard_pass" \
-e "SELECT 1" "$shard_name" >/dev/null 2>&1; then
log "分片数据库连接成功: $shard_name"
else
log "分片数据库连接失败: $shard_name"
validation_passed=false
fi
# 检查表是否创建
local tables=$(jq -r '.tables[] | .name' "$CONFIG_FILE")
for table in $tables; do
local shard_table_name
if [ "$SHARDING_TYPE" = "horizontal" ]; then
shard_table_name="${table}_${i}"
else
shard_table_name="$table"
fi
if mysql -h"$shard_host" -P"$shard_port" -u"$shard_user" -p"$shard_pass" \
-e "DESCRIBE \`$shard_table_name\`" "$shard_name" >/dev/null 2>&1; then
log "分片表验证成功: $shard_name.$shard_table_name"
else
log "分片表验证失败: $shard_name.$shard_table_name"
validation_passed=false
fi
done
done
if [ "$validation_passed" = true ]; then
log "部署验证通过"
return 0
else
error "部署验证失败"
return 1
fi
}
# 生成部署报告
generate_deployment_report() {
log "生成部署报告..."
local report_file="${SCRIPT_DIR}/deployment_report_$(date +%Y%m%d_%H%M%S).md"
cat > "$report_file" << EOF
# MySQL分库分表部署报告
## 部署信息
- 部署时间: $(date '+%Y-%m-%d %H:%M:%S')
- 分片类型: $SHARDING_TYPE
- 源数据库: $SOURCE_DB
- 分片数量: $(jq -r '.shards | length' "$CONFIG_FILE")
## 分片配置
$(jq -r '.shards[] | "- 数据库: " + .database + " @ " + .host + ":" + (.port | tostring)' "$CONFIG_FILE")
## 表配置
$(jq -r '.tables[] | "- 表: " + .name + ", 策略: " + .strategy + ", 分片键: " + .shard_key' "$CONFIG_FILE")
## 生成的文件
- 配置文件: $CONFIG_FILE
- SQL文件目录: $SQL_DIR
- 脚本目录: ${SCRIPT_DIR}/scripts
- 监控脚本: ${SCRIPT_DIR}/monitoring
- 备份文件: $BACKUP_DIR
## 下一步操作
1. 执行数据迁移: python scripts/migrate_data.py sharding_config.json
2. 启动路由中间件: python scripts/sharding_middleware.py
3. 运行监控脚本: python monitoring/shard_monitor.py sharding_config.json
4. 验证数据一致性
5. 切换应用程序连接
## 注意事项
- 在生产环境切换前,请充分测试
- 建议在低峰期进行数据迁移
- 保持原始数据备份直到确认迁移成功
- 监控分片性能和数据平衡性
EOF
log "部署报告生成完成: $report_file"
}
# 主函数
main() {
log "开始MySQL分库分表自动化部署"
# 检查参数
if [ $# -lt 1 ]; then
echo "用法: $0 <config_file> [--skip-backup] [--skip-migration]"
echo "选项:"
echo " --skip-backup 跳过数据备份"
echo " --skip-migration 跳过数据迁移"
exit 1
fi
CONFIG_FILE="$1"
SKIP_BACKUP=false
SKIP_MIGRATION=false
# 解析选项
shift
while [ $# -gt 0 ]; do
case $1 in
--skip-backup)
SKIP_BACKUP=true
shift
;;
--skip-migration)
SKIP_MIGRATION=true
shift
;;
*)
echo "未知选项: $1"
exit 1
;;
esac
done
# 执行部署步骤
check_dependencies
setup_directories
read_config
if [ "$SKIP_BACKUP" = false ]; then
backup_original_data
fi
create_shard_databases
generate_shard_schema
execute_shard_creation
generate_migration_scripts
generate_routing_middleware
generate_monitoring_scripts
generate_config_template
validate_deployment
generate_deployment_report
log "MySQL分库分表部署完成"
if [ "$SKIP_MIGRATION" = false ]; then
log "开始数据迁移..."
python3 "${SCRIPT_DIR}/scripts/migrate_data.py" "$CONFIG_FILE"
log "数据迁移完成"
fi
log "部署成功!请查看部署报告了解下一步操作。"
}
# 执行主函数
main "$@"
14.4.2 分库分表中间件
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MySQL分库分表中间件
支持多种分片策略和数据库连接池管理
"""
import json
import hashlib
import threading
import time
from typing import Dict, List, Any, Optional, Tuple
import pymysql
from pymysql.cursors import DictCursor
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
from dataclasses import dataclass
from enum import Enum
class ShardStrategy(Enum):
HASH = "hash"
RANGE = "range"
DIRECTORY = "directory"
CONSISTENT_HASH = "consistent_hash"
@dataclass
class ShardConfig:
host: str
port: int
user: str
password: str
database: str
max_connections: int = 10
@dataclass
class TableConfig:
name: str
strategy: ShardStrategy
shard_key: str
ranges: Optional[List[Dict]] = None
hash_slots: Optional[int] = None
class ConnectionPool:
"""数据库连接池"""
def __init__(self, config: ShardConfig):
self.config = config
self.pool = []
self.used = set()
self.lock = threading.Lock()
self._init_pool()
def _init_pool(self):
"""初始化连接池"""
for _ in range(self.config.max_connections):
conn = pymysql.connect(
host=self.config.host,
port=self.config.port,
user=self.config.user,
password=self.config.password,
database=self.config.database,
charset='utf8mb4',
autocommit=True
)
self.pool.append(conn)
def get_connection(self):
"""获取连接"""
with self.lock:
if self.pool:
conn = self.pool.pop()
self.used.add(conn)
return conn
else:
# 如果池中没有连接,创建新连接
conn = pymysql.connect(
host=self.config.host,
port=self.config.port,
user=self.config.user,
password=self.config.password,
database=self.config.database,
charset='utf8mb4',
autocommit=True
)
self.used.add(conn)
return conn
def return_connection(self, conn):
"""归还连接"""
with self.lock:
if conn in self.used:
self.used.remove(conn)
if len(self.pool) < self.config.max_connections:
self.pool.append(conn)
else:
conn.close()
def close_all(self):
"""关闭所有连接"""
with self.lock:
for conn in self.pool + list(self.used):
conn.close()
self.pool.clear()
self.used.clear()
class ConsistentHash:
"""一致性哈希实现"""
def __init__(self, nodes: List[int], replicas: int = 150):
self.replicas = replicas
self.ring = {}
self.sorted_keys = []
for node in nodes:
self.add_node(node)
def _hash(self, key: str) -> int:
"""计算哈希值"""
return int(hashlib.md5(key.encode()).hexdigest(), 16)
def add_node(self, node: int):
"""添加节点"""
for i in range(self.replicas):
key = self._hash(f"{node}:{i}")
self.ring[key] = node
self.sorted_keys = sorted(self.ring.keys())
def remove_node(self, node: int):
"""移除节点"""
for i in range(self.replicas):
key = self._hash(f"{node}:{i}")
if key in self.ring:
del self.ring[key]
self.sorted_keys = sorted(self.ring.keys())
def get_node(self, key: str) -> int:
"""获取键对应的节点"""
if not self.ring:
return 0
hash_key = self._hash(key)
# 找到第一个大于等于hash_key的节点
for ring_key in self.sorted_keys:
if ring_key >= hash_key:
return self.ring[ring_key]
# 如果没找到,返回第一个节点(环形结构)
return self.ring[self.sorted_keys[0]]
class ShardingMiddleware:
"""分库分表中间件"""
def __init__(self, config_file: str):
self.logger = self._setup_logger()
self.config = self._load_config(config_file)
self.connection_pools = {}
self.table_configs = {}
self.consistent_hash = None
self._init_connection_pools()
self._init_table_configs()
self._init_consistent_hash()
def _setup_logger(self) -> logging.Logger:
"""设置日志"""
logger = logging.getLogger('ShardingMiddleware')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def _load_config(self, config_file: str) -> Dict:
"""加载配置文件"""
with open(config_file, 'r', encoding='utf-8') as f:
return json.load(f)
def _init_connection_pools(self):
"""初始化连接池"""
for i, shard_config in enumerate(self.config['shards']):
config = ShardConfig(
host=shard_config['host'],
port=shard_config['port'],
user=shard_config['user'],
password=shard_config['password'],
database=shard_config['database'],
max_connections=shard_config.get('max_connections', 10)
)
self.connection_pools[i] = ConnectionPool(config)
def _init_table_configs(self):
"""初始化表配置"""
for table_config in self.config['tables']:
config = TableConfig(
name=table_config['name'],
strategy=ShardStrategy(table_config['strategy']),
shard_key=table_config['shard_key'],
ranges=table_config.get('ranges'),
hash_slots=table_config.get('hash_slots', 1024)
)
self.table_configs[table_config['name']] = config
def _init_consistent_hash(self):
"""初始化一致性哈希"""
nodes = list(range(len(self.config['shards'])))
self.consistent_hash = ConsistentHash(nodes)
def get_shard_index(self, table_name: str, shard_key_value: Any) -> int:
"""获取分片索引"""
if table_name not in self.table_configs:
return 0
table_config = self.table_configs[table_name]
strategy = table_config.strategy
if strategy == ShardStrategy.HASH:
hash_value = int(hashlib.md5(str(shard_key_value).encode()).hexdigest(), 16)
return hash_value % len(self.config['shards'])
elif strategy == ShardStrategy.RANGE:
if table_config.ranges:
for i, range_config in enumerate(table_config.ranges):
if range_config['min'] <= shard_key_value <= range_config['max']:
return i
return 0
elif strategy == ShardStrategy.CONSISTENT_HASH:
return self.consistent_hash.get_node(str(shard_key_value))
else:
return 0
def execute_query(self, sql: str, params: Tuple = None) -> List[Dict]:
"""执行查询"""
try:
parsed = self._parse_sql(sql)
if parsed['operation'] in ['SELECT']:
return self._execute_select(sql, parsed, params)
elif parsed['operation'] in ['INSERT', 'UPDATE', 'DELETE']:
return self._execute_modify(sql, parsed, params)
else:
return self._execute_ddl(sql, params)
except Exception as e:
self.logger.error(f"执行查询失败: {e}")
raise
def _parse_sql(self, sql: str) -> Dict:
"""解析SQL语句"""
import re
sql = sql.strip()
operation = sql.split()[0].upper()
# 提取表名
table_patterns = [
r'FROM\s+`?([a-zA-Z_][a-zA-Z0-9_]*)`?',
r'INTO\s+`?([a-zA-Z_][a-zA-Z0-9_]*)`?',
r'UPDATE\s+`?([a-zA-Z_][a-zA-Z0-9_]*)`?',
r'TABLE\s+`?([a-zA-Z_][a-zA-Z0-9_]*)`?'
]
table_name = None
for pattern in table_patterns:
match = re.search(pattern, sql, re.IGNORECASE)
if match:
table_name = match.group(1)
break
# 提取分片键值
shard_key_values = {}
if table_name and table_name in self.table_configs:
shard_key = self.table_configs[table_name].shard_key
# 简单的分片键值提取(可以根据需要扩展)
patterns = [
rf'{shard_key}\s*=\s*([\'\"]?)([^\s\'\";]+)\1',
rf'{shard_key}\s+IN\s*\(([^)]+)\)'
]
for pattern in patterns:
match = re.search(pattern, sql, re.IGNORECASE)
if match:
if 'IN' in pattern:
values = [v.strip().strip('\'"') for v in match.group(1).split(',')]
shard_key_values[shard_key] = values
else:
shard_key_values[shard_key] = [match.group(2)]
break
return {
'operation': operation,
'table': table_name,
'shard_key_values': shard_key_values
}
def _execute_select(self, sql: str, parsed: Dict, params: Tuple = None) -> List[Dict]:
"""执行SELECT查询"""
table_name = parsed['table']
shard_key_values = parsed['shard_key_values']
# 确定需要查询的分片
if table_name in self.table_configs and shard_key_values:
shard_key = self.table_configs[table_name].shard_key
if shard_key in shard_key_values:
shard_indices = []
for value in shard_key_values[shard_key]:
shard_index = self.get_shard_index(table_name, value)
if shard_index not in shard_indices:
shard_indices.append(shard_index)
else:
shard_indices = list(range(len(self.config['shards'])))
else:
shard_indices = list(range(len(self.config['shards'])))
# 并行查询多个分片
results = []
with ThreadPoolExecutor(max_workers=len(shard_indices)) as executor:
futures = []
for shard_index in shard_indices:
future = executor.submit(
self._execute_on_shard,
shard_index,
self._modify_sql_for_shard(sql, table_name, shard_index),
params
)
futures.append(future)
for future in as_completed(futures):
try:
shard_results = future.result()
if shard_results:
results.extend(shard_results)
except Exception as e:
self.logger.error(f"分片查询失败: {e}")
return results
def _execute_modify(self, sql: str, parsed: Dict, params: Tuple = None) -> List[Dict]:
"""执行修改操作"""
table_name = parsed['table']
shard_key_values = parsed['shard_key_values']
# 确定需要修改的分片
if table_name in self.table_configs and shard_key_values:
shard_key = self.table_configs[table_name].shard_key
if shard_key in shard_key_values:
shard_indices = []
for value in shard_key_values[shard_key]:
shard_index = self.get_shard_index(table_name, value)
if shard_index not in shard_indices:
shard_indices.append(shard_index)
else:
shard_indices = list(range(len(self.config['shards'])))
else:
shard_indices = list(range(len(self.config['shards'])))
total_affected = 0
for shard_index in shard_indices:
try:
result = self._execute_on_shard(
shard_index,
self._modify_sql_for_shard(sql, table_name, shard_index),
params,
return_affected=True
)
if result:
total_affected += result
except Exception as e:
self.logger.error(f"分片修改失败: {e}")
raise
return [{'affected_rows': total_affected}]
def _execute_ddl(self, sql: str, params: Tuple = None) -> List[Dict]:
"""执行DDL语句"""
results = []
for shard_index in range(len(self.config['shards'])):
try:
result = self._execute_on_shard(shard_index, sql, params)
if result:
results.extend(result)
except Exception as e:
self.logger.error(f"DDL执行失败: {e}")
raise
return results
def _execute_on_shard(self, shard_index: int, sql: str, params: Tuple = None, return_affected: bool = False):
"""在指定分片执行SQL"""
pool = self.connection_pools[shard_index]
conn = pool.get_connection()
try:
with conn.cursor(DictCursor) as cursor:
cursor.execute(sql, params)
if return_affected:
return cursor.rowcount
elif cursor.description:
return cursor.fetchall()
else:
return None
finally:
pool.return_connection(conn)
def _modify_sql_for_shard(self, sql: str, table_name: str, shard_index: int) -> str:
"""修改SQL以适应分片表名"""
if not table_name or self.config['sharding_type'] != 'horizontal':
return sql
shard_table_name = f"{table_name}_{shard_index}"
# 替换表名
import re
patterns = [
(rf'\b{table_name}\b', shard_table_name),
(rf'`{table_name}`', f'`{shard_table_name}`')
]
modified_sql = sql
for pattern, replacement in patterns:
modified_sql = re.sub(pattern, replacement, modified_sql, flags=re.IGNORECASE)
return modified_sql
def begin_transaction(self) -> str:
"""开始分布式事务"""
transaction_id = f"txn_{int(time.time() * 1000)}_{threading.current_thread().ident}"
# 在所有分片开始事务
for shard_index in range(len(self.config['shards'])):
try:
self._execute_on_shard(shard_index, "BEGIN")
except Exception as e:
self.logger.error(f"开始事务失败: {e}")
# 回滚已开始的事务
self.rollback_transaction(transaction_id)
raise
return transaction_id
def commit_transaction(self, transaction_id: str):
"""提交分布式事务"""
failed_shards = []
# 两阶段提交:准备阶段
for shard_index in range(len(self.config['shards'])):
try:
# 这里可以实现XA事务的PREPARE阶段
pass
except Exception as e:
failed_shards.append(shard_index)
self.logger.error(f"事务准备失败: {e}")
if failed_shards:
self.rollback_transaction(transaction_id)
raise Exception(f"事务准备失败,分片: {failed_shards}")
# 提交阶段
for shard_index in range(len(self.config['shards'])):
try:
self._execute_on_shard(shard_index, "COMMIT")
except Exception as e:
self.logger.error(f"事务提交失败: {e}")
# 注意:这里可能需要补偿机制
def rollback_transaction(self, transaction_id: str):
"""回滚分布式事务"""
for shard_index in range(len(self.config['shards'])):
try:
self._execute_on_shard(shard_index, "ROLLBACK")
except Exception as e:
self.logger.error(f"事务回滚失败: {e}")
def get_statistics(self) -> Dict:
"""获取统计信息"""
stats = {
'timestamp': time.time(),
'shards': [],
'total_connections': 0,
'active_connections': 0
}
for shard_index, pool in self.connection_pools.items():
shard_stats = {
'shard_index': shard_index,
'pool_size': len(pool.pool),
'active_connections': len(pool.used),
'total_connections': len(pool.pool) + len(pool.used)
}
stats['shards'].append(shard_stats)
stats['total_connections'] += shard_stats['total_connections']
stats['active_connections'] += shard_stats['active_connections']
return stats
def close(self):
"""关闭中间件"""
for pool in self.connection_pools.values():
pool.close_all()
self.logger.info("分库分表中间件已关闭")
# Web API接口
from flask import Flask, request, jsonify
from flask_cors import CORS
app = Flask(__name__)
CORS(app)
# 全局中间件实例
middleware = None
@app.before_first_request
def init_middleware():
global middleware
middleware = ShardingMiddleware('sharding_config.json')
@app.route('/api/query', methods=['POST'])
def execute_query():
try:
data = request.get_json()
sql = data.get('sql')
params = data.get('params')
if not sql:
return jsonify({'error': 'SQL语句不能为空'}), 400
results = middleware.execute_query(sql, params)
return jsonify({
'success': True,
'results': results,
'count': len(results)
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/transaction/begin', methods=['POST'])
def begin_transaction():
try:
transaction_id = middleware.begin_transaction()
return jsonify({
'success': True,
'transaction_id': transaction_id
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/transaction/commit', methods=['POST'])
def commit_transaction():
try:
data = request.get_json()
transaction_id = data.get('transaction_id')
if not transaction_id:
return jsonify({'error': '事务ID不能为空'}), 400
middleware.commit_transaction(transaction_id)
return jsonify({'success': True})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/transaction/rollback', methods=['POST'])
def rollback_transaction():
try:
data = request.get_json()
transaction_id = data.get('transaction_id')
if not transaction_id:
return jsonify({'error': '事务ID不能为空'}), 400
middleware.rollback_transaction(transaction_id)
return jsonify({'success': True})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/statistics', methods=['GET'])
def get_statistics():
try:
stats = middleware.get_statistics()
return jsonify({
'success': True,
'statistics': stats
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/health', methods=['GET'])
def health_check():
return jsonify({
'status': 'healthy',
'timestamp': time.time()
})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8080, debug=True)
14.5 数据一致性保证
14.5.1 分布式事务管理
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MySQL分布式事务管理器
支持两阶段提交(2PC)、三阶段提交(3PC)和Saga模式
"""
import json
import time
import uuid
import threading
from typing import Dict, List, Any, Optional, Callable
from enum import Enum
from dataclasses import dataclass, asdict
import pymysql
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
from datetime import datetime, timedelta
class TransactionState(Enum):
INIT = "init"
PREPARING = "preparing"
PREPARED = "prepared"
COMMITTING = "committing"
COMMITTED = "committed"
ABORTING = "aborting"
ABORTED = "aborted"
TIMEOUT = "timeout"
class TransactionMode(Enum):
TWO_PHASE_COMMIT = "2pc"
THREE_PHASE_COMMIT = "3pc"
SAGA = "saga"
BEST_EFFORT = "best_effort"
@dataclass
class TransactionParticipant:
shard_id: str
host: str
port: int
database: str
user: str
password: str
state: TransactionState = TransactionState.INIT
xa_id: Optional[str] = None
error: Optional[str] = None
@dataclass
class SagaStep:
step_id: str
forward_sql: str
compensate_sql: str
participant: str
executed: bool = False
compensated: bool = False
error: Optional[str] = None
@dataclass
class DistributedTransaction:
transaction_id: str
mode: TransactionMode
state: TransactionState
participants: List[TransactionParticipant]
saga_steps: Optional[List[SagaStep]] = None
created_at: datetime = None
timeout_seconds: int = 300
coordinator_id: str = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now()
if self.coordinator_id is None:
self.coordinator_id = str(uuid.uuid4())
class DistributedTransactionManager:
"""分布式事务管理器"""
def __init__(self, config: Dict):
self.config = config
self.logger = self._setup_logger()
self.active_transactions: Dict[str, DistributedTransaction] = {}
self.transaction_log = []
self.lock = threading.RLock()
# 启动超时检查线程
self.timeout_checker = threading.Thread(target=self._timeout_checker, daemon=True)
self.timeout_checker.start()
def _setup_logger(self) -> logging.Logger:
"""设置日志"""
logger = logging.getLogger('DistributedTransactionManager')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def begin_transaction(self, participants: List[Dict], mode: TransactionMode = TransactionMode.TWO_PHASE_COMMIT, timeout: int = 300) -> str:
"""开始分布式事务"""
transaction_id = str(uuid.uuid4())
# 创建参与者列表
participant_objects = []
for p in participants:
participant = TransactionParticipant(
shard_id=p['shard_id'],
host=p['host'],
port=p['port'],
database=p['database'],
user=p['user'],
password=p['password']
)
participant_objects.append(participant)
# 创建分布式事务
transaction = DistributedTransaction(
transaction_id=transaction_id,
mode=mode,
state=TransactionState.INIT,
participants=participant_objects,
timeout_seconds=timeout
)
with self.lock:
self.active_transactions[transaction_id] = transaction
self.transaction_log.append({
'timestamp': datetime.now(),
'transaction_id': transaction_id,
'action': 'BEGIN',
'mode': mode.value,
'participants': len(participant_objects)
})
self.logger.info(f"开始分布式事务: {transaction_id}, 模式: {mode.value}")
return transaction_id
def execute_sql(self, transaction_id: str, sql_commands: Dict[str, str]) -> bool:
"""在事务中执行SQL命令"""
with self.lock:
if transaction_id not in self.active_transactions:
raise ValueError(f"事务不存在: {transaction_id}")
transaction = self.active_transactions[transaction_id]
if transaction.mode == TransactionMode.SAGA:
return self._execute_saga(transaction, sql_commands)
else:
return self._execute_xa(transaction, sql_commands)
def _execute_xa(self, transaction: DistributedTransaction, sql_commands: Dict[str, str]) -> bool:
"""执行XA事务"""
transaction.state = TransactionState.PREPARING
# 为每个参与者生成XA ID
for participant in transaction.participants:
participant.xa_id = f"xa_{transaction.transaction_id}_{participant.shard_id}"
# 在所有参与者上开始XA事务
failed_participants = []
with ThreadPoolExecutor(max_workers=len(transaction.participants)) as executor:
futures = []
for participant in transaction.participants:
if participant.shard_id in sql_commands:
future = executor.submit(
self._execute_xa_on_participant,
participant,
sql_commands[participant.shard_id]
)
futures.append((future, participant))
for future, participant in futures:
try:
success = future.result(timeout=30)
if not success:
failed_participants.append(participant)
participant.state = TransactionState.ABORTED
except Exception as e:
self.logger.error(f"参与者 {participant.shard_id} 执行失败: {e}")
participant.error = str(e)
participant.state = TransactionState.ABORTED
failed_participants.append(participant)
if failed_participants:
self.logger.error(f"事务 {transaction.transaction_id} 执行失败,失败参与者: {[p.shard_id for p in failed_participants]}")
transaction.state = TransactionState.ABORTED
return False
return True
def _execute_xa_on_participant(self, participant: TransactionParticipant, sql: str) -> bool:
"""在单个参与者上执行XA事务"""
conn = None
try:
conn = pymysql.connect(
host=participant.host,
port=participant.port,
user=participant.user,
password=participant.password,
database=participant.database,
autocommit=False
)
with conn.cursor() as cursor:
# 开始XA事务
cursor.execute(f"XA START '{participant.xa_id}'")
# 执行业务SQL
cursor.execute(sql)
# 结束XA事务
cursor.execute(f"XA END '{participant.xa_id}'")
# 准备XA事务
cursor.execute(f"XA PREPARE '{participant.xa_id}'")
participant.state = TransactionState.PREPARED
return True
except Exception as e:
self.logger.error(f"XA事务执行失败: {e}")
participant.error = str(e)
if conn:
try:
with conn.cursor() as cursor:
cursor.execute(f"XA ROLLBACK '{participant.xa_id}'")
except:
pass
return False
finally:
if conn:
conn.close()
def _execute_saga(self, transaction: DistributedTransaction, sql_commands: Dict[str, str]) -> bool:
"""执行Saga事务"""
if not transaction.saga_steps:
# 如果没有预定义的Saga步骤,根据SQL命令创建
transaction.saga_steps = []
for shard_id, sql in sql_commands.items():
step = SagaStep(
step_id=str(uuid.uuid4()),
forward_sql=sql,
compensate_sql=self._generate_compensate_sql(sql),
participant=shard_id
)
transaction.saga_steps.append(step)
# 按顺序执行Saga步骤
for step in transaction.saga_steps:
try:
participant = next(p for p in transaction.participants if p.shard_id == step.participant)
success = self._execute_saga_step(participant, step)
if not success:
# 执行补偿操作
self._compensate_saga(transaction, step)
transaction.state = TransactionState.ABORTED
return False
step.executed = True
except Exception as e:
self.logger.error(f"Saga步骤执行失败: {e}")
step.error = str(e)
self._compensate_saga(transaction, step)
transaction.state = TransactionState.ABORTED
return False
transaction.state = TransactionState.COMMITTED
return True
def _execute_saga_step(self, participant: TransactionParticipant, step: SagaStep) -> bool:
"""执行单个Saga步骤"""
conn = None
try:
conn = pymysql.connect(
host=participant.host,
port=participant.port,
user=participant.user,
password=participant.password,
database=participant.database,
autocommit=True
)
with conn.cursor() as cursor:
cursor.execute(step.forward_sql)
return True
except Exception as e:
self.logger.error(f"Saga步骤执行失败: {e}")
step.error = str(e)
return False
finally:
if conn:
conn.close()
def _compensate_saga(self, transaction: DistributedTransaction, failed_step: SagaStep):
"""执行Saga补偿"""
# 逆序执行已完成步骤的补偿操作
executed_steps = [step for step in transaction.saga_steps if step.executed and step != failed_step]
executed_steps.reverse()
for step in executed_steps:
try:
participant = next(p for p in transaction.participants if p.shard_id == step.participant)
self._execute_compensate_step(participant, step)
step.compensated = True
except Exception as e:
self.logger.error(f"补偿操作失败: {e}")
step.error = str(e)
def _execute_compensate_step(self, participant: TransactionParticipant, step: SagaStep) -> bool:
"""执行补偿步骤"""
conn = None
try:
conn = pymysql.connect(
host=participant.host,
port=participant.port,
user=participant.user,
password=participant.password,
database=participant.database,
autocommit=True
)
with conn.cursor() as cursor:
cursor.execute(step.compensate_sql)
return True
except Exception as e:
self.logger.error(f"补偿步骤执行失败: {e}")
return False
finally:
if conn:
conn.close()
def _generate_compensate_sql(self, forward_sql: str) -> str:
"""生成补偿SQL(简化实现)"""
# 这里是一个简化的实现,实际应用中需要更复杂的逻辑
forward_sql = forward_sql.strip().upper()
if forward_sql.startswith('INSERT'):
# INSERT的补偿是DELETE
return "-- DELETE compensation for INSERT"
elif forward_sql.startswith('UPDATE'):
# UPDATE的补偿是恢复原值
return "-- UPDATE compensation for UPDATE"
elif forward_sql.startswith('DELETE'):
# DELETE的补偿是INSERT
return "-- INSERT compensation for DELETE"
else:
return "-- No compensation defined"
def commit_transaction(self, transaction_id: str) -> bool:
"""提交分布式事务"""
with self.lock:
if transaction_id not in self.active_transactions:
raise ValueError(f"事务不存在: {transaction_id}")
transaction = self.active_transactions[transaction_id]
if transaction.mode == TransactionMode.SAGA:
# Saga模式在执行时已经提交
return transaction.state == TransactionState.COMMITTED
# XA事务的两阶段提交
return self._commit_xa_transaction(transaction)
def _commit_xa_transaction(self, transaction: DistributedTransaction) -> bool:
"""提交XA事务"""
transaction.state = TransactionState.COMMITTING
# 检查所有参与者是否都已准备好
unprepared = [p for p in transaction.participants if p.state != TransactionState.PREPARED]
if unprepared:
self.logger.error(f"事务 {transaction.transaction_id} 有未准备的参与者: {[p.shard_id for p in unprepared]}")
return self.rollback_transaction(transaction.transaction_id)
# 提交所有参与者
failed_commits = []
with ThreadPoolExecutor(max_workers=len(transaction.participants)) as executor:
futures = []
for participant in transaction.participants:
future = executor.submit(self._commit_xa_participant, participant)
futures.append((future, participant))
for future, participant in futures:
try:
success = future.result(timeout=30)
if not success:
failed_commits.append(participant)
except Exception as e:
self.logger.error(f"参与者 {participant.shard_id} 提交失败: {e}")
participant.error = str(e)
failed_commits.append(participant)
if failed_commits:
self.logger.error(f"事务 {transaction.transaction_id} 提交失败,失败参与者: {[p.shard_id for p in failed_commits]}")
transaction.state = TransactionState.ABORTED
return False
transaction.state = TransactionState.COMMITTED
# 记录事务日志
with self.lock:
self.transaction_log.append({
'timestamp': datetime.now(),
'transaction_id': transaction.transaction_id,
'action': 'COMMIT',
'state': transaction.state.value
})
# 清理已完成的事务
del self.active_transactions[transaction.transaction_id]
self.logger.info(f"事务 {transaction.transaction_id} 提交成功")
return True
def _commit_xa_participant(self, participant: TransactionParticipant) -> bool:
"""提交单个XA参与者"""
conn = None
try:
conn = pymysql.connect(
host=participant.host,
port=participant.port,
user=participant.user,
password=participant.password,
database=participant.database
)
with conn.cursor() as cursor:
cursor.execute(f"XA COMMIT '{participant.xa_id}'")
participant.state = TransactionState.COMMITTED
return True
except Exception as e:
self.logger.error(f"XA提交失败: {e}")
participant.error = str(e)
return False
finally:
if conn:
conn.close()
def rollback_transaction(self, transaction_id: str) -> bool:
"""回滚分布式事务"""
with self.lock:
if transaction_id not in self.active_transactions:
raise ValueError(f"事务不存在: {transaction_id}")
transaction = self.active_transactions[transaction_id]
transaction.state = TransactionState.ABORTING
if transaction.mode == TransactionMode.SAGA:
# Saga模式执行补偿
if transaction.saga_steps:
executed_steps = [step for step in transaction.saga_steps if step.executed]
executed_steps.reverse()
for step in executed_steps:
try:
participant = next(p for p in transaction.participants if p.shard_id == step.participant)
self._execute_compensate_step(participant, step)
step.compensated = True
except Exception as e:
self.logger.error(f"补偿操作失败: {e}")
else:
# XA事务回滚
with ThreadPoolExecutor(max_workers=len(transaction.participants)) as executor:
futures = []
for participant in transaction.participants:
if participant.xa_id:
future = executor.submit(self._rollback_xa_participant, participant)
futures.append(future)
for future in futures:
try:
future.result(timeout=30)
except Exception as e:
self.logger.error(f"XA回滚失败: {e}")
transaction.state = TransactionState.ABORTED
# 记录事务日志
with self.lock:
self.transaction_log.append({
'timestamp': datetime.now(),
'transaction_id': transaction.transaction_id,
'action': 'ROLLBACK',
'state': transaction.state.value
})
# 清理已回滚的事务
del self.active_transactions[transaction.transaction_id]
self.logger.info(f"事务 {transaction.transaction_id} 回滚成功")
return True
def _rollback_xa_participant(self, participant: TransactionParticipant) -> bool:
"""回滚单个XA参与者"""
conn = None
try:
conn = pymysql.connect(
host=participant.host,
port=participant.port,
user=participant.user,
password=participant.password,
database=participant.database
)
with conn.cursor() as cursor:
cursor.execute(f"XA ROLLBACK '{participant.xa_id}'")
participant.state = TransactionState.ABORTED
return True
except Exception as e:
self.logger.error(f"XA回滚失败: {e}")
participant.error = str(e)
return False
finally:
if conn:
conn.close()
def _timeout_checker(self):
"""超时检查线程"""
while True:
try:
time.sleep(10) # 每10秒检查一次
current_time = datetime.now()
timeout_transactions = []
with self.lock:
for transaction_id, transaction in self.active_transactions.items():
if current_time - transaction.created_at > timedelta(seconds=transaction.timeout_seconds):
timeout_transactions.append(transaction_id)
# 处理超时事务
for transaction_id in timeout_transactions:
self.logger.warning(f"事务 {transaction_id} 超时,自动回滚")
try:
self.rollback_transaction(transaction_id)
except Exception as e:
self.logger.error(f"超时事务回滚失败: {e}")
except Exception as e:
self.logger.error(f"超时检查失败: {e}")
def get_transaction_status(self, transaction_id: str) -> Optional[Dict]:
"""获取事务状态"""
with self.lock:
if transaction_id not in self.active_transactions:
return None
transaction = self.active_transactions[transaction_id]
return {
'transaction_id': transaction.transaction_id,
'mode': transaction.mode.value,
'state': transaction.state.value,
'created_at': transaction.created_at.isoformat(),
'timeout_seconds': transaction.timeout_seconds,
'participants': [
{
'shard_id': p.shard_id,
'state': p.state.value,
'error': p.error
} for p in transaction.participants
],
'saga_steps': [
{
'step_id': s.step_id,
'participant': s.participant,
'executed': s.executed,
'compensated': s.compensated,
'error': s.error
} for s in transaction.saga_steps
] if transaction.saga_steps else None
}
def get_statistics(self) -> Dict:
"""获取统计信息"""
with self.lock:
active_count = len(self.active_transactions)
total_count = len(self.transaction_log)
# 统计各种状态的事务数量
state_counts = {}
for transaction in self.active_transactions.values():
state = transaction.state.value
state_counts[state] = state_counts.get(state, 0) + 1
# 统计各种模式的事务数量
mode_counts = {}
for transaction in self.active_transactions.values():
mode = transaction.mode.value
mode_counts[mode] = mode_counts.get(mode, 0) + 1
return {
'timestamp': datetime.now().isoformat(),
'active_transactions': active_count,
'total_transactions': total_count,
'state_distribution': state_counts,
'mode_distribution': mode_counts,
'recent_logs': self.transaction_log[-10:] # 最近10条日志
}
# 使用示例
if __name__ == '__main__':
# 配置
config = {
'coordinator_id': 'coord_001',
'log_level': 'INFO'
}
# 创建事务管理器
tm = DistributedTransactionManager(config)
# 定义参与者
participants = [
{
'shard_id': 'shard_0',
'host': 'localhost',
'port': 3306,
'database': 'test_shard_0',
'user': 'root',
'password': 'password'
},
{
'shard_id': 'shard_1',
'host': 'localhost',
'port': 3307,
'database': 'test_shard_1',
'user': 'root',
'password': 'password'
}
]
try:
# 开始2PC事务
txn_id = tm.begin_transaction(participants, TransactionMode.TWO_PHASE_COMMIT)
print(f"开始事务: {txn_id}")
# 执行SQL
sql_commands = {
'shard_0': "INSERT INTO users (id, name) VALUES (1, 'Alice')",
'shard_1': "INSERT INTO orders (id, user_id, amount) VALUES (1, 1, 100.00)"
}
success = tm.execute_sql(txn_id, sql_commands)
if success:
# 提交事务
commit_success = tm.commit_transaction(txn_id)
print(f"事务提交: {'成功' if commit_success else '失败'}")
else:
# 回滚事务
rollback_success = tm.rollback_transaction(txn_id)
print(f"事务回滚: {'成功' if rollback_success else '失败'}")
# 获取统计信息
stats = tm.get_statistics()
print(f"统计信息: {json.dumps(stats, indent=2, ensure_ascii=False)}")
except Exception as e:
print(f"事务执行失败: {e}")
14.5.2 数据一致性检查
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MySQL分库分表数据一致性检查工具
支持数据校验、修复和监控
"""
import json
import hashlib
import time
from typing import Dict, List, Any, Optional, Tuple, Set
import pymysql
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
from dataclasses import dataclass
from datetime import datetime
import threading
@dataclass
class ConsistencyCheckResult:
table_name: str
shard_id: str
check_type: str
status: str # 'consistent', 'inconsistent', 'error'
details: Dict[str, Any]
timestamp: datetime
error_message: Optional[str] = None
@dataclass
class DataInconsistency:
table_name: str
primary_key: Any
shard_ids: List[str]
inconsistency_type: str # 'missing', 'different', 'duplicate'
details: Dict[str, Any]
severity: str # 'low', 'medium', 'high', 'critical'
auto_fixable: bool = False
class DataConsistencyChecker:
"""数据一致性检查器"""
def __init__(self, config: Dict):
self.config = config
self.logger = self._setup_logger()
self.shard_connections = {}
self.check_results = []
self.inconsistencies = []
self.lock = threading.RLock()
self._init_connections()
def _setup_logger(self) -> logging.Logger:
"""设置日志"""
logger = logging.getLogger('DataConsistencyChecker')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def _init_connections(self):
"""初始化数据库连接"""
for shard_config in self.config['shards']:
shard_id = shard_config['shard_id']
try:
conn = pymysql.connect(
host=shard_config['host'],
port=shard_config['port'],
user=shard_config['user'],
password=shard_config['password'],
database=shard_config['database'],
charset='utf8mb4'
)
self.shard_connections[shard_id] = conn
self.logger.info(f"连接分片 {shard_id} 成功")
except Exception as e:
self.logger.error(f"连接分片 {shard_id} 失败: {e}")
def check_table_consistency(self, table_name: str, primary_key: str, check_types: List[str] = None) -> List[ConsistencyCheckResult]:
"""检查表的数据一致性"""
if check_types is None:
check_types = ['count', 'checksum', 'sample_data']
results = []
for check_type in check_types:
if check_type == 'count':
results.extend(self._check_row_count(table_name))
elif check_type == 'checksum':
results.extend(self._check_data_checksum(table_name, primary_key))
elif check_type == 'sample_data':
results.extend(self._check_sample_data(table_name, primary_key))
elif check_type == 'foreign_key':
results.extend(self._check_foreign_key_consistency(table_name))
with self.lock:
self.check_results.extend(results)
return results
def _check_row_count(self, table_name: str) -> List[ConsistencyCheckResult]:
"""检查行数一致性"""
results = []
shard_counts = {}
# 获取每个分片的行数
with ThreadPoolExecutor(max_workers=len(self.shard_connections)) as executor:
futures = []
for shard_id, conn in self.shard_connections.items():
future = executor.submit(self._get_table_count, shard_id, conn, table_name)
futures.append((future, shard_id))
for future, shard_id in futures:
try:
count = future.result()
shard_counts[shard_id] = count
except Exception as e:
self.logger.error(f"获取分片 {shard_id} 行数失败: {e}")
results.append(ConsistencyCheckResult(
table_name=table_name,
shard_id=shard_id,
check_type='count',
status='error',
details={},
timestamp=datetime.now(),
error_message=str(e)
))
# 分析行数分布
if shard_counts:
total_count = sum(shard_counts.values())
avg_count = total_count / len(shard_counts)
for shard_id, count in shard_counts.items():
# 检查是否存在明显的数据倾斜
deviation = abs(count - avg_count) / avg_count if avg_count > 0 else 0
status = 'consistent'
if deviation > 0.5: # 偏差超过50%
status = 'inconsistent'
# 记录数据不一致
inconsistency = DataInconsistency(
table_name=table_name,
primary_key='N/A',
shard_ids=[shard_id],
inconsistency_type='data_skew',
details={
'shard_count': count,
'average_count': avg_count,
'deviation': deviation
},
severity='medium' if deviation > 0.8 else 'low'
)
with self.lock:
self.inconsistencies.append(inconsistency)
results.append(ConsistencyCheckResult(
table_name=table_name,
shard_id=shard_id,
check_type='count',
status=status,
details={
'count': count,
'total_count': total_count,
'average_count': avg_count,
'deviation': deviation
},
timestamp=datetime.now()
))
return results
def _get_table_count(self, shard_id: str, conn, table_name: str) -> int:
"""获取表的行数"""
with conn.cursor() as cursor:
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
return cursor.fetchone()[0]
def _check_data_checksum(self, table_name: str, primary_key: str) -> List[ConsistencyCheckResult]:
"""检查数据校验和"""
results = []
shard_checksums = {}
# 获取每个分片的数据校验和
with ThreadPoolExecutor(max_workers=len(self.shard_connections)) as executor:
futures = []
for shard_id, conn in self.shard_connections.items():
future = executor.submit(self._calculate_checksum, shard_id, conn, table_name, primary_key)
futures.append((future, shard_id))
for future, shard_id in futures:
try:
checksum_data = future.result()
shard_checksums[shard_id] = checksum_data
except Exception as e:
self.logger.error(f"计算分片 {shard_id} 校验和失败: {e}")
results.append(ConsistencyCheckResult(
table_name=table_name,
shard_id=shard_id,
check_type='checksum',
status='error',
details={},
timestamp=datetime.now(),
error_message=str(e)
))
# 比较校验和
for shard_id, checksum_data in shard_checksums.items():
# 检查是否有重复的主键在不同分片中
overlapping_keys = set()
for other_shard_id, other_checksum_data in shard_checksums.items():
if shard_id != other_shard_id:
overlap = set(checksum_data['primary_keys']) & set(other_checksum_data['primary_keys'])
overlapping_keys.update(overlap)
status = 'consistent' if not overlapping_keys else 'inconsistent'
if overlapping_keys:
# 记录重复主键不一致
inconsistency = DataInconsistency(
table_name=table_name,
primary_key=list(overlapping_keys),
shard_ids=[shard_id],
inconsistency_type='duplicate',
details={
'overlapping_keys': list(overlapping_keys),
'count': len(overlapping_keys)
},
severity='high',
auto_fixable=True
)
with self.lock:
self.inconsistencies.append(inconsistency)
results.append(ConsistencyCheckResult(
table_name=table_name,
shard_id=shard_id,
check_type='checksum',
status=status,
details={
'total_checksum': checksum_data['total_checksum'],
'row_count': len(checksum_data['primary_keys']),
'overlapping_keys': list(overlapping_keys)
},
timestamp=datetime.now()
))
return results
def _calculate_checksum(self, shard_id: str, conn, table_name: str, primary_key: str) -> Dict:
"""计算数据校验和"""
with conn.cursor() as cursor:
# 获取所有行的校验和
cursor.execute(f"""
SELECT {primary_key},
MD5(CONCAT_WS('|', *))
FROM {table_name}
ORDER BY {primary_key}
""")
rows = cursor.fetchall()
primary_keys = [row[0] for row in rows]
row_checksums = [row[1] for row in rows]
# 计算总校验和
total_checksum = hashlib.md5(''.join(row_checksums).encode()).hexdigest()
return {
'primary_keys': primary_keys,
'row_checksums': row_checksums,
'total_checksum': total_checksum
}
def _check_sample_data(self, table_name: str, primary_key: str, sample_size: int = 100) -> List[ConsistencyCheckResult]:
"""检查样本数据一致性"""
results = []
# 获取样本主键
sample_keys = self._get_sample_keys(table_name, primary_key, sample_size)
if not sample_keys:
return results
# 检查每个样本在各分片中的数据
inconsistent_records = []
for key in sample_keys:
shard_data = {}
# 获取该主键在各分片中的数据
with ThreadPoolExecutor(max_workers=len(self.shard_connections)) as executor:
futures = []
for shard_id, conn in self.shard_connections.items():
future = executor.submit(self._get_record_by_key, conn, table_name, primary_key, key)
futures.append((future, shard_id))
for future, shard_id in futures:
try:
record = future.result()
if record:
shard_data[shard_id] = record
except Exception as e:
self.logger.error(f"获取记录失败: {e}")
# 分析数据一致性
if len(shard_data) > 1:
# 同一主键在多个分片中存在
inconsistent_records.append({
'primary_key': key,
'type': 'duplicate',
'shards': list(shard_data.keys()),
'data': shard_data
})
elif len(shard_data) == 0:
# 主键不存在于任何分片中(可能是删除后的残留)
inconsistent_records.append({
'primary_key': key,
'type': 'missing',
'shards': [],
'data': {}
})
# 生成检查结果
for shard_id in self.shard_connections.keys():
status = 'consistent' if not inconsistent_records else 'inconsistent'
results.append(ConsistencyCheckResult(
table_name=table_name,
shard_id=shard_id,
check_type='sample_data',
status=status,
details={
'sample_size': len(sample_keys),
'inconsistent_count': len(inconsistent_records),
'inconsistent_records': inconsistent_records[:10] # 只显示前10个
},
timestamp=datetime.now()
))
# 记录不一致数据
for record in inconsistent_records:
inconsistency = DataInconsistency(
table_name=table_name,
primary_key=record['primary_key'],
shard_ids=record['shards'],
inconsistency_type=record['type'],
details=record,
severity='medium' if record['type'] == 'duplicate' else 'low',
auto_fixable=True
)
with self.lock:
self.inconsistencies.append(inconsistency)
return results
def _get_sample_keys(self, table_name: str, primary_key: str, sample_size: int) -> List[Any]:
"""获取样本主键"""
all_keys = set()
# 从所有分片收集主键
for shard_id, conn in self.shard_connections.items():
try:
with conn.cursor() as cursor:
cursor.execute(f"SELECT {primary_key} FROM {table_name} ORDER BY RAND() LIMIT {sample_size}")
keys = [row[0] for row in cursor.fetchall()]
all_keys.update(keys)
except Exception as e:
self.logger.error(f"获取分片 {shard_id} 样本主键失败: {e}")
return list(all_keys)[:sample_size]
def _get_record_by_key(self, conn, table_name: str, primary_key: str, key_value: Any) -> Optional[Dict]:
"""根据主键获取记录"""
with conn.cursor(pymysql.cursors.DictCursor) as cursor:
cursor.execute(f"SELECT * FROM {table_name} WHERE {primary_key} = %s", (key_value,))
return cursor.fetchone()
def _check_foreign_key_consistency(self, table_name: str) -> List[ConsistencyCheckResult]:
"""检查外键一致性"""
results = []
# 获取表的外键信息
foreign_keys = self._get_foreign_keys(table_name)
for fk in foreign_keys:
fk_results = self._check_single_foreign_key(table_name, fk)
results.extend(fk_results)
return results
def _get_foreign_keys(self, table_name: str) -> List[Dict]:
"""获取表的外键信息"""
foreign_keys = []
# 从第一个分片获取外键信息(假设所有分片结构相同)
first_shard = next(iter(self.shard_connections.values()))
with first_shard.cursor() as cursor:
cursor.execute(f"""
SELECT
COLUMN_NAME,
REFERENCED_TABLE_NAME,
REFERENCED_COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE TABLE_NAME = '{table_name}'
AND REFERENCED_TABLE_NAME IS NOT NULL
""")
for row in cursor.fetchall():
foreign_keys.append({
'column': row[0],
'referenced_table': row[1],
'referenced_column': row[2]
})
return foreign_keys
def _check_single_foreign_key(self, table_name: str, fk: Dict) -> List[ConsistencyCheckResult]:
"""检查单个外键的一致性"""
results = []
# 检查每个分片的外键约束
for shard_id, conn in self.shard_connections.items():
try:
orphaned_records = self._find_orphaned_records(conn, table_name, fk)
status = 'consistent' if not orphaned_records else 'inconsistent'
if orphaned_records:
# 记录外键不一致
inconsistency = DataInconsistency(
table_name=table_name,
primary_key=orphaned_records,
shard_ids=[shard_id],
inconsistency_type='foreign_key_violation',
details={
'foreign_key': fk,
'orphaned_count': len(orphaned_records)
},
severity='high',
auto_fixable=False
)
with self.lock:
self.inconsistencies.append(inconsistency)
results.append(ConsistencyCheckResult(
table_name=table_name,
shard_id=shard_id,
check_type='foreign_key',
status=status,
details={
'foreign_key': fk,
'orphaned_count': len(orphaned_records),
'orphaned_records': orphaned_records[:10] # 只显示前10个
},
timestamp=datetime.now()
))
except Exception as e:
self.logger.error(f"检查分片 {shard_id} 外键失败: {e}")
results.append(ConsistencyCheckResult(
table_name=table_name,
shard_id=shard_id,
check_type='foreign_key',
status='error',
details={'foreign_key': fk},
timestamp=datetime.now(),
error_message=str(e)
))
return results
def _find_orphaned_records(self, conn, table_name: str, fk: Dict) -> List[Any]:
"""查找孤立记录"""
with conn.cursor() as cursor:
cursor.execute(f"""
SELECT t.{fk['column']}
FROM {table_name} t
LEFT JOIN {fk['referenced_table']} r
ON t.{fk['column']} = r.{fk['referenced_column']}
WHERE r.{fk['referenced_column']} IS NULL
AND t.{fk['column']} IS NOT NULL
""")
return [row[0] for row in cursor.fetchall()]
def fix_inconsistencies(self, auto_fix: bool = False) -> Dict[str, int]:
"""修复数据不一致"""
fix_results = {
'attempted': 0,
'successful': 0,
'failed': 0,
'skipped': 0
}
with self.lock:
inconsistencies_to_fix = self.inconsistencies.copy()
for inconsistency in inconsistencies_to_fix:
fix_results['attempted'] += 1
if not auto_fix and not inconsistency.auto_fixable:
fix_results['skipped'] += 1
continue
try:
success = self._fix_single_inconsistency(inconsistency)
if success:
fix_results['successful'] += 1
# 从不一致列表中移除已修复的项
with self.lock:
if inconsistency in self.inconsistencies:
self.inconsistencies.remove(inconsistency)
else:
fix_results['failed'] += 1
except Exception as e:
self.logger.error(f"修复不一致失败: {e}")
fix_results['failed'] += 1
return fix_results
def _fix_single_inconsistency(self, inconsistency: DataInconsistency) -> bool:
"""修复单个数据不一致"""
if inconsistency.inconsistency_type == 'duplicate':
return self._fix_duplicate_records(inconsistency)
elif inconsistency.inconsistency_type == 'missing':
return self._fix_missing_records(inconsistency)
else:
self.logger.warning(f"不支持的不一致类型: {inconsistency.inconsistency_type}")
return False
def _fix_duplicate_records(self, inconsistency: DataInconsistency) -> bool:
"""修复重复记录"""
# 简化实现:删除除第一个分片外的重复记录
if len(inconsistency.shard_ids) <= 1:
return True
primary_shard = inconsistency.shard_ids[0]
duplicate_shards = inconsistency.shard_ids[1:]
for shard_id in duplicate_shards:
try:
conn = self.shard_connections[shard_id]
with conn.cursor() as cursor:
# 删除重复记录
if isinstance(inconsistency.primary_key, list):
for pk in inconsistency.primary_key:
cursor.execute(f"DELETE FROM {inconsistency.table_name} WHERE id = %s", (pk,))
else:
cursor.execute(f"DELETE FROM {inconsistency.table_name} WHERE id = %s", (inconsistency.primary_key,))
conn.commit()
self.logger.info(f"删除分片 {shard_id} 中的重复记录")
except Exception as e:
self.logger.error(f"删除重复记录失败: {e}")
return False
return True
def _fix_missing_records(self, inconsistency: DataInconsistency) -> bool:
"""修复缺失记录"""
# 这里需要根据具体业务逻辑实现
self.logger.info(f"缺失记录修复需要手动处理: {inconsistency.primary_key}")
return False
def generate_report(self) -> Dict:
"""生成一致性检查报告"""
with self.lock:
total_checks = len(self.check_results)
consistent_checks = len([r for r in self.check_results if r.status == 'consistent'])
inconsistent_checks = len([r for r in self.check_results if r.status == 'inconsistent'])
error_checks = len([r for r in self.check_results if r.status == 'error'])
total_inconsistencies = len(self.inconsistencies)
# 按严重程度分组
severity_counts = {}
for inconsistency in self.inconsistencies:
severity = inconsistency.severity
severity_counts[severity] = severity_counts.get(severity, 0) + 1
# 按类型分组
type_counts = {}
for inconsistency in self.inconsistencies:
inc_type = inconsistency.inconsistency_type
type_counts[inc_type] = type_counts.get(inc_type, 0) + 1
return {
'timestamp': datetime.now().isoformat(),
'summary': {
'total_checks': total_checks,
'consistent_checks': consistent_checks,
'inconsistent_checks': inconsistent_checks,
'error_checks': error_checks,
'consistency_rate': consistent_checks / total_checks if total_checks > 0 else 0
},
'inconsistencies': {
'total': total_inconsistencies,
'by_severity': severity_counts,
'by_type': type_counts
},
'recent_checks': [
{
'table_name': r.table_name,
'shard_id': r.shard_id,
'check_type': r.check_type,
'status': r.status,
'timestamp': r.timestamp.isoformat()
} for r in self.check_results[-20:] # 最近20次检查
],
'critical_inconsistencies': [
{
'table_name': i.table_name,
'primary_key': i.primary_key,
'type': i.inconsistency_type,
'severity': i.severity,
'auto_fixable': i.auto_fixable
} for i in self.inconsistencies if i.severity == 'critical'
]
}
def close(self):
"""关闭所有连接"""
for conn in self.shard_connections.values():
conn.close()
self.logger.info("数据一致性检查器已关闭")
# 使用示例
if __name__ == '__main__':
# 配置
config = {
'shards': [
{
'shard_id': 'shard_0',
'host': 'localhost',
'port': 3306,
'database': 'test_shard_0',
'user': 'root',
'password': 'password'
},
{
'shard_id': 'shard_1',
'host': 'localhost',
'port': 3307,
'database': 'test_shard_1',
'user': 'root',
'password': 'password'
}
]
}
# 创建一致性检查器
checker = DataConsistencyChecker(config)
try:
# 检查用户表的一致性
results = checker.check_table_consistency('users', 'id', ['count', 'checksum', 'sample_data'])
print(f"检查完成,共 {len(results)} 个结果")
# 修复不一致(仅自动修复)
fix_results = checker.fix_inconsistencies(auto_fix=True)
print(f"修复结果: {fix_results}")
# 生成报告
report = checker.generate_report()
print(f"一致性报告: {json.dumps(report, indent=2, ensure_ascii=False)}")
finally:
checker.close()
14.6 数据迁移和扩容
14.6.1 在线数据迁移
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MySQL分库分表在线数据迁移工具
支持双写、数据同步和一致性验证
"""
import json
import time
import threading
from typing import Dict, List, Any, Optional, Callable, Tuple
from enum import Enum
from dataclasses import dataclass, asdict
import pymysql
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
from datetime import datetime, timedelta
import hashlib
import queue
class MigrationState(Enum):
INIT = "init"
PREPARING = "preparing"
SYNCING = "syncing"
DOUBLE_WRITING = "double_writing"
VERIFYING = "verifying"
SWITCHING = "switching"
COMPLETED = "completed"
FAILED = "failed"
ROLLBACK = "rollback"
class MigrationMode(Enum):
FULL_MIGRATION = "full_migration" # 全量迁移
INCREMENTAL = "incremental" # 增量迁移
DOUBLE_WRITE = "double_write" # 双写模式
SHADOW_TABLE = "shadow_table" # 影子表模式
@dataclass
class MigrationTask:
task_id: str
source_config: Dict
target_config: Dict
table_name: str
migration_mode: MigrationMode
state: MigrationState
progress: float = 0.0
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
error_message: Optional[str] = None
batch_size: int = 1000
parallel_workers: int = 4
@dataclass
class MigrationProgress:
total_rows: int
migrated_rows: int
failed_rows: int
current_batch: int
total_batches: int
speed_rows_per_second: float
estimated_remaining_time: float
class OnlineDataMigrator:
"""在线数据迁移器"""
def __init__(self, config: Dict):
self.config = config
self.logger = self._setup_logger()
self.migration_tasks: Dict[str, MigrationTask] = {}
self.migration_progress: Dict[str, MigrationProgress] = {}
self.double_write_enabled = False
self.verification_results = {}
self.lock = threading.RLock()
# 数据同步队列
self.sync_queue = queue.Queue(maxsize=10000)
self.sync_workers = []
# 启动同步工作线程
self._start_sync_workers()
def _setup_logger(self) -> logging.Logger:
"""设置日志"""
logger = logging.getLogger('OnlineDataMigrator')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def _start_sync_workers(self):
"""启动同步工作线程"""
worker_count = self.config.get('sync_workers', 2)
for i in range(worker_count):
worker = threading.Thread(
target=self._sync_worker,
name=f"SyncWorker-{i}",
daemon=True
)
worker.start()
self.sync_workers.append(worker)
self.logger.info(f"启动了 {worker_count} 个同步工作线程")
def _sync_worker(self):
"""同步工作线程"""
while True:
try:
# 从队列获取同步任务
sync_task = self.sync_queue.get(timeout=1)
if sync_task is None: # 停止信号
break
# 执行同步
self._execute_sync_task(sync_task)
self.sync_queue.task_done()
except queue.Empty:
continue
except Exception as e:
self.logger.error(f"同步工作线程错误: {e}")
def _execute_sync_task(self, sync_task: Dict):
"""执行同步任务"""
try:
operation = sync_task['operation']
table_name = sync_task['table_name']
data = sync_task['data']
target_config = sync_task['target_config']
# 连接目标数据库
target_conn = pymysql.connect(**target_config)
try:
with target_conn.cursor() as cursor:
if operation == 'INSERT':
self._sync_insert(cursor, table_name, data)
elif operation == 'UPDATE':
self._sync_update(cursor, table_name, data)
elif operation == 'DELETE':
self._sync_delete(cursor, table_name, data)
target_conn.commit()
finally:
target_conn.close()
except Exception as e:
self.logger.error(f"同步任务执行失败: {e}")
def _sync_insert(self, cursor, table_name: str, data: Dict):
"""同步插入操作"""
columns = list(data.keys())
values = list(data.values())
placeholders = ', '.join(['%s'] * len(values))
sql = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})"
cursor.execute(sql, values)
def _sync_update(self, cursor, table_name: str, data: Dict):
"""同步更新操作"""
primary_key = data.pop('_primary_key')
primary_value = data.pop('_primary_value')
set_clause = ', '.join([f"{col} = %s" for col in data.keys()])
sql = f"UPDATE {table_name} SET {set_clause} WHERE {primary_key} = %s"
values = list(data.values()) + [primary_value]
cursor.execute(sql, values)
def _sync_delete(self, cursor, table_name: str, data: Dict):
"""同步删除操作"""
primary_key = data['_primary_key']
primary_value = data['_primary_value']
sql = f"DELETE FROM {table_name} WHERE {primary_key} = %s"
cursor.execute(sql, [primary_value])
def create_migration_task(self, source_config: Dict, target_config: Dict,
table_name: str, migration_mode: MigrationMode,
**kwargs) -> str:
"""创建迁移任务"""
task_id = f"migration_{int(time.time())}_{table_name}"
task = MigrationTask(
task_id=task_id,
source_config=source_config,
target_config=target_config,
table_name=table_name,
migration_mode=migration_mode,
state=MigrationState.INIT,
batch_size=kwargs.get('batch_size', 1000),
parallel_workers=kwargs.get('parallel_workers', 4)
)
with self.lock:
self.migration_tasks[task_id] = task
self.logger.info(f"创建迁移任务: {task_id}, 表: {table_name}, 模式: {migration_mode.value}")
return task_id
def start_migration(self, task_id: str) -> bool:
"""开始迁移"""
with self.lock:
if task_id not in self.migration_tasks:
raise ValueError(f"迁移任务不存在: {task_id}")
task = self.migration_tasks[task_id]
if task.state != MigrationState.INIT:
self.logger.warning(f"任务 {task_id} 状态不正确: {task.state}")
return False
task.state = MigrationState.PREPARING
task.start_time = datetime.now()
try:
# 根据迁移模式执行不同的策略
if task.migration_mode == MigrationMode.FULL_MIGRATION:
return self._full_migration(task)
elif task.migration_mode == MigrationMode.INCREMENTAL:
return self._incremental_migration(task)
elif task.migration_mode == MigrationMode.DOUBLE_WRITE:
return self._double_write_migration(task)
elif task.migration_mode == MigrationMode.SHADOW_TABLE:
return self._shadow_table_migration(task)
else:
raise ValueError(f"不支持的迁移模式: {task.migration_mode}")
except Exception as e:
self.logger.error(f"迁移任务 {task_id} 失败: {e}")
task.state = MigrationState.FAILED
task.error_message = str(e)
return False
def _full_migration(self, task: MigrationTask) -> bool:
"""全量迁移"""
self.logger.info(f"开始全量迁移: {task.task_id}")
# 获取源表总行数
total_rows = self._get_table_row_count(task.source_config, task.table_name)
total_batches = (total_rows + task.batch_size - 1) // task.batch_size
# 初始化进度
progress = MigrationProgress(
total_rows=total_rows,
migrated_rows=0,
failed_rows=0,
current_batch=0,
total_batches=total_batches,
speed_rows_per_second=0.0,
estimated_remaining_time=0.0
)
with self.lock:
self.migration_progress[task.task_id] = progress
task.state = MigrationState.SYNCING
# 获取主键列名
primary_key = self._get_primary_key(task.source_config, task.table_name)
# 分批迁移数据
start_time = time.time()
with ThreadPoolExecutor(max_workers=task.parallel_workers) as executor:
futures = []
for batch_num in range(total_batches):
offset = batch_num * task.batch_size
future = executor.submit(
self._migrate_batch,
task,
primary_key,
offset,
task.batch_size,
batch_num
)
futures.append(future)
# 等待所有批次完成
for future in as_completed(futures):
try:
batch_result = future.result()
with self.lock:
progress.migrated_rows += batch_result['migrated_rows']
progress.failed_rows += batch_result['failed_rows']
progress.current_batch += 1
# 更新速度和预估时间
elapsed_time = time.time() - start_time
if elapsed_time > 0:
progress.speed_rows_per_second = progress.migrated_rows / elapsed_time
remaining_rows = total_rows - progress.migrated_rows
if progress.speed_rows_per_second > 0:
progress.estimated_remaining_time = remaining_rows / progress.speed_rows_per_second
# 更新任务进度
task.progress = progress.migrated_rows / total_rows * 100
except Exception as e:
self.logger.error(f"批次迁移失败: {e}")
with self.lock:
progress.failed_rows += task.batch_size
# 检查迁移结果
if progress.failed_rows == 0:
task.state = MigrationState.COMPLETED
task.end_time = datetime.now()
self.logger.info(f"全量迁移完成: {task.task_id}")
return True
else:
task.state = MigrationState.FAILED
task.error_message = f"迁移失败,失败行数: {progress.failed_rows}"
self.logger.error(f"全量迁移失败: {task.task_id}, 失败行数: {progress.failed_rows}")
return False
def _migrate_batch(self, task: MigrationTask, primary_key: str,
offset: int, limit: int, batch_num: int) -> Dict:
"""迁移单个批次"""
migrated_rows = 0
failed_rows = 0
source_conn = None
target_conn = None
try:
# 连接源和目标数据库
source_conn = pymysql.connect(**task.source_config)
target_conn = pymysql.connect(**task.target_config)
# 获取批次数据
with source_conn.cursor(pymysql.cursors.DictCursor) as source_cursor:
sql = f"SELECT * FROM {task.table_name} ORDER BY {primary_key} LIMIT {limit} OFFSET {offset}"
source_cursor.execute(sql)
rows = source_cursor.fetchall()
# 插入目标数据库
if rows:
with target_conn.cursor() as target_cursor:
for row in rows:
try:
columns = list(row.keys())
values = list(row.values())
placeholders = ', '.join(['%s'] * len(values))
insert_sql = f"INSERT INTO {task.table_name} ({', '.join(columns)}) VALUES ({placeholders})"
target_cursor.execute(insert_sql, values)
migrated_rows += 1
except Exception as e:
self.logger.error(f"插入行失败: {e}")
failed_rows += 1
target_conn.commit()
self.logger.debug(f"批次 {batch_num} 完成,迁移 {migrated_rows} 行,失败 {failed_rows} 行")
except Exception as e:
self.logger.error(f"批次 {batch_num} 迁移失败: {e}")
failed_rows = limit
finally:
if source_conn:
source_conn.close()
if target_conn:
target_conn.close()
return {
'migrated_rows': migrated_rows,
'failed_rows': failed_rows,
'batch_num': batch_num
}
def _incremental_migration(self, task: MigrationTask) -> bool:
"""增量迁移"""
self.logger.info(f"开始增量迁移: {task.task_id}")
# 增量迁移通常基于binlog或时间戳
# 这里实现一个基于时间戳的简化版本
task.state = MigrationState.SYNCING
# 获取最后同步时间
last_sync_time = self._get_last_sync_time(task.task_id)
# 查询增量数据
incremental_data = self._get_incremental_data(
task.source_config,
task.table_name,
last_sync_time
)
if not incremental_data:
self.logger.info(f"没有增量数据需要同步: {task.task_id}")
task.state = MigrationState.COMPLETED
return True
# 同步增量数据
success_count = 0
failed_count = 0
target_conn = pymysql.connect(**task.target_config)
try:
with target_conn.cursor() as cursor:
for row in incremental_data:
try:
# 根据操作类型执行相应的SQL
if row['_operation'] == 'INSERT':
self._execute_insert(cursor, task.table_name, row)
elif row['_operation'] == 'UPDATE':
self._execute_update(cursor, task.table_name, row)
elif row['_operation'] == 'DELETE':
self._execute_delete(cursor, task.table_name, row)
success_count += 1
except Exception as e:
self.logger.error(f"同步行失败: {e}")
failed_count += 1
target_conn.commit()
finally:
target_conn.close()
# 更新最后同步时间
self._update_last_sync_time(task.task_id, datetime.now())
if failed_count == 0:
task.state = MigrationState.COMPLETED
self.logger.info(f"增量迁移完成: {task.task_id}, 同步 {success_count} 行")
return True
else:
task.state = MigrationState.FAILED
task.error_message = f"增量迁移失败,失败行数: {failed_count}"
return False
def _double_write_migration(self, task: MigrationTask) -> bool:
"""双写迁移"""
self.logger.info(f"开始双写迁移: {task.task_id}")
# 首先执行全量迁移
if not self._full_migration(task):
return False
# 启用双写模式
task.state = MigrationState.DOUBLE_WRITING
self.double_write_enabled = True
# 在实际应用中,这里需要修改应用程序的数据访问层
# 使其同时写入源和目标数据库
self.logger.info(f"双写模式已启用: {task.task_id}")
# 等待一段时间确保数据同步
time.sleep(self.config.get('double_write_duration', 300)) # 默认5分钟
# 验证数据一致性
if self._verify_data_consistency(task):
task.state = MigrationState.COMPLETED
self.logger.info(f"双写迁移完成: {task.task_id}")
return True
else:
task.state = MigrationState.FAILED
task.error_message = "数据一致性验证失败"
return False
def _shadow_table_migration(self, task: MigrationTask) -> bool:
"""影子表迁移"""
self.logger.info(f"开始影子表迁移: {task.task_id}")
shadow_table_name = f"{task.table_name}_shadow"
try:
# 创建影子表
self._create_shadow_table(task.target_config, task.table_name, shadow_table_name)
# 迁移数据到影子表
shadow_task = MigrationTask(
task_id=f"{task.task_id}_shadow",
source_config=task.source_config,
target_config=task.target_config,
table_name=shadow_table_name,
migration_mode=MigrationMode.FULL_MIGRATION,
state=MigrationState.INIT,
batch_size=task.batch_size,
parallel_workers=task.parallel_workers
)
if not self._full_migration(shadow_task):
return False
# 验证数据一致性
if self._verify_shadow_table_consistency(task, shadow_table_name):
# 原子性切换表名
self._atomic_table_switch(task.target_config, task.table_name, shadow_table_name)
task.state = MigrationState.COMPLETED
self.logger.info(f"影子表迁移完成: {task.task_id}")
return True
else:
task.state = MigrationState.FAILED
task.error_message = "影子表数据一致性验证失败"
return False
except Exception as e:
self.logger.error(f"影子表迁移失败: {e}")
task.state = MigrationState.FAILED
task.error_message = str(e)
return False
def _get_table_row_count(self, db_config: Dict, table_name: str) -> int:
"""获取表的行数"""
conn = pymysql.connect(**db_config)
try:
with conn.cursor() as cursor:
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
return cursor.fetchone()[0]
finally:
conn.close()
def _get_primary_key(self, db_config: Dict, table_name: str) -> str:
"""获取表的主键列名"""
conn = pymysql.connect(**db_config)
try:
with conn.cursor() as cursor:
cursor.execute(f"""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE TABLE_NAME = '{table_name}'
AND CONSTRAINT_NAME = 'PRIMARY'
ORDER BY ORDINAL_POSITION
LIMIT 1
""")
result = cursor.fetchone()
return result[0] if result else 'id'
finally:
conn.close()
def _get_last_sync_time(self, task_id: str) -> datetime:
"""获取最后同步时间"""
# 这里应该从持久化存储中获取
# 简化实现返回1小时前
return datetime.now() - timedelta(hours=1)
def _update_last_sync_time(self, task_id: str, sync_time: datetime):
"""更新最后同步时间"""
# 这里应该持久化存储同步时间
pass
def _get_incremental_data(self, db_config: Dict, table_name: str,
last_sync_time: datetime) -> List[Dict]:
"""获取增量数据"""
conn = pymysql.connect(**db_config)
try:
with conn.cursor(pymysql.cursors.DictCursor) as cursor:
# 假设表有updated_at字段
cursor.execute(f"""
SELECT *, 'UPDATE' as _operation
FROM {table_name}
WHERE updated_at > %s
ORDER BY updated_at
""", (last_sync_time,))
return cursor.fetchall()
finally:
conn.close()
def _execute_insert(self, cursor, table_name: str, row: Dict):
"""执行插入操作"""
row.pop('_operation', None)
columns = list(row.keys())
values = list(row.values())
placeholders = ', '.join(['%s'] * len(values))
sql = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})"
cursor.execute(sql, values)
def _execute_update(self, cursor, table_name: str, row: Dict):
"""执行更新操作"""
row.pop('_operation', None)
primary_key = 'id' # 假设主键是id
primary_value = row.pop(primary_key)
set_clause = ', '.join([f"{col} = %s" for col in row.keys()])
sql = f"UPDATE {table_name} SET {set_clause} WHERE {primary_key} = %s"
values = list(row.values()) + [primary_value]
cursor.execute(sql, values)
def _execute_delete(self, cursor, table_name: str, row: Dict):
"""执行删除操作"""
primary_key = 'id' # 假设主键是id
primary_value = row[primary_key]
sql = f"DELETE FROM {table_name} WHERE {primary_key} = %s"
cursor.execute(sql, [primary_value])
def _verify_data_consistency(self, task: MigrationTask) -> bool:
"""验证数据一致性"""
self.logger.info(f"开始验证数据一致性: {task.task_id}")
# 比较源表和目标表的行数
source_count = self._get_table_row_count(task.source_config, task.table_name)
target_count = self._get_table_row_count(task.target_config, task.table_name)
if source_count != target_count:
self.logger.error(f"行数不一致: 源表 {source_count}, 目标表 {target_count}")
return False
# 比较数据校验和
source_checksum = self._calculate_table_checksum(task.source_config, task.table_name)
target_checksum = self._calculate_table_checksum(task.target_config, task.table_name)
if source_checksum != target_checksum:
self.logger.error(f"数据校验和不一致: 源表 {source_checksum}, 目标表 {target_checksum}")
return False
self.logger.info(f"数据一致性验证通过: {task.task_id}")
return True
def _calculate_table_checksum(self, db_config: Dict, table_name: str) -> str:
"""计算表的数据校验和"""
conn = pymysql.connect(**db_config)
try:
with conn.cursor() as cursor:
cursor.execute(f"CHECKSUM TABLE {table_name}")
result = cursor.fetchone()
return str(result[1]) if result else "0"
finally:
conn.close()
def _create_shadow_table(self, db_config: Dict, original_table: str, shadow_table: str):
"""创建影子表"""
conn = pymysql.connect(**db_config)
try:
with conn.cursor() as cursor:
# 删除已存在的影子表
cursor.execute(f"DROP TABLE IF EXISTS {shadow_table}")
# 创建影子表(复制原表结构)
cursor.execute(f"CREATE TABLE {shadow_table} LIKE {original_table}")
conn.commit()
self.logger.info(f"创建影子表: {shadow_table}")
finally:
conn.close()
def _verify_shadow_table_consistency(self, task: MigrationTask, shadow_table: str) -> bool:
"""验证影子表数据一致性"""
# 比较源表和影子表的数据
source_checksum = self._calculate_table_checksum(task.source_config, task.table_name)
shadow_checksum = self._calculate_table_checksum(task.target_config, shadow_table)
return source_checksum == shadow_checksum
def _atomic_table_switch(self, db_config: Dict, original_table: str, shadow_table: str):
"""原子性切换表名"""
conn = pymysql.connect(**db_config)
try:
with conn.cursor() as cursor:
# 使用RENAME TABLE进行原子性切换
backup_table = f"{original_table}_backup_{int(time.time())}"
cursor.execute(f"""
RENAME TABLE
{original_table} TO {backup_table},
{shadow_table} TO {original_table}
""")
conn.commit()
self.logger.info(f"原子性切换完成: {shadow_table} -> {original_table}")
finally:
conn.close()
def get_migration_status(self, task_id: str) -> Optional[Dict]:
"""获取迁移状态"""
with self.lock:
if task_id not in self.migration_tasks:
return None
task = self.migration_tasks[task_id]
progress = self.migration_progress.get(task_id)
status = {
'task_id': task.task_id,
'table_name': task.table_name,
'migration_mode': task.migration_mode.value,
'state': task.state.value,
'progress': task.progress,
'start_time': task.start_time.isoformat() if task.start_time else None,
'end_time': task.end_time.isoformat() if task.end_time else None,
'error_message': task.error_message
}
if progress:
status['progress_details'] = {
'total_rows': progress.total_rows,
'migrated_rows': progress.migrated_rows,
'failed_rows': progress.failed_rows,
'current_batch': progress.current_batch,
'total_batches': progress.total_batches,
'speed_rows_per_second': progress.speed_rows_per_second,
'estimated_remaining_time': progress.estimated_remaining_time
}
return status
def stop_migration(self, task_id: str) -> bool:
"""停止迁移任务"""
with self.lock:
if task_id not in self.migration_tasks:
return False
task = self.migration_tasks[task_id]
if task.state in [MigrationState.COMPLETED, MigrationState.FAILED]:
return True
task.state = MigrationState.FAILED
task.error_message = "用户手动停止"
task.end_time = datetime.now()
self.logger.info(f"停止迁移任务: {task_id}")
return True
def rollback_migration(self, task_id: str) -> bool:
"""回滚迁移"""
with self.lock:
if task_id not in self.migration_tasks:
return False
task = self.migration_tasks[task_id]
if task.state != MigrationState.COMPLETED:
self.logger.warning(f"任务 {task_id} 未完成,无法回滚")
return False
task.state = MigrationState.ROLLBACK
try:
# 删除目标表数据
conn = pymysql.connect(**task.target_config)
try:
with conn.cursor() as cursor:
cursor.execute(f"TRUNCATE TABLE {task.table_name}")
conn.commit()
self.logger.info(f"回滚完成: {task_id}")
return True
finally:
conn.close()
except Exception as e:
self.logger.error(f"回滚失败: {e}")
task.error_message = f"回滚失败: {e}"
return False
def cleanup(self):
"""清理资源"""
# 停止同步工作线程
for _ in self.sync_workers:
self.sync_queue.put(None)
for worker in self.sync_workers:
worker.join(timeout=5)
self.logger.info("数据迁移器已清理")
# 使用示例
if __name__ == '__main__':
# 配置
config = {
'sync_workers': 2,
'double_write_duration': 300
}
# 创建迁移器
migrator = OnlineDataMigrator(config)
# 源数据库配置
source_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': 'password',
'database': 'source_db',
'charset': 'utf8mb4'
}
# 目标数据库配置
target_config = {
'host': 'localhost',
'port': 3307,
'user': 'root',
'password': 'password',
'database': 'target_db',
'charset': 'utf8mb4'
}
try:
# 创建全量迁移任务
task_id = migrator.create_migration_task(
source_config=source_config,
target_config=target_config,
table_name='users',
migration_mode=MigrationMode.FULL_MIGRATION,
batch_size=1000,
parallel_workers=4
)
print(f"创建迁移任务: {task_id}")
# 开始迁移
success = migrator.start_migration(task_id)
print(f"迁移结果: {'成功' if success else '失败'}")
# 监控迁移进度
while True:
status = migrator.get_migration_status(task_id)
if status:
print(f"迁移状态: {status['state']}, 进度: {status['progress']:.2f}%")
if status['state'] in ['completed', 'failed']:
break
time.sleep(5)
# 获取最终状态
final_status = migrator.get_migration_status(task_id)
print(f"最终状态: {json.dumps(final_status, indent=2, ensure_ascii=False)}")
except Exception as e:
print(f"迁移失败: {e}")
finally:
migrator.cleanup()
14.6.2 分片扩容策略
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MySQL分库分表扩容管理器
支持水平扩容、垂直扩容和数据重平衡
"""
import json
import time
import threading
from typing import Dict, List, Any, Optional, Tuple
from enum import Enum
from dataclasses import dataclass, asdict
import pymysql
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
from datetime import datetime
import hashlib
import math
class ExpansionType(Enum):
HORIZONTAL = "horizontal" # 水平扩容
VERTICAL = "vertical" # 垂直扩容
REBALANCE = "rebalance" # 数据重平衡
class ExpansionState(Enum):
INIT = "init"
PLANNING = "planning"
PREPARING = "preparing"
MIGRATING = "migrating"
VERIFYING = "verifying"
SWITCHING = "switching"
COMPLETED = "completed"
FAILED = "failed"
ROLLBACK = "rollback"
@dataclass
class ShardInfo:
shard_id: str
host: str
port: int
database: str
table_name: str
shard_key_range: Tuple[Any, Any] # (min_value, max_value)
row_count: int = 0
data_size_mb: float = 0.0
load_factor: float = 0.0
@dataclass
class ExpansionPlan:
plan_id: str
expansion_type: ExpansionType
source_shards: List[ShardInfo]
target_shards: List[ShardInfo]
migration_strategy: str
estimated_duration: float
risk_level: str
rollback_plan: Dict
class ShardExpansionManager:
"""分片扩容管理器"""
def __init__(self, config: Dict):
self.config = config
self.logger = self._setup_logger()
self.expansion_plans: Dict[str, ExpansionPlan] = {}
self.expansion_states: Dict[str, ExpansionState] = {}
self.shard_registry: Dict[str, ShardInfo] = {}
self.lock = threading.RLock()
# 负载监控
self.load_monitor = threading.Thread(
target=self._monitor_shard_loads,
daemon=True
)
self.load_monitor.start()
def _setup_logger(self) -> logging.Logger:
"""设置日志"""
logger = logging.getLogger('ShardExpansionManager')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def register_shard(self, shard_info: ShardInfo):
"""注册分片信息"""
with self.lock:
self.shard_registry[shard_info.shard_id] = shard_info
self.logger.info(f"注册分片: {shard_info.shard_id}")
def _monitor_shard_loads(self):
"""监控分片负载"""
while True:
try:
self._update_shard_metrics()
self._check_expansion_triggers()
time.sleep(self.config.get('monitor_interval', 300)) # 默认5分钟
except Exception as e:
self.logger.error(f"负载监控错误: {e}")
def _update_shard_metrics(self):
"""更新分片指标"""
with self.lock:
for shard_id, shard_info in self.shard_registry.items():
try:
# 连接分片数据库
conn = pymysql.connect(
host=shard_info.host,
port=shard_info.port,
user=self.config['db_user'],
password=self.config['db_password'],
database=shard_info.database
)
with conn.cursor() as cursor:
# 获取行数
cursor.execute(f"SELECT COUNT(*) FROM {shard_info.table_name}")
shard_info.row_count = cursor.fetchone()[0]
# 获取数据大小
cursor.execute(f"""
SELECT
ROUND(((data_length + index_length) / 1024 / 1024), 2) AS size_mb
FROM information_schema.tables
WHERE table_schema = '{shard_info.database}'
AND table_name = '{shard_info.table_name}'
""")
result = cursor.fetchone()
shard_info.data_size_mb = result[0] if result else 0.0
# 计算负载因子(基于行数和数据大小)
max_rows = self.config.get('max_rows_per_shard', 1000000)
max_size_mb = self.config.get('max_size_mb_per_shard', 1024)
row_factor = shard_info.row_count / max_rows
size_factor = shard_info.data_size_mb / max_size_mb
shard_info.load_factor = max(row_factor, size_factor)
conn.close()
except Exception as e:
self.logger.error(f"更新分片 {shard_id} 指标失败: {e}")
def _check_expansion_triggers(self):
"""检查扩容触发条件"""
high_load_threshold = self.config.get('high_load_threshold', 0.8)
for shard_id, shard_info in self.shard_registry.items():
if shard_info.load_factor > high_load_threshold:
self.logger.warning(
f"分片 {shard_id} 负载过高: {shard_info.load_factor:.2f}, "
f"行数: {shard_info.row_count}, 大小: {shard_info.data_size_mb}MB"
)
# 自动创建扩容计划
if self.config.get('auto_expansion', False):
self._create_auto_expansion_plan(shard_info)
def _create_auto_expansion_plan(self, overloaded_shard: ShardInfo):
"""创建自动扩容计划"""
plan_id = f"auto_expansion_{int(time.time())}_{overloaded_shard.shard_id}"
# 检查是否已有扩容计划
if any(plan.plan_id.endswith(overloaded_shard.shard_id)
for plan in self.expansion_plans.values()):
return
# 创建水平扩容计划
expansion_plan = self.create_horizontal_expansion_plan(
source_shard_ids=[overloaded_shard.shard_id],
target_shard_count=2
)
if expansion_plan:
self.logger.info(f"创建自动扩容计划: {plan_id}")
def create_horizontal_expansion_plan(self, source_shard_ids: List[str],
target_shard_count: int) -> Optional[str]:
"""创建水平扩容计划"""
plan_id = f"horizontal_expansion_{int(time.time())}"
try:
# 获取源分片信息
source_shards = []
for shard_id in source_shard_ids:
if shard_id in self.shard_registry:
source_shards.append(self.shard_registry[shard_id])
else:
raise ValueError(f"分片不存在: {shard_id}")
# 计算目标分片配置
target_shards = self._calculate_target_shards(
source_shards, target_shard_count, ExpansionType.HORIZONTAL
)
# 估算迁移时间
total_rows = sum(shard.row_count for shard in source_shards)
migration_speed = self.config.get('migration_speed_rows_per_second', 1000)
estimated_duration = total_rows / migration_speed
# 评估风险等级
risk_level = self._assess_risk_level(source_shards, target_shards)
# 创建回滚计划
rollback_plan = self._create_rollback_plan(source_shards, target_shards)
expansion_plan = ExpansionPlan(
plan_id=plan_id,
expansion_type=ExpansionType.HORIZONTAL,
source_shards=source_shards,
target_shards=target_shards,
migration_strategy="range_split",
estimated_duration=estimated_duration,
risk_level=risk_level,
rollback_plan=rollback_plan
)
with self.lock:
self.expansion_plans[plan_id] = expansion_plan
self.expansion_states[plan_id] = ExpansionState.INIT
self.logger.info(f"创建水平扩容计划: {plan_id}")
return plan_id
except Exception as e:
self.logger.error(f"创建水平扩容计划失败: {e}")
return None
def create_vertical_expansion_plan(self, table_name: str,
split_columns: List[str]) -> Optional[str]:
"""创建垂直扩容计划"""
plan_id = f"vertical_expansion_{int(time.time())}"
try:
# 获取相关分片
source_shards = [shard for shard in self.shard_registry.values()
if shard.table_name == table_name]
if not source_shards:
raise ValueError(f"未找到表 {table_name} 的分片")
# 计算垂直分割后的目标分片
target_shards = self._calculate_vertical_split_shards(
source_shards, split_columns
)
# 估算迁移时间
total_rows = sum(shard.row_count for shard in source_shards)
migration_speed = self.config.get('migration_speed_rows_per_second', 800)
estimated_duration = total_rows / migration_speed * 1.5 # 垂直分割更复杂
# 评估风险等级
risk_level = self._assess_risk_level(source_shards, target_shards)
# 创建回滚计划
rollback_plan = self._create_rollback_plan(source_shards, target_shards)
expansion_plan = ExpansionPlan(
plan_id=plan_id,
expansion_type=ExpansionType.VERTICAL,
source_shards=source_shards,
target_shards=target_shards,
migration_strategy="column_split",
estimated_duration=estimated_duration,
risk_level=risk_level,
rollback_plan=rollback_plan
)
with self.lock:
self.expansion_plans[plan_id] = expansion_plan
self.expansion_states[plan_id] = ExpansionState.INIT
self.logger.info(f"创建垂直扩容计划: {plan_id}")
return plan_id
except Exception as e:
self.logger.error(f"创建垂直扩容计划失败: {e}")
return None
def create_rebalance_plan(self, target_load_factor: float = 0.7) -> Optional[str]:
"""创建数据重平衡计划"""
plan_id = f"rebalance_{int(time.time())}"
try:
# 找出负载不均衡的分片
overloaded_shards = []
underloaded_shards = []
for shard in self.shard_registry.values():
if shard.load_factor > target_load_factor:
overloaded_shards.append(shard)
elif shard.load_factor < target_load_factor * 0.5:
underloaded_shards.append(shard)
if not overloaded_shards:
self.logger.info("无需重平衡,所有分片负载正常")
return None
# 计算重平衡后的目标分片配置
target_shards = self._calculate_rebalance_shards(
overloaded_shards, underloaded_shards, target_load_factor
)
# 估算迁移时间
migration_rows = sum(
max(0, shard.row_count - int(shard.row_count * target_load_factor))
for shard in overloaded_shards
)
migration_speed = self.config.get('migration_speed_rows_per_second', 1200)
estimated_duration = migration_rows / migration_speed
# 评估风险等级
risk_level = "LOW" # 重平衡风险相对较低
# 创建回滚计划
rollback_plan = self._create_rollback_plan(overloaded_shards, target_shards)
expansion_plan = ExpansionPlan(
plan_id=plan_id,
expansion_type=ExpansionType.REBALANCE,
source_shards=overloaded_shards,
target_shards=target_shards,
migration_strategy="load_balance",
estimated_duration=estimated_duration,
risk_level=risk_level,
rollback_plan=rollback_plan
)
with self.lock:
self.expansion_plans[plan_id] = expansion_plan
self.expansion_states[plan_id] = ExpansionState.INIT
self.logger.info(f"创建重平衡计划: {plan_id}")
return plan_id
except Exception as e:
self.logger.error(f"创建重平衡计划失败: {e}")
return None
def _calculate_target_shards(self, source_shards: List[ShardInfo],
target_count: int, expansion_type: ExpansionType) -> List[ShardInfo]:
"""计算目标分片配置"""
target_shards = []
if expansion_type == ExpansionType.HORIZONTAL:
# 水平扩容:按范围分割
for i, source_shard in enumerate(source_shards):
min_val, max_val = source_shard.shard_key_range
range_size = (max_val - min_val) / target_count
for j in range(target_count):
new_min = min_val + j * range_size
new_max = min_val + (j + 1) * range_size if j < target_count - 1 else max_val
target_shard = ShardInfo(
shard_id=f"{source_shard.shard_id}_split_{j}",
host=self._get_next_available_host(),
port=3306,
database=f"{source_shard.database}_split_{j}",
table_name=source_shard.table_name,
shard_key_range=(new_min, new_max),
row_count=source_shard.row_count // target_count,
data_size_mb=source_shard.data_size_mb / target_count
)
target_shards.append(target_shard)
return target_shards
def _calculate_vertical_split_shards(self, source_shards: List[ShardInfo],
split_columns: List[str]) -> List[ShardInfo]:
"""计算垂直分割后的分片配置"""
target_shards = []
for source_shard in source_shards:
# 主表(保留主要列)
main_shard = ShardInfo(
shard_id=f"{source_shard.shard_id}_main",
host=source_shard.host,
port=source_shard.port,
database=source_shard.database,
table_name=f"{source_shard.table_name}_main",
shard_key_range=source_shard.shard_key_range,
row_count=source_shard.row_count,
data_size_mb=source_shard.data_size_mb * 0.6 # 估算
)
target_shards.append(main_shard)
# 扩展表(分离的列)
ext_shard = ShardInfo(
shard_id=f"{source_shard.shard_id}_ext",
host=self._get_next_available_host(),
port=3306,
database=f"{source_shard.database}_ext",
table_name=f"{source_shard.table_name}_ext",
shard_key_range=source_shard.shard_key_range,
row_count=source_shard.row_count,
data_size_mb=source_shard.data_size_mb * 0.4 # 估算
)
target_shards.append(ext_shard)
return target_shards
def _calculate_rebalance_shards(self, overloaded_shards: List[ShardInfo],
underloaded_shards: List[ShardInfo],
target_load_factor: float) -> List[ShardInfo]:
"""计算重平衡后的分片配置"""
target_shards = []
# 计算需要迁移的数据量
total_excess_rows = sum(
max(0, shard.row_count - int(shard.row_count * target_load_factor))
for shard in overloaded_shards
)
# 分配到负载较低的分片
available_capacity = sum(
max(0, int(shard.row_count / target_load_factor) - shard.row_count)
for shard in underloaded_shards
)
if available_capacity < total_excess_rows:
# 需要创建新分片
new_shards_needed = math.ceil(
(total_excess_rows - available_capacity) /
self.config.get('max_rows_per_shard', 1000000)
)
for i in range(new_shards_needed):
new_shard = ShardInfo(
shard_id=f"rebalance_new_{int(time.time())}_{i}",
host=self._get_next_available_host(),
port=3306,
database=f"rebalance_db_{i}",
table_name=overloaded_shards[0].table_name,
shard_key_range=(0, 0), # 动态分配
row_count=0
)
target_shards.append(new_shard)
return target_shards
def _get_next_available_host(self) -> str:
"""获取下一个可用主机"""
available_hosts = self.config.get('available_hosts', ['localhost'])
# 简单的轮询策略
host_usage = {}
for shard in self.shard_registry.values():
host_usage[shard.host] = host_usage.get(shard.host, 0) + 1
# 选择使用最少的主机
return min(available_hosts, key=lambda h: host_usage.get(h, 0))
def _assess_risk_level(self, source_shards: List[ShardInfo],
target_shards: List[ShardInfo]) -> str:
"""评估风险等级"""
total_data_size = sum(shard.data_size_mb for shard in source_shards)
total_rows = sum(shard.row_count for shard in source_shards)
if total_data_size > 10000 or total_rows > 10000000: # 10GB 或 1000万行
return "HIGH"
elif total_data_size > 1000 or total_rows > 1000000: # 1GB 或 100万行
return "MEDIUM"
else:
return "LOW"
def _create_rollback_plan(self, source_shards: List[ShardInfo],
target_shards: List[ShardInfo]) -> Dict:
"""创建回滚计划"""
return {
'backup_locations': [f"/backup/{shard.shard_id}" for shard in source_shards],
'rollback_scripts': [f"rollback_{shard.shard_id}.sql" for shard in source_shards],
'estimated_rollback_time': sum(shard.data_size_mb for shard in source_shards) / 100 # 估算
}
def execute_expansion_plan(self, plan_id: str) -> bool:
"""执行扩容计划"""
if plan_id not in self.expansion_plans:
self.logger.error(f"扩容计划不存在: {plan_id}")
return False
plan = self.expansion_plans[plan_id]
try:
with self.lock:
self.expansion_states[plan_id] = ExpansionState.PLANNING
self.logger.info(f"开始执行扩容计划: {plan_id}")
# 1. 准备阶段
if not self._prepare_expansion(plan):
raise Exception("准备阶段失败")
with self.lock:
self.expansion_states[plan_id] = ExpansionState.PREPARING
# 2. 数据迁移阶段
if not self._migrate_data(plan):
raise Exception("数据迁移失败")
with self.lock:
self.expansion_states[plan_id] = ExpansionState.MIGRATING
# 3. 验证阶段
if not self._verify_expansion(plan):
raise Exception("验证阶段失败")
with self.lock:
self.expansion_states[plan_id] = ExpansionState.VERIFYING
# 4. 切换阶段
if not self._switch_traffic(plan):
raise Exception("流量切换失败")
with self.lock:
self.expansion_states[plan_id] = ExpansionState.COMPLETED
self.logger.info(f"扩容计划执行完成: {plan_id}")
return True
except Exception as e:
self.logger.error(f"扩容计划执行失败: {plan_id}, 错误: {e}")
with self.lock:
self.expansion_states[plan_id] = ExpansionState.FAILED
# 尝试回滚
self._rollback_expansion(plan)
return False
def _prepare_expansion(self, plan: ExpansionPlan) -> bool:
"""准备扩容"""
try:
# 创建目标数据库和表
for target_shard in plan.target_shards:
self._create_target_shard(target_shard, plan.source_shards[0])
# 创建备份
for source_shard in plan.source_shards:
self._create_backup(source_shard)
return True
except Exception as e:
self.logger.error(f"准备扩容失败: {e}")
return False
def _create_target_shard(self, target_shard: ShardInfo, template_shard: ShardInfo):
"""创建目标分片"""
# 连接目标主机
conn = pymysql.connect(
host=target_shard.host,
port=target_shard.port,
user=self.config['db_user'],
password=self.config['db_password']
)
try:
with conn.cursor() as cursor:
# 创建数据库
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {target_shard.database}")
cursor.execute(f"USE {target_shard.database}")
# 获取源表结构
template_conn = pymysql.connect(
host=template_shard.host,
port=template_shard.port,
user=self.config['db_user'],
password=self.config['db_password'],
database=template_shard.database
)
try:
with template_conn.cursor() as template_cursor:
template_cursor.execute(f"SHOW CREATE TABLE {template_shard.table_name}")
create_sql = template_cursor.fetchone()[1]
# 修改表名
create_sql = create_sql.replace(
f"CREATE TABLE `{template_shard.table_name}`",
f"CREATE TABLE `{target_shard.table_name}`"
)
# 执行创建表语句
cursor.execute(create_sql)
finally:
template_conn.close()
conn.commit()
self.logger.info(f"创建目标分片: {target_shard.shard_id}")
finally:
conn.close()
def _create_backup(self, shard: ShardInfo):
"""创建备份"""
backup_file = f"/backup/{shard.shard_id}_{int(time.time())}.sql"
# 这里应该调用mysqldump或其他备份工具
# 简化实现
self.logger.info(f"创建备份: {shard.shard_id} -> {backup_file}")
def _migrate_data(self, plan: ExpansionPlan) -> bool:
"""迁移数据"""
try:
if plan.expansion_type == ExpansionType.HORIZONTAL:
return self._migrate_horizontal_data(plan)
elif plan.expansion_type == ExpansionType.VERTICAL:
return self._migrate_vertical_data(plan)
elif plan.expansion_type == ExpansionType.REBALANCE:
return self._migrate_rebalance_data(plan)
else:
raise ValueError(f"不支持的扩容类型: {plan.expansion_type}")
except Exception as e:
self.logger.error(f"数据迁移失败: {e}")
return False
def _migrate_horizontal_data(self, plan: ExpansionPlan) -> bool:
"""水平数据迁移"""
for source_shard in plan.source_shards:
# 获取相关的目标分片
related_targets = [
target for target in plan.target_shards
if target.shard_id.startswith(source_shard.shard_id)
]
# 按范围迁移数据
for target_shard in related_targets:
min_val, max_val = target_shard.shard_key_range
# 连接源和目标数据库
source_conn = pymysql.connect(
host=source_shard.host,
port=source_shard.port,
user=self.config['db_user'],
password=self.config['db_password'],
database=source_shard.database
)
target_conn = pymysql.connect(
host=target_shard.host,
port=target_shard.port,
user=self.config['db_user'],
password=self.config['db_password'],
database=target_shard.database
)
try:
# 分批迁移数据
batch_size = self.config.get('migration_batch_size', 1000)
offset = 0
while True:
with source_conn.cursor(pymysql.cursors.DictCursor) as source_cursor:
# 假设分片键是id
source_cursor.execute(f"""
SELECT * FROM {source_shard.table_name}
WHERE id >= {min_val} AND id < {max_val}
LIMIT {batch_size} OFFSET {offset}
""")
rows = source_cursor.fetchall()
if not rows:
break
# 插入目标分片
with target_conn.cursor() as target_cursor:
for row in rows:
columns = list(row.keys())
values = list(row.values())
placeholders = ', '.join(['%s'] * len(values))
insert_sql = f"""
INSERT INTO {target_shard.table_name}
({', '.join(columns)}) VALUES ({placeholders})
"""
target_cursor.execute(insert_sql, values)
target_conn.commit()
offset += batch_size
self.logger.debug(f"迁移批次完成: {len(rows)} 行")
finally:
source_conn.close()
target_conn.close()
return True
def _migrate_vertical_data(self, plan: ExpansionPlan) -> bool:
"""垂直数据迁移"""
# 垂直分割的具体实现
# 这里简化处理
self.logger.info("执行垂直数据迁移")
return True
def _migrate_rebalance_data(self, plan: ExpansionPlan) -> bool:
"""重平衡数据迁移"""
# 重平衡的具体实现
# 这里简化处理
self.logger.info("执行重平衡数据迁移")
return True
def _verify_expansion(self, plan: ExpansionPlan) -> bool:
"""验证扩容结果"""
try:
# 验证数据完整性
for target_shard in plan.target_shards:
if not self._verify_shard_data(target_shard):
return False
# 验证数据一致性
if not self._verify_data_consistency(plan):
return False
return True
except Exception as e:
self.logger.error(f"验证扩容失败: {e}")
return False
def _verify_shard_data(self, shard: ShardInfo) -> bool:
"""验证分片数据"""
try:
conn = pymysql.connect(
host=shard.host,
port=shard.port,
user=self.config['db_user'],
password=self.config['db_password'],
database=shard.database
)
with conn.cursor() as cursor:
# 检查表是否存在
cursor.execute(f"SHOW TABLES LIKE '{shard.table_name}'")
if not cursor.fetchone():
return False
# 检查数据行数
cursor.execute(f"SELECT COUNT(*) FROM {shard.table_name}")
row_count = cursor.fetchone()[0]
if row_count == 0 and shard.row_count > 0:
self.logger.warning(f"分片 {shard.shard_id} 数据为空")
return False
conn.close()
return True
except Exception as e:
self.logger.error(f"验证分片数据失败: {e}")
return False
def _verify_data_consistency(self, plan: ExpansionPlan) -> bool:
"""验证数据一致性"""
# 比较源分片和目标分片的数据总量
source_total_rows = sum(shard.row_count for shard in plan.source_shards)
target_total_rows = 0
for target_shard in plan.target_shards:
conn = pymysql.connect(
host=target_shard.host,
port=target_shard.port,
user=self.config['db_user'],
password=self.config['db_password'],
database=target_shard.database
)
try:
with conn.cursor() as cursor:
cursor.execute(f"SELECT COUNT(*) FROM {target_shard.table_name}")
target_total_rows += cursor.fetchone()[0]
finally:
conn.close()
if source_total_rows != target_total_rows:
self.logger.error(
f"数据行数不一致: 源 {source_total_rows}, 目标 {target_total_rows}"
)
return False
return True
def _switch_traffic(self, plan: ExpansionPlan) -> bool:
"""切换流量"""
try:
# 更新分片注册表
with self.lock:
# 移除源分片
for source_shard in plan.source_shards:
if source_shard.shard_id in self.shard_registry:
del self.shard_registry[source_shard.shard_id]
# 添加目标分片
for target_shard in plan.target_shards:
self.shard_registry[target_shard.shard_id] = target_shard
# 更新路由配置(这里应该通知应用程序更新路由)
self._update_routing_config(plan)
self.logger.info(f"流量切换完成: {plan.plan_id}")
return True
except Exception as e:
self.logger.error(f"流量切换失败: {e}")
return False
def _update_routing_config(self, plan: ExpansionPlan):
"""更新路由配置"""
# 生成新的路由配置
routing_config = {
'shards': [
{
'shard_id': shard.shard_id,
'host': shard.host,
'port': shard.port,
'database': shard.database,
'table_name': shard.table_name,
'key_range': shard.shard_key_range
}
for shard in plan.target_shards
]
}
# 保存配置文件
config_file = self.config.get('routing_config_file', '/etc/sharding/routing.json')
with open(config_file, 'w') as f:
json.dump(routing_config, f, indent=2)
self.logger.info(f"更新路由配置: {config_file}")
def _rollback_expansion(self, plan: ExpansionPlan):
"""回滚扩容"""
try:
self.logger.info(f"开始回滚扩容: {plan.plan_id}")
# 恢复备份数据
for source_shard in plan.source_shards:
self._restore_backup(source_shard)
# 删除目标分片
for target_shard in plan.target_shards:
self._cleanup_target_shard(target_shard)
# 恢复路由配置
self._restore_routing_config(plan)
with self.lock:
self.expansion_states[plan.plan_id] = ExpansionState.ROLLBACK
self.logger.info(f"回滚完成: {plan.plan_id}")
except Exception as e:
self.logger.error(f"回滚失败: {e}")
def _restore_backup(self, shard: ShardInfo):
"""恢复备份"""
# 这里应该调用mysql恢复命令
self.logger.info(f"恢复备份: {shard.shard_id}")
def _cleanup_target_shard(self, shard: ShardInfo):
"""清理目标分片"""
try:
conn = pymysql.connect(
host=shard.host,
port=shard.port,
user=self.config['db_user'],
password=self.config['db_password']
)
with conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE IF EXISTS {shard.database}")
conn.commit()
conn.close()
self.logger.info(f"清理目标分片: {shard.shard_id}")
except Exception as e:
self.logger.error(f"清理目标分片失败: {e}")
def _restore_routing_config(self, plan: ExpansionPlan):
"""恢复路由配置"""
# 恢复原始路由配置
self.logger.info(f"恢复路由配置: {plan.plan_id}")
def get_expansion_status(self, plan_id: str) -> Optional[Dict]:
"""获取扩容状态"""
if plan_id not in self.expansion_plans:
return None
plan = self.expansion_plans[plan_id]
state = self.expansion_states.get(plan_id, ExpansionState.INIT)
return {
'plan_id': plan.plan_id,
'expansion_type': plan.expansion_type.value,
'state': state.value,
'source_shards': [shard.shard_id for shard in plan.source_shards],
'target_shards': [shard.shard_id for shard in plan.target_shards],
'migration_strategy': plan.migration_strategy,
'estimated_duration': plan.estimated_duration,
'risk_level': plan.risk_level
}
def get_shard_metrics(self) -> Dict:
"""获取分片指标"""
with self.lock:
metrics = {
'total_shards': len(self.shard_registry),
'total_rows': sum(shard.row_count for shard in self.shard_registry.values()),
'total_size_mb': sum(shard.data_size_mb for shard in self.shard_registry.values()),
'average_load_factor': sum(shard.load_factor for shard in self.shard_registry.values()) / len(self.shard_registry) if self.shard_registry else 0,
'high_load_shards': [
shard.shard_id for shard in self.shard_registry.values()
if shard.load_factor > 0.8
],
'low_load_shards': [
shard.shard_id for shard in self.shard_registry.values()
if shard.load_factor < 0.3
]
}
return metrics
def cleanup(self):
"""清理资源"""
self.logger.info("分片扩容管理器已清理")
# 使用示例
if __name__ == '__main__':
# 配置
config = {
'db_user': 'root',
'db_password': 'password',
'monitor_interval': 300,
'high_load_threshold': 0.8,
'auto_expansion': True,
'max_rows_per_shard': 1000000,
'max_size_mb_per_shard': 1024,
'migration_speed_rows_per_second': 1000,
'migration_batch_size': 1000,
'available_hosts': ['host1', 'host2', 'host3'],
'routing_config_file': '/etc/sharding/routing.json'
}
# 创建扩容管理器
expansion_manager = ShardExpansionManager(config)
try:
# 注册分片
shard1 = ShardInfo(
shard_id='shard_001',
host='localhost',
port=3306,
database='db_shard_001',
table_name='users',
shard_key_range=(0, 1000000),
row_count=800000,
data_size_mb=512,
load_factor=0.85
)
expansion_manager.register_shard(shard1)
# 创建水平扩容计划
plan_id = expansion_manager.create_horizontal_expansion_plan(
source_shard_ids=['shard_001'],
target_shard_count=2
)
if plan_id:
print(f"创建扩容计划: {plan_id}")
# 获取计划状态
status = expansion_manager.get_expansion_status(plan_id)
print(f"计划状态: {json.dumps(status, indent=2, ensure_ascii=False)}")
# 执行扩容计划
success = expansion_manager.execute_expansion_plan(plan_id)
print(f"扩容结果: {'成功' if success else '失败'}")
# 获取分片指标
metrics = expansion_manager.get_shard_metrics()
print(f"分片指标: {json.dumps(metrics, indent=2, ensure_ascii=False)}")
except Exception as e:
print(f"扩容失败: {e}")
finally:
expansion_manager.cleanup()
"""
import json
import math
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
import pymysql
from datetime import datetime, timedelta
class ShardingType(Enum):
"""分片类型"""
HORIZONTAL = "horizontal" # 水平分片
VERTICAL = "vertical" # 垂直分片
HYBRID = "hybrid" # 混合分片
class ShardingStrategy(Enum):
"""分片策略"""
RANGE = "range" # 范围分片
HASH = "hash" # 哈希分片
DIRECTORY = "directory" # 目录分片
TIME = "time" # 时间分片
@dataclass
class TableAnalysis:
"""表分析结果"""
table_name: str
row_count: int
data_size_mb: float
index_size_mb: float
avg_row_length: float
growth_rate_per_month: float
query_patterns: List[str]
hot_columns: List[str]
recommended_sharding: bool
recommended_strategy: Optional[ShardingStrategy]
recommended_shard_count: int
estimated_performance_gain: float
class MySQLShardingAnalyzer:
"""MySQL分库分表策略分析器"""
def __init__(self, host: str, port: int, user: str, password: str, database: str):
self.host = host
self.port = port
self.user = user
self.password = password
self.database = database
self.connection = None
# 分析阈值
self.thresholds = {
'max_table_size_mb': 10240, # 10GB
'max_row_count': 10000000, # 1000万行
'max_query_time_ms': 1000, # 1秒
'min_performance_gain': 0.3, # 30%性能提升
'optimal_shard_size_mb': 2048, # 2GB每分片
'max_shard_count': 1024 # 最大分片数
}
def connect(self):
"""连接数据库"""
try:
self.connection = pymysql.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=self.database,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
return True
except Exception as e:
print(f"数据库连接失败: {e}")
return False
def disconnect(self):
"""断开数据库连接"""
if self.connection:
self.connection.close()
def get_table_info(self, table_name: str) -> Dict:
"""获取表基本信息"""
with self.connection.cursor() as cursor:
# 获取表状态信息
cursor.execute(f"SHOW TABLE STATUS LIKE '{table_name}'")
table_status = cursor.fetchone()
if not table_status:
return None
# 获取表结构信息
cursor.execute(f"DESCRIBE {table_name}")
columns = cursor.fetchall()
# 获取索引信息
cursor.execute(f"SHOW INDEX FROM {table_name}")
indexes = cursor.fetchall()
return {
'status': table_status,
'columns': columns,
'indexes': indexes
}
def analyze_query_patterns(self, table_name: str, days: int = 7) -> List[str]:
"""分析查询模式"""
patterns = []
try:
with self.connection.cursor() as cursor:
# 从慢查询日志分析(如果可用)
cursor.execute("""
SELECT sql_text, exec_count, avg_timer_wait/1000000000 as avg_time_sec
FROM performance_schema.events_statements_summary_by_digest
WHERE digest_text LIKE %s
ORDER BY exec_count DESC
LIMIT 10
""", (f'%{table_name}%',))
queries = cursor.fetchall()
for query in queries:
if query['avg_time_sec'] > 1.0: # 超过1秒的查询
patterns.append(f"慢查询: {query['sql_text'][:100]}...")
if 'WHERE' in query['sql_text'].upper():
# 提取WHERE条件中的列
patterns.append("频繁WHERE查询")
if 'ORDER BY' in query['sql_text'].upper():
patterns.append("频繁排序查询")
if 'GROUP BY' in query['sql_text'].upper():
patterns.append("频繁聚合查询")
except Exception as e:
print(f"查询模式分析失败: {e}")
# 使用默认模式
patterns = ["SELECT查询", "INSERT操作", "UPDATE操作"]
return list(set(patterns))
def identify_hot_columns(self, table_name: str) -> List[str]:
"""识别热点列"""
hot_columns = []
try:
with self.connection.cursor() as cursor:
# 分析索引使用情况
cursor.execute(f"""
SELECT COLUMN_NAME, CARDINALITY
FROM INFORMATION_SCHEMA.STATISTICS
WHERE TABLE_SCHEMA = '{self.database}'
AND TABLE_NAME = '{table_name}'
AND CARDINALITY > 0
ORDER BY CARDINALITY DESC
""")
index_stats = cursor.fetchall()
# 获取主键列
cursor.execute(f"""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = '{self.database}'
AND TABLE_NAME = '{table_name}'
AND CONSTRAINT_NAME = 'PRIMARY'
""")
primary_keys = [row['COLUMN_NAME'] for row in cursor.fetchall()]
hot_columns.extend(primary_keys)
# 添加高基数索引列
for stat in index_stats[:3]: # 取前3个高基数列
if stat['COLUMN_NAME'] not in hot_columns:
hot_columns.append(stat['COLUMN_NAME'])
# 分析时间类型列
cursor.execute(f"""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = '{self.database}'
AND TABLE_NAME = '{table_name}'
AND DATA_TYPE IN ('datetime', 'timestamp', 'date')
""")
time_columns = [row['COLUMN_NAME'] for row in cursor.fetchall()]
hot_columns.extend(time_columns)
except Exception as e:
print(f"热点列分析失败: {e}")
return list(set(hot_columns))
def estimate_growth_rate(self, table_name: str) -> float:
"""估算表增长率(每月)"""
try:
with self.connection.cursor() as cursor:
# 尝试从时间列估算增长率
cursor.execute(f"""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = '{self.database}'
AND TABLE_NAME = '{table_name}'
AND DATA_TYPE IN ('datetime', 'timestamp')
AND COLUMN_DEFAULT = 'CURRENT_TIMESTAMP'
LIMIT 1
""")
time_column = cursor.fetchone()
if time_column:
col_name = time_column['COLUMN_NAME']
# 计算最近30天的增长
cursor.execute(f"""
SELECT COUNT(*) as recent_count
FROM {table_name}
WHERE {col_name} >= DATE_SUB(NOW(), INTERVAL 30 DAY)
""")
recent_count = cursor.fetchone()['recent_count']
# 计算总数
cursor.execute(f"SELECT COUNT(*) as total_count FROM {table_name}")
total_count = cursor.fetchone()['total_count']
if total_count > 0:
return recent_count / total_count
except Exception as e:
print(f"增长率估算失败: {e}")
# 默认增长率 5%
return 0.05
def recommend_sharding_strategy(self, analysis: TableAnalysis) -> Tuple[ShardingStrategy, int]:
"""推荐分片策略"""
# 基于热点列推荐策略
if any('id' in col.lower() for col in analysis.hot_columns):
# 有ID列,推荐哈希分片
strategy = ShardingStrategy.HASH
elif any(col.lower() in ['created_at', 'updated_at', 'date', 'time'] for col in analysis.hot_columns):
# 有时间列,推荐时间分片
strategy = ShardingStrategy.TIME
elif len(analysis.hot_columns) > 0:
# 有其他热点列,推荐范围分片
strategy = ShardingStrategy.RANGE
else:
# 默认哈希分片
strategy = ShardingStrategy.HASH
# 计算推荐分片数
total_size_mb = analysis.data_size_mb + analysis.index_size_mb
optimal_shard_count = max(2, math.ceil(total_size_mb / self.thresholds['optimal_shard_size_mb']))
# 限制最大分片数
shard_count = min(optimal_shard_count, self.thresholds['max_shard_count'])
# 确保是2的幂(对哈希分片更友好)
if strategy == ShardingStrategy.HASH:
shard_count = 2 ** math.ceil(math.log2(shard_count))
return strategy, shard_count
def estimate_performance_gain(self, analysis: TableAnalysis, shard_count: int) -> float:
"""估算性能提升"""
# 基于数据量减少的性能提升
data_reduction_factor = 1.0 / shard_count
# 查询性能提升(非线性)
query_improvement = 1 - (data_reduction_factor ** 0.7)
# 并发性能提升
concurrency_improvement = min(0.5, (shard_count - 1) * 0.1)
# 总体性能提升
total_improvement = query_improvement + concurrency_improvement
return min(0.9, total_improvement) # 最大90%提升
def analyze_table(self, table_name: str) -> TableAnalysis:
"""分析单个表"""
table_info = self.get_table_info(table_name)
if not table_info:
raise ValueError(f"表 {table_name} 不存在")
status = table_info['status']
# 基本信息
row_count = status['Rows'] or 0
data_size_mb = (status['Data_length'] or 0) / 1024 / 1024
index_size_mb = (status['Index_length'] or 0) / 1024 / 1024
avg_row_length = status['Avg_row_length'] or 0
# 分析查询模式和热点列
query_patterns = self.analyze_query_patterns(table_name)
hot_columns = self.identify_hot_columns(table_name)
growth_rate = self.estimate_growth_rate(table_name)
# 判断是否需要分片
total_size_mb = data_size_mb + index_size_mb
needs_sharding = (
row_count > self.thresholds['max_row_count'] or
total_size_mb > self.thresholds['max_table_size_mb']
)
# 推荐策略
strategy = None
shard_count = 1
performance_gain = 0.0
if needs_sharding:
strategy, shard_count = self.recommend_sharding_strategy(
TableAnalysis(
table_name=table_name,
row_count=row_count,
data_size_mb=data_size_mb,
index_size_mb=index_size_mb,
avg_row_length=avg_row_length,
growth_rate_per_month=growth_rate,
query_patterns=query_patterns,
hot_columns=hot_columns,
recommended_sharding=needs_sharding,
recommended_strategy=strategy,
recommended_shard_count=shard_count,
estimated_performance_gain=performance_gain
)
)
performance_gain = self.estimate_performance_gain(
TableAnalysis(
table_name=table_name,
row_count=row_count,
data_size_mb=data_size_mb,
index_size_mb=index_size_mb,
avg_row_length=avg_row_length,
growth_rate_per_month=growth_rate,
query_patterns=query_patterns,
hot_columns=hot_columns,
recommended_sharding=needs_sharding,
recommended_strategy=strategy,
recommended_shard_count=shard_count,
estimated_performance_gain=performance_gain
),
shard_count
)
return TableAnalysis(
table_name=table_name,
row_count=row_count,
data_size_mb=data_size_mb,
index_size_mb=index_size_mb,
avg_row_length=avg_row_length,
growth_rate_per_month=growth_rate,
query_patterns=query_patterns,
hot_columns=hot_columns,
recommended_sharding=needs_sharding,
recommended_strategy=strategy,
recommended_shard_count=shard_count,
estimated_performance_gain=performance_gain
)
def analyze_database(self) -> List[TableAnalysis]:
"""分析整个数据库"""
results = []
with self.connection.cursor() as cursor:
# 获取所有表
cursor.execute(f"SHOW TABLES FROM {self.database}")
tables = [row[f'Tables_in_{self.database}'] for row in cursor.fetchall()]
for table_name in tables:
try:
analysis = self.analyze_table(table_name)
results.append(analysis)
except Exception as e:
print(f"分析表 {table_name} 失败: {e}")
return results
def generate_sharding_plan(self, analyses: List[TableAnalysis]) -> Dict:
"""生成分库分表方案"""
plan = {
'timestamp': datetime.now().isoformat(),
'database': self.database,
'total_tables': len(analyses),
'tables_need_sharding': 0,
'estimated_total_performance_gain': 0.0,
'recommended_shard_databases': 1,
'sharding_recommendations': [],
'implementation_priority': [],
'estimated_costs': {
'development_days': 0,
'migration_hours': 0,
'additional_servers': 0
}
}
total_gain = 0.0
tables_needing_sharding = []
for analysis in analyses:
if analysis.recommended_sharding:
tables_needing_sharding.append(analysis)
total_gain += analysis.estimated_performance_gain
recommendation = {
'table_name': analysis.table_name,
'current_size_mb': analysis.data_size_mb + analysis.index_size_mb,
'current_rows': analysis.row_count,
'strategy': analysis.recommended_strategy.value if analysis.recommended_strategy else None,
'shard_count': analysis.recommended_shard_count,
'shard_key_candidates': analysis.hot_columns,
'expected_performance_gain': f"{analysis.estimated_performance_gain:.1%}",
'priority': self._calculate_priority(analysis),
'implementation_complexity': self._estimate_complexity(analysis)
}
plan['sharding_recommendations'].append(recommendation)
plan['tables_need_sharding'] = len(tables_needing_sharding)
plan['estimated_total_performance_gain'] = total_gain / len(analyses) if analyses else 0
# 推荐数据库分片数
if len(tables_needing_sharding) > 5:
plan['recommended_shard_databases'] = min(8, len(tables_needing_sharding) // 3)
# 按优先级排序
plan['sharding_recommendations'].sort(key=lambda x: x['priority'], reverse=True)
plan['implementation_priority'] = [r['table_name'] for r in plan['sharding_recommendations']]
# 估算成本
plan['estimated_costs'] = self._estimate_implementation_costs(tables_needing_sharding)
return plan
def _calculate_priority(self, analysis: TableAnalysis) -> float:
"""计算实施优先级"""
# 基于数据量、性能提升和增长率计算优先级
size_score = min(1.0, (analysis.data_size_mb + analysis.index_size_mb) / 10240) # 10GB为满分
performance_score = analysis.estimated_performance_gain
growth_score = min(1.0, analysis.growth_rate_per_month * 10) # 10%增长率为满分
return (size_score * 0.4 + performance_score * 0.4 + growth_score * 0.2) * 100
def _estimate_complexity(self, analysis: TableAnalysis) -> str:
"""估算实施复杂度"""
complexity_score = 0
# 基于查询模式复杂度
if len(analysis.query_patterns) > 5:
complexity_score += 2
elif len(analysis.query_patterns) > 3:
complexity_score += 1
# 基于分片数量
if analysis.recommended_shard_count > 64:
complexity_score += 2
elif analysis.recommended_shard_count > 16:
complexity_score += 1
# 基于热点列数量
if len(analysis.hot_columns) > 5:
complexity_score += 1
if complexity_score >= 4:
return "高"
elif complexity_score >= 2:
return "中"
else:
return "低"
def _estimate_implementation_costs(self, analyses: List[TableAnalysis]) -> Dict:
"""估算实施成本"""
total_tables = len(analyses)
# 开发时间(天)
development_days = total_tables * 3 + 10 # 每表3天 + 基础架构10天
# 迁移时间(小时)
migration_hours = sum(max(2, (a.data_size_mb + a.index_size_mb) / 1024) for a in analyses)
# 额外服务器数量
max_shards = max((a.recommended_shard_count for a in analyses), default=1)
additional_servers = max(2, max_shards // 4) # 每4个分片一台服务器
return {
'development_days': int(development_days),
'migration_hours': int(migration_hours),
'additional_servers': additional_servers
}
def export_report(self, plan: Dict, filename: str = None) -> str:
"""导出分析报告"""
if not filename:
filename = f"sharding_analysis_{self.database}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(filename, 'w', encoding='utf-8') as f:
json.dump(plan, f, ensure_ascii=False, indent=2)
return filename
def print_summary(self, plan: Dict):
"""打印分析摘要"""
print(f"\n=== MySQL分库分表分析报告 ===")
print(f"数据库: {plan['database']}")
print(f"分析时间: {plan['timestamp']}")
print(f"总表数: {plan['total_tables']}")
print(f"需要分片的表: {plan['tables_need_sharding']}")
print(f"预估总体性能提升: {plan['estimated_total_performance_gain']:.1%}")
print(f"推荐分库数: {plan['recommended_shard_databases']}")
print(f"\n=== 实施成本估算 ===")
costs = plan['estimated_costs']
print(f"开发时间: {costs['development_days']} 天")
print(f"迁移时间: {costs['migration_hours']} 小时")
print(f"额外服务器: {costs['additional_servers']} 台")
print(f"\n=== 优先级排序 ===")
for i, table_name in enumerate(plan['implementation_priority'][:5], 1):
rec = next(r for r in plan['sharding_recommendations'] if r['table_name'] == table_name)
print(f"{i}. {table_name} - 策略: {rec['strategy']}, 分片数: {rec['shard_count']}, 优先级: {rec['priority']:.1f}")
# 使用示例
if __name__ == '__main__':
import sys
if len(sys.argv) < 6:
print("用法: python sharding_analyzer.py <host> <port> <user> <password> <database> [table_name]")
sys.exit(1)
host = sys.argv[1]
port = int(sys.argv[2])
user = sys.argv[3]
password = sys.argv[4]
database = sys.argv[5]
table_name = sys.argv[6] if len(sys.argv) > 6 else None
analyzer = MySQLShardingAnalyzer(host, port, user, password, database)
if not analyzer.connect():
sys.exit(1)
try:
if table_name:
# 分析单个表
analysis = analyzer.analyze_table(table_name)
print(f"\n表 {table_name} 分析结果:")
print(f"行数: {analysis.row_count:,}")
print(f"数据大小: {analysis.data_size_mb:.2f} MB")
print(f"索引大小: {analysis.index_size_mb:.2f} MB")
print(f"月增长率: {analysis.growth_rate_per_month:.1%}")
print(f"推荐分片: {'是' if analysis.recommended_sharding else '否'}")
if analysis.recommended_sharding:
print(f"推荐策略: {analysis.recommended_strategy.value}")
print(f"推荐分片数: {analysis.recommended_shard_count}")
print(f"预估性能提升: {analysis.estimated_performance_gain:.1%}")
print(f"热点列: {', '.join(analysis.hot_columns)}")
else:
# 分析整个数据库
analyses = analyzer.analyze_database()
plan = analyzer.generate_sharding_plan(analyses)
analyzer.print_summary(plan)
# 导出报告
report_file = analyzer.export_report(plan)
print(f"\n详细报告已导出到: {report_file}")
finally:
analyzer.disconnect()
14.2 水平分片策略
14.2.1 范围分片(Range Sharding)
范围分片根据分片键的值范围将数据分配到不同的分片中。
优点: - 范围查询效率高 - 数据分布相对均匀 - 易于理解和实现
缺点: - 可能出现热点问题 - 需要预先规划范围 - 扩容相对复杂
-- 范围分片示例:按用户ID分片
-- 分片1:user_id 1-1000000
-- 分片2:user_id 1000001-2000000
-- 分片3:user_id 2000001-3000000
-- 创建分片表结构
CREATE TABLE users_shard_1 (
user_id BIGINT PRIMARY KEY,
username VARCHAR(50) NOT NULL,
email VARCHAR(100),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT chk_user_id_range CHECK (user_id BETWEEN 1 AND 1000000)
) ENGINE=InnoDB;
CREATE TABLE users_shard_2 (
user_id BIGINT PRIMARY KEY,
username VARCHAR(50) NOT NULL,
email VARCHAR(100),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT chk_user_id_range CHECK (user_id BETWEEN 1000001 AND 2000000)
) ENGINE=InnoDB;
CREATE TABLE users_shard_3 (
user_id BIGINT PRIMARY KEY,
username VARCHAR(50) NOT NULL,
email VARCHAR(100),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT chk_user_id_range CHECK (user_id BETWEEN 2000001 AND 3000000)
) ENGINE=InnoDB;
-- 范围分片路由逻辑
DELIMITER //
CREATE FUNCTION get_user_shard_by_range(user_id BIGINT)
RETURNS VARCHAR(20)
READS SQL DATA
DETERMINISTIC
BEGIN
DECLARE shard_name VARCHAR(20);
CASE
WHEN user_id BETWEEN 1 AND 1000000 THEN
SET shard_name = 'users_shard_1';
WHEN user_id BETWEEN 1000001 AND 2000000 THEN
SET shard_name = 'users_shard_2';
WHEN user_id BETWEEN 2000001 AND 3000000 THEN
SET shard_name = 'users_shard_3';
ELSE
SET shard_name = 'users_shard_1'; -- 默认分片
END CASE;
RETURN shard_name;
END//
DELIMITER ;
-- 时间范围分片示例:按月分片
CREATE TABLE orders_202401 (
order_id BIGINT PRIMARY KEY,
user_id BIGINT NOT NULL,
amount DECIMAL(10,2),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_user_id (user_id),
INDEX idx_created_at (created_at),
CONSTRAINT chk_date_range CHECK (created_at >= '2024-01-01' AND created_at < '2024-02-01')
) ENGINE=InnoDB;
CREATE TABLE orders_202402 (
order_id BIGINT PRIMARY KEY,
user_id BIGINT NOT NULL,
amount DECIMAL(10,2),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_user_id (user_id),
INDEX idx_created_at (created_at),
CONSTRAINT chk_date_range CHECK (created_at >= '2024-02-01' AND created_at < '2024-03-01')
) ENGINE=InnoDB;
14.2.2 哈希分片(Hash Sharding)
哈希分片通过对分片键进行哈希运算来确定数据分片。
优点: - 数据分布均匀 - 避免热点问题 - 实现简单
缺点: - 范围查询效率低 - 扩容需要重新分布数据 - 跨分片查询复杂
-- 哈希分片示例:按用户ID哈希分片
-- 使用CRC32哈希函数,分4个分片
-- 创建哈希分片表
CREATE TABLE users_hash_0 (
user_id BIGINT PRIMARY KEY,
username VARCHAR(50) NOT NULL,
email VARCHAR(100),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB;
CREATE TABLE users_hash_1 (
user_id BIGINT PRIMARY KEY,
username VARCHAR(50) NOT NULL,
email VARCHAR(100),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB;
CREATE TABLE users_hash_2 (
user_id BIGINT PRIMARY KEY,
username VARCHAR(50) NOT NULL,
email VARCHAR(100),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB;
CREATE TABLE users_hash_3 (
user_id BIGINT PRIMARY KEY,
username VARCHAR(50) NOT NULL,
email VARCHAR(100),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB;
-- 哈希分片路由函数
DELIMITER //
CREATE FUNCTION get_user_shard_by_hash(user_id BIGINT)
RETURNS VARCHAR(20)
READS SQL DATA
DETERMINISTIC
BEGIN
DECLARE shard_index INT;
DECLARE shard_name VARCHAR(20);
-- 使用CRC32哈希并取模
SET shard_index = CRC32(user_id) % 4;
SET shard_name = CONCAT('users_hash_', shard_index);
RETURN shard_name;
END//
DELIMITER ;
-- 一致性哈希分片(更适合动态扩容)
DELIMITER //
CREATE FUNCTION get_consistent_hash_shard(shard_key VARCHAR(255), shard_count INT)
RETURNS INT
READS SQL DATA
DETERMINISTIC
BEGIN
DECLARE hash_value BIGINT;
DECLARE shard_index INT;
-- 使用SHA1哈希的前8位作为哈希值
SET hash_value = CONV(SUBSTRING(SHA1(shard_key), 1, 8), 16, 10);
SET shard_index = hash_value % shard_count;
RETURN shard_index;
END//
DELIMITER ;
-- 测试哈希分布
SELECT
get_user_shard_by_hash(1) as shard_1,
get_user_shard_by_hash(2) as shard_2,
get_user_shard_by_hash(3) as shard_3,
get_user_shard_by_hash(4) as shard_4;
14.2.3 目录分片(Directory Sharding)
目录分片使用独立的查找表来维护分片键与分片的映射关系。
优点: - 灵活性高 - 支持复杂的分片逻辑 - 易于重新平衡
缺点: - 需要额外的查找开销 - 目录表成为潜在瓶颈 - 实现复杂度高
-- 目录分片示例
-- 创建分片目录表
CREATE TABLE shard_directory (
shard_key VARCHAR(255) PRIMARY KEY,
shard_name VARCHAR(50) NOT NULL,
shard_database VARCHAR(50) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_shard_name (shard_name)
) ENGINE=InnoDB;
-- 插入分片映射
INSERT INTO shard_directory (shard_key, shard_name, shard_database) VALUES
('user_1', 'users_shard_1', 'db_shard_1'),
('user_2', 'users_shard_2', 'db_shard_2'),
('user_3', 'users_shard_1', 'db_shard_1'),
('user_4', 'users_shard_2', 'db_shard_2');
-- 分片查找函数
DELIMITER //
CREATE FUNCTION get_shard_by_directory(lookup_key VARCHAR(255))
RETURNS VARCHAR(100)
READS SQL DATA
BEGIN
DECLARE shard_info VARCHAR(100);
SELECT CONCAT(shard_database, '.', shard_name)
INTO shard_info
FROM shard_directory
WHERE shard_key = lookup_key;
RETURN IFNULL(shard_info, 'default.users');
END//
DELIMITER ;
-- 动态分片重平衡存储过程
DELIMITER //
CREATE PROCEDURE rebalance_shards()
BEGIN
DECLARE done INT DEFAULT FALSE;
DECLARE v_shard_key VARCHAR(255);
DECLARE v_current_shard VARCHAR(50);
DECLARE v_new_shard VARCHAR(50);
DECLARE cur CURSOR FOR
SELECT shard_key, shard_name
FROM shard_directory;
DECLARE CONTINUE HANDLER FOR NOT FOUND SET done = TRUE;
OPEN cur;
read_loop: LOOP
FETCH cur INTO v_shard_key, v_current_shard;
IF done THEN
LEAVE read_loop;
END IF;
-- 计算新的分片(基于负载均衡算法)
SET v_new_shard = calculate_optimal_shard(v_shard_key);
-- 如果需要迁移
IF v_current_shard != v_new_shard THEN
UPDATE shard_directory
SET shard_name = v_new_shard,
updated_at = CURRENT_TIMESTAMP
WHERE shard_key = v_shard_key;
END IF;
END LOOP;
CLOSE cur;
END//
DELIMITER ;
14.7 性能监控与优化
14.7.1 分片性能监控
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MySQL分库分表性能监控系统
实时监控分片性能指标并提供优化建议
"""
import json
import time
import threading
import psutil
from typing import Dict, List, Any, Optional, Tuple
from enum import Enum
from dataclasses import dataclass, asdict
import pymysql
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
from datetime import datetime, timedelta
import statistics
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict, deque
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
class AlertLevel(Enum):
INFO = "info"
WARNING = "warning"
CRITICAL = "critical"
class MetricType(Enum):
QPS = "qps" # 每秒查询数
TPS = "tps" # 每秒事务数
RESPONSE_TIME = "response_time" # 响应时间
CONNECTION_COUNT = "connection_count" # 连接数
CPU_USAGE = "cpu_usage" # CPU使用率
MEMORY_USAGE = "memory_usage" # 内存使用率
DISK_IO = "disk_io" # 磁盘IO
SLOW_QUERY = "slow_query" # 慢查询
LOCK_WAIT = "lock_wait" # 锁等待
REPLICATION_LAG = "replication_lag" # 复制延迟
@dataclass
class MetricData:
timestamp: datetime
shard_id: str
metric_type: MetricType
value: float
unit: str
tags: Dict[str, str] = None
@dataclass
class AlertRule:
rule_id: str
metric_type: MetricType
threshold: float
operator: str # >, <, >=, <=, ==
duration: int # 持续时间(秒)
level: AlertLevel
enabled: bool = True
description: str = ""
@dataclass
class Alert:
alert_id: str
rule_id: str
shard_id: str
level: AlertLevel
message: str
timestamp: datetime
resolved: bool = False
resolved_timestamp: Optional[datetime] = None
class ShardPerformanceMonitor:
"""分片性能监控器"""
def __init__(self, config: Dict):
self.config = config
self.logger = self._setup_logger()
# 监控数据存储
self.metrics_buffer: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
self.alerts: List[Alert] = []
self.alert_rules: Dict[str, AlertRule] = {}
# 分片连接池
self.shard_connections: Dict[str, pymysql.Connection] = {}
# 监控线程
self.monitoring_active = True
self.monitor_threads = []
# 初始化默认告警规则
self._init_default_alert_rules()
# 启动监控线程
self._start_monitoring_threads()
def _setup_logger(self) -> logging.Logger:
"""设置日志"""
logger = logging.getLogger('ShardPerformanceMonitor')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def _init_default_alert_rules(self):
"""初始化默认告警规则"""
default_rules = [
AlertRule(
rule_id="high_qps",
metric_type=MetricType.QPS,
threshold=1000,
operator=">",
duration=300,
level=AlertLevel.WARNING,
description="QPS过高"
),
AlertRule(
rule_id="high_response_time",
metric_type=MetricType.RESPONSE_TIME,
threshold=1000, # 1秒
operator=">",
duration=60,
level=AlertLevel.CRITICAL,
description="响应时间过长"
),
AlertRule(
rule_id="high_cpu_usage",
metric_type=MetricType.CPU_USAGE,
threshold=80,
operator=">",
duration=300,
level=AlertLevel.WARNING,
description="CPU使用率过高"
),
AlertRule(
rule_id="high_memory_usage",
metric_type=MetricType.MEMORY_USAGE,
threshold=85,
operator=">",
duration=300,
level=AlertLevel.WARNING,
description="内存使用率过高"
),
AlertRule(
rule_id="too_many_connections",
metric_type=MetricType.CONNECTION_COUNT,
threshold=500,
operator=">",
duration=60,
level=AlertLevel.CRITICAL,
description="连接数过多"
),
AlertRule(
rule_id="slow_query_detected",
metric_type=MetricType.SLOW_QUERY,
threshold=10,
operator=">",
duration=60,
level=AlertLevel.WARNING,
description="慢查询过多"
),
AlertRule(
rule_id="replication_lag",
metric_type=MetricType.REPLICATION_LAG,
threshold=5, # 5秒
operator=">",
duration=120,
level=AlertLevel.CRITICAL,
description="复制延迟过高"
)
]
for rule in default_rules:
self.alert_rules[rule.rule_id] = rule
def _start_monitoring_threads(self):
"""启动监控线程"""
# 指标收集线程
metrics_thread = threading.Thread(
target=self._collect_metrics_loop,
daemon=True
)
metrics_thread.start()
self.monitor_threads.append(metrics_thread)
# 告警检查线程
alert_thread = threading.Thread(
target=self._check_alerts_loop,
daemon=True
)
alert_thread.start()
self.monitor_threads.append(alert_thread)
# 数据清理线程
cleanup_thread = threading.Thread(
target=self._cleanup_old_data_loop,
daemon=True
)
cleanup_thread.start()
self.monitor_threads.append(cleanup_thread)
self.logger.info("监控线程已启动")
def add_shard(self, shard_id: str, host: str, port: int,
database: str, username: str, password: str):
"""添加分片监控"""
try:
conn = pymysql.connect(
host=host,
port=port,
user=username,
password=password,
database=database,
autocommit=True
)
self.shard_connections[shard_id] = conn
self.logger.info(f"添加分片监控: {shard_id}")
except Exception as e:
self.logger.error(f"添加分片监控失败 {shard_id}: {e}")
def remove_shard(self, shard_id: str):
"""移除分片监控"""
if shard_id in self.shard_connections:
try:
self.shard_connections[shard_id].close()
del self.shard_connections[shard_id]
# 清理相关数据
keys_to_remove = [key for key in self.metrics_buffer.keys()
if key.startswith(f"{shard_id}_")]
for key in keys_to_remove:
del self.metrics_buffer[key]
self.logger.info(f"移除分片监控: {shard_id}")
except Exception as e:
self.logger.error(f"移除分片监控失败 {shard_id}: {e}")
def _collect_metrics_loop(self):
"""指标收集循环"""
while self.monitoring_active:
try:
self._collect_all_metrics()
time.sleep(self.config.get('metrics_interval', 30)) # 默认30秒
except Exception as e:
self.logger.error(f"指标收集错误: {e}")
def _collect_all_metrics(self):
"""收集所有分片的指标"""
with ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for shard_id in self.shard_connections.keys():
future = executor.submit(self._collect_shard_metrics, shard_id)
futures.append(future)
for future in as_completed(futures):
try:
future.result()
except Exception as e:
self.logger.error(f"收集分片指标失败: {e}")
def _collect_shard_metrics(self, shard_id: str):
"""收集单个分片的指标"""
if shard_id not in self.shard_connections:
return
conn = self.shard_connections[shard_id]
timestamp = datetime.now()
try:
# 检查连接是否有效
conn.ping(reconnect=True)
with conn.cursor(pymysql.cursors.DictCursor) as cursor:
# 收集MySQL状态指标
self._collect_mysql_status(cursor, shard_id, timestamp)
# 收集性能指标
self._collect_performance_metrics(cursor, shard_id, timestamp)
# 收集慢查询指标
self._collect_slow_query_metrics(cursor, shard_id, timestamp)
# 收集复制延迟(如果是从库)
self._collect_replication_metrics(cursor, shard_id, timestamp)
# 收集系统资源指标
self._collect_system_metrics(shard_id, timestamp)
except Exception as e:
self.logger.error(f"收集分片 {shard_id} 指标失败: {e}")
def _collect_mysql_status(self, cursor, shard_id: str, timestamp: datetime):
"""收集MySQL状态指标"""
# 获取全局状态
cursor.execute("SHOW GLOBAL STATUS")
status_data = {row['Variable_name']: row['Value'] for row in cursor.fetchall()}
# QPS计算
questions = int(status_data.get('Questions', 0))
uptime = int(status_data.get('Uptime', 1))
qps = questions / uptime if uptime > 0 else 0
self._store_metric(MetricData(
timestamp=timestamp,
shard_id=shard_id,
metric_type=MetricType.QPS,
value=qps,
unit="queries/sec"
))
# TPS计算
com_commit = int(status_data.get('Com_commit', 0))
com_rollback = int(status_data.get('Com_rollback', 0))
tps = (com_commit + com_rollback) / uptime if uptime > 0 else 0
self._store_metric(MetricData(
timestamp=timestamp,
shard_id=shard_id,
metric_type=MetricType.TPS,
value=tps,
unit="transactions/sec"
))
# 连接数
threads_connected = int(status_data.get('Threads_connected', 0))
self._store_metric(MetricData(
timestamp=timestamp,
shard_id=shard_id,
metric_type=MetricType.CONNECTION_COUNT,
value=threads_connected,
unit="connections"
))
# 锁等待
innodb_row_lock_waits = int(status_data.get('Innodb_row_lock_waits', 0))
self._store_metric(MetricData(
timestamp=timestamp,
shard_id=shard_id,
metric_type=MetricType.LOCK_WAIT,
value=innodb_row_lock_waits,
unit="waits"
))
def _collect_performance_metrics(self, cursor, shard_id: str, timestamp: datetime):
"""收集性能指标"""
# 平均查询时间(通过performance_schema)
try:
cursor.execute("""
SELECT AVG(TIMER_WAIT/1000000000) as avg_response_time
FROM performance_schema.events_statements_summary_global_by_event_name
WHERE EVENT_NAME LIKE 'statement/sql/%'
AND COUNT_STAR > 0
""")
result = cursor.fetchone()
if result and result['avg_response_time']:
avg_response_time = float(result['avg_response_time']) * 1000 # 转换为毫秒
self._store_metric(MetricData(
timestamp=timestamp,
shard_id=shard_id,
metric_type=MetricType.RESPONSE_TIME,
value=avg_response_time,
unit="ms"
))
except Exception as e:
self.logger.debug(f"收集响应时间指标失败: {e}")
def _collect_slow_query_metrics(self, cursor, shard_id: str, timestamp: datetime):
"""收集慢查询指标"""
try:
# 获取慢查询数量
cursor.execute("SHOW GLOBAL STATUS LIKE 'Slow_queries'")
result = cursor.fetchone()
if result:
slow_queries = int(result['Value'])
self._store_metric(MetricData(
timestamp=timestamp,
shard_id=shard_id,
metric_type=MetricType.SLOW_QUERY,
value=slow_queries,
unit="queries"
))
except Exception as e:
self.logger.debug(f"收集慢查询指标失败: {e}")
def _collect_replication_metrics(self, cursor, shard_id: str, timestamp: datetime):
"""收集复制延迟指标"""
try:
cursor.execute("SHOW SLAVE STATUS")
result = cursor.fetchone()
if result and result.get('Seconds_Behind_Master') is not None:
lag = int(result['Seconds_Behind_Master'])
self._store_metric(MetricData(
timestamp=timestamp,
shard_id=shard_id,
metric_type=MetricType.REPLICATION_LAG,
value=lag,
unit="seconds"
))
except Exception as e:
self.logger.debug(f"收集复制延迟指标失败: {e}")
def _collect_system_metrics(self, shard_id: str, timestamp: datetime):
"""收集系统资源指标"""
try:
# CPU使用率
cpu_percent = psutil.cpu_percent(interval=1)
self._store_metric(MetricData(
timestamp=timestamp,
shard_id=shard_id,
metric_type=MetricType.CPU_USAGE,
value=cpu_percent,
unit="percent"
))
# 内存使用率
memory = psutil.virtual_memory()
memory_percent = memory.percent
self._store_metric(MetricData(
timestamp=timestamp,
shard_id=shard_id,
metric_type=MetricType.MEMORY_USAGE,
value=memory_percent,
unit="percent"
))
# 磁盘IO
disk_io = psutil.disk_io_counters()
if disk_io:
io_util = (disk_io.read_bytes + disk_io.write_bytes) / (1024 * 1024) # MB
self._store_metric(MetricData(
timestamp=timestamp,
shard_id=shard_id,
metric_type=MetricType.DISK_IO,
value=io_util,
unit="MB"
))
except Exception as e:
self.logger.debug(f"收集系统指标失败: {e}")
def _store_metric(self, metric: MetricData):
"""存储指标数据"""
key = f"{metric.shard_id}_{metric.metric_type.value}"
self.metrics_buffer[key].append(metric)
def _check_alerts_loop(self):
"""告警检查循环"""
while self.monitoring_active:
try:
self._check_all_alerts()
time.sleep(self.config.get('alert_check_interval', 60)) # 默认1分钟
except Exception as e:
self.logger.error(f"告警检查错误: {e}")
def _check_all_alerts(self):
"""检查所有告警规则"""
for rule_id, rule in self.alert_rules.items():
if not rule.enabled:
continue
for shard_id in self.shard_connections.keys():
self._check_alert_rule(shard_id, rule)
def _check_alert_rule(self, shard_id: str, rule: AlertRule):
"""检查单个告警规则"""
key = f"{shard_id}_{rule.metric_type.value}"
if key not in self.metrics_buffer or not self.metrics_buffer[key]:
return
# 获取最近的指标数据
recent_metrics = list(self.metrics_buffer[key])
if not recent_metrics:
return
# 检查是否满足告警条件
now = datetime.now()
duration_threshold = now - timedelta(seconds=rule.duration)
# 过滤出时间范围内的指标
relevant_metrics = [
m for m in recent_metrics
if m.timestamp >= duration_threshold
]
if not relevant_metrics:
return
# 检查是否所有指标都满足条件
all_match = True
for metric in relevant_metrics:
if not self._evaluate_condition(metric.value, rule.threshold, rule.operator):
all_match = False
break
if all_match:
# 触发告警
self._trigger_alert(shard_id, rule, relevant_metrics[-1].value)
def _evaluate_condition(self, value: float, threshold: float, operator: str) -> bool:
"""评估告警条件"""
if operator == ">":
return value > threshold
elif operator == "<":
return value < threshold
elif operator == ">=":
return value >= threshold
elif operator == "<=":
return value <= threshold
elif operator == "==":
return value == threshold
else:
return False
def _trigger_alert(self, shard_id: str, rule: AlertRule, current_value: float):
"""触发告警"""
# 检查是否已有相同的未解决告警
existing_alert = None
for alert in self.alerts:
if (alert.rule_id == rule.rule_id and
alert.shard_id == shard_id and
not alert.resolved):
existing_alert = alert
break
if existing_alert:
return # 已有未解决的告警,不重复发送
# 创建新告警
alert_id = f"{rule.rule_id}_{shard_id}_{int(time.time())}"
message = f"分片 {shard_id} {rule.description}: 当前值 {current_value:.2f} {rule.operator} 阈值 {rule.threshold}"
alert = Alert(
alert_id=alert_id,
rule_id=rule.rule_id,
shard_id=shard_id,
level=rule.level,
message=message,
timestamp=datetime.now()
)
self.alerts.append(alert)
# 发送告警通知
self._send_alert_notification(alert)
self.logger.warning(f"触发告警: {message}")
def _send_alert_notification(self, alert: Alert):
"""发送告警通知"""
try:
# 邮件通知
if self.config.get('email_alerts_enabled', False):
self._send_email_alert(alert)
# Webhook通知
if self.config.get('webhook_alerts_enabled', False):
self._send_webhook_alert(alert)
# 短信通知(高级别告警)
if (alert.level == AlertLevel.CRITICAL and
self.config.get('sms_alerts_enabled', False)):
self._send_sms_alert(alert)
except Exception as e:
self.logger.error(f"发送告警通知失败: {e}")
def _send_email_alert(self, alert: Alert):
"""发送邮件告警"""
smtp_config = self.config.get('smtp_config', {})
if not smtp_config:
return
msg = MIMEMultipart()
msg['From'] = smtp_config['from_email']
msg['To'] = ', '.join(smtp_config['to_emails'])
msg['Subject'] = f"[{alert.level.value.upper()}] MySQL分片告警 - {alert.shard_id}"
body = f"""
告警详情:
分片ID: {alert.shard_id}
告警级别: {alert.level.value.upper()}
告警时间: {alert.timestamp.strftime('%Y-%m-%d %H:%M:%S')}
告警消息: {alert.message}
请及时处理!
"""
msg.attach(MIMEText(body, 'plain', 'utf-8'))
try:
server = smtplib.SMTP(smtp_config['smtp_server'], smtp_config['smtp_port'])
if smtp_config.get('use_tls', False):
server.starttls()
if smtp_config.get('username') and smtp_config.get('password'):
server.login(smtp_config['username'], smtp_config['password'])
server.send_message(msg)
server.quit()
self.logger.info(f"邮件告警已发送: {alert.alert_id}")
except Exception as e:
self.logger.error(f"发送邮件告警失败: {e}")
def _send_webhook_alert(self, alert: Alert):
"""发送Webhook告警"""
# 这里可以集成Slack、钉钉、企业微信等
webhook_url = self.config.get('webhook_url')
if not webhook_url:
return
import requests
payload = {
'alert_id': alert.alert_id,
'shard_id': alert.shard_id,
'level': alert.level.value,
'message': alert.message,
'timestamp': alert.timestamp.isoformat()
}
try:
response = requests.post(webhook_url, json=payload, timeout=10)
response.raise_for_status()
self.logger.info(f"Webhook告警已发送: {alert.alert_id}")
except Exception as e:
self.logger.error(f"发送Webhook告警失败: {e}")
def _send_sms_alert(self, alert: Alert):
"""发送短信告警"""
# 这里可以集成短信服务提供商API
self.logger.info(f"短信告警: {alert.message}")
def _cleanup_old_data_loop(self):
"""清理旧数据循环"""
while self.monitoring_active:
try:
self._cleanup_old_metrics()
self._cleanup_old_alerts()
time.sleep(self.config.get('cleanup_interval', 3600)) # 默认1小时
except Exception as e:
self.logger.error(f"数据清理错误: {e}")
def _cleanup_old_metrics(self):
"""清理旧的指标数据"""
retention_hours = self.config.get('metrics_retention_hours', 24)
cutoff_time = datetime.now() - timedelta(hours=retention_hours)
for key, metrics_deque in self.metrics_buffer.items():
# 移除过期的指标
while metrics_deque and metrics_deque[0].timestamp < cutoff_time:
metrics_deque.popleft()
def _cleanup_old_alerts(self):
"""清理旧的告警数据"""
retention_days = self.config.get('alerts_retention_days', 7)
cutoff_time = datetime.now() - timedelta(days=retention_days)
# 移除过期的已解决告警
self.alerts = [
alert for alert in self.alerts
if not (alert.resolved and alert.resolved_timestamp and
alert.resolved_timestamp < cutoff_time)
]
def get_metrics(self, shard_id: str, metric_type: MetricType,
start_time: datetime = None, end_time: datetime = None) -> List[MetricData]:
"""获取指标数据"""
key = f"{shard_id}_{metric_type.value}"
if key not in self.metrics_buffer:
return []
metrics = list(self.metrics_buffer[key])
# 时间过滤
if start_time:
metrics = [m for m in metrics if m.timestamp >= start_time]
if end_time:
metrics = [m for m in metrics if m.timestamp <= end_time]
return metrics
def get_current_metrics(self, shard_id: str) -> Dict[str, Any]:
"""获取当前指标快照"""
current_metrics = {}
for metric_type in MetricType:
key = f"{shard_id}_{metric_type.value}"
if key in self.metrics_buffer and self.metrics_buffer[key]:
latest_metric = self.metrics_buffer[key][-1]
current_metrics[metric_type.value] = {
'value': latest_metric.value,
'unit': latest_metric.unit,
'timestamp': latest_metric.timestamp.isoformat()
}
return current_metrics
def get_alerts(self, shard_id: str = None, resolved: bool = None) -> List[Alert]:
"""获取告警列表"""
alerts = self.alerts
if shard_id:
alerts = [a for a in alerts if a.shard_id == shard_id]
if resolved is not None:
alerts = [a for a in alerts if a.resolved == resolved]
return alerts
def resolve_alert(self, alert_id: str) -> bool:
"""解决告警"""
for alert in self.alerts:
if alert.alert_id == alert_id:
alert.resolved = True
alert.resolved_timestamp = datetime.now()
self.logger.info(f"告警已解决: {alert_id}")
return True
return False
def add_alert_rule(self, rule: AlertRule):
"""添加告警规则"""
self.alert_rules[rule.rule_id] = rule
self.logger.info(f"添加告警规则: {rule.rule_id}")
def remove_alert_rule(self, rule_id: str) -> bool:
"""移除告警规则"""
if rule_id in self.alert_rules:
del self.alert_rules[rule_id]
self.logger.info(f"移除告警规则: {rule_id}")
return True
return False
def generate_performance_report(self, shard_id: str, hours: int = 24) -> Dict[str, Any]:
"""生成性能报告"""
end_time = datetime.now()
start_time = end_time - timedelta(hours=hours)
report = {
'shard_id': shard_id,
'report_period': {
'start': start_time.isoformat(),
'end': end_time.isoformat(),
'duration_hours': hours
},
'metrics_summary': {},
'alerts_summary': {},
'recommendations': []
}
# 指标汇总
for metric_type in MetricType:
metrics = self.get_metrics(shard_id, metric_type, start_time, end_time)
if metrics:
values = [m.value for m in metrics]
report['metrics_summary'][metric_type.value] = {
'count': len(values),
'avg': statistics.mean(values),
'min': min(values),
'max': max(values),
'median': statistics.median(values),
'unit': metrics[0].unit
}
# 告警汇总
period_alerts = [
a for a in self.alerts
if a.shard_id == shard_id and start_time <= a.timestamp <= end_time
]
alert_counts = defaultdict(int)
for alert in period_alerts:
alert_counts[alert.level.value] += 1
report['alerts_summary'] = {
'total': len(period_alerts),
'by_level': dict(alert_counts),
'resolved': len([a for a in period_alerts if a.resolved]),
'unresolved': len([a for a in period_alerts if not a.resolved])
}
# 生成优化建议
report['recommendations'] = self._generate_recommendations(shard_id, report['metrics_summary'])
return report
def _generate_recommendations(self, shard_id: str, metrics_summary: Dict) -> List[str]:
"""生成优化建议"""
recommendations = []
# QPS建议
if 'qps' in metrics_summary:
avg_qps = metrics_summary['qps']['avg']
max_qps = metrics_summary['qps']['max']
if avg_qps > 800:
recommendations.append("QPS较高,建议考虑读写分离或增加缓存")
if max_qps > 1500:
recommendations.append("峰值QPS过高,建议进行分片扩容")
# 响应时间建议
if 'response_time' in metrics_summary:
avg_response_time = metrics_summary['response_time']['avg']
if avg_response_time > 500: # 500ms
recommendations.append("平均响应时间较长,建议优化慢查询")
# CPU使用率建议
if 'cpu_usage' in metrics_summary:
avg_cpu = metrics_summary['cpu_usage']['avg']
max_cpu = metrics_summary['cpu_usage']['max']
if avg_cpu > 70:
recommendations.append("CPU使用率较高,建议优化查询或增加服务器资源")
if max_cpu > 90:
recommendations.append("CPU峰值使用率过高,存在性能瓶颈风险")
# 内存使用率建议
if 'memory_usage' in metrics_summary:
avg_memory = metrics_summary['memory_usage']['avg']
if avg_memory > 80:
recommendations.append("内存使用率较高,建议调整缓存配置或增加内存")
# 连接数建议
if 'connection_count' in metrics_summary:
avg_connections = metrics_summary['connection_count']['avg']
max_connections = metrics_summary['connection_count']['max']
if avg_connections > 300:
recommendations.append("连接数较多,建议使用连接池优化")
if max_connections > 500:
recommendations.append("峰值连接数过高,可能存在连接泄漏")
# 慢查询建议
if 'slow_query' in metrics_summary:
avg_slow_queries = metrics_summary['slow_query']['avg']
if avg_slow_queries > 5:
recommendations.append("慢查询较多,建议分析并优化SQL语句")
# 复制延迟建议
if 'replication_lag' in metrics_summary:
avg_lag = metrics_summary['replication_lag']['avg']
if avg_lag > 2:
recommendations.append("复制延迟较高,建议检查网络和主从配置")
return recommendations
def export_metrics_to_csv(self, shard_id: str, metric_type: MetricType,
filename: str, hours: int = 24):
"""导出指标数据到CSV"""
end_time = datetime.now()
start_time = end_time - timedelta(hours=hours)
metrics = self.get_metrics(shard_id, metric_type, start_time, end_time)
if not metrics:
self.logger.warning(f"没有找到指标数据: {shard_id} - {metric_type.value}")
return
# 转换为DataFrame
data = [
{
'timestamp': m.timestamp.isoformat(),
'shard_id': m.shard_id,
'metric_type': m.metric_type.value,
'value': m.value,
'unit': m.unit
}
for m in metrics
]
df = pd.DataFrame(data)
df.to_csv(filename, index=False)
self.logger.info(f"指标数据已导出到: {filename}")
def plot_metrics(self, shard_id: str, metric_types: List[MetricType],
hours: int = 24, save_path: str = None):
"""绘制指标图表"""
end_time = datetime.now()
start_time = end_time - timedelta(hours=hours)
fig, axes = plt.subplots(len(metric_types), 1, figsize=(12, 6 * len(metric_types)))
if len(metric_types) == 1:
axes = [axes]
for i, metric_type in enumerate(metric_types):
metrics = self.get_metrics(shard_id, metric_type, start_time, end_time)
if metrics:
timestamps = [m.timestamp for m in metrics]
values = [m.value for m in metrics]
unit = metrics[0].unit
axes[i].plot(timestamps, values, marker='o', markersize=2)
axes[i].set_title(f'{shard_id} - {metric_type.value.replace("_", " ").title()}')
axes[i].set_ylabel(f'Value ({unit})')
axes[i].grid(True, alpha=0.3)
# 格式化x轴时间显示
axes[i].tick_params(axis='x', rotation=45)
else:
axes[i].text(0.5, 0.5, 'No data available',
horizontalalignment='center', verticalalignment='center',
transform=axes[i].transAxes)
axes[i].set_title(f'{shard_id} - {metric_type.value.replace("_", " ").title()}')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
self.logger.info(f"图表已保存到: {save_path}")
else:
plt.show()
def stop_monitoring(self):
"""停止监控"""
self.monitoring_active = False
# 关闭数据库连接
for conn in self.shard_connections.values():
try:
conn.close()
except:
pass
self.logger.info("性能监控已停止")
14.7.2 查询优化策略
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MySQL分库分表查询优化器
提供智能查询路由和优化建议
"""
import re
import json
import time
from typing import Dict, List, Any, Optional, Tuple, Set
from enum import Enum
from dataclasses import dataclass
import pymysql
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
from datetime import datetime
import hashlib
from collections import defaultdict
class QueryType(Enum):
SELECT = "select"
INSERT = "insert"
UPDATE = "update"
DELETE = "delete"
DDL = "ddl"
class OptimizationLevel(Enum):
BASIC = "basic"
ADVANCED = "advanced"
AGGRESSIVE = "aggressive"
@dataclass
class QueryPlan:
query_id: str
original_query: str
optimized_query: str
target_shards: List[str]
execution_order: List[str]
estimated_cost: float
optimization_suggestions: List[str]
cache_key: Optional[str] = None
@dataclass
class QueryStats:
query_hash: str
execution_count: int
total_time: float
avg_time: float
min_time: float
max_time: float
error_count: int
last_executed: datetime
class ShardQueryOptimizer:
"""分片查询优化器"""
def __init__(self, config: Dict):
self.config = config
self.logger = self._setup_logger()
# 查询统计
self.query_stats: Dict[str, QueryStats] = {}
# 查询缓存
self.query_cache: Dict[str, Any] = {}
self.cache_ttl = config.get('cache_ttl', 300) # 5分钟
# 分片信息
self.shard_info: Dict[str, Dict] = {}
# SQL解析器
self.sql_patterns = self._init_sql_patterns()
# 优化规则
self.optimization_rules = self._init_optimization_rules()
def _setup_logger(self) -> logging.Logger:
"""设置日志"""
logger = logging.getLogger('ShardQueryOptimizer')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def _init_sql_patterns(self) -> Dict[str, re.Pattern]:
"""初始化SQL解析模式"""
return {
'select': re.compile(r'\bSELECT\b.*?\bFROM\b\s+(\w+)', re.IGNORECASE | re.DOTALL),
'insert': re.compile(r'\bINSERT\s+INTO\s+(\w+)', re.IGNORECASE),
'update': re.compile(r'\bUPDATE\s+(\w+)', re.IGNORECASE),
'delete': re.compile(r'\bDELETE\s+FROM\s+(\w+)', re.IGNORECASE),
'where': re.compile(r'\bWHERE\b(.+?)(?:\bORDER\s+BY\b|\bGROUP\s+BY\b|\bHAVING\b|\bLIMIT\b|$)', re.IGNORECASE | re.DOTALL),
'join': re.compile(r'\b(?:INNER\s+|LEFT\s+|RIGHT\s+|FULL\s+)?JOIN\s+(\w+)', re.IGNORECASE),
'order_by': re.compile(r'\bORDER\s+BY\b(.+?)(?:\bLIMIT\b|$)', re.IGNORECASE | re.DOTALL),
'group_by': re.compile(r'\bGROUP\s+BY\b(.+?)(?:\bHAVING\b|\bORDER\s+BY\b|\bLIMIT\b|$)', re.IGNORECASE | re.DOTALL),
'limit': re.compile(r'\bLIMIT\s+(\d+)(?:\s*,\s*(\d+))?', re.IGNORECASE)
}
def _init_optimization_rules(self) -> List[Dict]:
"""初始化优化规则"""
return [
{
'name': 'avoid_cross_shard_join',
'description': '避免跨分片JOIN',
'priority': 1,
'check_func': self._check_cross_shard_join
},
{
'name': 'use_shard_key_in_where',
'description': '在WHERE子句中使用分片键',
'priority': 2,
'check_func': self._check_shard_key_usage
},
{
'name': 'optimize_limit_offset',
'description': '优化LIMIT和OFFSET',
'priority': 3,
'check_func': self._check_limit_optimization
},
{
'name': 'use_covering_index',
'description': '使用覆盖索引',
'priority': 4,
'check_func': self._check_covering_index
},
{
'name': 'avoid_function_in_where',
'description': '避免在WHERE子句中使用函数',
'priority': 5,
'check_func': self._check_function_in_where
}
]
def add_shard_info(self, shard_id: str, info: Dict):
"""添加分片信息"""
self.shard_info[shard_id] = info
self.logger.info(f"添加分片信息: {shard_id}")
def parse_query(self, sql: str) -> Dict[str, Any]:
"""解析SQL查询"""
sql = sql.strip()
# 确定查询类型
query_type = self._detect_query_type(sql)
# 提取表名
tables = self._extract_tables(sql, query_type)
# 提取WHERE条件
where_conditions = self._extract_where_conditions(sql)
# 提取JOIN信息
joins = self._extract_joins(sql)
# 提取ORDER BY
order_by = self._extract_order_by(sql)
# 提取GROUP BY
group_by = self._extract_group_by(sql)
# 提取LIMIT
limit_info = self._extract_limit(sql)
return {
'query_type': query_type,
'tables': tables,
'where_conditions': where_conditions,
'joins': joins,
'order_by': order_by,
'group_by': group_by,
'limit': limit_info,
'original_sql': sql
}
def _detect_query_type(self, sql: str) -> QueryType:
"""检测查询类型"""
sql_upper = sql.upper().strip()
if sql_upper.startswith('SELECT'):
return QueryType.SELECT
elif sql_upper.startswith('INSERT'):
return QueryType.INSERT
elif sql_upper.startswith('UPDATE'):
return QueryType.UPDATE
elif sql_upper.startswith('DELETE'):
return QueryType.DELETE
else:
return QueryType.DDL
def _extract_tables(self, sql: str, query_type: QueryType) -> List[str]:
"""提取表名"""
tables = []
if query_type == QueryType.SELECT:
match = self.sql_patterns['select'].search(sql)
if match:
tables.append(match.group(1))
elif query_type == QueryType.INSERT:
match = self.sql_patterns['insert'].search(sql)
if match:
tables.append(match.group(1))
elif query_type == QueryType.UPDATE:
match = self.sql_patterns['update'].search(sql)
if match:
tables.append(match.group(1))
elif query_type == QueryType.DELETE:
match = self.sql_patterns['delete'].search(sql)
if match:
tables.append(match.group(1))
# 提取JOIN中的表
join_matches = self.sql_patterns['join'].findall(sql)
tables.extend(join_matches)
return list(set(tables)) # 去重
def _extract_where_conditions(self, sql: str) -> List[Dict]:
"""提取WHERE条件"""
conditions = []
match = self.sql_patterns['where'].search(sql)
if match:
where_clause = match.group(1).strip()
# 简单的条件解析(可以扩展为更复杂的解析器)
# 这里只处理基本的等值条件
condition_pattern = re.compile(r'(\w+)\s*=\s*([\'\"]?[^\s\'\"]+[\'\"]?)', re.IGNORECASE)
for match in condition_pattern.finditer(where_clause):
conditions.append({
'column': match.group(1),
'operator': '=',
'value': match.group(2).strip('\'\"`')
})
return conditions
def _extract_joins(self, sql: str) -> List[str]:
"""提取JOIN信息"""
return self.sql_patterns['join'].findall(sql)
def _extract_order_by(self, sql: str) -> Optional[str]:
"""提取ORDER BY"""
match = self.sql_patterns['order_by'].search(sql)
return match.group(1).strip() if match else None
def _extract_group_by(self, sql: str) -> Optional[str]:
"""提取GROUP BY"""
match = self.sql_patterns['group_by'].search(sql)
return match.group(1).strip() if match else None
def _extract_limit(self, sql: str) -> Optional[Dict]:
"""提取LIMIT信息"""
match = self.sql_patterns['limit'].search(sql)
if match:
if match.group(2): # LIMIT offset, count
return {
'offset': int(match.group(1)),
'count': int(match.group(2))
}
else: # LIMIT count
return {
'offset': 0,
'count': int(match.group(1))
}
return None
def optimize_query(self, sql: str, optimization_level: OptimizationLevel = OptimizationLevel.BASIC) -> QueryPlan:
"""优化查询"""
query_id = self._generate_query_id(sql)
# 解析查询
parsed_query = self.parse_query(sql)
# 确定目标分片
target_shards = self._determine_target_shards(parsed_query)
# 生成优化建议
suggestions = self._generate_optimization_suggestions(parsed_query, optimization_level)
# 优化SQL
optimized_sql = self._apply_optimizations(sql, parsed_query, suggestions)
# 计算执行成本
estimated_cost = self._estimate_execution_cost(parsed_query, target_shards)
# 确定执行顺序
execution_order = self._determine_execution_order(target_shards, parsed_query)
# 生成缓存键
cache_key = self._generate_cache_key(optimized_sql, target_shards)
return QueryPlan(
query_id=query_id,
original_query=sql,
optimized_query=optimized_sql,
target_shards=target_shards,
execution_order=execution_order,
estimated_cost=estimated_cost,
optimization_suggestions=suggestions,
cache_key=cache_key
)
def _generate_query_id(self, sql: str) -> str:
"""生成查询ID"""
return hashlib.md5(sql.encode()).hexdigest()[:16]
def _determine_target_shards(self, parsed_query: Dict) -> List[str]:
"""确定目标分片"""
target_shards = []
# 根据WHERE条件中的分片键确定目标分片
for condition in parsed_query['where_conditions']:
if self._is_shard_key(condition['column']):
shard_id = self._calculate_shard_id(condition['value'])
if shard_id not in target_shards:
target_shards.append(shard_id)
# 如果没有分片键条件,需要查询所有分片
if not target_shards:
target_shards = list(self.shard_info.keys())
return target_shards
def _is_shard_key(self, column: str) -> bool:
"""检查是否为分片键"""
# 这里需要根据实际的分片配置来判断
shard_keys = self.config.get('shard_keys', ['user_id', 'id'])
return column.lower() in [key.lower() for key in shard_keys]
def _calculate_shard_id(self, value: str) -> str:
"""计算分片ID"""
# 简单的哈希分片算法
shard_count = len(self.shard_info)
if shard_count == 0:
return 'shard_0'
hash_value = int(hashlib.md5(str(value).encode()).hexdigest(), 16)
shard_index = hash_value % shard_count
return f'shard_{shard_index}'
def _generate_optimization_suggestions(self, parsed_query: Dict, level: OptimizationLevel) -> List[str]:
"""生成优化建议"""
suggestions = []
# 根据优化级别应用不同的规则
rules_to_apply = self.optimization_rules
if level == OptimizationLevel.BASIC:
rules_to_apply = [r for r in self.optimization_rules if r['priority'] <= 3]
elif level == OptimizationLevel.ADVANCED:
rules_to_apply = [r for r in self.optimization_rules if r['priority'] <= 5]
for rule in rules_to_apply:
try:
rule_suggestions = rule['check_func'](parsed_query)
suggestions.extend(rule_suggestions)
except Exception as e:
self.logger.error(f"应用优化规则 {rule['name']} 失败: {e}")
return suggestions
def _check_cross_shard_join(self, parsed_query: Dict) -> List[str]:
"""检查跨分片JOIN"""
suggestions = []
if len(parsed_query['joins']) > 0 and len(parsed_query['tables']) > 1:
# 检查是否为跨分片JOIN
table_shards = {}
for table in parsed_query['tables']:
# 这里需要根据表的分片策略来判断
table_shards[table] = self._get_table_shards(table)
# 如果不同表分布在不同分片上,则为跨分片JOIN
all_shards = set()
for shards in table_shards.values():
all_shards.update(shards)
if len(all_shards) > 1:
suggestions.append("检测到跨分片JOIN,建议重新设计查询或使用应用层聚合")
return suggestions
def _check_shard_key_usage(self, parsed_query: Dict) -> List[str]:
"""检查分片键使用"""
suggestions = []
# 检查WHERE条件中是否包含分片键
has_shard_key = False
for condition in parsed_query['where_conditions']:
if self._is_shard_key(condition['column']):
has_shard_key = True
break
if not has_shard_key and parsed_query['query_type'] in [QueryType.SELECT, QueryType.UPDATE, QueryType.DELETE]:
suggestions.append("建议在WHERE子句中包含分片键以提高查询性能")
return suggestions
def _check_limit_optimization(self, parsed_query: Dict) -> List[str]:
"""检查LIMIT优化"""
suggestions = []
if parsed_query['limit'] and parsed_query['limit']['offset'] > 1000:
suggestions.append("OFFSET值过大,建议使用游标分页或其他分页策略")
if parsed_query['limit'] and parsed_query['limit']['count'] > 10000:
suggestions.append("LIMIT值过大,可能影响性能,建议分批处理")
return suggestions
def _check_covering_index(self, parsed_query: Dict) -> List[str]:
"""检查覆盖索引"""
suggestions = []
# 这里需要根据实际的索引信息来判断
# 简化实现,仅提供建议
if parsed_query['query_type'] == QueryType.SELECT:
suggestions.append("建议为常用查询创建覆盖索引以提高性能")
return suggestions
def _check_function_in_where(self, parsed_query: Dict) -> List[str]:
"""检查WHERE子句中的函数"""
suggestions = []
# 检查WHERE子句中是否使用了函数
where_clause = ' '.join([cond['column'] for cond in parsed_query['where_conditions']])
function_pattern = re.compile(r'\b(?:UPPER|LOWER|SUBSTRING|DATE|YEAR|MONTH)\s*\(', re.IGNORECASE)
if function_pattern.search(where_clause):
suggestions.append("避免在WHERE子句中使用函数,这会导致无法使用索引")
return suggestions
def _get_table_shards(self, table: str) -> List[str]:
"""获取表的分片列表"""
# 这里需要根据实际的分片配置来实现
# 简化实现,假设所有表都分布在所有分片上
return list(self.shard_info.keys())
def _apply_optimizations(self, sql: str, parsed_query: Dict, suggestions: List[str]) -> str:
"""应用优化"""
optimized_sql = sql
# 这里可以根据建议自动应用一些优化
# 例如:添加LIMIT、重写JOIN等
# 示例:如果没有LIMIT且为SELECT查询,添加默认LIMIT
if (parsed_query['query_type'] == QueryType.SELECT and
not parsed_query['limit'] and
'LIMIT' not in sql.upper()):
optimized_sql += ' LIMIT 1000'
return optimized_sql
def _estimate_execution_cost(self, parsed_query: Dict, target_shards: List[str]) -> float:
"""估算执行成本"""
base_cost = 1.0
# 分片数量影响成本
shard_cost = len(target_shards) * 0.5
# JOIN影响成本
join_cost = len(parsed_query['joins']) * 2.0
# ORDER BY影响成本
order_cost = 1.5 if parsed_query['order_by'] else 0
# GROUP BY影响成本
group_cost = 2.0 if parsed_query['group_by'] else 0
return base_cost + shard_cost + join_cost + order_cost + group_cost
def _determine_execution_order(self, target_shards: List[str], parsed_query: Dict) -> List[str]:
"""确定执行顺序"""
# 简单的执行顺序策略
# 可以根据分片负载、网络延迟等因素优化
return sorted(target_shards)
def _generate_cache_key(self, sql: str, target_shards: List[str]) -> str:
"""生成缓存键"""
cache_data = {
'sql': sql,
'shards': sorted(target_shards)
}
return hashlib.md5(json.dumps(cache_data, sort_keys=True).encode()).hexdigest()
def execute_optimized_query(self, query_plan: QueryPlan, connections: Dict[str, pymysql.Connection]) -> List[Dict]:
"""执行优化后的查询"""
start_time = time.time()
try:
# 检查缓存
if query_plan.cache_key and query_plan.cache_key in self.query_cache:
cache_entry = self.query_cache[query_plan.cache_key]
if time.time() - cache_entry['timestamp'] < self.cache_ttl:
self.logger.info(f"查询命中缓存: {query_plan.query_id}")
return cache_entry['result']
results = []
if len(query_plan.target_shards) == 1:
# 单分片查询
shard_id = query_plan.target_shards[0]
if shard_id in connections:
result = self._execute_single_shard_query(
connections[shard_id],
query_plan.optimized_query
)
results.extend(result)
else:
# 多分片查询
results = self._execute_multi_shard_query(
query_plan, connections
)
# 缓存结果
if query_plan.cache_key:
self.query_cache[query_plan.cache_key] = {
'result': results,
'timestamp': time.time()
}
# 更新统计信息
execution_time = time.time() - start_time
self._update_query_stats(query_plan.query_id, execution_time, success=True)
return results
except Exception as e:
execution_time = time.time() - start_time
self._update_query_stats(query_plan.query_id, execution_time, success=False)
self.logger.error(f"查询执行失败 {query_plan.query_id}: {e}")
raise
def _execute_single_shard_query(self, connection: pymysql.Connection, sql: str) -> List[Dict]:
"""执行单分片查询"""
with connection.cursor(pymysql.cursors.DictCursor) as cursor:
cursor.execute(sql)
return cursor.fetchall()
def _execute_multi_shard_query(self, query_plan: QueryPlan, connections: Dict[str, pymysql.Connection]) -> List[Dict]:
"""执行多分片查询"""
all_results = []
# 并行执行查询
with ThreadPoolExecutor(max_workers=len(query_plan.target_shards)) as executor:
futures = {}
for shard_id in query_plan.execution_order:
if shard_id in connections:
future = executor.submit(
self._execute_single_shard_query,
connections[shard_id],
query_plan.optimized_query
)
futures[future] = shard_id
for future in as_completed(futures):
try:
shard_id = futures[future]
result = future.result()
all_results.extend(result)
except Exception as e:
shard_id = futures[future]
self.logger.error(f"分片 {shard_id} 查询失败: {e}")
# 合并和排序结果
return self._merge_results(all_results, query_plan)
def _merge_results(self, results: List[Dict], query_plan: QueryPlan) -> List[Dict]:
"""合并查询结果"""
if not results:
return []
# 解析原始查询以获取ORDER BY和LIMIT信息
parsed_query = self.parse_query(query_plan.original_query)
# 应用ORDER BY
if parsed_query['order_by']:
results = self._apply_order_by(results, parsed_query['order_by'])
# 应用LIMIT
if parsed_query['limit']:
offset = parsed_query['limit']['offset']
count = parsed_query['limit']['count']
results = results[offset:offset + count]
return results
def _apply_order_by(self, results: List[Dict], order_by: str) -> List[Dict]:
"""应用ORDER BY排序"""
# 简单的排序实现
# 解析ORDER BY子句
order_parts = [part.strip() for part in order_by.split(',')]
for part in reversed(order_parts): # 从后往前应用排序
if ' DESC' in part.upper():
column = part.replace(' DESC', '').replace(' desc', '').strip()
results.sort(key=lambda x: x.get(column, ''), reverse=True)
else:
column = part.replace(' ASC', '').replace(' asc', '').strip()
results.sort(key=lambda x: x.get(column, ''))
return results
def _update_query_stats(self, query_id: str, execution_time: float, success: bool):
"""更新查询统计信息"""
query_hash = query_id
if query_hash not in self.query_stats:
self.query_stats[query_hash] = QueryStats(
query_hash=query_hash,
execution_count=0,
total_time=0.0,
avg_time=0.0,
min_time=float('inf'),
max_time=0.0,
error_count=0,
last_executed=datetime.now()
)
stats = self.query_stats[query_hash]
stats.execution_count += 1
stats.last_executed = datetime.now()
if success:
stats.total_time += execution_time
stats.avg_time = stats.total_time / stats.execution_count
stats.min_time = min(stats.min_time, execution_time)
stats.max_time = max(stats.max_time, execution_time)
else:
stats.error_count += 1
def get_query_stats(self, query_id: str = None) -> Dict[str, QueryStats]:
"""获取查询统计信息"""
if query_id:
return {query_id: self.query_stats.get(query_id)}
return self.query_stats.copy()
def clear_cache(self):
"""清理缓存"""
self.query_cache.clear()
self.logger.info("查询缓存已清理")
def get_optimization_report(self) -> Dict[str, Any]:
"""获取优化报告"""
total_queries = len(self.query_stats)
total_errors = sum(stats.error_count for stats in self.query_stats.values())
avg_execution_time = sum(stats.avg_time for stats in self.query_stats.values()) / total_queries if total_queries > 0 else 0
# 找出最慢的查询
slowest_queries = sorted(
self.query_stats.items(),
key=lambda x: x[1].avg_time,
reverse=True
)[:10]
# 找出错误最多的查询
error_queries = sorted(
self.query_stats.items(),
key=lambda x: x[1].error_count,
reverse=True
)[:10]
return {
'summary': {
'total_queries': total_queries,
'total_errors': total_errors,
'error_rate': total_errors / total_queries if total_queries > 0 else 0,
'avg_execution_time': avg_execution_time,
'cache_hit_rate': len(self.query_cache) / total_queries if total_queries > 0 else 0
},
'slowest_queries': [
{
'query_id': qid,
'avg_time': stats.avg_time,
'execution_count': stats.execution_count
}
for qid, stats in slowest_queries if stats.error_count == 0
],
'error_queries': [
{
'query_id': qid,
'error_count': stats.error_count,
'execution_count': stats.execution_count
}
for qid, stats in error_queries if stats.error_count > 0
]
}
14.8 最佳实践与总结
14.8.1 分库分表最佳实践
1. 设计原则
# MySQL分库分表设计原则
## 1. 分片策略选择
### 水平分片优先
- 优先考虑水平分片,保持表结构一致性
- 根据业务特点选择合适的分片键
- 避免热点数据集中在单个分片
### 分片键选择标准
- 查询频率高的字段
- 数据分布均匀的字段
- 业务逻辑相关性强的字段
- 避免经常变更的字段
## 2. 分片数量规划
### 初始分片数
- 根据当前数据量和增长预期确定
- 考虑单表数据量限制(建议500万-1000万行)
- 预留扩容空间,建议初始分片数为2的幂次
### 扩容策略
- 制定明确的扩容触发条件
- 设计平滑的数据迁移方案
- 考虑业务低峰期进行扩容操作
## 3. 跨分片查询优化
### 避免跨分片JOIN
- 重新设计数据模型,减少表间关联
- 使用冗余字段替代JOIN查询
- 在应用层进行数据聚合
### 分页查询优化
- 避免使用OFFSET进行深度分页
- 使用游标分页或范围查询
- 在应用层合并分页结果
## 4. 数据一致性保证
### 分布式事务
- 尽量避免跨分片事务
- 使用最终一致性替代强一致性
- 实现补偿机制处理异常情况
### 数据同步
- 建立主从复制保证数据可靠性
- 实现跨分片数据一致性检查
- 定期进行数据校验和修复
2. 性能优化建议
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MySQL分库分表性能优化建议生成器
"""
from typing import Dict, List, Any
from dataclasses import dataclass
from enum import Enum
class OptimizationType(Enum):
QUERY = "query"
INDEX = "index"
SCHEMA = "schema"
HARDWARE = "hardware"
CONFIGURATION = "configuration"
@dataclass
class OptimizationSuggestion:
type: OptimizationType
priority: int # 1-5, 1为最高优先级
title: str
description: str
implementation: str
expected_improvement: str
risk_level: str # low, medium, high
class PerformanceOptimizationAdvisor:
"""性能优化建议生成器"""
def __init__(self):
self.suggestions = self._init_suggestions()
def _init_suggestions(self) -> List[OptimizationSuggestion]:
"""初始化优化建议"""
return [
# 查询优化
OptimizationSuggestion(
type=OptimizationType.QUERY,
priority=1,
title="使用分片键进行查询",
description="在WHERE子句中包含分片键,避免全分片扫描",
implementation="""-- 优化前
SELECT * FROM user_orders WHERE order_date > '2023-01-01';
-- 优化后
SELECT * FROM user_orders WHERE user_id = 12345 AND order_date > '2023-01-01';""",
expected_improvement="查询性能提升80-90%",
risk_level="low"
),
OptimizationSuggestion(
type=OptimizationType.QUERY,
priority=2,
title="避免跨分片JOIN",
description="重构查询逻辑,在应用层进行数据聚合",
implementation="""-- 避免这样的查询
SELECT u.name, o.total
FROM users u
JOIN orders o ON u.id = o.user_id;
-- 改为应用层查询
1. SELECT * FROM users WHERE id IN (1,2,3);
2. SELECT * FROM orders WHERE user_id IN (1,2,3);
3. 在应用层合并结果""",
expected_improvement="避免跨分片网络开销",
risk_level="medium"
),
# 索引优化
OptimizationSuggestion(
type=OptimizationType.INDEX,
priority=1,
title="创建复合索引",
description="为常用查询条件创建复合索引",
implementation="""-- 分析查询模式
SELECT * FROM orders WHERE user_id = ? AND status = ? ORDER BY created_at DESC;
-- 创建复合索引
CREATE INDEX idx_user_status_time ON orders(user_id, status, created_at);""",
expected_improvement="查询性能提升50-70%",
risk_level="low"
),
OptimizationSuggestion(
type=OptimizationType.INDEX,
priority=2,
title="使用覆盖索引",
description="创建包含所有查询字段的索引,避免回表查询",
implementation="""-- 查询只需要特定字段
SELECT user_id, order_id, total FROM orders WHERE status = 'completed';
-- 创建覆盖索引
CREATE INDEX idx_status_covering ON orders(status, user_id, order_id, total);""",
expected_improvement="减少磁盘I/O,性能提升30-50%",
risk_level="low"
),
# 表结构优化
OptimizationSuggestion(
type=OptimizationType.SCHEMA,
priority=2,
title="垂直分表",
description="将大表按字段访问频率进行垂直拆分",
implementation="""-- 原始表
CREATE TABLE user_profiles (
id BIGINT PRIMARY KEY,
username VARCHAR(50),
email VARCHAR(100),
profile_data TEXT, -- 大字段,访问频率低
settings JSON, -- 大字段,访问频率低
created_at TIMESTAMP
);
-- 拆分后
CREATE TABLE user_basic (
id BIGINT PRIMARY KEY,
username VARCHAR(50),
email VARCHAR(100),
created_at TIMESTAMP
);
CREATE TABLE user_details (
user_id BIGINT PRIMARY KEY,
profile_data TEXT,
settings JSON,
FOREIGN KEY (user_id) REFERENCES user_basic(id)
);""",
expected_improvement="减少常用查询的数据传输量",
risk_level="medium"
),
# 硬件优化
OptimizationSuggestion(
type=OptimizationType.HARDWARE,
priority=3,
title="使用SSD存储",
description="将数据库文件存储在SSD上,提升I/O性能",
implementation="""# 迁移数据到SSD
1. 停止MySQL服务
2. 复制数据文件到SSD
3. 修改配置文件中的数据目录
4. 重启MySQL服务
# 配置示例
[mysqld]
datadir = /ssd/mysql/data
innodb_data_home_dir = /ssd/mysql/data
innodb_log_group_home_dir = /ssd/mysql/logs""",
expected_improvement="I/O性能提升5-10倍",
risk_level="low"
),
# 配置优化
OptimizationSuggestion(
type=OptimizationType.CONFIGURATION,
priority=2,
title="优化InnoDB缓冲池",
description="调整innodb_buffer_pool_size以充分利用内存",
implementation="""# 计算合适的缓冲池大小
# 建议设置为系统内存的70-80%
[mysqld]
# 对于16GB内存的服务器
innodb_buffer_pool_size = 12G
innodb_buffer_pool_instances = 8
innodb_buffer_pool_chunk_size = 128M""",
expected_improvement="减少磁盘I/O,性能提升20-40%",
risk_level="low"
),
OptimizationSuggestion(
type=OptimizationType.CONFIGURATION,
priority=3,
title="优化连接池配置",
description="合理配置连接池大小和超时参数",
implementation="""# 应用层连接池配置
max_connections = 200
min_connections = 10
max_idle_time = 300
validation_query = SELECT 1
# MySQL服务器配置
[mysqld]
max_connections = 1000
max_connect_errors = 100000
connect_timeout = 10
wait_timeout = 28800""",
expected_improvement="提升并发处理能力",
risk_level="low"
)
]
def get_suggestions_by_type(self, optimization_type: OptimizationType) -> List[OptimizationSuggestion]:
"""根据类型获取优化建议"""
return [s for s in self.suggestions if s.type == optimization_type]
def get_high_priority_suggestions(self) -> List[OptimizationSuggestion]:
"""获取高优先级优化建议"""
return [s for s in self.suggestions if s.priority <= 2]
def generate_optimization_plan(self, current_issues: List[str]) -> Dict[str, Any]:
"""生成优化计划"""
plan = {
'immediate_actions': [],
'short_term_goals': [],
'long_term_goals': []
}
# 根据当前问题匹配建议
for issue in current_issues:
if 'slow query' in issue.lower():
plan['immediate_actions'].extend(
self.get_suggestions_by_type(OptimizationType.QUERY)
)
elif 'high cpu' in issue.lower():
plan['immediate_actions'].extend(
self.get_suggestions_by_type(OptimizationType.INDEX)
)
elif 'memory' in issue.lower():
plan['short_term_goals'].extend(
self.get_suggestions_by_type(OptimizationType.CONFIGURATION)
)
# 添加长期优化目标
plan['long_term_goals'].extend(
self.get_suggestions_by_type(OptimizationType.HARDWARE)
)
plan['long_term_goals'].extend(
self.get_suggestions_by_type(OptimizationType.SCHEMA)
)
return plan
14.8.2 总结
MySQL分库分表技术是解决大规模数据存储和高并发访问的重要手段。通过本教程的学习,我们掌握了:
核心技术要点
分片策略设计
- 水平分片:按行分割数据
- 垂直分片:按列分割数据
- 混合分片:结合水平和垂直分片
分片键选择
- 数据分布均匀性
- 查询路由效率
- 业务逻辑相关性
中间件实现
- SQL解析与路由
- 连接池管理
- 结果集合并
数据一致性
- 分布式事务管理
- 数据同步机制
- 一致性检查
运维管理
- 在线数据迁移
- 平滑扩容缩容
- 性能监控
实施建议
渐进式实施
- 从单库单表开始
- 逐步引入分库分表
- 持续优化和调整
充分测试
- 功能测试
- 性能测试
- 压力测试
- 故障恢复测试
监控告警
- 建立完善的监控体系
- 设置合理的告警阈值
- 制定应急响应预案
团队培训
- 开发团队技术培训
- 运维团队操作培训
- 建立最佳实践文档
注意事项
复杂性管理
- 分库分表会增加系统复杂性
- 需要权衡性能收益和维护成本
- 建立完善的开发和运维流程
数据治理
- 制定数据分片规范
- 建立数据质量监控
- 定期进行数据清理
技术演进
- 关注新技术发展
- 评估技术栈升级
- 保持架构的可扩展性
通过合理的设计和实施,MySQL分库分表技术能够有效解决大规模数据处理的挑战,为业务发展提供强有力的技术支撑。