6.1 SQLAlchemy基础

6.1.1 安装和配置

# 安装Flask-SQLAlchemy和数据库驱动
pip install Flask-SQLAlchemy
pip install PyMySQL          # MySQL
pip install psycopg2-binary  # PostgreSQL
pip install sqlite3          # SQLite (Python内置)
# config.py
import os
from urllib.parse import quote_plus

class Config:
    # 数据库配置
    SQLALCHEMY_TRACK_MODIFICATIONS = False
    SQLALCHEMY_RECORD_QUERIES = True
    SQLALCHEMY_ENGINE_OPTIONS = {
        'pool_size': 10,
        'pool_recycle': 120,
        'pool_pre_ping': True,
        'max_overflow': 20
    }

class DevelopmentConfig(Config):
    DEBUG = True
    SQLALCHEMY_DATABASE_URI = 'sqlite:///app.db'
    SQLALCHEMY_ECHO = True  # 打印SQL语句

class ProductionConfig(Config):
    DEBUG = False
    # MySQL配置
    DB_USER = os.environ.get('DB_USER', 'root')
    DB_PASSWORD = os.environ.get('DB_PASSWORD', '')
    DB_HOST = os.environ.get('DB_HOST', 'localhost')
    DB_PORT = os.environ.get('DB_PORT', '3306')
    DB_NAME = os.environ.get('DB_NAME', 'flask_app')
    
    SQLALCHEMY_DATABASE_URI = f'mysql+pymysql://{DB_USER}:{quote_plus(DB_PASSWORD)}@{DB_HOST}:{DB_PORT}/{DB_NAME}?charset=utf8mb4'

class TestingConfig(Config):
    TESTING = True
    SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'

config = {
    'development': DevelopmentConfig,
    'production': ProductionConfig,
    'testing': TestingConfig,
    'default': DevelopmentConfig
}
# app.py
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from config import config

db = SQLAlchemy()
migrate = Migrate()

def create_app(config_name='default'):
    app = Flask(__name__)
    app.config.from_object(config[config_name])
    
    # 初始化扩展
    db.init_app(app)
    migrate.init_app(app, db)
    
    return app

app = create_app()

6.1.2 数据库连接管理

# database.py
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlalchemy.pool import Pool
import time
import logging

logger = logging.getLogger(__name__)

class DatabaseManager:
    """数据库管理器"""
    
    def __init__(self, app=None):
        self.db = SQLAlchemy()
        if app:
            self.init_app(app)
    
    def init_app(self, app):
        """初始化应用"""
        self.db.init_app(app)
        
        # 注册事件监听器
        self.register_events()
        
        # 注册CLI命令
        self.register_commands(app)
    
    def register_events(self):
        """注册数据库事件"""
        
        @event.listens_for(Engine, "before_cursor_execute")
        def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
            """SQL执行前"""
            context._query_start_time = time.time()
            logger.debug(f"SQL执行开始: {statement[:100]}...")
        
        @event.listens_for(Engine, "after_cursor_execute")
        def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
            """SQL执行后"""
            total = time.time() - context._query_start_time
            if total > 0.1:  # 慢查询警告
                logger.warning(f"慢查询检测 ({total:.3f}s): {statement[:100]}...")
            else:
                logger.debug(f"SQL执行完成 ({total:.3f}s)")
        
        @event.listens_for(Pool, "connect")
        def set_sqlite_pragma(dbapi_connection, connection_record):
            """SQLite优化设置"""
            if 'sqlite' in str(dbapi_connection):
                cursor = dbapi_connection.cursor()
                cursor.execute("PRAGMA foreign_keys=ON")
                cursor.execute("PRAGMA journal_mode=WAL")
                cursor.execute("PRAGMA synchronous=NORMAL")
                cursor.execute("PRAGMA cache_size=10000")
                cursor.execute("PRAGMA temp_store=MEMORY")
                cursor.close()
    
    def register_commands(self, app):
        """注册CLI命令"""
        
        @app.cli.command()
        def init_db():
            """初始化数据库"""
            self.db.create_all()
            print('数据库初始化完成')
        
        @app.cli.command()
        def drop_db():
            """删除数据库"""
            self.db.drop_all()
            print('数据库已删除')
        
        @app.cli.command()
        def reset_db():
            """重置数据库"""
            self.db.drop_all()
            self.db.create_all()
            print('数据库已重置')
    
    def get_engine_info(self):
        """获取数据库引擎信息"""
        engine = self.db.engine
        return {
            'url': str(engine.url),
            'driver': engine.driver,
            'pool_size': engine.pool.size(),
            'checked_in': engine.pool.checkedin(),
            'checked_out': engine.pool.checkedout(),
            'overflow': engine.pool.overflow(),
            'invalid': engine.pool.invalid()
        }

# 全局数据库管理器实例
db_manager = DatabaseManager()
db = db_manager.db

6.2 数据模型设计

6.2.1 基础模型

# models/base.py
from datetime import datetime
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import Column, Integer, DateTime, String
from sqlalchemy.ext.declarative import declared_attr

db = SQLAlchemy()

class TimestampMixin:
    """时间戳混入类"""
    created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)
    updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)

class BaseModel(db.Model, TimestampMixin):
    """基础模型类"""
    __abstract__ = True
    
    id = db.Column(db.Integer, primary_key=True)
    
    @declared_attr
    def __tablename__(cls):
        """自动生成表名"""
        return cls.__name__.lower()
    
    def save(self):
        """保存到数据库"""
        db.session.add(self)
        db.session.commit()
        return self
    
    def delete(self):
        """从数据库删除"""
        db.session.delete(self)
        db.session.commit()
    
    def update(self, **kwargs):
        """更新字段"""
        for key, value in kwargs.items():
            if hasattr(self, key):
                setattr(self, key, value)
        self.updated_at = datetime.utcnow()
        db.session.commit()
        return self
    
    def to_dict(self, exclude=None, include_relationships=False):
        """转换为字典"""
        exclude = exclude or []
        result = {}
        
        for column in self.__table__.columns:
            if column.name not in exclude:
                value = getattr(self, column.name)
                if isinstance(value, datetime):
                    value = value.isoformat()
                result[column.name] = value
        
        if include_relationships:
            for relationship in self.__mapper__.relationships:
                if relationship.key not in exclude:
                    value = getattr(self, relationship.key)
                    if value is not None:
                        if relationship.uselist:
                            result[relationship.key] = [item.to_dict() for item in value]
                        else:
                            result[relationship.key] = value.to_dict()
        
        return result
    
    @classmethod
    def create(cls, **kwargs):
        """创建新实例"""
        instance = cls(**kwargs)
        return instance.save()
    
    @classmethod
    def get_or_404(cls, id):
        """根据ID获取或返回404"""
        return cls.query.get_or_404(id)
    
    @classmethod
    def get_or_create(cls, defaults=None, **kwargs):
        """获取或创建"""
        instance = cls.query.filter_by(**kwargs).first()
        if instance:
            return instance, False
        else:
            params = dict((k, v) for k, v in kwargs.items())
            params.update(defaults or {})
            instance = cls(**params)
            instance.save()
            return instance, True
    
    def __repr__(self):
        return f'<{self.__class__.__name__} {self.id}>'

6.2.2 用户模型

# models/user.py
from werkzeug.security import generate_password_hash, check_password_hash
from flask_login import UserMixin
from sqlalchemy import event
from models.base import BaseModel, db
import re

class User(BaseModel, UserMixin):
    """用户模型"""
    __tablename__ = 'users'
    
    # 基本信息
    username = db.Column(db.String(80), unique=True, nullable=False, index=True)
    email = db.Column(db.String(120), unique=True, nullable=False, index=True)
    password_hash = db.Column(db.String(255), nullable=False)
    
    # 个人信息
    first_name = db.Column(db.String(50))
    last_name = db.Column(db.String(50))
    phone = db.Column(db.String(20))
    avatar = db.Column(db.String(255))
    bio = db.Column(db.Text)
    
    # 状态字段
    is_active = db.Column(db.Boolean, default=True, nullable=False)
    is_verified = db.Column(db.Boolean, default=False, nullable=False)
    is_admin = db.Column(db.Boolean, default=False, nullable=False)
    
    # 统计字段
    login_count = db.Column(db.Integer, default=0)
    last_login_at = db.Column(db.DateTime)
    last_login_ip = db.Column(db.String(45))
    
    # 关系
    posts = db.relationship('Post', backref='author', lazy='dynamic', cascade='all, delete-orphan')
    comments = db.relationship('Comment', backref='author', lazy='dynamic', cascade='all, delete-orphan')
    
    def set_password(self, password):
        """设置密码"""
        if not self.validate_password_strength(password):
            raise ValueError('密码强度不够')
        self.password_hash = generate_password_hash(password)
    
    def check_password(self, password):
        """验证密码"""
        return check_password_hash(self.password_hash, password)
    
    @staticmethod
    def validate_password_strength(password):
        """验证密码强度"""
        if len(password) < 8:
            return False
        if not re.search(r'[A-Z]', password):
            return False
        if not re.search(r'[a-z]', password):
            return False
        if not re.search(r'\d', password):
            return False
        return True
    
    @property
    def full_name(self):
        """全名"""
        if self.first_name and self.last_name:
            return f'{self.first_name} {self.last_name}'
        return self.username
    
    @property
    def is_authenticated(self):
        """是否已认证"""
        return True
    
    @property
    def is_anonymous(self):
        """是否匿名"""
        return False
    
    def get_id(self):
        """获取用户ID"""
        return str(self.id)
    
    def update_login_info(self, ip_address):
        """更新登录信息"""
        self.login_count += 1
        self.last_login_at = datetime.utcnow()
        self.last_login_ip = ip_address
        db.session.commit()
    
    def get_posts_count(self):
        """获取文章数量"""
        return self.posts.count()
    
    def get_comments_count(self):
        """获取评论数量"""
        return self.comments.count()
    
    def to_dict(self, include_sensitive=False):
        """转换为字典"""
        exclude = ['password_hash'] if not include_sensitive else []
        result = super().to_dict(exclude=exclude)
        result['full_name'] = self.full_name
        result['posts_count'] = self.get_posts_count()
        result['comments_count'] = self.get_comments_count()
        return result
    
    def __repr__(self):
        return f'<User {self.username}>'

# 事件监听器
@event.listens_for(User.username, 'set')
def validate_username(target, value, oldvalue, initiator):
    """验证用户名"""
    if value and not re.match(r'^[a-zA-Z0-9_]{3,20}$', value):
        raise ValueError('用户名只能包含字母、数字和下划线,长度3-20个字符')

@event.listens_for(User.email, 'set')
def validate_email(target, value, oldvalue, initiator):
    """验证邮箱"""
    if value and not re.match(r'^[\w\.-]+@[\w\.-]+\.\w+$', value):
        raise ValueError('邮箱格式不正确')

6.2.3 文章和评论模型

# models/blog.py
from sqlalchemy import text
from models.base import BaseModel, db
from datetime import datetime

# 文章标签关联表
post_tags = db.Table('post_tags',
    db.Column('post_id', db.Integer, db.ForeignKey('posts.id'), primary_key=True),
    db.Column('tag_id', db.Integer, db.ForeignKey('tags.id'), primary_key=True)
)

class Category(BaseModel):
    """分类模型"""
    __tablename__ = 'categories'
    
    name = db.Column(db.String(50), unique=True, nullable=False)
    slug = db.Column(db.String(50), unique=True, nullable=False)
    description = db.Column(db.Text)
    color = db.Column(db.String(7), default='#007bff')  # 十六进制颜色
    
    # 层级关系
    parent_id = db.Column(db.Integer, db.ForeignKey('categories.id'))
    parent = db.relationship('Category', remote_side=[id], backref='children')
    
    # 统计字段
    posts_count = db.Column(db.Integer, default=0)
    
    # 关系
    posts = db.relationship('Post', backref='category', lazy='dynamic')
    
    def update_posts_count(self):
        """更新文章数量"""
        self.posts_count = self.posts.count()
        db.session.commit()
    
    def get_all_posts(self):
        """获取包含子分类的所有文章"""
        category_ids = [self.id]
        category_ids.extend([child.id for child in self.children])
        return Post.query.filter(Post.category_id.in_(category_ids))
    
    def __repr__(self):
        return f'<Category {self.name}>'

class Tag(BaseModel):
    """标签模型"""
    __tablename__ = 'tags'
    
    name = db.Column(db.String(30), unique=True, nullable=False)
    color = db.Column(db.String(7), default='#6c757d')
    
    # 统计字段
    posts_count = db.Column(db.Integer, default=0)
    
    def update_posts_count(self):
        """更新文章数量"""
        self.posts_count = db.session.query(post_tags).filter_by(tag_id=self.id).count()
        db.session.commit()
    
    def __repr__(self):
        return f'<Tag {self.name}>'

class Post(BaseModel):
    """文章模型"""
    __tablename__ = 'posts'
    
    # 基本信息
    title = db.Column(db.String(200), nullable=False)
    slug = db.Column(db.String(200), unique=True, nullable=False)
    summary = db.Column(db.Text)
    content = db.Column(db.Text, nullable=False)
    
    # 状态字段
    status = db.Column(db.Enum('draft', 'published', 'archived', name='post_status'), 
                      default='draft', nullable=False)
    is_featured = db.Column(db.Boolean, default=False)
    allow_comments = db.Column(db.Boolean, default=True)
    
    # 时间字段
    published_at = db.Column(db.DateTime)
    
    # 统计字段
    views_count = db.Column(db.Integer, default=0)
    likes_count = db.Column(db.Integer, default=0)
    comments_count = db.Column(db.Integer, default=0)
    
    # 外键
    author_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
    category_id = db.Column(db.Integer, db.ForeignKey('categories.id'))
    
    # 关系
    tags = db.relationship('Tag', secondary=post_tags, backref='posts')
    comments = db.relationship('Comment', backref='post', lazy='dynamic', 
                             cascade='all, delete-orphan')
    
    def publish(self):
        """发布文章"""
        self.status = 'published'
        self.published_at = datetime.utcnow()
        db.session.commit()
    
    def archive(self):
        """归档文章"""
        self.status = 'archived'
        db.session.commit()
    
    def increment_views(self):
        """增加浏览量"""
        self.views_count += 1
        db.session.commit()
    
    def update_comments_count(self):
        """更新评论数量"""
        self.comments_count = self.comments.filter_by(status='approved').count()
        db.session.commit()
    
    def get_reading_time(self):
        """计算阅读时间(分钟)"""
        words_per_minute = 200
        word_count = len(self.content.split())
        return max(1, round(word_count / words_per_minute))
    
    @property
    def is_published(self):
        """是否已发布"""
        return self.status == 'published' and self.published_at is not None
    
    def to_dict(self, include_content=True):
        """转换为字典"""
        result = super().to_dict()
        result['reading_time'] = self.get_reading_time()
        result['is_published'] = self.is_published
        
        if not include_content:
            result.pop('content', None)
        
        return result
    
    def __repr__(self):
        return f'<Post {self.title}>'

class Comment(BaseModel):
    """评论模型"""
    __tablename__ = 'comments'
    
    content = db.Column(db.Text, nullable=False)
    author_name = db.Column(db.String(50))  # 游客评论时使用
    author_email = db.Column(db.String(120))  # 游客评论时使用
    author_ip = db.Column(db.String(45))
    
    # 状态字段
    status = db.Column(db.Enum('pending', 'approved', 'rejected', name='comment_status'),
                      default='pending', nullable=False)
    
    # 层级关系
    parent_id = db.Column(db.Integer, db.ForeignKey('comments.id'))
    parent = db.relationship('Comment', remote_side=[id], backref='replies')
    
    # 外键
    post_id = db.Column(db.Integer, db.ForeignKey('posts.id'), nullable=False)
    author_id = db.Column(db.Integer, db.ForeignKey('users.id'))  # 可选,注册用户评论
    
    def approve(self):
        """批准评论"""
        self.status = 'approved'
        db.session.commit()
        # 更新文章评论数量
        self.post.update_comments_count()
    
    def reject(self):
        """拒绝评论"""
        self.status = 'rejected'
        db.session.commit()
        self.post.update_comments_count()
    
    @property
    def author_display_name(self):
        """显示名称"""
        if self.author:
            return self.author.full_name
        return self.author_name or '匿名用户'
    
    def __repr__(self):
        return f'<Comment {self.id} on Post {self.post_id}>'

# 事件监听器
@event.listens_for(Post, 'before_insert')
def generate_slug(mapper, connection, target):
    """自动生成slug"""
    if not target.slug:
        import re
        slug = re.sub(r'[^\w\s-]', '', target.title.lower())
        slug = re.sub(r'[-\s]+', '-', slug)
        target.slug = slug

@event.listens_for(Comment, 'after_insert')
def update_post_comments_count_after_insert(mapper, connection, target):
    """评论插入后更新文章评论数"""
    if target.status == 'approved':
        connection.execute(
            text("UPDATE posts SET comments_count = comments_count + 1 WHERE id = :post_id"),
            {'post_id': target.post_id}
        )

@event.listens_for(Comment, 'after_delete')
def update_post_comments_count_after_delete(mapper, connection, target):
    """评论删除后更新文章评论数"""
    if target.status == 'approved':
        connection.execute(
            text("UPDATE posts SET comments_count = comments_count - 1 WHERE id = :post_id"),
            {'post_id': target.post_id}
        )

6.3 数据库迁移

6.3.1 Flask-Migrate配置

# 安装Flask-Migrate
pip install Flask-Migrate

# 初始化迁移仓库
flask db init

# 生成迁移文件
flask db migrate -m "Initial migration"

# 应用迁移
flask db upgrade

# 查看迁移历史
flask db history

# 回滚迁移
flask db downgrade

6.3.2 迁移管理

# migrations/migration_manager.py
from flask_migrate import Migrate, upgrade, downgrade, current, history
from flask import current_app
import os

class MigrationManager:
    """迁移管理器"""
    
    def __init__(self, app=None, db=None):
        self.migrate = Migrate()
        if app and db:
            self.init_app(app, db)
    
    def init_app(self, app, db):
        """初始化应用"""
        self.migrate.init_app(app, db)
        self.register_commands(app)
    
    def register_commands(self, app):
        """注册CLI命令"""
        
        @app.cli.command()
        def db_status():
            """显示数据库状态"""
            try:
                current_rev = current()
                print(f"当前版本: {current_rev}")
                
                # 显示迁移历史
                revisions = history()
                print("\n迁移历史:")
                for rev in revisions:
                    print(f"  {rev.revision}: {rev.doc}")
                    
            except Exception as e:
                print(f"获取数据库状态失败: {e}")
        
        @app.cli.command()
        def db_upgrade():
            """升级数据库"""
            try:
                upgrade()
                print("数据库升级成功")
            except Exception as e:
                print(f"数据库升级失败: {e}")
        
        @app.cli.command()
        def db_downgrade():
            """降级数据库"""
            try:
                downgrade()
                print("数据库降级成功")
            except Exception as e:
                print(f"数据库降级失败: {e}")
        
        @app.cli.command()
        def db_backup():
            """备份数据库"""
            self.backup_database()
        
        @app.cli.command()
        def db_restore():
            """恢复数据库"""
            self.restore_database()
    
    def backup_database(self):
        """备份数据库"""
        from datetime import datetime
        import subprocess
        
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        backup_dir = 'backups'
        
        if not os.path.exists(backup_dir):
            os.makedirs(backup_dir)
        
        db_url = current_app.config['SQLALCHEMY_DATABASE_URI']
        
        if db_url.startswith('sqlite'):
            # SQLite备份
            db_path = db_url.replace('sqlite:///', '')
            backup_path = f"{backup_dir}/backup_{timestamp}.db"
            
            import shutil
            shutil.copy2(db_path, backup_path)
            print(f"SQLite数据库已备份到: {backup_path}")
            
        elif db_url.startswith('mysql'):
            # MySQL备份
            backup_path = f"{backup_dir}/backup_{timestamp}.sql"
            
            # 解析数据库连接信息
            from urllib.parse import urlparse
            parsed = urlparse(db_url)
            
            cmd = [
                'mysqldump',
                f'-h{parsed.hostname}',
                f'-P{parsed.port or 3306}',
                f'-u{parsed.username}',
                f'-p{parsed.password}',
                parsed.path[1:]  # 去掉开头的/
            ]
            
            with open(backup_path, 'w') as f:
                subprocess.run(cmd, stdout=f, check=True)
            
            print(f"MySQL数据库已备份到: {backup_path}")
    
    def restore_database(self, backup_file=None):
        """恢复数据库"""
        if not backup_file:
            # 选择最新的备份文件
            backup_dir = 'backups'
            if not os.path.exists(backup_dir):
                print("没有找到备份目录")
                return
            
            backup_files = [f for f in os.listdir(backup_dir) if f.startswith('backup_')]
            if not backup_files:
                print("没有找到备份文件")
                return
            
            backup_file = os.path.join(backup_dir, sorted(backup_files)[-1])
        
        db_url = current_app.config['SQLALCHEMY_DATABASE_URI']
        
        if db_url.startswith('sqlite') and backup_file.endswith('.db'):
            # SQLite恢复
            db_path = db_url.replace('sqlite:///', '')
            import shutil
            shutil.copy2(backup_file, db_path)
            print(f"SQLite数据库已从 {backup_file} 恢复")
            
        elif db_url.startswith('mysql') and backup_file.endswith('.sql'):
            # MySQL恢复
            from urllib.parse import urlparse
            import subprocess
            
            parsed = urlparse(db_url)
            
            cmd = [
                'mysql',
                f'-h{parsed.hostname}',
                f'-P{parsed.port or 3306}',
                f'-u{parsed.username}',
                f'-p{parsed.password}',
                parsed.path[1:]
            ]
            
            with open(backup_file, 'r') as f:
                subprocess.run(cmd, stdin=f, check=True)
            
            print(f"MySQL数据库已从 {backup_file} 恢复")

6.3.3 自定义迁移脚本

# migrations/versions/001_add_user_indexes.py
"""添加用户表索引

Revision ID: 001
Revises: 
Create Date: 2024-01-01 10:00:00.000000

"""
from alembic import op
import sqlalchemy as sa

# revision identifiers
revision = '001'
down_revision = None
branch_labels = None
depends_on = None

def upgrade():
    """升级操作"""
    # 添加索引
    op.create_index('idx_users_email', 'users', ['email'])
    op.create_index('idx_users_username', 'users', ['username'])
    op.create_index('idx_users_created_at', 'users', ['created_at'])
    
    # 添加复合索引
    op.create_index('idx_users_active_verified', 'users', ['is_active', 'is_verified'])
    
    # 添加全文索引(MySQL)
    op.execute("ALTER TABLE users ADD FULLTEXT(first_name, last_name, bio)")

def downgrade():
    """降级操作"""
    # 删除索引
    op.drop_index('idx_users_email', 'users')
    op.drop_index('idx_users_username', 'users')
    op.drop_index('idx_users_created_at', 'users')
    op.drop_index('idx_users_active_verified', 'users')
    
    # 删除全文索引
    op.execute("ALTER TABLE users DROP INDEX first_name")
# migrations/versions/002_add_post_search.py
"""添加文章搜索功能

Revision ID: 002
Revises: 001
Create Date: 2024-01-02 10:00:00.000000

"""
from alembic import op
import sqlalchemy as sa

revision = '002'
down_revision = '001'
branch_labels = None
depends_on = None

def upgrade():
    """升级操作"""
    # 添加搜索向量列(PostgreSQL)
    op.add_column('posts', sa.Column('search_vector', sa.Text))
    
    # 创建搜索索引
    op.create_index('idx_posts_search', 'posts', ['search_vector'])
    
    # 添加触发器更新搜索向量
    op.execute("""
        CREATE OR REPLACE FUNCTION update_post_search_vector() RETURNS trigger AS $$
        BEGIN
            NEW.search_vector := to_tsvector('english', 
                COALESCE(NEW.title, '') || ' ' || COALESCE(NEW.content, ''));
            RETURN NEW;
        END;
        $$ LANGUAGE plpgsql;
    """)
    
    op.execute("""
        CREATE TRIGGER update_post_search_vector_trigger
        BEFORE INSERT OR UPDATE ON posts
        FOR EACH ROW EXECUTE FUNCTION update_post_search_vector();
    """)
    
    # 更新现有数据
    op.execute("""
        UPDATE posts SET search_vector = to_tsvector('english', 
            COALESCE(title, '') || ' ' || COALESCE(content, ''));
    """)

def downgrade():
    """降级操作"""
    # 删除触发器
    op.execute("DROP TRIGGER IF EXISTS update_post_search_vector_trigger ON posts;")
    op.execute("DROP FUNCTION IF EXISTS update_post_search_vector();")
    
    # 删除索引和列
    op.drop_index('idx_posts_search', 'posts')
    op.drop_column('posts', 'search_vector')

6.4 查询优化

6.4.1 查询构建器

# utils/query_builder.py
from sqlalchemy import and_, or_, not_, func, desc, asc
from sqlalchemy.orm import joinedload, selectinload, contains_eager
from flask_sqlalchemy import Pagination

class QueryBuilder:
    """查询构建器"""
    
    def __init__(self, model):
        self.model = model
        self.query = model.query
        self._filters = []
        self._orders = []
        self._joins = []
        self._options = []
    
    def filter_by(self, **kwargs):
        """按字段过滤"""
        for key, value in kwargs.items():
            if hasattr(self.model, key):
                self._filters.append(getattr(self.model, key) == value)
        return self
    
    def filter(self, *conditions):
        """添加过滤条件"""
        self._filters.extend(conditions)
        return self
    
    def search(self, term, fields):
        """搜索"""
        if term:
            search_conditions = []
            for field in fields:
                if hasattr(self.model, field):
                    attr = getattr(self.model, field)
                    search_conditions.append(attr.ilike(f'%{term}%'))
            
            if search_conditions:
                self._filters.append(or_(*search_conditions))
        return self
    
    def date_range(self, field, start_date=None, end_date=None):
        """日期范围过滤"""
        if hasattr(self.model, field):
            attr = getattr(self.model, field)
            if start_date:
                self._filters.append(attr >= start_date)
            if end_date:
                self._filters.append(attr <= end_date)
        return self
    
    def order_by(self, field, direction='asc'):
        """排序"""
        if hasattr(self.model, field):
            attr = getattr(self.model, field)
            if direction.lower() == 'desc':
                self._orders.append(desc(attr))
            else:
                self._orders.append(asc(attr))
        return self
    
    def join(self, *args, **kwargs):
        """连接"""
        self._joins.append((args, kwargs))
        return self
    
    def options(self, *options):
        """查询选项"""
        self._options.extend(options)
        return self
    
    def eager_load(self, *relationships):
        """预加载关系"""
        for rel in relationships:
            self._options.append(joinedload(rel))
        return self
    
    def select_load(self, *relationships):
        """选择加载关系"""
        for rel in relationships:
            self._options.append(selectinload(rel))
        return self
    
    def build(self):
        """构建查询"""
        query = self.query
        
        # 应用连接
        for args, kwargs in self._joins:
            query = query.join(*args, **kwargs)
        
        # 应用过滤条件
        if self._filters:
            query = query.filter(and_(*self._filters))
        
        # 应用排序
        if self._orders:
            query = query.order_by(*self._orders)
        
        # 应用查询选项
        if self._options:
            query = query.options(*self._options)
        
        return query
    
    def paginate(self, page=1, per_page=20, error_out=False):
        """分页"""
        query = self.build()
        return query.paginate(
            page=page,
            per_page=per_page,
            error_out=error_out
        )
    
    def all(self):
        """获取所有结果"""
        return self.build().all()
    
    def first(self):
        """获取第一个结果"""
        return self.build().first()
    
    def count(self):
        """获取数量"""
        return self.build().count()
    
    def exists(self):
        """检查是否存在"""
        return self.build().first() is not None

# 使用示例
def get_posts_query():
    """获取文章查询构建器"""
    return QueryBuilder(Post)

def search_posts(search_term=None, category_id=None, tag_ids=None, 
                author_id=None, status='published', page=1, per_page=10):
    """搜索文章"""
    builder = get_posts_query()
    
    # 基本过滤
    builder.filter_by(status=status)
    
    # 搜索
    if search_term:
        builder.search(search_term, ['title', 'content', 'summary'])
    
    # 分类过滤
    if category_id:
        builder.filter_by(category_id=category_id)
    
    # 标签过滤
    if tag_ids:
        builder.join(Post.tags).filter(Tag.id.in_(tag_ids))
    
    # 作者过滤
    if author_id:
        builder.filter_by(author_id=author_id)
    
    # 预加载关系
    builder.eager_load('author', 'category', 'tags')
    
    # 排序
    builder.order_by('created_at', 'desc')
    
    return builder.paginate(page=page, per_page=per_page)

6.4.2 查询优化技巧

# utils/query_optimizer.py
from sqlalchemy import func, text
from sqlalchemy.orm import load_only, defer
from flask import current_app
import time

class QueryOptimizer:
    """查询优化器"""
    
    @staticmethod
    def optimize_pagination(query, page, per_page):
        """优化分页查询"""
        # 使用窗口函数优化大偏移量分页
        if page > 100:  # 大偏移量时使用游标分页
            return QueryOptimizer.cursor_paginate(query, page, per_page)
        else:
            return query.paginate(page=page, per_page=per_page, error_out=False)
    
    @staticmethod
    def cursor_paginate(query, page, per_page):
        """游标分页"""
        # 计算偏移量
        offset = (page - 1) * per_page
        
        # 使用子查询获取ID
        subquery = query.with_entities(query.column_descriptions[0]['entity'].id)\
                        .offset(offset).limit(per_page).subquery()
        
        # 使用ID列表查询完整数据
        ids = [row.id for row in subquery]
        if not ids:
            return []
        
        return query.filter(query.column_descriptions[0]['entity'].id.in_(ids)).all()
    
    @staticmethod
    def optimize_n_plus_one(query, relationships):
        """解决N+1查询问题"""
        from sqlalchemy.orm import joinedload, selectinload
        
        options = []
        for rel in relationships:
            if isinstance(rel, str):
                # 简单关系使用joinedload
                options.append(joinedload(rel))
            elif isinstance(rel, dict):
                # 复杂关系配置
                rel_name = rel['name']
                strategy = rel.get('strategy', 'joined')
                
                if strategy == 'select':
                    options.append(selectinload(rel_name))
                else:
                    options.append(joinedload(rel_name))
        
        return query.options(*options)
    
    @staticmethod
    def optimize_columns(query, model, include_fields=None, exclude_fields=None):
        """优化查询列"""
        if include_fields:
            # 只查询指定字段
            columns = [getattr(model, field) for field in include_fields 
                      if hasattr(model, field)]
            return query.options(load_only(*columns))
        
        elif exclude_fields:
            # 延迟加载指定字段
            deferred_columns = [defer(getattr(model, field)) for field in exclude_fields 
                               if hasattr(model, field)]
            return query.options(*deferred_columns)
        
        return query
    
    @staticmethod
    def add_query_cache(query, cache_key, timeout=300):
        """添加查询缓存"""
        from flask import current_app
        
        if hasattr(current_app, 'cache'):
            cached_result = current_app.cache.get(cache_key)
            if cached_result is not None:
                return cached_result
            
            result = query.all()
            current_app.cache.set(cache_key, result, timeout=timeout)
            return result
        
        return query.all()
    
    @staticmethod
    def explain_query(query):
        """分析查询执行计划"""
        sql = str(query.statement.compile(compile_kwargs={"literal_binds": True}))
        
        if 'mysql' in current_app.config['SQLALCHEMY_DATABASE_URI']:
            explain_sql = f"EXPLAIN {sql}"
        elif 'postgresql' in current_app.config['SQLALCHEMY_DATABASE_URI']:
            explain_sql = f"EXPLAIN ANALYZE {sql}"
        else:
            explain_sql = f"EXPLAIN QUERY PLAN {sql}"
        
        result = db.session.execute(text(explain_sql))
        return [dict(row) for row in result]
    
    @staticmethod
    def profile_query(query, description=""):
        """查询性能分析"""
        start_time = time.time()
        result = query.all()
        end_time = time.time()
        
        execution_time = end_time - start_time
        
        if execution_time > 0.1:  # 慢查询警告
            current_app.logger.warning(
                f"慢查询检测 ({execution_time:.3f}s): {description}"
            )
        
        return result, execution_time

# 使用示例
def get_optimized_posts(page=1, per_page=10, include_author=True, include_tags=True):
    """获取优化的文章列表"""
    query = Post.query.filter_by(status='published')
    
    # 解决N+1问题
    relationships = ['author', 'category']
    if include_tags:
        relationships.append('tags')
    
    query = QueryOptimizer.optimize_n_plus_one(query, relationships)
    
    # 优化列查询(排除大字段)
    query = QueryOptimizer.optimize_columns(
        query, Post, exclude_fields=['content']
    )
    
    # 优化分页
    posts = QueryOptimizer.optimize_pagination(query, page, per_page)
    
    return posts

def get_cached_popular_posts(limit=10):
    """获取缓存的热门文章"""
    query = Post.query.filter_by(status='published')\
                     .order_by(Post.views_count.desc())\
                     .limit(limit)
    
    cache_key = f"popular_posts_{limit}"
    return QueryOptimizer.add_query_cache(query, cache_key, timeout=600)

本章小结

本章详细介绍了Flask的数据库集成,包括:

  1. SQLAlchemy基础:安装配置、连接管理、事件监听
  2. 数据模型设计:基础模型、用户模型、关系设计
  3. 数据库迁移:Flask-Migrate使用、迁移管理、自定义脚本
  4. 查询优化:查询构建器、性能优化、缓存策略

掌握这些技能能够帮助你构建高效、可维护的数据库应用。

下一章预告

下一章我们将学习用户认证与授权,包括:

  • Flask-Login用户会话管理
  • 密码加密和验证
  • 角色权限系统
  • OAuth第三方登录
  • JWT令牌认证

练习题

  1. 模型设计:设计一个电商系统的数据模型(商品、订单、用户等)
  2. 查询优化:优化一个复杂的多表关联查询
  3. 迁移脚本:编写一个数据迁移脚本,将旧数据格式转换为新格式
  4. 性能分析:分析并优化一个慢查询
  5. 缓存策略:为文章系统实现多级缓存策略