6.1 SQLAlchemy基础
6.1.1 安装和配置
# 安装Flask-SQLAlchemy和数据库驱动
pip install Flask-SQLAlchemy
pip install PyMySQL # MySQL
pip install psycopg2-binary # PostgreSQL
pip install sqlite3 # SQLite (Python内置)
# config.py
import os
from urllib.parse import quote_plus
class Config:
# 数据库配置
SQLALCHEMY_TRACK_MODIFICATIONS = False
SQLALCHEMY_RECORD_QUERIES = True
SQLALCHEMY_ENGINE_OPTIONS = {
'pool_size': 10,
'pool_recycle': 120,
'pool_pre_ping': True,
'max_overflow': 20
}
class DevelopmentConfig(Config):
DEBUG = True
SQLALCHEMY_DATABASE_URI = 'sqlite:///app.db'
SQLALCHEMY_ECHO = True # 打印SQL语句
class ProductionConfig(Config):
DEBUG = False
# MySQL配置
DB_USER = os.environ.get('DB_USER', 'root')
DB_PASSWORD = os.environ.get('DB_PASSWORD', '')
DB_HOST = os.environ.get('DB_HOST', 'localhost')
DB_PORT = os.environ.get('DB_PORT', '3306')
DB_NAME = os.environ.get('DB_NAME', 'flask_app')
SQLALCHEMY_DATABASE_URI = f'mysql+pymysql://{DB_USER}:{quote_plus(DB_PASSWORD)}@{DB_HOST}:{DB_PORT}/{DB_NAME}?charset=utf8mb4'
class TestingConfig(Config):
TESTING = True
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'
config = {
'development': DevelopmentConfig,
'production': ProductionConfig,
'testing': TestingConfig,
'default': DevelopmentConfig
}
# app.py
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from config import config
db = SQLAlchemy()
migrate = Migrate()
def create_app(config_name='default'):
app = Flask(__name__)
app.config.from_object(config[config_name])
# 初始化扩展
db.init_app(app)
migrate.init_app(app, db)
return app
app = create_app()
6.1.2 数据库连接管理
# database.py
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlalchemy.pool import Pool
import time
import logging
logger = logging.getLogger(__name__)
class DatabaseManager:
"""数据库管理器"""
def __init__(self, app=None):
self.db = SQLAlchemy()
if app:
self.init_app(app)
def init_app(self, app):
"""初始化应用"""
self.db.init_app(app)
# 注册事件监听器
self.register_events()
# 注册CLI命令
self.register_commands(app)
def register_events(self):
"""注册数据库事件"""
@event.listens_for(Engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
"""SQL执行前"""
context._query_start_time = time.time()
logger.debug(f"SQL执行开始: {statement[:100]}...")
@event.listens_for(Engine, "after_cursor_execute")
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
"""SQL执行后"""
total = time.time() - context._query_start_time
if total > 0.1: # 慢查询警告
logger.warning(f"慢查询检测 ({total:.3f}s): {statement[:100]}...")
else:
logger.debug(f"SQL执行完成 ({total:.3f}s)")
@event.listens_for(Pool, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
"""SQLite优化设置"""
if 'sqlite' in str(dbapi_connection):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA cache_size=10000")
cursor.execute("PRAGMA temp_store=MEMORY")
cursor.close()
def register_commands(self, app):
"""注册CLI命令"""
@app.cli.command()
def init_db():
"""初始化数据库"""
self.db.create_all()
print('数据库初始化完成')
@app.cli.command()
def drop_db():
"""删除数据库"""
self.db.drop_all()
print('数据库已删除')
@app.cli.command()
def reset_db():
"""重置数据库"""
self.db.drop_all()
self.db.create_all()
print('数据库已重置')
def get_engine_info(self):
"""获取数据库引擎信息"""
engine = self.db.engine
return {
'url': str(engine.url),
'driver': engine.driver,
'pool_size': engine.pool.size(),
'checked_in': engine.pool.checkedin(),
'checked_out': engine.pool.checkedout(),
'overflow': engine.pool.overflow(),
'invalid': engine.pool.invalid()
}
# 全局数据库管理器实例
db_manager = DatabaseManager()
db = db_manager.db
6.2 数据模型设计
6.2.1 基础模型
# models/base.py
from datetime import datetime
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import Column, Integer, DateTime, String
from sqlalchemy.ext.declarative import declared_attr
db = SQLAlchemy()
class TimestampMixin:
"""时间戳混入类"""
created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
class BaseModel(db.Model, TimestampMixin):
"""基础模型类"""
__abstract__ = True
id = db.Column(db.Integer, primary_key=True)
@declared_attr
def __tablename__(cls):
"""自动生成表名"""
return cls.__name__.lower()
def save(self):
"""保存到数据库"""
db.session.add(self)
db.session.commit()
return self
def delete(self):
"""从数据库删除"""
db.session.delete(self)
db.session.commit()
def update(self, **kwargs):
"""更新字段"""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
self.updated_at = datetime.utcnow()
db.session.commit()
return self
def to_dict(self, exclude=None, include_relationships=False):
"""转换为字典"""
exclude = exclude or []
result = {}
for column in self.__table__.columns:
if column.name not in exclude:
value = getattr(self, column.name)
if isinstance(value, datetime):
value = value.isoformat()
result[column.name] = value
if include_relationships:
for relationship in self.__mapper__.relationships:
if relationship.key not in exclude:
value = getattr(self, relationship.key)
if value is not None:
if relationship.uselist:
result[relationship.key] = [item.to_dict() for item in value]
else:
result[relationship.key] = value.to_dict()
return result
@classmethod
def create(cls, **kwargs):
"""创建新实例"""
instance = cls(**kwargs)
return instance.save()
@classmethod
def get_or_404(cls, id):
"""根据ID获取或返回404"""
return cls.query.get_or_404(id)
@classmethod
def get_or_create(cls, defaults=None, **kwargs):
"""获取或创建"""
instance = cls.query.filter_by(**kwargs).first()
if instance:
return instance, False
else:
params = dict((k, v) for k, v in kwargs.items())
params.update(defaults or {})
instance = cls(**params)
instance.save()
return instance, True
def __repr__(self):
return f'<{self.__class__.__name__} {self.id}>'
6.2.2 用户模型
# models/user.py
from werkzeug.security import generate_password_hash, check_password_hash
from flask_login import UserMixin
from sqlalchemy import event
from models.base import BaseModel, db
import re
class User(BaseModel, UserMixin):
"""用户模型"""
__tablename__ = 'users'
# 基本信息
username = db.Column(db.String(80), unique=True, nullable=False, index=True)
email = db.Column(db.String(120), unique=True, nullable=False, index=True)
password_hash = db.Column(db.String(255), nullable=False)
# 个人信息
first_name = db.Column(db.String(50))
last_name = db.Column(db.String(50))
phone = db.Column(db.String(20))
avatar = db.Column(db.String(255))
bio = db.Column(db.Text)
# 状态字段
is_active = db.Column(db.Boolean, default=True, nullable=False)
is_verified = db.Column(db.Boolean, default=False, nullable=False)
is_admin = db.Column(db.Boolean, default=False, nullable=False)
# 统计字段
login_count = db.Column(db.Integer, default=0)
last_login_at = db.Column(db.DateTime)
last_login_ip = db.Column(db.String(45))
# 关系
posts = db.relationship('Post', backref='author', lazy='dynamic', cascade='all, delete-orphan')
comments = db.relationship('Comment', backref='author', lazy='dynamic', cascade='all, delete-orphan')
def set_password(self, password):
"""设置密码"""
if not self.validate_password_strength(password):
raise ValueError('密码强度不够')
self.password_hash = generate_password_hash(password)
def check_password(self, password):
"""验证密码"""
return check_password_hash(self.password_hash, password)
@staticmethod
def validate_password_strength(password):
"""验证密码强度"""
if len(password) < 8:
return False
if not re.search(r'[A-Z]', password):
return False
if not re.search(r'[a-z]', password):
return False
if not re.search(r'\d', password):
return False
return True
@property
def full_name(self):
"""全名"""
if self.first_name and self.last_name:
return f'{self.first_name} {self.last_name}'
return self.username
@property
def is_authenticated(self):
"""是否已认证"""
return True
@property
def is_anonymous(self):
"""是否匿名"""
return False
def get_id(self):
"""获取用户ID"""
return str(self.id)
def update_login_info(self, ip_address):
"""更新登录信息"""
self.login_count += 1
self.last_login_at = datetime.utcnow()
self.last_login_ip = ip_address
db.session.commit()
def get_posts_count(self):
"""获取文章数量"""
return self.posts.count()
def get_comments_count(self):
"""获取评论数量"""
return self.comments.count()
def to_dict(self, include_sensitive=False):
"""转换为字典"""
exclude = ['password_hash'] if not include_sensitive else []
result = super().to_dict(exclude=exclude)
result['full_name'] = self.full_name
result['posts_count'] = self.get_posts_count()
result['comments_count'] = self.get_comments_count()
return result
def __repr__(self):
return f'<User {self.username}>'
# 事件监听器
@event.listens_for(User.username, 'set')
def validate_username(target, value, oldvalue, initiator):
"""验证用户名"""
if value and not re.match(r'^[a-zA-Z0-9_]{3,20}$', value):
raise ValueError('用户名只能包含字母、数字和下划线,长度3-20个字符')
@event.listens_for(User.email, 'set')
def validate_email(target, value, oldvalue, initiator):
"""验证邮箱"""
if value and not re.match(r'^[\w\.-]+@[\w\.-]+\.\w+$', value):
raise ValueError('邮箱格式不正确')
6.2.3 文章和评论模型
# models/blog.py
from sqlalchemy import text
from models.base import BaseModel, db
from datetime import datetime
# 文章标签关联表
post_tags = db.Table('post_tags',
db.Column('post_id', db.Integer, db.ForeignKey('posts.id'), primary_key=True),
db.Column('tag_id', db.Integer, db.ForeignKey('tags.id'), primary_key=True)
)
class Category(BaseModel):
"""分类模型"""
__tablename__ = 'categories'
name = db.Column(db.String(50), unique=True, nullable=False)
slug = db.Column(db.String(50), unique=True, nullable=False)
description = db.Column(db.Text)
color = db.Column(db.String(7), default='#007bff') # 十六进制颜色
# 层级关系
parent_id = db.Column(db.Integer, db.ForeignKey('categories.id'))
parent = db.relationship('Category', remote_side=[id], backref='children')
# 统计字段
posts_count = db.Column(db.Integer, default=0)
# 关系
posts = db.relationship('Post', backref='category', lazy='dynamic')
def update_posts_count(self):
"""更新文章数量"""
self.posts_count = self.posts.count()
db.session.commit()
def get_all_posts(self):
"""获取包含子分类的所有文章"""
category_ids = [self.id]
category_ids.extend([child.id for child in self.children])
return Post.query.filter(Post.category_id.in_(category_ids))
def __repr__(self):
return f'<Category {self.name}>'
class Tag(BaseModel):
"""标签模型"""
__tablename__ = 'tags'
name = db.Column(db.String(30), unique=True, nullable=False)
color = db.Column(db.String(7), default='#6c757d')
# 统计字段
posts_count = db.Column(db.Integer, default=0)
def update_posts_count(self):
"""更新文章数量"""
self.posts_count = db.session.query(post_tags).filter_by(tag_id=self.id).count()
db.session.commit()
def __repr__(self):
return f'<Tag {self.name}>'
class Post(BaseModel):
"""文章模型"""
__tablename__ = 'posts'
# 基本信息
title = db.Column(db.String(200), nullable=False)
slug = db.Column(db.String(200), unique=True, nullable=False)
summary = db.Column(db.Text)
content = db.Column(db.Text, nullable=False)
# 状态字段
status = db.Column(db.Enum('draft', 'published', 'archived', name='post_status'),
default='draft', nullable=False)
is_featured = db.Column(db.Boolean, default=False)
allow_comments = db.Column(db.Boolean, default=True)
# 时间字段
published_at = db.Column(db.DateTime)
# 统计字段
views_count = db.Column(db.Integer, default=0)
likes_count = db.Column(db.Integer, default=0)
comments_count = db.Column(db.Integer, default=0)
# 外键
author_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
category_id = db.Column(db.Integer, db.ForeignKey('categories.id'))
# 关系
tags = db.relationship('Tag', secondary=post_tags, backref='posts')
comments = db.relationship('Comment', backref='post', lazy='dynamic',
cascade='all, delete-orphan')
def publish(self):
"""发布文章"""
self.status = 'published'
self.published_at = datetime.utcnow()
db.session.commit()
def archive(self):
"""归档文章"""
self.status = 'archived'
db.session.commit()
def increment_views(self):
"""增加浏览量"""
self.views_count += 1
db.session.commit()
def update_comments_count(self):
"""更新评论数量"""
self.comments_count = self.comments.filter_by(status='approved').count()
db.session.commit()
def get_reading_time(self):
"""计算阅读时间(分钟)"""
words_per_minute = 200
word_count = len(self.content.split())
return max(1, round(word_count / words_per_minute))
@property
def is_published(self):
"""是否已发布"""
return self.status == 'published' and self.published_at is not None
def to_dict(self, include_content=True):
"""转换为字典"""
result = super().to_dict()
result['reading_time'] = self.get_reading_time()
result['is_published'] = self.is_published
if not include_content:
result.pop('content', None)
return result
def __repr__(self):
return f'<Post {self.title}>'
class Comment(BaseModel):
"""评论模型"""
__tablename__ = 'comments'
content = db.Column(db.Text, nullable=False)
author_name = db.Column(db.String(50)) # 游客评论时使用
author_email = db.Column(db.String(120)) # 游客评论时使用
author_ip = db.Column(db.String(45))
# 状态字段
status = db.Column(db.Enum('pending', 'approved', 'rejected', name='comment_status'),
default='pending', nullable=False)
# 层级关系
parent_id = db.Column(db.Integer, db.ForeignKey('comments.id'))
parent = db.relationship('Comment', remote_side=[id], backref='replies')
# 外键
post_id = db.Column(db.Integer, db.ForeignKey('posts.id'), nullable=False)
author_id = db.Column(db.Integer, db.ForeignKey('users.id')) # 可选,注册用户评论
def approve(self):
"""批准评论"""
self.status = 'approved'
db.session.commit()
# 更新文章评论数量
self.post.update_comments_count()
def reject(self):
"""拒绝评论"""
self.status = 'rejected'
db.session.commit()
self.post.update_comments_count()
@property
def author_display_name(self):
"""显示名称"""
if self.author:
return self.author.full_name
return self.author_name or '匿名用户'
def __repr__(self):
return f'<Comment {self.id} on Post {self.post_id}>'
# 事件监听器
@event.listens_for(Post, 'before_insert')
def generate_slug(mapper, connection, target):
"""自动生成slug"""
if not target.slug:
import re
slug = re.sub(r'[^\w\s-]', '', target.title.lower())
slug = re.sub(r'[-\s]+', '-', slug)
target.slug = slug
@event.listens_for(Comment, 'after_insert')
def update_post_comments_count_after_insert(mapper, connection, target):
"""评论插入后更新文章评论数"""
if target.status == 'approved':
connection.execute(
text("UPDATE posts SET comments_count = comments_count + 1 WHERE id = :post_id"),
{'post_id': target.post_id}
)
@event.listens_for(Comment, 'after_delete')
def update_post_comments_count_after_delete(mapper, connection, target):
"""评论删除后更新文章评论数"""
if target.status == 'approved':
connection.execute(
text("UPDATE posts SET comments_count = comments_count - 1 WHERE id = :post_id"),
{'post_id': target.post_id}
)
6.3 数据库迁移
6.3.1 Flask-Migrate配置
# 安装Flask-Migrate
pip install Flask-Migrate
# 初始化迁移仓库
flask db init
# 生成迁移文件
flask db migrate -m "Initial migration"
# 应用迁移
flask db upgrade
# 查看迁移历史
flask db history
# 回滚迁移
flask db downgrade
6.3.2 迁移管理
# migrations/migration_manager.py
from flask_migrate import Migrate, upgrade, downgrade, current, history
from flask import current_app
import os
class MigrationManager:
"""迁移管理器"""
def __init__(self, app=None, db=None):
self.migrate = Migrate()
if app and db:
self.init_app(app, db)
def init_app(self, app, db):
"""初始化应用"""
self.migrate.init_app(app, db)
self.register_commands(app)
def register_commands(self, app):
"""注册CLI命令"""
@app.cli.command()
def db_status():
"""显示数据库状态"""
try:
current_rev = current()
print(f"当前版本: {current_rev}")
# 显示迁移历史
revisions = history()
print("\n迁移历史:")
for rev in revisions:
print(f" {rev.revision}: {rev.doc}")
except Exception as e:
print(f"获取数据库状态失败: {e}")
@app.cli.command()
def db_upgrade():
"""升级数据库"""
try:
upgrade()
print("数据库升级成功")
except Exception as e:
print(f"数据库升级失败: {e}")
@app.cli.command()
def db_downgrade():
"""降级数据库"""
try:
downgrade()
print("数据库降级成功")
except Exception as e:
print(f"数据库降级失败: {e}")
@app.cli.command()
def db_backup():
"""备份数据库"""
self.backup_database()
@app.cli.command()
def db_restore():
"""恢复数据库"""
self.restore_database()
def backup_database(self):
"""备份数据库"""
from datetime import datetime
import subprocess
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
backup_dir = 'backups'
if not os.path.exists(backup_dir):
os.makedirs(backup_dir)
db_url = current_app.config['SQLALCHEMY_DATABASE_URI']
if db_url.startswith('sqlite'):
# SQLite备份
db_path = db_url.replace('sqlite:///', '')
backup_path = f"{backup_dir}/backup_{timestamp}.db"
import shutil
shutil.copy2(db_path, backup_path)
print(f"SQLite数据库已备份到: {backup_path}")
elif db_url.startswith('mysql'):
# MySQL备份
backup_path = f"{backup_dir}/backup_{timestamp}.sql"
# 解析数据库连接信息
from urllib.parse import urlparse
parsed = urlparse(db_url)
cmd = [
'mysqldump',
f'-h{parsed.hostname}',
f'-P{parsed.port or 3306}',
f'-u{parsed.username}',
f'-p{parsed.password}',
parsed.path[1:] # 去掉开头的/
]
with open(backup_path, 'w') as f:
subprocess.run(cmd, stdout=f, check=True)
print(f"MySQL数据库已备份到: {backup_path}")
def restore_database(self, backup_file=None):
"""恢复数据库"""
if not backup_file:
# 选择最新的备份文件
backup_dir = 'backups'
if not os.path.exists(backup_dir):
print("没有找到备份目录")
return
backup_files = [f for f in os.listdir(backup_dir) if f.startswith('backup_')]
if not backup_files:
print("没有找到备份文件")
return
backup_file = os.path.join(backup_dir, sorted(backup_files)[-1])
db_url = current_app.config['SQLALCHEMY_DATABASE_URI']
if db_url.startswith('sqlite') and backup_file.endswith('.db'):
# SQLite恢复
db_path = db_url.replace('sqlite:///', '')
import shutil
shutil.copy2(backup_file, db_path)
print(f"SQLite数据库已从 {backup_file} 恢复")
elif db_url.startswith('mysql') and backup_file.endswith('.sql'):
# MySQL恢复
from urllib.parse import urlparse
import subprocess
parsed = urlparse(db_url)
cmd = [
'mysql',
f'-h{parsed.hostname}',
f'-P{parsed.port or 3306}',
f'-u{parsed.username}',
f'-p{parsed.password}',
parsed.path[1:]
]
with open(backup_file, 'r') as f:
subprocess.run(cmd, stdin=f, check=True)
print(f"MySQL数据库已从 {backup_file} 恢复")
6.3.3 自定义迁移脚本
# migrations/versions/001_add_user_indexes.py
"""添加用户表索引
Revision ID: 001
Revises:
Create Date: 2024-01-01 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers
revision = '001'
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
"""升级操作"""
# 添加索引
op.create_index('idx_users_email', 'users', ['email'])
op.create_index('idx_users_username', 'users', ['username'])
op.create_index('idx_users_created_at', 'users', ['created_at'])
# 添加复合索引
op.create_index('idx_users_active_verified', 'users', ['is_active', 'is_verified'])
# 添加全文索引(MySQL)
op.execute("ALTER TABLE users ADD FULLTEXT(first_name, last_name, bio)")
def downgrade():
"""降级操作"""
# 删除索引
op.drop_index('idx_users_email', 'users')
op.drop_index('idx_users_username', 'users')
op.drop_index('idx_users_created_at', 'users')
op.drop_index('idx_users_active_verified', 'users')
# 删除全文索引
op.execute("ALTER TABLE users DROP INDEX first_name")
# migrations/versions/002_add_post_search.py
"""添加文章搜索功能
Revision ID: 002
Revises: 001
Create Date: 2024-01-02 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = '002'
down_revision = '001'
branch_labels = None
depends_on = None
def upgrade():
"""升级操作"""
# 添加搜索向量列(PostgreSQL)
op.add_column('posts', sa.Column('search_vector', sa.Text))
# 创建搜索索引
op.create_index('idx_posts_search', 'posts', ['search_vector'])
# 添加触发器更新搜索向量
op.execute("""
CREATE OR REPLACE FUNCTION update_post_search_vector() RETURNS trigger AS $$
BEGIN
NEW.search_vector := to_tsvector('english',
COALESCE(NEW.title, '') || ' ' || COALESCE(NEW.content, ''));
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
""")
op.execute("""
CREATE TRIGGER update_post_search_vector_trigger
BEFORE INSERT OR UPDATE ON posts
FOR EACH ROW EXECUTE FUNCTION update_post_search_vector();
""")
# 更新现有数据
op.execute("""
UPDATE posts SET search_vector = to_tsvector('english',
COALESCE(title, '') || ' ' || COALESCE(content, ''));
""")
def downgrade():
"""降级操作"""
# 删除触发器
op.execute("DROP TRIGGER IF EXISTS update_post_search_vector_trigger ON posts;")
op.execute("DROP FUNCTION IF EXISTS update_post_search_vector();")
# 删除索引和列
op.drop_index('idx_posts_search', 'posts')
op.drop_column('posts', 'search_vector')
6.4 查询优化
6.4.1 查询构建器
# utils/query_builder.py
from sqlalchemy import and_, or_, not_, func, desc, asc
from sqlalchemy.orm import joinedload, selectinload, contains_eager
from flask_sqlalchemy import Pagination
class QueryBuilder:
"""查询构建器"""
def __init__(self, model):
self.model = model
self.query = model.query
self._filters = []
self._orders = []
self._joins = []
self._options = []
def filter_by(self, **kwargs):
"""按字段过滤"""
for key, value in kwargs.items():
if hasattr(self.model, key):
self._filters.append(getattr(self.model, key) == value)
return self
def filter(self, *conditions):
"""添加过滤条件"""
self._filters.extend(conditions)
return self
def search(self, term, fields):
"""搜索"""
if term:
search_conditions = []
for field in fields:
if hasattr(self.model, field):
attr = getattr(self.model, field)
search_conditions.append(attr.ilike(f'%{term}%'))
if search_conditions:
self._filters.append(or_(*search_conditions))
return self
def date_range(self, field, start_date=None, end_date=None):
"""日期范围过滤"""
if hasattr(self.model, field):
attr = getattr(self.model, field)
if start_date:
self._filters.append(attr >= start_date)
if end_date:
self._filters.append(attr <= end_date)
return self
def order_by(self, field, direction='asc'):
"""排序"""
if hasattr(self.model, field):
attr = getattr(self.model, field)
if direction.lower() == 'desc':
self._orders.append(desc(attr))
else:
self._orders.append(asc(attr))
return self
def join(self, *args, **kwargs):
"""连接"""
self._joins.append((args, kwargs))
return self
def options(self, *options):
"""查询选项"""
self._options.extend(options)
return self
def eager_load(self, *relationships):
"""预加载关系"""
for rel in relationships:
self._options.append(joinedload(rel))
return self
def select_load(self, *relationships):
"""选择加载关系"""
for rel in relationships:
self._options.append(selectinload(rel))
return self
def build(self):
"""构建查询"""
query = self.query
# 应用连接
for args, kwargs in self._joins:
query = query.join(*args, **kwargs)
# 应用过滤条件
if self._filters:
query = query.filter(and_(*self._filters))
# 应用排序
if self._orders:
query = query.order_by(*self._orders)
# 应用查询选项
if self._options:
query = query.options(*self._options)
return query
def paginate(self, page=1, per_page=20, error_out=False):
"""分页"""
query = self.build()
return query.paginate(
page=page,
per_page=per_page,
error_out=error_out
)
def all(self):
"""获取所有结果"""
return self.build().all()
def first(self):
"""获取第一个结果"""
return self.build().first()
def count(self):
"""获取数量"""
return self.build().count()
def exists(self):
"""检查是否存在"""
return self.build().first() is not None
# 使用示例
def get_posts_query():
"""获取文章查询构建器"""
return QueryBuilder(Post)
def search_posts(search_term=None, category_id=None, tag_ids=None,
author_id=None, status='published', page=1, per_page=10):
"""搜索文章"""
builder = get_posts_query()
# 基本过滤
builder.filter_by(status=status)
# 搜索
if search_term:
builder.search(search_term, ['title', 'content', 'summary'])
# 分类过滤
if category_id:
builder.filter_by(category_id=category_id)
# 标签过滤
if tag_ids:
builder.join(Post.tags).filter(Tag.id.in_(tag_ids))
# 作者过滤
if author_id:
builder.filter_by(author_id=author_id)
# 预加载关系
builder.eager_load('author', 'category', 'tags')
# 排序
builder.order_by('created_at', 'desc')
return builder.paginate(page=page, per_page=per_page)
6.4.2 查询优化技巧
# utils/query_optimizer.py
from sqlalchemy import func, text
from sqlalchemy.orm import load_only, defer
from flask import current_app
import time
class QueryOptimizer:
"""查询优化器"""
@staticmethod
def optimize_pagination(query, page, per_page):
"""优化分页查询"""
# 使用窗口函数优化大偏移量分页
if page > 100: # 大偏移量时使用游标分页
return QueryOptimizer.cursor_paginate(query, page, per_page)
else:
return query.paginate(page=page, per_page=per_page, error_out=False)
@staticmethod
def cursor_paginate(query, page, per_page):
"""游标分页"""
# 计算偏移量
offset = (page - 1) * per_page
# 使用子查询获取ID
subquery = query.with_entities(query.column_descriptions[0]['entity'].id)\
.offset(offset).limit(per_page).subquery()
# 使用ID列表查询完整数据
ids = [row.id for row in subquery]
if not ids:
return []
return query.filter(query.column_descriptions[0]['entity'].id.in_(ids)).all()
@staticmethod
def optimize_n_plus_one(query, relationships):
"""解决N+1查询问题"""
from sqlalchemy.orm import joinedload, selectinload
options = []
for rel in relationships:
if isinstance(rel, str):
# 简单关系使用joinedload
options.append(joinedload(rel))
elif isinstance(rel, dict):
# 复杂关系配置
rel_name = rel['name']
strategy = rel.get('strategy', 'joined')
if strategy == 'select':
options.append(selectinload(rel_name))
else:
options.append(joinedload(rel_name))
return query.options(*options)
@staticmethod
def optimize_columns(query, model, include_fields=None, exclude_fields=None):
"""优化查询列"""
if include_fields:
# 只查询指定字段
columns = [getattr(model, field) for field in include_fields
if hasattr(model, field)]
return query.options(load_only(*columns))
elif exclude_fields:
# 延迟加载指定字段
deferred_columns = [defer(getattr(model, field)) for field in exclude_fields
if hasattr(model, field)]
return query.options(*deferred_columns)
return query
@staticmethod
def add_query_cache(query, cache_key, timeout=300):
"""添加查询缓存"""
from flask import current_app
if hasattr(current_app, 'cache'):
cached_result = current_app.cache.get(cache_key)
if cached_result is not None:
return cached_result
result = query.all()
current_app.cache.set(cache_key, result, timeout=timeout)
return result
return query.all()
@staticmethod
def explain_query(query):
"""分析查询执行计划"""
sql = str(query.statement.compile(compile_kwargs={"literal_binds": True}))
if 'mysql' in current_app.config['SQLALCHEMY_DATABASE_URI']:
explain_sql = f"EXPLAIN {sql}"
elif 'postgresql' in current_app.config['SQLALCHEMY_DATABASE_URI']:
explain_sql = f"EXPLAIN ANALYZE {sql}"
else:
explain_sql = f"EXPLAIN QUERY PLAN {sql}"
result = db.session.execute(text(explain_sql))
return [dict(row) for row in result]
@staticmethod
def profile_query(query, description=""):
"""查询性能分析"""
start_time = time.time()
result = query.all()
end_time = time.time()
execution_time = end_time - start_time
if execution_time > 0.1: # 慢查询警告
current_app.logger.warning(
f"慢查询检测 ({execution_time:.3f}s): {description}"
)
return result, execution_time
# 使用示例
def get_optimized_posts(page=1, per_page=10, include_author=True, include_tags=True):
"""获取优化的文章列表"""
query = Post.query.filter_by(status='published')
# 解决N+1问题
relationships = ['author', 'category']
if include_tags:
relationships.append('tags')
query = QueryOptimizer.optimize_n_plus_one(query, relationships)
# 优化列查询(排除大字段)
query = QueryOptimizer.optimize_columns(
query, Post, exclude_fields=['content']
)
# 优化分页
posts = QueryOptimizer.optimize_pagination(query, page, per_page)
return posts
def get_cached_popular_posts(limit=10):
"""获取缓存的热门文章"""
query = Post.query.filter_by(status='published')\
.order_by(Post.views_count.desc())\
.limit(limit)
cache_key = f"popular_posts_{limit}"
return QueryOptimizer.add_query_cache(query, cache_key, timeout=600)
本章小结
本章详细介绍了Flask的数据库集成,包括:
- SQLAlchemy基础:安装配置、连接管理、事件监听
- 数据模型设计:基础模型、用户模型、关系设计
- 数据库迁移:Flask-Migrate使用、迁移管理、自定义脚本
- 查询优化:查询构建器、性能优化、缓存策略
掌握这些技能能够帮助你构建高效、可维护的数据库应用。
下一章预告
下一章我们将学习用户认证与授权,包括:
- Flask-Login用户会话管理
- 密码加密和验证
- 角色权限系统
- OAuth第三方登录
- JWT令牌认证
练习题
- 模型设计:设计一个电商系统的数据模型(商品、订单、用户等)
- 查询优化:优化一个复杂的多表关联查询
- 迁移脚本:编写一个数据迁移脚本,将旧数据格式转换为新格式
- 性能分析:分析并优化一个慢查询
- 缓存策略:为文章系统实现多级缓存策略