8.1 缓存策略概述
8.1.1 缓存基础概念
缓存是提高Web应用性能的重要手段,通过将频繁访问的数据存储在快速访问的存储介质中,减少数据库查询和计算开销。
# 缓存层次结构示例
class CacheLayer:
"""缓存层次结构"""
def __init__(self):
self.levels = {
'browser': 'HTTP缓存头',
'cdn': 'CDN边缘缓存',
'reverse_proxy': 'Nginx/Apache缓存',
'application': 'Flask应用缓存',
'database': '数据库查询缓存',
'memory': '内存缓存(Redis/Memcached)',
'disk': '磁盘缓存'
}
def get_cache_strategy(self, data_type):
"""根据数据类型选择缓存策略"""
strategies = {
'static_files': ['browser', 'cdn', 'reverse_proxy'],
'user_session': ['memory'],
'database_query': ['memory', 'application'],
'computed_data': ['memory', 'application'],
'api_response': ['memory', 'application', 'browser']
}
return strategies.get(data_type, ['application'])
8.1.2 缓存模式
# cache/patterns.py
from abc import ABC, abstractmethod
from typing import Any, Optional, Callable
import time
import json
import hashlib
class CachePattern(ABC):
"""缓存模式基类"""
@abstractmethod
def get(self, key: str) -> Optional[Any]:
pass
@abstractmethod
def set(self, key: str, value: Any, timeout: int = None) -> bool:
pass
@abstractmethod
def delete(self, key: str) -> bool:
pass
class CacheAside(CachePattern):
"""Cache-Aside模式(旁路缓存)"""
def __init__(self, cache_client, data_source):
self.cache = cache_client
self.data_source = data_source
def get_data(self, key: str, loader: Callable = None):
"""获取数据"""
# 1. 先从缓存获取
data = self.cache.get(key)
if data is not None:
return data
# 2. 缓存未命中,从数据源获取
if loader:
data = loader(key)
else:
data = self.data_source.get(key)
# 3. 将数据写入缓存
if data is not None:
self.cache.set(key, data)
return data
def update_data(self, key: str, value: Any):
"""更新数据"""
# 1. 更新数据源
self.data_source.update(key, value)
# 2. 删除缓存(让下次读取时重新加载)
self.cache.delete(key)
def get(self, key: str) -> Optional[Any]:
return self.cache.get(key)
def set(self, key: str, value: Any, timeout: int = None) -> bool:
return self.cache.set(key, value, timeout)
def delete(self, key: str) -> bool:
return self.cache.delete(key)
class WriteThrough(CachePattern):
"""Write-Through模式(写穿透)"""
def __init__(self, cache_client, data_source):
self.cache = cache_client
self.data_source = data_source
def get_data(self, key: str):
"""获取数据"""
# 先从缓存获取
data = self.cache.get(key)
if data is not None:
return data
# 缓存未命中,从数据源获取并写入缓存
data = self.data_source.get(key)
if data is not None:
self.cache.set(key, data)
return data
def update_data(self, key: str, value: Any):
"""更新数据"""
# 同时更新缓存和数据源
self.data_source.update(key, value)
self.cache.set(key, value)
def get(self, key: str) -> Optional[Any]:
return self.cache.get(key)
def set(self, key: str, value: Any, timeout: int = None) -> bool:
return self.cache.set(key, value, timeout)
def delete(self, key: str) -> bool:
return self.cache.delete(key)
class WriteBack(CachePattern):
"""Write-Back模式(写回)"""
def __init__(self, cache_client, data_source, sync_interval=300):
self.cache = cache_client
self.data_source = data_source
self.sync_interval = sync_interval
self.dirty_keys = set()
self.last_sync = time.time()
def get_data(self, key: str):
"""获取数据"""
data = self.cache.get(key)
if data is not None:
return data
data = self.data_source.get(key)
if data is not None:
self.cache.set(key, data)
return data
def update_data(self, key: str, value: Any):
"""更新数据"""
# 只更新缓存,标记为脏数据
self.cache.set(key, value)
self.dirty_keys.add(key)
# 检查是否需要同步
if time.time() - self.last_sync > self.sync_interval:
self.sync_to_data_source()
def sync_to_data_source(self):
"""同步脏数据到数据源"""
for key in self.dirty_keys:
value = self.cache.get(key)
if value is not None:
self.data_source.update(key, value)
self.dirty_keys.clear()
self.last_sync = time.time()
def get(self, key: str) -> Optional[Any]:
return self.cache.get(key)
def set(self, key: str, value: Any, timeout: int = None) -> bool:
return self.cache.set(key, value, timeout)
def delete(self, key: str) -> bool:
return self.cache.delete(key)
8.2 Redis缓存集成
8.2.1 Redis配置和连接
# config/redis_config.py
import redis
from redis.sentinel import Sentinel
from flask import current_app
import json
import pickle
from typing import Any, Optional
class RedisConfig:
"""Redis配置类"""
def __init__(self, app=None):
self.redis_client = None
if app is not None:
self.init_app(app)
def init_app(self, app):
"""初始化Redis连接"""
redis_config = app.config.get('REDIS_CONFIG', {})
# 单机模式
if redis_config.get('mode') == 'standalone':
self.redis_client = redis.Redis(
host=redis_config.get('host', 'localhost'),
port=redis_config.get('port', 6379),
db=redis_config.get('db', 0),
password=redis_config.get('password'),
decode_responses=True,
socket_connect_timeout=redis_config.get('connect_timeout', 5),
socket_timeout=redis_config.get('socket_timeout', 5),
retry_on_timeout=True,
health_check_interval=30
)
# 哨兵模式
elif redis_config.get('mode') == 'sentinel':
sentinels = redis_config.get('sentinels', [('localhost', 26379)])
sentinel = Sentinel(sentinels)
self.redis_client = sentinel.master_for(
redis_config.get('service_name', 'mymaster'),
decode_responses=True
)
# 集群模式
elif redis_config.get('mode') == 'cluster':
from rediscluster import RedisCluster
startup_nodes = redis_config.get('startup_nodes', [{'host': 'localhost', 'port': 7000}])
self.redis_client = RedisCluster(
startup_nodes=startup_nodes,
decode_responses=True,
skip_full_coverage_check=True
)
# 默认单机模式
else:
self.redis_client = redis.Redis(
host='localhost',
port=6379,
db=0,
decode_responses=True
)
# 测试连接
try:
self.redis_client.ping()
app.logger.info('Redis连接成功')
except Exception as e:
app.logger.error(f'Redis连接失败: {e}')
self.redis_client = None
def get_client(self):
"""获取Redis客户端"""
return self.redis_client
class RedisCache:
"""Redis缓存操作类"""
def __init__(self, redis_client, prefix='flask_cache:', serializer='json'):
self.redis = redis_client
self.prefix = prefix
self.serializer = serializer
def _make_key(self, key: str) -> str:
"""生成缓存键"""
return f"{self.prefix}{key}"
def _serialize(self, value: Any) -> str:
"""序列化值"""
if self.serializer == 'json':
return json.dumps(value, ensure_ascii=False)
elif self.serializer == 'pickle':
return pickle.dumps(value)
else:
return str(value)
def _deserialize(self, value: str) -> Any:
"""反序列化值"""
if value is None:
return None
try:
if self.serializer == 'json':
return json.loads(value)
elif self.serializer == 'pickle':
return pickle.loads(value)
else:
return value
except:
return value
def get(self, key: str) -> Optional[Any]:
"""获取缓存值"""
if not self.redis:
return None
try:
value = self.redis.get(self._make_key(key))
return self._deserialize(value)
except Exception as e:
current_app.logger.error(f'Redis get error: {e}')
return None
def set(self, key: str, value: Any, timeout: int = None) -> bool:
"""设置缓存值"""
if not self.redis:
return False
try:
serialized_value = self._serialize(value)
cache_key = self._make_key(key)
if timeout:
return self.redis.setex(cache_key, timeout, serialized_value)
else:
return self.redis.set(cache_key, serialized_value)
except Exception as e:
current_app.logger.error(f'Redis set error: {e}')
return False
def delete(self, key: str) -> bool:
"""删除缓存值"""
if not self.redis:
return False
try:
return bool(self.redis.delete(self._make_key(key)))
except Exception as e:
current_app.logger.error(f'Redis delete error: {e}')
return False
def exists(self, key: str) -> bool:
"""检查键是否存在"""
if not self.redis:
return False
try:
return bool(self.redis.exists(self._make_key(key)))
except Exception as e:
current_app.logger.error(f'Redis exists error: {e}')
return False
def expire(self, key: str, timeout: int) -> bool:
"""设置键过期时间"""
if not self.redis:
return False
try:
return bool(self.redis.expire(self._make_key(key), timeout))
except Exception as e:
current_app.logger.error(f'Redis expire error: {e}')
return False
def ttl(self, key: str) -> int:
"""获取键剩余生存时间"""
if not self.redis:
return -1
try:
return self.redis.ttl(self._make_key(key))
except Exception as e:
current_app.logger.error(f'Redis ttl error: {e}')
return -1
def clear_pattern(self, pattern: str) -> int:
"""清除匹配模式的键"""
if not self.redis:
return 0
try:
keys = self.redis.keys(self._make_key(pattern))
if keys:
return self.redis.delete(*keys)
return 0
except Exception as e:
current_app.logger.error(f'Redis clear_pattern error: {e}')
return 0
def get_stats(self) -> dict:
"""获取Redis统计信息"""
if not self.redis:
return {}
try:
info = self.redis.info()
return {
'used_memory': info.get('used_memory_human'),
'connected_clients': info.get('connected_clients'),
'total_commands_processed': info.get('total_commands_processed'),
'keyspace_hits': info.get('keyspace_hits'),
'keyspace_misses': info.get('keyspace_misses'),
'hit_rate': info.get('keyspace_hits', 0) / max(info.get('keyspace_hits', 0) + info.get('keyspace_misses', 0), 1)
}
except Exception as e:
current_app.logger.error(f'Redis get_stats error: {e}')
return {}
8.2.2 缓存装饰器
# cache/decorators.py
from functools import wraps
from flask import current_app, request, g
import hashlib
import inspect
import time
from typing import Callable, Any, Optional
def cache_key_generator(func_name: str, args: tuple, kwargs: dict,
include_user: bool = False,
include_request: bool = False) -> str:
"""生成缓存键"""
key_parts = [func_name]
# 添加参数
if args:
key_parts.extend([str(arg) for arg in args])
if kwargs:
sorted_kwargs = sorted(kwargs.items())
key_parts.extend([f"{k}:{v}" for k, v in sorted_kwargs])
# 添加用户信息
if include_user and hasattr(g, 'current_user') and g.current_user:
key_parts.append(f"user:{g.current_user.id}")
# 添加请求信息
if include_request:
key_parts.append(f"path:{request.path}")
if request.args:
sorted_args = sorted(request.args.items())
key_parts.extend([f"{k}:{v}" for k, v in sorted_args])
# 生成哈希键
key_string = ":".join(key_parts)
return hashlib.md5(key_string.encode('utf-8')).hexdigest()
def cached(timeout: int = 300,
key_prefix: str = None,
include_user: bool = False,
include_request: bool = False,
condition: Callable = None,
cache_client: str = 'default'):
"""缓存装饰器"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
# 检查条件
if condition and not condition(*args, **kwargs):
return func(*args, **kwargs)
# 获取缓存客户端
cache = getattr(current_app, f'cache_{cache_client}', None)
if not cache:
return func(*args, **kwargs)
# 生成缓存键
func_name = f"{func.__module__}.{func.__name__}"
if key_prefix:
func_name = f"{key_prefix}:{func_name}"
cache_key = cache_key_generator(
func_name, args, kwargs,
include_user, include_request
)
# 尝试从缓存获取
cached_result = cache.get(cache_key)
if cached_result is not None:
return cached_result
# 执行函数并缓存结果
result = func(*args, **kwargs)
cache.set(cache_key, result, timeout)
return result
# 添加缓存控制方法
def invalidate(*args, **kwargs):
"""使缓存失效"""
cache = getattr(current_app, f'cache_{cache_client}', None)
if cache:
func_name = f"{func.__module__}.{func.__name__}"
if key_prefix:
func_name = f"{key_prefix}:{func_name}"
cache_key = cache_key_generator(
func_name, args, kwargs,
include_user, include_request
)
cache.delete(cache_key)
wrapper.invalidate = invalidate
wrapper._cache_config = {
'timeout': timeout,
'key_prefix': key_prefix,
'include_user': include_user,
'include_request': include_request
}
return wrapper
return decorator
def cache_memoize(timeout: int = 300, key_prefix: str = None):
"""记忆化缓存装饰器"""
def decorator(func: Callable) -> Callable:
cache_dict = {}
@wraps(func)
def wrapper(*args, **kwargs):
# 生成缓存键
func_name = f"{func.__module__}.{func.__name__}"
if key_prefix:
func_name = f"{key_prefix}:{func_name}"
cache_key = cache_key_generator(func_name, args, kwargs)
# 检查内存缓存
if cache_key in cache_dict:
cached_time, cached_result = cache_dict[cache_key]
if time.time() - cached_time < timeout:
return cached_result
else:
del cache_dict[cache_key]
# 执行函数并缓存结果
result = func(*args, **kwargs)
cache_dict[cache_key] = (time.time(), result)
return result
def clear_cache():
"""清除缓存"""
cache_dict.clear()
wrapper.clear_cache = clear_cache
return wrapper
return decorator
def cache_region(region_name: str, timeout: int = 300):
"""区域缓存装饰器"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
cache = getattr(current_app, 'cache_default', None)
if not cache:
return func(*args, **kwargs)
# 生成区域缓存键
func_name = f"{func.__module__}.{func.__name__}"
cache_key = f"region:{region_name}:{cache_key_generator(func_name, args, kwargs)}"
# 尝试从缓存获取
cached_result = cache.get(cache_key)
if cached_result is not None:
return cached_result
# 执行函数并缓存结果
result = func(*args, **kwargs)
cache.set(cache_key, result, timeout)
return result
def invalidate_region():
"""使整个区域缓存失效"""
cache = getattr(current_app, 'cache_default', None)
if cache:
cache.clear_pattern(f"region:{region_name}:*")
wrapper.invalidate_region = invalidate_region
return wrapper
return decorator
def conditional_cache(condition_func: Callable, timeout: int = 300):
"""条件缓存装饰器"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
# 检查缓存条件
if not condition_func(*args, **kwargs):
return func(*args, **kwargs)
# 应用缓存
return cached(timeout=timeout)(func)(*args, **kwargs)
return wrapper
return decorator
8.6 数据库性能优化
8.6.1 查询优化
# database/optimization.py
from sqlalchemy import event, text
from sqlalchemy.engine import Engine
from flask_sqlalchemy import SQLAlchemy
from flask import current_app
import time
import logging
from typing import List, Dict, Any
class DatabaseOptimizer:
"""数据库优化器"""
def __init__(self, db: SQLAlchemy):
self.db = db
self.slow_queries = []
self.query_stats = {}
self.setup_query_monitoring()
def setup_query_monitoring(self):
"""设置查询监控"""
@event.listens_for(Engine, "before_cursor_execute")
def receive_before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
context._query_start_time = time.time()
@event.listens_for(Engine, "after_cursor_execute")
def receive_after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
total = time.time() - context._query_start_time
# 记录慢查询
slow_query_threshold = current_app.config.get('SLOW_QUERY_THRESHOLD', 1.0)
if total > slow_query_threshold:
self.slow_queries.append({
'query': statement,
'parameters': parameters,
'duration': total,
'timestamp': time.time()
})
# 只保留最近100个慢查询
if len(self.slow_queries) > 100:
self.slow_queries = self.slow_queries[-100:]
# 更新查询统计
query_hash = hash(statement)
if query_hash not in self.query_stats:
self.query_stats[query_hash] = {
'query': statement[:200], # 只保留前200个字符
'count': 0,
'total_time': 0,
'avg_time': 0,
'max_time': 0
}
stats = self.query_stats[query_hash]
stats['count'] += 1
stats['total_time'] += total
stats['avg_time'] = stats['total_time'] / stats['count']
stats['max_time'] = max(stats['max_time'], total)
def get_slow_queries(self, limit: int = 50) -> List[Dict]:
"""获取慢查询列表"""
return sorted(self.slow_queries, key=lambda x: x['duration'], reverse=True)[:limit]
def get_query_stats(self) -> List[Dict]:
"""获取查询统计"""
return sorted(self.query_stats.values(), key=lambda x: x['avg_time'], reverse=True)
def analyze_table_usage(self) -> Dict[str, Any]:
"""分析表使用情况"""
try:
# 获取表大小信息(MySQL示例)
result = self.db.session.execute(text("""
SELECT
table_name,
table_rows,
data_length,
index_length,
(data_length + index_length) as total_size
FROM information_schema.tables
WHERE table_schema = DATABASE()
ORDER BY total_size DESC
"""))
tables = []
for row in result:
tables.append({
'table_name': row.table_name,
'rows': row.table_rows,
'data_size': row.data_length,
'index_size': row.index_length,
'total_size': row.total_size
})
return {'tables': tables}
except Exception as e:
current_app.logger.error(f'分析表使用情况失败: {e}')
return {'error': str(e)}
def suggest_indexes(self) -> List[Dict]:
"""建议索引优化"""
suggestions = []
# 分析慢查询中的WHERE条件
for query_info in self.slow_queries:
query = query_info['query'].upper()
# 简单的WHERE条件分析
if 'WHERE' in query:
# 提取WHERE条件中的列名(简化版本)
where_part = query.split('WHERE')[1].split('ORDER BY')[0].split('GROUP BY')[0]
# 查找可能需要索引的列
import re
columns = re.findall(r'(\w+)\s*[=<>]', where_part)
for column in columns:
suggestions.append({
'type': 'index',
'column': column,
'reason': f'在慢查询中频繁使用WHERE条件',
'query_sample': query_info['query'][:100]
})
# 去重
unique_suggestions = []
seen = set()
for suggestion in suggestions:
key = (suggestion['type'], suggestion['column'])
if key not in seen:
seen.add(key)
unique_suggestions.append(suggestion)
return unique_suggestions[:10] # 返回前10个建议
class ConnectionPoolOptimizer:
"""连接池优化器"""
@staticmethod
def optimize_pool_settings(app):
"""优化连接池设置"""
# 根据应用规模调整连接池大小
pool_size = app.config.get('SQLALCHEMY_ENGINE_OPTIONS', {}).get('pool_size', 10)
max_overflow = app.config.get('SQLALCHEMY_ENGINE_OPTIONS', {}).get('max_overflow', 20)
# 动态调整建议
import psutil
cpu_count = psutil.cpu_count()
recommended_pool_size = min(cpu_count * 2, 20)
recommended_max_overflow = recommended_pool_size
if pool_size != recommended_pool_size:
app.logger.info(f'建议调整连接池大小: 当前={pool_size}, 建议={recommended_pool_size}')
if max_overflow != recommended_max_overflow:
app.logger.info(f'建议调整最大溢出连接数: 当前={max_overflow}, 建议={recommended_max_overflow}')
return {
'current_pool_size': pool_size,
'recommended_pool_size': recommended_pool_size,
'current_max_overflow': max_overflow,
'recommended_max_overflow': recommended_max_overflow
}
@staticmethod
def monitor_connection_usage(db: SQLAlchemy):
"""监控连接使用情况"""
try:
engine = db.engine
pool = engine.pool
return {
'pool_size': pool.size(),
'checked_in': pool.checkedin(),
'checked_out': pool.checkedout(),
'overflow': pool.overflow(),
'invalid': pool.invalid()
}
except Exception as e:
current_app.logger.error(f'监控连接使用情况失败: {e}')
return {'error': str(e)}
class QueryOptimizer:
"""查询优化器"""
@staticmethod
def optimize_pagination(query, page: int, per_page: int, count_query=None):
"""优化分页查询"""
# 使用窗口函数优化大偏移量分页
if page > 100: # 大偏移量时使用游标分页
# 这里需要根据具体业务逻辑实现游标分页
pass
# 标准分页
offset = (page - 1) * per_page
items = query.offset(offset).limit(per_page).all()
# 优化计数查询
if count_query is None:
# 对于大表,可以使用估算计数
total = query.count()
else:
total = count_query.scalar()
return {
'items': items,
'total': total,
'page': page,
'per_page': per_page,
'pages': (total + per_page - 1) // per_page
}
@staticmethod
def optimize_joins(query):
"""优化JOIN查询"""
# 使用joinedload预加载关联数据
from sqlalchemy.orm import joinedload, selectinload
# 示例:优化用户和文章的关联查询
# query = query.options(joinedload(User.articles))
return query
@staticmethod
def batch_operations(model_class, operations: List[Dict], batch_size: int = 1000):
"""批量操作优化"""
from sqlalchemy import insert, update, delete
results = []
for i in range(0, len(operations), batch_size):
batch = operations[i:i + batch_size]
try:
if batch[0]['operation'] == 'insert':
# 批量插入
stmt = insert(model_class).values([op['data'] for op in batch])
result = current_app.db.session.execute(stmt)
results.append(result.rowcount)
elif batch[0]['operation'] == 'update':
# 批量更新
for op in batch:
stmt = update(model_class).where(
model_class.id == op['id']
).values(op['data'])
current_app.db.session.execute(stmt)
results.append(len(batch))
elif batch[0]['operation'] == 'delete':
# 批量删除
ids = [op['id'] for op in batch]
stmt = delete(model_class).where(model_class.id.in_(ids))
result = current_app.db.session.execute(stmt)
results.append(result.rowcount)
current_app.db.session.commit()
except Exception as e:
current_app.db.session.rollback()
current_app.logger.error(f'批量操作失败: {e}')
raise
return sum(results)
8.6.2 数据库连接优化
# database/connection.py
from sqlalchemy import create_engine, pool
from sqlalchemy.pool import QueuePool, NullPool
from flask import current_app
import time
from typing import Dict, Any
class DatabaseConnectionManager:
"""数据库连接管理器"""
def __init__(self):
self.connection_stats = {
'total_connections': 0,
'active_connections': 0,
'connection_errors': 0,
'avg_connection_time': 0
}
def create_optimized_engine(self, database_url: str, **kwargs) -> Any:
"""创建优化的数据库引擎"""
# 默认优化配置
default_config = {
'poolclass': QueuePool,
'pool_size': 10,
'max_overflow': 20,
'pool_pre_ping': True,
'pool_recycle': 3600, # 1小时回收连接
'pool_timeout': 30,
'echo': current_app.config.get('SQLALCHEMY_ECHO', False),
'echo_pool': current_app.config.get('SQLALCHEMY_ECHO_POOL', False)
}
# 合并用户配置
config = {**default_config, **kwargs}
# 根据数据库类型优化
if 'mysql' in database_url:
config.update({
'connect_args': {
'charset': 'utf8mb4',
'autocommit': True,
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30
}
})
elif 'postgresql' in database_url:
config.update({
'connect_args': {
'connect_timeout': 10,
'application_name': 'flask_app'
}
})
engine = create_engine(database_url, **config)
# 添加连接事件监听
self._setup_connection_events(engine)
return engine
def _setup_connection_events(self, engine):
"""设置连接事件监听"""
from sqlalchemy import event
@event.listens_for(engine, "connect")
def receive_connect(dbapi_connection, connection_record):
self.connection_stats['total_connections'] += 1
connection_record.info['connect_time'] = time.time()
@event.listens_for(engine, "checkout")
def receive_checkout(dbapi_connection, connection_record, connection_proxy):
self.connection_stats['active_connections'] += 1
@event.listens_for(engine, "checkin")
def receive_checkin(dbapi_connection, connection_record):
self.connection_stats['active_connections'] -= 1
# 计算连接使用时间
if 'connect_time' in connection_record.info:
duration = time.time() - connection_record.info['connect_time']
# 更新平均连接时间
current_avg = self.connection_stats['avg_connection_time']
total_connections = self.connection_stats['total_connections']
self.connection_stats['avg_connection_time'] = (
(current_avg * (total_connections - 1) + duration) / total_connections
)
@event.listens_for(engine, "invalidate")
def receive_invalidate(dbapi_connection, connection_record, exception):
self.connection_stats['connection_errors'] += 1
current_app.logger.error(f'数据库连接失效: {exception}')
def get_connection_stats(self) -> Dict[str, Any]:
"""获取连接统计信息"""
return self.connection_stats.copy()
def health_check(self, engine) -> Dict[str, Any]:
"""数据库健康检查"""
try:
start_time = time.time()
# 执行简单查询测试连接
with engine.connect() as conn:
result = conn.execute('SELECT 1')
result.fetchone()
response_time = time.time() - start_time
# 获取连接池状态
pool_status = {
'size': engine.pool.size(),
'checked_in': engine.pool.checkedin(),
'checked_out': engine.pool.checkedout(),
'overflow': engine.pool.overflow(),
'invalid': engine.pool.invalid()
}
return {
'status': 'healthy',
'response_time': response_time,
'pool_status': pool_status,
'connection_stats': self.get_connection_stats()
}
except Exception as e:
return {
'status': 'unhealthy',
'error': str(e),
'connection_stats': self.get_connection_stats()
}
class ReadWriteSplitter:
"""读写分离器"""
def __init__(self, write_engine, read_engines: list):
self.write_engine = write_engine
self.read_engines = read_engines
self.current_read_index = 0
def get_read_engine(self):
"""获取读数据库引擎(轮询)"""
if not self.read_engines:
return self.write_engine
engine = self.read_engines[self.current_read_index]
self.current_read_index = (self.current_read_index + 1) % len(self.read_engines)
return engine
def get_write_engine(self):
"""获取写数据库引擎"""
return self.write_engine
def execute_read_query(self, query, params=None):
"""执行读查询"""
engine = self.get_read_engine()
with engine.connect() as conn:
return conn.execute(query, params or {})
def execute_write_query(self, query, params=None):
"""执行写查询"""
engine = self.get_write_engine()
with engine.connect() as conn:
return conn.execute(query, params or {})
8.7 部署优化
8.7.1 WSGI服务器优化
# deployment/wsgi_config.py
import multiprocessing
import os
from typing import Dict, Any
class WSGIOptimizer:
"""WSGI服务器优化器"""
@staticmethod
def get_gunicorn_config() -> Dict[str, Any]:
"""获取Gunicorn优化配置"""
# 计算最佳worker数量
cpu_count = multiprocessing.cpu_count()
workers = min(cpu_count * 2 + 1, 8) # 限制最大worker数
return {
# Worker配置
'workers': workers,
'worker_class': 'gevent', # 使用异步worker
'worker_connections': 1000,
'max_requests': 1000,
'max_requests_jitter': 100,
'timeout': 30,
'keepalive': 2,
# 内存管理
'preload_app': True,
'max_requests': 1000, # 防止内存泄漏
# 日志配置
'accesslog': '-',
'errorlog': '-',
'loglevel': 'info',
'access_log_format': '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s" %(D)s',
# 绑定配置
'bind': '0.0.0.0:8000',
'backlog': 2048,
# 安全配置
'limit_request_line': 4094,
'limit_request_fields': 100,
'limit_request_field_size': 8190,
}
@staticmethod
def get_uwsgi_config() -> Dict[str, Any]:
"""获取uWSGI优化配置"""
cpu_count = multiprocessing.cpu_count()
return {
# 进程配置
'processes': cpu_count,
'threads': 2,
'enable-threads': True,
# 内存管理
'max-requests': 1000,
'reload-on-rss': 512, # 512MB时重启worker
'evil-reload-on-rss': 1024, # 1GB时强制重启
# 网络配置
'http': '0.0.0.0:8000',
'listen': 1024,
'buffer-size': 32768,
# 性能优化
'master': True,
'vacuum': True,
'single-interpreter': True,
'lazy-apps': True,
# 日志配置
'disable-logging': False,
'log-4xx': True,
'log-5xx': True,
}
@staticmethod
def generate_gunicorn_config_file(config: Dict[str, Any], file_path: str):
"""生成Gunicorn配置文件"""
with open(file_path, 'w') as f:
f.write("# Gunicorn配置文件\n")
f.write("# 自动生成,请勿手动修改\n\n")
for key, value in config.items():
if isinstance(value, str):
f.write(f'{key} = "{value}"\n')
else:
f.write(f'{key} = {value}\n')
@staticmethod
def generate_uwsgi_config_file(config: Dict[str, Any], file_path: str):
"""生成uWSGI配置文件"""
with open(file_path, 'w') as f:
f.write("[uwsgi]\n")
f.write("# uWSGI配置文件\n")
f.write("# 自动生成,请勿手动修改\n\n")
for key, value in config.items():
f.write(f'{key} = {value}\n')
class NginxOptimizer:
"""Nginx优化器"""
@staticmethod
def generate_nginx_config(app_name: str, upstream_servers: list,
static_path: str = None) -> str:
"""生成Nginx配置"""
config = f"""
# Nginx配置 - {app_name}
# 自动生成,请根据实际情况调整
upstream {app_name}_backend {{
# 负载均衡配置
least_conn;
keepalive 32;
"""
# 添加上游服务器
for server in upstream_servers:
config += f" server {server} max_fails=3 fail_timeout=30s;\n"
config += f"""
}}
server {{
listen 80;
server_name {app_name}.example.com;
# 安全头
add_header X-Frame-Options DENY;
add_header X-Content-Type-Options nosniff;
add_header X-XSS-Protection "1; mode=block";
# 压缩配置
gzip on;
gzip_vary on;
gzip_min_length 1024;
gzip_types text/plain text/css text/xml text/javascript application/javascript application/xml+rss application/json;
# 静态文件缓存
location /static/ {{
alias {static_path or '/app/static/'};
expires 30d;
add_header Cache-Control "public, immutable";
# 启用gzip
gzip_static on;
}}
# API请求
location / {{
proxy_pass http://{app_name}_backend;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# 连接配置
proxy_connect_timeout 30s;
proxy_send_timeout 30s;
proxy_read_timeout 30s;
# 缓冲配置
proxy_buffering on;
proxy_buffer_size 4k;
proxy_buffers 8 4k;
# HTTP/1.1支持
proxy_http_version 1.1;
proxy_set_header Connection "";
}}
# 健康检查
location /health {{
access_log off;
proxy_pass http://{app_name}_backend/health;
}}
# 错误页面
error_page 500 502 503 504 /50x.html;
location = /50x.html {{
root /usr/share/nginx/html;
}}
}}
"""
return config
class DockerOptimizer:
"""Docker优化器"""
@staticmethod
def generate_dockerfile(python_version: str = "3.9",
requirements_file: str = "requirements.txt") -> str:
"""生成优化的Dockerfile"""
return f"""
# 多阶段构建Dockerfile
# 构建阶段
FROM python:{python_version}-slim as builder
# 设置工作目录
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \
gcc \
g++ \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY {requirements_file} .
# 安装Python依赖
RUN pip install --no-cache-dir --user -r {requirements_file}
# 运行阶段
FROM python:{python_version}-slim
# 创建非root用户
RUN groupadd -r appuser && useradd -r -g appuser appuser
# 设置工作目录
WORKDIR /app
# 从构建阶段复制依赖
COPY --from=builder /root/.local /home/appuser/.local
# 复制应用代码
COPY . .
# 设置环境变量
ENV PATH=/home/appuser/.local/bin:$PATH
ENV PYTHONPATH=/app
ENV FLASK_APP=app.py
ENV FLASK_ENV=production
# 更改文件所有者
RUN chown -R appuser:appuser /app
# 切换到非root用户
USER appuser
# 暴露端口
EXPOSE 8000
# 健康检查
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# 启动命令
CMD ["gunicorn", "--config", "gunicorn.conf.py", "app:app"]
"""
@staticmethod
def generate_docker_compose(app_name: str,
redis_enabled: bool = True,
postgres_enabled: bool = True) -> str:
"""生成Docker Compose配置"""
config = f"""
version: '3.8'
services:
{app_name}:
build: .
ports:
- "8000:8000"
environment:
- FLASK_ENV=production
- DATABASE_URL=postgresql://postgres:password@postgres:5432/{app_name}
- REDIS_URL=redis://redis:6379/0
depends_on:
- postgres
- redis
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
deploy:
resources:
limits:
memory: 512M
reservations:
memory: 256M
"""
if postgres_enabled:
config += f"""
postgres:
image: postgres:13
environment:
- POSTGRES_DB={app_name}
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=password
volumes:
- postgres_data:/var/lib/postgresql/data
restart: unless-stopped
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 30s
timeout: 10s
retries: 3
"""
if redis_enabled:
config += f"""
redis:
image: redis:6-alpine
command: redis-server --appendonly yes
volumes:
- redis_data:/data
restart: unless-stopped
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 30s
timeout: 10s
retries: 3
"""
config += """
volumes:
postgres_data:
redis_data:
"""
return config
8.8 性能最佳实践
8.8.1 代码优化最佳实践
# optimization/best_practices.py
from functools import wraps
from flask import request, jsonify, current_app
import time
import gc
from typing import Dict, Any, List
import psutil
import threading
from collections import defaultdict
class PerformanceBestPractices:
"""性能最佳实践工具类"""
@staticmethod
def lazy_loading_decorator(loader_func):
"""延迟加载装饰器"""
def decorator(func):
_cache = {}
_lock = threading.Lock()
@wraps(func)
def wrapper(*args, **kwargs):
cache_key = str(args) + str(sorted(kwargs.items()))
if cache_key not in _cache:
with _lock:
if cache_key not in _cache:
_cache[cache_key] = loader_func(*args, **kwargs)
return func(_cache[cache_key], *args, **kwargs)
return wrapper
return decorator
@staticmethod
def batch_processor(batch_size: int = 100, timeout: float = 1.0):
"""批处理装饰器"""
def decorator(func):
_batch = []
_lock = threading.Lock()
_last_process_time = time.time()
def process_batch():
nonlocal _batch, _last_process_time
if _batch:
batch_copy = _batch.copy()
_batch.clear()
_last_process_time = time.time()
return func(batch_copy)
return None
@wraps(func)
def wrapper(item):
with _lock:
_batch.append(item)
# 检查是否需要处理批次
should_process = (
len(_batch) >= batch_size or
time.time() - _last_process_time >= timeout
)
if should_process:
return process_batch()
return None
# 添加强制处理方法
wrapper.flush = lambda: process_batch()
return wrapper
return decorator
@staticmethod
def memory_efficient_iterator(data_source, chunk_size: int = 1000):
"""内存高效的迭代器"""
if hasattr(data_source, '__iter__'):
# 处理可迭代对象
chunk = []
for item in data_source:
chunk.append(item)
if len(chunk) >= chunk_size:
yield chunk
chunk = []
gc.collect() # 强制垃圾回收
if chunk:
yield chunk
elif hasattr(data_source, 'query'):
# 处理SQLAlchemy查询
offset = 0
while True:
chunk = data_source.offset(offset).limit(chunk_size).all()
if not chunk:
break
yield chunk
offset += chunk_size
gc.collect()
@staticmethod
def optimize_json_response(data: Any, use_orjson: bool = True) -> Any:
"""优化JSON响应"""
if use_orjson:
try:
import orjson
return orjson.dumps(data).decode('utf-8')
except ImportError:
pass
# 回退到标准json
import json
return json.dumps(data, separators=(',', ':'), ensure_ascii=False)
@staticmethod
def compress_response_data(data: str, min_size: int = 1024) -> tuple:
"""压缩响应数据"""
if len(data) < min_size:
return data, None
import gzip
compressed = gzip.compress(data.encode('utf-8'))
# 只有在压缩效果明显时才使用压缩
if len(compressed) < len(data) * 0.8:
return compressed, 'gzip'
return data, None
class ResourceMonitor:
"""资源监控器"""
def __init__(self):
self.stats = defaultdict(list)
self.alerts = []
self.thresholds = {
'cpu_percent': 80.0,
'memory_percent': 85.0,
'disk_percent': 90.0,
'response_time': 2.0
}
def collect_system_stats(self) -> Dict[str, Any]:
"""收集系统统计信息"""
try:
# CPU使用率
cpu_percent = psutil.cpu_percent(interval=1)
# 内存使用情况
memory = psutil.virtual_memory()
# 磁盘使用情况
disk = psutil.disk_usage('/')
# 网络IO
network = psutil.net_io_counters()
stats = {
'timestamp': time.time(),
'cpu_percent': cpu_percent,
'memory_percent': memory.percent,
'memory_available': memory.available,
'disk_percent': disk.percent,
'disk_free': disk.free,
'network_bytes_sent': network.bytes_sent,
'network_bytes_recv': network.bytes_recv
}
# 检查阈值
self._check_thresholds(stats)
# 保存统计信息
for key, value in stats.items():
if key != 'timestamp':
self.stats[key].append(value)
# 只保留最近100个数据点
if len(self.stats[key]) > 100:
self.stats[key] = self.stats[key][-100:]
return stats
except Exception as e:
current_app.logger.error(f'收集系统统计信息失败: {e}')
return {'error': str(e)}
def _check_thresholds(self, stats: Dict[str, Any]):
"""检查阈值并生成告警"""
for metric, threshold in self.thresholds.items():
if metric in stats and stats[metric] > threshold:
alert = {
'timestamp': stats['timestamp'],
'metric': metric,
'value': stats[metric],
'threshold': threshold,
'severity': 'warning' if stats[metric] < threshold * 1.2 else 'critical'
}
self.alerts.append(alert)
# 只保留最近50个告警
if len(self.alerts) > 50:
self.alerts = self.alerts[-50:]
current_app.logger.warning(
f'性能告警: {metric}={stats[metric]:.2f} 超过阈值 {threshold}'
)
def get_performance_summary(self) -> Dict[str, Any]:
"""获取性能摘要"""
summary = {}
for metric, values in self.stats.items():
if values:
summary[metric] = {
'current': values[-1],
'average': sum(values) / len(values),
'min': min(values),
'max': max(values),
'trend': 'up' if len(values) > 1 and values[-1] > values[-2] else 'down'
}
summary['alerts_count'] = len(self.alerts)
summary['recent_alerts'] = self.alerts[-5:] if self.alerts else []
return summary
class OptimizationRecommendations:
"""优化建议生成器"""
@staticmethod
def analyze_performance_data(monitor: ResourceMonitor) -> List[Dict[str, Any]]:
"""分析性能数据并生成建议"""
recommendations = []
summary = monitor.get_performance_summary()
# CPU优化建议
if 'cpu_percent' in summary:
cpu_stats = summary['cpu_percent']
if cpu_stats['average'] > 70:
recommendations.append({
'category': 'CPU',
'priority': 'high',
'issue': f'CPU平均使用率过高: {cpu_stats["average"]:.1f}%',
'recommendations': [
'考虑使用异步处理减少CPU阻塞',
'优化算法复杂度',
'使用缓存减少重复计算',
'考虑水平扩展增加服务器实例'
]
})
# 内存优化建议
if 'memory_percent' in summary:
memory_stats = summary['memory_percent']
if memory_stats['average'] > 80:
recommendations.append({
'category': '内存',
'priority': 'high',
'issue': f'内存平均使用率过高: {memory_stats["average"]:.1f}%',
'recommendations': [
'检查内存泄漏',
'优化数据结构使用',
'实现对象池减少内存分配',
'使用生成器替代列表减少内存占用',
'考虑增加服务器内存'
]
})
# 磁盘优化建议
if 'disk_percent' in summary:
disk_stats = summary['disk_percent']
if disk_stats['current'] > 85:
recommendations.append({
'category': '磁盘',
'priority': 'medium',
'issue': f'磁盘使用率过高: {disk_stats["current"]:.1f}%',
'recommendations': [
'清理临时文件和日志',
'实现日志轮转',
'压缩或归档旧数据',
'考虑增加存储空间'
]
})
# 告警分析
if summary['alerts_count'] > 10:
recommendations.append({
'category': '告警',
'priority': 'medium',
'issue': f'告警数量过多: {summary["alerts_count"]}',
'recommendations': [
'调整告警阈值避免误报',
'分析告警模式找出根本原因',
'实现自动化响应机制',
'优化监控策略'
]
})
return recommendations
@staticmethod
def generate_optimization_report(recommendations: List[Dict[str, Any]]) -> str:
"""生成优化报告"""
if not recommendations:
return "系统性能良好,暂无优化建议。"
report = "# 性能优化报告\n\n"
# 按优先级分组
high_priority = [r for r in recommendations if r['priority'] == 'high']
medium_priority = [r for r in recommendations if r['priority'] == 'medium']
low_priority = [r for r in recommendations if r['priority'] == 'low']
if high_priority:
report += "## 🔴 高优先级问题\n\n"
for rec in high_priority:
report += f"### {rec['category']}\n"
report += f"**问题**: {rec['issue']}\n\n"
report += "**建议**:\n"
for suggestion in rec['recommendations']:
report += f"- {suggestion}\n"
report += "\n"
if medium_priority:
report += "## 🟡 中优先级问题\n\n"
for rec in medium_priority:
report += f"### {rec['category']}\n"
report += f"**问题**: {rec['issue']}\n\n"
report += "**建议**:\n"
for suggestion in rec['recommendations']:
report += f"- {suggestion}\n"
report += "\n"
if low_priority:
report += "## 🟢 低优先级问题\n\n"
for rec in low_priority:
report += f"### {rec['category']}\n"
report += f"**问题**: {rec['issue']}\n\n"
report += "**建议**:\n"
for suggestion in rec['recommendations']:
report += f"- {suggestion}\n"
report += "\n"
return report
# 使用示例
def setup_performance_monitoring(app):
"""设置性能监控"""
monitor = ResourceMonitor()
@app.before_request
def before_request():
request.start_time = time.time()
@app.after_request
def after_request(response):
# 记录响应时间
if hasattr(request, 'start_time'):
response_time = time.time() - request.start_time
monitor.stats['response_time'].append(response_time)
# 检查响应时间阈值
if response_time > monitor.thresholds['response_time']:
monitor.alerts.append({
'timestamp': time.time(),
'metric': 'response_time',
'value': response_time,
'threshold': monitor.thresholds['response_time'],
'url': request.url,
'method': request.method
})
return response
# 定期收集系统统计信息
import threading
def collect_stats():
while True:
monitor.collect_system_stats()
time.sleep(60) # 每分钟收集一次
stats_thread = threading.Thread(target=collect_stats, daemon=True)
stats_thread.start()
# 添加性能报告端点
@app.route('/admin/performance')
def performance_report():
summary = monitor.get_performance_summary()
recommendations = OptimizationRecommendations.analyze_performance_data(monitor)
report = OptimizationRecommendations.generate_optimization_report(recommendations)
return jsonify({
'summary': summary,
'recommendations': recommendations,
'report': report
})
return monitor
8.8.2 配置优化
# config/optimization.py
import os
from typing import Dict, Any
class OptimizedConfig:
"""优化的配置类"""
# 基础配置
SECRET_KEY = os.environ.get('SECRET_KEY') or 'dev-secret-key'
# 数据库优化配置
SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL')
SQLALCHEMY_TRACK_MODIFICATIONS = False
SQLALCHEMY_RECORD_QUERIES = True
SQLALCHEMY_ENGINE_OPTIONS = {
'pool_size': 10,
'pool_timeout': 20,
'pool_recycle': 3600,
'max_overflow': 20,
'pool_pre_ping': True,
'echo': False
}
# Redis配置
REDIS_URL = os.environ.get('REDIS_URL') or 'redis://localhost:6379/0'
# 缓存配置
CACHE_TYPE = 'redis'
CACHE_REDIS_URL = REDIS_URL
CACHE_DEFAULT_TIMEOUT = 300
CACHE_KEY_PREFIX = 'flask_cache_'
# 会话配置
SESSION_TYPE = 'redis'
SESSION_REDIS = None # 将在应用初始化时设置
SESSION_PERMANENT = False
SESSION_USE_SIGNER = True
SESSION_KEY_PREFIX = 'flask_session_'
PERMANENT_SESSION_LIFETIME = 3600
# 安全配置
WTF_CSRF_ENABLED = True
WTF_CSRF_TIME_LIMIT = 3600
# JSON配置
JSON_SORT_KEYS = False
JSONIFY_PRETTYPRINT_REGULAR = False
# 上传配置
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
# 性能监控配置
SLOW_QUERY_THRESHOLD = 1.0
ENABLE_PROFILING = False
# 日志配置
LOG_LEVEL = 'INFO'
LOG_FORMAT = '%(asctime)s %(levelname)s %(name)s %(message)s'
@staticmethod
def init_app(app):
"""初始化应用配置"""
# 设置JSON编码器
try:
import orjson
app.json_encoder = orjson.dumps
except ImportError:
pass
# 配置日志
import logging
logging.basicConfig(
level=getattr(logging, app.config['LOG_LEVEL']),
format=app.config['LOG_FORMAT']
)
class ProductionConfig(OptimizedConfig):
"""生产环境配置"""
DEBUG = False
TESTING = False
# 生产环境数据库配置
SQLALCHEMY_ENGINE_OPTIONS = {
**OptimizedConfig.SQLALCHEMY_ENGINE_OPTIONS,
'pool_size': 20,
'max_overflow': 40,
'echo': False
}
# 生产环境缓存配置
CACHE_DEFAULT_TIMEOUT = 3600
# 安全配置
SESSION_COOKIE_SECURE = True
SESSION_COOKIE_HTTPONLY = True
SESSION_COOKIE_SAMESITE = 'Lax'
# 性能配置
SEND_FILE_MAX_AGE_DEFAULT = 31536000 # 1年
class DevelopmentConfig(OptimizedConfig):
"""开发环境配置"""
DEBUG = True
TESTING = False
# 开发环境数据库配置
SQLALCHEMY_ENGINE_OPTIONS = {
**OptimizedConfig.SQLALCHEMY_ENGINE_OPTIONS,
'echo': True, # 开发环境显示SQL
'pool_size': 5
}
# 开发环境缓存配置
CACHE_DEFAULT_TIMEOUT = 60
# 性能监控
ENABLE_PROFILING = True
SLOW_QUERY_THRESHOLD = 0.5
class TestingConfig(OptimizedConfig):
"""测试环境配置"""
DEBUG = True
TESTING = True
# 测试数据库
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'
# 禁用CSRF保护
WTF_CSRF_ENABLED = False
# 测试缓存
CACHE_TYPE = 'simple'
# 快速会话过期
PERMANENT_SESSION_LIFETIME = 60
config = {
'development': DevelopmentConfig,
'testing': TestingConfig,
'production': ProductionConfig,
'default': DevelopmentConfig
}
8.9 本章小结
8.9.1 技术要点总结
本章详细介绍了Flask应用的缓存与性能优化技术,主要包括:
缓存策略: - 缓存基础概念和常用缓存模式 - Redis缓存集成和操作 - 应用级缓存和HTTP缓存 - CDN集成和静态资源优化
性能监控: - 请求性能监控和分析 - 自定义指标收集 - APM集成和错误追踪 - 系统资源监控
数据库优化: - 查询优化和慢查询分析 - 连接池优化和读写分离 - 批量操作和分页优化 - 索引建议和表分析
部署优化: - WSGI服务器配置优化 - Nginx反向代理配置 - Docker容器化部署 - 负载均衡和高可用
最佳实践: - 代码优化技巧 - 资源监控和告警 - 性能分析和优化建议 - 配置管理和环境区分
8.9.2 性能优化原则
- 测量优先:先测量再优化,避免过早优化
- 缓存策略:合理使用多级缓存提升响应速度
- 数据库优化:优化查询、使用索引、连接池管理
- 资源管理:监控系统资源,及时发现瓶颈
- 代码质量:编写高效代码,避免性能陷阱
8.9.3 监控和维护
- 建立完善的监控体系
- 设置合理的告警阈值
- 定期分析性能数据
- 持续优化和改进
下一章预告
下一章将介绍API设计与实现,包括: - RESTful API设计原则 - Flask-RESTful框架使用 - API认证与授权 - API文档生成 - API测试和版本管理
练习题
基础练习
缓存实现:
- 实现一个简单的内存缓存装饰器
- 为用户信息查询添加Redis缓存
- 实现缓存预热功能
性能监控:
- 创建请求响应时间监控
- 实现慢查询日志记录
- 添加系统资源监控端点
进阶练习
数据库优化:
- 分析并优化一个复杂查询
- 实现数据库连接池监控
- 设计读写分离方案
部署优化:
- 配置Nginx反向代理
- 编写Docker多阶段构建文件
- 实现健康检查端点
综合项目
- 性能优化项目:
- 为现有Flask应用添加完整的缓存系统
- 实现性能监控和告警系统
- 编写性能测试和基准测试
- 生成性能优化报告
思考题
- 如何在缓存一致性和性能之间找到平衡?
- 什么情况下应该使用CDN,什么情况下不适合?
- 如何设计一个可扩展的性能监控系统?
- 在微服务架构中如何实现分布式缓存?
- 如何评估性能优化的效果和ROI?
8.3 应用级缓存
8.3.1 Flask-Caching集成
# cache/flask_cache.py
from flask_caching import Cache
from flask import current_app
import time
from typing import Any, Optional, List
class FlaskCacheManager:
"""Flask缓存管理器"""
def __init__(self, app=None):
self.cache = Cache()
if app is not None:
self.init_app(app)
def init_app(self, app):
"""初始化缓存"""
# 配置缓存
cache_config = {
'CACHE_TYPE': app.config.get('CACHE_TYPE', 'simple'),
'CACHE_DEFAULT_TIMEOUT': app.config.get('CACHE_DEFAULT_TIMEOUT', 300),
'CACHE_KEY_PREFIX': app.config.get('CACHE_KEY_PREFIX', 'flask_cache_'),
}
# Redis缓存配置
if cache_config['CACHE_TYPE'] == 'redis':
cache_config.update({
'CACHE_REDIS_HOST': app.config.get('CACHE_REDIS_HOST', 'localhost'),
'CACHE_REDIS_PORT': app.config.get('CACHE_REDIS_PORT', 6379),
'CACHE_REDIS_DB': app.config.get('CACHE_REDIS_DB', 0),
'CACHE_REDIS_PASSWORD': app.config.get('CACHE_REDIS_PASSWORD'),
})
# Memcached缓存配置
elif cache_config['CACHE_TYPE'] == 'memcached':
cache_config.update({
'CACHE_MEMCACHED_SERVERS': app.config.get('CACHE_MEMCACHED_SERVERS', ['127.0.0.1:11211']),
})
app.config.update(cache_config)
self.cache.init_app(app)
# 注册到应用
app.cache = self.cache
def get_cache(self):
"""获取缓存实例"""
return self.cache
class CacheService:
"""缓存服务类"""
def __init__(self, cache_instance=None):
self.cache = cache_instance or current_app.cache
def get_or_set(self, key: str, func: callable, timeout: int = 300) -> Any:
"""获取或设置缓存"""
value = self.cache.get(key)
if value is None:
value = func()
self.cache.set(key, value, timeout=timeout)
return value
def get_many(self, keys: List[str]) -> dict:
"""批量获取缓存"""
return self.cache.get_many(*keys)
def set_many(self, mapping: dict, timeout: int = 300) -> bool:
"""批量设置缓存"""
return self.cache.set_many(mapping, timeout=timeout)
def delete_many(self, keys: List[str]) -> bool:
"""批量删除缓存"""
return self.cache.delete_many(*keys)
def increment(self, key: str, delta: int = 1) -> Optional[int]:
"""递增计数器"""
try:
return self.cache.inc(key, delta)
except:
# 如果键不存在,设置初始值
self.cache.set(key, delta)
return delta
def decrement(self, key: str, delta: int = 1) -> Optional[int]:
"""递减计数器"""
try:
return self.cache.dec(key, delta)
except:
# 如果键不存在,设置初始值
self.cache.set(key, -delta)
return -delta
def add(self, key: str, value: Any, timeout: int = 300) -> bool:
"""仅当键不存在时添加"""
return self.cache.add(key, value, timeout=timeout)
def touch(self, key: str, timeout: int = 300) -> bool:
"""更新键的过期时间"""
value = self.cache.get(key)
if value is not None:
return self.cache.set(key, value, timeout=timeout)
return False
def get_stats(self) -> dict:
"""获取缓存统计信息"""
# 这里需要根据具体的缓存后端实现
return {
'cache_type': current_app.config.get('CACHE_TYPE'),
'default_timeout': current_app.config.get('CACHE_DEFAULT_TIMEOUT'),
'key_prefix': current_app.config.get('CACHE_KEY_PREFIX'),
}
class CacheWarmer:
"""缓存预热器"""
def __init__(self, cache_service: CacheService):
self.cache = cache_service
self.warming_tasks = []
def register_warming_task(self, key: str, func: callable, timeout: int = 300):
"""注册预热任务"""
self.warming_tasks.append({
'key': key,
'func': func,
'timeout': timeout
})
def warm_cache(self, keys: List[str] = None):
"""执行缓存预热"""
tasks_to_run = self.warming_tasks
if keys:
tasks_to_run = [task for task in self.warming_tasks if task['key'] in keys]
for task in tasks_to_run:
try:
start_time = time.time()
value = task['func']()
self.cache.cache.set(task['key'], value, timeout=task['timeout'])
duration = time.time() - start_time
current_app.logger.info(f"缓存预热完成: {task['key']} (耗时: {duration:.2f}s)")
except Exception as e:
current_app.logger.error(f"缓存预热失败: {task['key']} - {e}")
def schedule_warming(self, interval: int = 3600):
"""定时预热缓存"""
from threading import Timer
def warm_and_schedule():
self.warm_cache()
# 递归调度下次预热
Timer(interval, warm_and_schedule).start()
Timer(interval, warm_and_schedule).start()
8.3.2 查询结果缓存
# cache/query_cache.py
from sqlalchemy import event
from sqlalchemy.orm import Query
from flask_sqlalchemy import SQLAlchemy
from flask import current_app
import hashlib
import json
from typing import Any, Optional
class QueryCache:
"""查询结果缓存"""
def __init__(self, db: SQLAlchemy, cache_service):
self.db = db
self.cache = cache_service
self.enabled = True
self._setup_events()
def _setup_events(self):
"""设置数据库事件监听"""
# 监听数据变更,自动清除相关缓存
@event.listens_for(self.db.session, 'after_commit')
def clear_cache_after_commit(session):
if not self.enabled:
return
# 获取变更的表
changed_tables = set()
for obj in session.new:
changed_tables.add(obj.__tablename__)
for obj in session.dirty:
changed_tables.add(obj.__tablename__)
for obj in session.deleted:
changed_tables.add(obj.__tablename__)
# 清除相关缓存
for table in changed_tables:
self.invalidate_table_cache(table)
def _generate_cache_key(self, query: Query, params: dict = None) -> str:
"""生成查询缓存键"""
# 获取SQL语句和参数
compiled = query.statement.compile(compile_kwargs={"literal_binds": True})
sql = str(compiled)
# 添加参数
if params:
sql += json.dumps(params, sort_keys=True)
# 生成哈希键
return f"query_cache:{hashlib.md5(sql.encode()).hexdigest()}"
def get_cached_result(self, query: Query, timeout: int = 300) -> Optional[Any]:
"""获取缓存的查询结果"""
if not self.enabled:
return None
cache_key = self._generate_cache_key(query)
return self.cache.cache.get(cache_key)
def cache_query_result(self, query: Query, result: Any, timeout: int = 300):
"""缓存查询结果"""
if not self.enabled:
return
cache_key = self._generate_cache_key(query)
# 序列化结果
if hasattr(result, '__iter__') and not isinstance(result, (str, bytes)):
# 处理查询结果列表
serialized_result = []
for item in result:
if hasattr(item, 'to_dict'):
serialized_result.append(item.to_dict())
else:
serialized_result.append(str(item))
elif hasattr(result, 'to_dict'):
serialized_result = result.to_dict()
else:
serialized_result = result
self.cache.cache.set(cache_key, serialized_result, timeout=timeout)
# 记录表关联
table_name = self._get_table_from_query(query)
if table_name:
self._add_table_cache_key(table_name, cache_key)
def _get_table_from_query(self, query: Query) -> Optional[str]:
"""从查询中获取表名"""
try:
if hasattr(query, 'column_descriptions'):
for desc in query.column_descriptions:
if desc.get('entity') and hasattr(desc['entity'], '__tablename__'):
return desc['entity'].__tablename__
return None
except:
return None
def _add_table_cache_key(self, table_name: str, cache_key: str):
"""添加表缓存键关联"""
table_cache_keys = self.cache.cache.get(f"table_cache_keys:{table_name}") or set()
table_cache_keys.add(cache_key)
self.cache.cache.set(f"table_cache_keys:{table_name}", table_cache_keys, timeout=86400)
def invalidate_table_cache(self, table_name: str):
"""清除表相关的所有缓存"""
table_cache_keys = self.cache.cache.get(f"table_cache_keys:{table_name}")
if table_cache_keys:
# 删除所有相关缓存键
for cache_key in table_cache_keys:
self.cache.cache.delete(cache_key)
# 清除表缓存键记录
self.cache.cache.delete(f"table_cache_keys:{table_name}")
def enable(self):
"""启用查询缓存"""
self.enabled = True
def disable(self):
"""禁用查询缓存"""
self.enabled = False
def clear_all(self):
"""清除所有查询缓存"""
# 这里需要根据具体的缓存实现来清除所有query_cache:*键
if hasattr(self.cache.cache, 'clear'):
# 简单缓存可以直接清除
self.cache.cache.clear()
else:
# Redis等需要模式匹配删除
if hasattr(self.cache, 'clear_pattern'):
self.cache.clear_pattern('query_cache:*')
self.cache.clear_pattern('table_cache_keys:*')
def cached_query(timeout: int = 300, key_suffix: str = None):
"""查询缓存装饰器"""
def decorator(func):
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
# 生成缓存键
cache_key = f"cached_query:{func.__name__}"
if key_suffix:
cache_key += f":{key_suffix}"
# 添加参数到缓存键
if args or kwargs:
import hashlib
params_str = str(args) + str(sorted(kwargs.items()))
params_hash = hashlib.md5(params_str.encode()).hexdigest()[:8]
cache_key += f":{params_hash}"
# 尝试从缓存获取
cache_service = CacheService()
cached_result = cache_service.cache.get(cache_key)
if cached_result is not None:
return cached_result
# 执行查询并缓存结果
result = func(*args, **kwargs)
cache_service.cache.set(cache_key, result, timeout=timeout)
return result
return wrapper
return decorator
8.4 HTTP缓存
8.4.1 HTTP缓存头
# cache/http_cache.py
from flask import make_response, request, current_app
from datetime import datetime, timedelta
from functools import wraps
import hashlib
from typing import Optional, Union
class HTTPCacheManager:
"""HTTP缓存管理器"""
@staticmethod
def set_cache_headers(response,
max_age: int = None,
public: bool = True,
private: bool = False,
no_cache: bool = False,
no_store: bool = False,
must_revalidate: bool = False,
etag: str = None,
last_modified: datetime = None):
"""设置缓存头"""
# Cache-Control头
cache_control = []
if no_store:
cache_control.append('no-store')
elif no_cache:
cache_control.append('no-cache')
else:
if public:
cache_control.append('public')
elif private:
cache_control.append('private')
if max_age is not None:
cache_control.append(f'max-age={max_age}')
if must_revalidate:
cache_control.append('must-revalidate')
if cache_control:
response.headers['Cache-Control'] = ', '.join(cache_control)
# ETag头
if etag:
response.headers['ETag'] = f'"{etag}"'
# Last-Modified头
if last_modified:
response.headers['Last-Modified'] = last_modified.strftime('%a, %d %b %Y %H:%M:%S GMT')
# Expires头(如果设置了max_age)
if max_age is not None:
expires = datetime.utcnow() + timedelta(seconds=max_age)
response.headers['Expires'] = expires.strftime('%a, %d %b %Y %H:%M:%S GMT')
return response
@staticmethod
def generate_etag(content: Union[str, bytes]) -> str:
"""生成ETag"""
if isinstance(content, str):
content = content.encode('utf-8')
return hashlib.md5(content).hexdigest()
@staticmethod
def check_etag(etag: str) -> bool:
"""检查ETag是否匹配"""
client_etag = request.headers.get('If-None-Match')
if client_etag:
# 移除引号
client_etag = client_etag.strip('"')
return client_etag == etag
return False
@staticmethod
def check_last_modified(last_modified: datetime) -> bool:
"""检查Last-Modified是否匹配"""
client_modified = request.headers.get('If-Modified-Since')
if client_modified:
try:
client_time = datetime.strptime(client_modified, '%a, %d %b %Y %H:%M:%S GMT')
return last_modified <= client_time
except ValueError:
return False
return False
def cache_control(max_age: int = 300,
public: bool = True,
private: bool = False,
no_cache: bool = False,
no_store: bool = False,
must_revalidate: bool = False):
"""缓存控制装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
response = make_response(func(*args, **kwargs))
HTTPCacheManager.set_cache_headers(
response,
max_age=max_age,
public=public,
private=private,
no_cache=no_cache,
no_store=no_store,
must_revalidate=must_revalidate
)
return response
return wrapper
return decorator
def etag_cache(func):
"""ETag缓存装饰器"""
@wraps(func)
def wrapper(*args, **kwargs):
# 执行函数获取响应
response = make_response(func(*args, **kwargs))
# 生成ETag
if response.data:
etag = HTTPCacheManager.generate_etag(response.data)
# 检查客户端ETag
if HTTPCacheManager.check_etag(etag):
# 返回304 Not Modified
response = make_response('', 304)
# 设置ETag头
response.headers['ETag'] = f'"{etag}"'
return response
return wrapper
def last_modified_cache(get_last_modified):
"""Last-Modified缓存装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# 获取最后修改时间
last_modified = get_last_modified(*args, **kwargs)
# 检查客户端Last-Modified
if last_modified and HTTPCacheManager.check_last_modified(last_modified):
# 返回304 Not Modified
response = make_response('', 304)
response.headers['Last-Modified'] = last_modified.strftime('%a, %d %b %Y %H:%M:%S GMT')
return response
# 执行函数获取响应
response = make_response(func(*args, **kwargs))
# 设置Last-Modified头
if last_modified:
response.headers['Last-Modified'] = last_modified.strftime('%a, %d %b %Y %H:%M:%S GMT')
return response
return wrapper
return decorator
def conditional_cache(condition_func, max_age: int = 300):
"""条件缓存装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# 检查缓存条件
should_cache = condition_func(*args, **kwargs)
response = make_response(func(*args, **kwargs))
if should_cache:
HTTPCacheManager.set_cache_headers(
response,
max_age=max_age,
public=True
)
else:
HTTPCacheManager.set_cache_headers(
response,
no_cache=True
)
return response
return wrapper
return decorator
class StaticFileCache:
"""静态文件缓存"""
@staticmethod
def setup_static_cache(app):
"""设置静态文件缓存"""
@app.after_request
def add_static_cache_headers(response):
# 只对静态文件添加缓存头
if request.endpoint == 'static':
# 根据文件类型设置不同的缓存时间
file_ext = request.path.split('.')[-1].lower()
cache_times = {
'css': 86400 * 7, # 7天
'js': 86400 * 7, # 7天
'png': 86400 * 30, # 30天
'jpg': 86400 * 30, # 30天
'jpeg': 86400 * 30, # 30天
'gif': 86400 * 30, # 30天
'ico': 86400 * 30, # 30天
'woff': 86400 * 30, # 30天
'woff2': 86400 * 30, # 30天
'ttf': 86400 * 30, # 30天
}
max_age = cache_times.get(file_ext, 86400) # 默认1天
HTTPCacheManager.set_cache_headers(
response,
max_age=max_age,
public=True
)
return response
8.4.2 CDN集成
# cache/cdn.py
from flask import current_app, url_for
from typing import Optional, Dict, Any
import requests
import time
class CDNManager:
"""CDN管理器"""
def __init__(self, app=None):
self.cdn_config = {}
if app is not None:
self.init_app(app)
def init_app(self, app):
"""初始化CDN配置"""
self.cdn_config = app.config.get('CDN_CONFIG', {})
# 注册模板函数
app.jinja_env.globals['cdn_url'] = self.cdn_url
app.jinja_env.globals['static_url'] = self.static_url
def cdn_url(self, filename: str, _external: bool = True) -> str:
"""生成CDN URL"""
if not self.cdn_config.get('enabled', False):
return url_for('static', filename=filename, _external=_external)
cdn_domain = self.cdn_config.get('domain')
if not cdn_domain:
return url_for('static', filename=filename, _external=_external)
# 添加版本号
version = self.cdn_config.get('version', '')
if version:
filename = f"{filename}?v={version}"
return f"{cdn_domain.rstrip('/')}/static/{filename}"
def static_url(self, filename: str, **kwargs) -> str:
"""生成静态文件URL(兼容CDN)"""
return self.cdn_url(filename, **kwargs)
def purge_cache(self, urls: list) -> Dict[str, Any]:
"""清除CDN缓存"""
if not self.cdn_config.get('enabled', False):
return {'success': False, 'message': 'CDN未启用'}
purge_config = self.cdn_config.get('purge', {})
if not purge_config:
return {'success': False, 'message': 'CDN清除配置未设置'}
try:
# 这里以阿里云CDN为例
if purge_config.get('provider') == 'aliyun':
return self._purge_aliyun_cdn(urls, purge_config)
# 可以添加其他CDN提供商
elif purge_config.get('provider') == 'cloudflare':
return self._purge_cloudflare_cdn(urls, purge_config)
else:
return {'success': False, 'message': '不支持的CDN提供商'}
except Exception as e:
current_app.logger.error(f'CDN缓存清除失败: {e}')
return {'success': False, 'message': str(e)}
def _purge_aliyun_cdn(self, urls: list, config: dict) -> Dict[str, Any]:
"""清除阿里云CDN缓存"""
# 这里需要集成阿里云CDN API
# 示例代码,实际需要使用阿里云SDK
api_url = 'https://cdn.aliyuncs.com/'
headers = {
'Authorization': f"Bearer {config.get('access_token')}",
'Content-Type': 'application/json'
}
data = {
'Action': 'RefreshObjectCaches',
'ObjectPath': '\n'.join(urls),
'ObjectType': 'File'
}
response = requests.post(api_url, json=data, headers=headers)
if response.status_code == 200:
return {'success': True, 'message': 'CDN缓存清除成功'}
else:
return {'success': False, 'message': f'CDN缓存清除失败: {response.text}'}
def _purge_cloudflare_cdn(self, urls: list, config: dict) -> Dict[str, Any]:
"""清除Cloudflare CDN缓存"""
zone_id = config.get('zone_id')
api_token = config.get('api_token')
if not zone_id or not api_token:
return {'success': False, 'message': 'Cloudflare配置不完整'}
api_url = f'https://api.cloudflare.com/client/v4/zones/{zone_id}/purge_cache'
headers = {
'Authorization': f'Bearer {api_token}',
'Content-Type': 'application/json'
}
data = {
'files': urls
}
response = requests.post(api_url, json=data, headers=headers)
if response.status_code == 200:
result = response.json()
if result.get('success'):
return {'success': True, 'message': 'CDN缓存清除成功'}
else:
return {'success': False, 'message': f"CDN缓存清除失败: {result.get('errors')}"}
else:
return {'success': False, 'message': f'CDN缓存清除失败: {response.text}'}
class AssetVersioning:
"""资源版本管理"""
def __init__(self, app=None):
self.version_map = {}
if app is not None:
self.init_app(app)
def init_app(self, app):
"""初始化资源版本管理"""
# 加载版本映射文件
version_file = app.config.get('ASSET_VERSION_FILE', 'static/version.json')
try:
import json
import os
version_path = os.path.join(app.static_folder, 'version.json')
if os.path.exists(version_path):
with open(version_path, 'r') as f:
self.version_map = json.load(f)
except Exception as e:
app.logger.warning(f'无法加载资源版本文件: {e}')
# 注册模板函数
app.jinja_env.globals['versioned_url'] = self.versioned_url
def versioned_url(self, filename: str) -> str:
"""生成带版本号的URL"""
# 获取文件版本
version = self.version_map.get(filename)
if version:
# 添加版本参数
separator = '&' if '?' in filename else '?'
filename = f"{filename}{separator}v={version}"
return url_for('static', filename=filename)
def generate_version_map(self, static_folder: str) -> dict:
"""生成资源版本映射"""
import os
import hashlib
version_map = {}
for root, dirs, files in os.walk(static_folder):
for file in files:
if file.endswith(('.css', '.js', '.png', '.jpg', '.jpeg', '.gif')):
file_path = os.path.join(root, file)
relative_path = os.path.relpath(file_path, static_folder)
# 计算文件哈希作为版本号
with open(file_path, 'rb') as f:
file_hash = hashlib.md5(f.read()).hexdigest()[:8]
version_map[relative_path.replace('\\', '/')] = file_hash
return version_map
def save_version_map(self, static_folder: str):
"""保存版本映射到文件"""
import json
import os
version_map = self.generate_version_map(static_folder)
version_file = os.path.join(static_folder, 'version.json')
with open(version_file, 'w') as f:
json.dump(version_map, f, indent=2)
self.version_map = version_map
current_app.logger.info(f'资源版本映射已保存: {len(version_map)} 个文件')
## 8.5 性能监控与分析
### 8.5.1 性能监控工具
```python
# monitoring/performance.py
from flask import Flask, request, g, current_app
import time
import psutil
import threading
from collections import defaultdict, deque
from datetime import datetime, timedelta
from typing import Dict, List, Any
import json
class PerformanceMonitor:
"""性能监控器"""
def __init__(self, app=None):
self.metrics = defaultdict(list)
self.request_times = deque(maxlen=1000) # 保留最近1000个请求
self.error_count = defaultdict(int)
self.endpoint_stats = defaultdict(lambda: {'count': 0, 'total_time': 0, 'errors': 0})
self.system_stats = deque(maxlen=100) # 保留最近100个系统状态
self._lock = threading.Lock()
if app is not None:
self.init_app(app)
def init_app(self, app: Flask):
"""初始化性能监控"""
app.before_request(self._before_request)
app.after_request(self._after_request)
app.teardown_appcontext(self._teardown_request)
# 启动系统监控线程
self._start_system_monitoring()
# 注册到应用
app.performance_monitor = self
def _before_request(self):
"""请求开始前的处理"""
g.start_time = time.time()
g.request_id = f"{int(time.time() * 1000000)}_{threading.get_ident()}"
def _after_request(self, response):
"""请求结束后的处理"""
if hasattr(g, 'start_time'):
duration = time.time() - g.start_time
with self._lock:
# 记录请求时间
self.request_times.append({
'timestamp': datetime.utcnow(),
'duration': duration,
'endpoint': request.endpoint,
'method': request.method,
'status_code': response.status_code,
'path': request.path
})
# 更新端点统计
endpoint_key = f"{request.method} {request.endpoint or request.path}"
stats = self.endpoint_stats[endpoint_key]
stats['count'] += 1
stats['total_time'] += duration
if response.status_code >= 400:
stats['errors'] += 1
self.error_count[response.status_code] += 1
return response
def _teardown_request(self, exception):
"""请求清理"""
if exception:
with self._lock:
self.error_count['exceptions'] += 1
def _start_system_monitoring(self):
"""启动系统监控线程"""
def monitor_system():
while True:
try:
cpu_percent = psutil.cpu_percent(interval=1)
memory = psutil.virtual_memory()
disk = psutil.disk_usage('/')
with self._lock:
self.system_stats.append({
'timestamp': datetime.utcnow(),
'cpu_percent': cpu_percent,
'memory_percent': memory.percent,
'memory_used': memory.used,
'memory_total': memory.total,
'disk_percent': disk.percent,
'disk_used': disk.used,
'disk_total': disk.total
})
time.sleep(60) # 每分钟采集一次
except Exception as e:
current_app.logger.error(f'系统监控错误: {e}')
time.sleep(60)
thread = threading.Thread(target=monitor_system, daemon=True)
thread.start()
def get_performance_stats(self) -> Dict[str, Any]:
"""获取性能统计信息"""
with self._lock:
if not self.request_times:
return {'message': '暂无性能数据'}
# 计算请求统计
durations = [req['duration'] for req in self.request_times]
recent_requests = [req for req in self.request_times
if req['timestamp'] > datetime.utcnow() - timedelta(minutes=5)]
# 端点性能排序
endpoint_performance = []
for endpoint, stats in self.endpoint_stats.items():
if stats['count'] > 0:
avg_time = stats['total_time'] / stats['count']
error_rate = stats['errors'] / stats['count'] * 100
endpoint_performance.append({
'endpoint': endpoint,
'count': stats['count'],
'avg_time': round(avg_time, 4),
'total_time': round(stats['total_time'], 4),
'error_rate': round(error_rate, 2)
})
# 按平均响应时间排序
endpoint_performance.sort(key=lambda x: x['avg_time'], reverse=True)
# 系统资源统计
latest_system = self.system_stats[-1] if self.system_stats else None
return {
'request_stats': {
'total_requests': len(self.request_times),
'recent_requests_5min': len(recent_requests),
'avg_response_time': round(sum(durations) / len(durations), 4),
'min_response_time': round(min(durations), 4),
'max_response_time': round(max(durations), 4),
'p95_response_time': round(sorted(durations)[int(len(durations) * 0.95)], 4),
'p99_response_time': round(sorted(durations)[int(len(durations) * 0.99)], 4)
},
'error_stats': dict(self.error_count),
'endpoint_performance': endpoint_performance[:10], # 前10个最慢的端点
'system_stats': latest_system,
'timestamp': datetime.utcnow().isoformat()
}
def get_slow_requests(self, threshold: float = 1.0, limit: int = 50) -> List[Dict]:
"""获取慢请求列表"""
with self._lock:
slow_requests = [
req for req in self.request_times
if req['duration'] > threshold
]
# 按响应时间降序排序
slow_requests.sort(key=lambda x: x['duration'], reverse=True)
return slow_requests[:limit]
def export_metrics(self, format_type: str = 'json') -> str:
"""导出性能指标"""
stats = self.get_performance_stats()
if format_type == 'json':
return json.dumps(stats, indent=2, default=str)
elif format_type == 'prometheus':
return self._export_prometheus_format(stats)
else:
raise ValueError(f'不支持的导出格式: {format_type}')
def _export_prometheus_format(self, stats: Dict) -> str:
"""导出Prometheus格式的指标"""
lines = []
# 请求指标
request_stats = stats.get('request_stats', {})
lines.append(f"# HELP flask_requests_total Total number of requests")
lines.append(f"# TYPE flask_requests_total counter")
lines.append(f"flask_requests_total {request_stats.get('total_requests', 0)}")
lines.append(f"# HELP flask_request_duration_seconds Request duration in seconds")
lines.append(f"# TYPE flask_request_duration_seconds histogram")
lines.append(f"flask_request_duration_seconds_sum {request_stats.get('avg_response_time', 0)}")
# 错误指标
error_stats = stats.get('error_stats', {})
for status_code, count in error_stats.items():
lines.append(f"flask_errors_total{{status_code=\"{status_code}\"}} {count}")
# 系统指标
system_stats = stats.get('system_stats', {})
if system_stats:
lines.append(f"flask_cpu_usage_percent {system_stats.get('cpu_percent', 0)}")
lines.append(f"flask_memory_usage_percent {system_stats.get('memory_percent', 0)}")
return '\n'.join(lines)
class RequestProfiler:
"""请求分析器"""
def __init__(self):
self.profiles = {}
def profile_request(self, func):
"""请求分析装饰器"""
from functools import wraps
import cProfile
import pstats
import io
@wraps(func)
def wrapper(*args, **kwargs):
if current_app.config.get('ENABLE_PROFILING', False):
pr = cProfile.Profile()
pr.enable()
try:
result = func(*args, **kwargs)
finally:
pr.disable()
# 生成分析报告
s = io.StringIO()
ps = pstats.Stats(pr, stream=s)
ps.sort_stats('cumulative')
ps.print_stats(20) # 显示前20个函数
# 保存分析结果
profile_key = f"{func.__name__}_{int(time.time())}"
self.profiles[profile_key] = {
'timestamp': datetime.utcnow(),
'function': func.__name__,
'stats': s.getvalue()
}
# 只保留最近的50个分析结果
if len(self.profiles) > 50:
oldest_key = min(self.profiles.keys(),
key=lambda k: self.profiles[k]['timestamp'])
del self.profiles[oldest_key]
return result
else:
return func(*args, **kwargs)
return wrapper
def get_profiles(self) -> Dict:
"""获取分析结果"""
return self.profiles
def clear_profiles(self):
"""清除分析结果"""
self.profiles.clear()
8.5.2 APM集成
# monitoring/apm.py
from flask import Flask, request, g
import time
import requests
import json
from typing import Dict, Any, Optional
class APMIntegration:
"""APM(应用性能监控)集成"""
def __init__(self, app=None):
self.config = {}
if app is not None:
self.init_app(app)
def init_app(self, app: Flask):
"""初始化APM集成"""
self.config = app.config.get('APM_CONFIG', {})
if self.config.get('enabled', False):
app.before_request(self._before_request)
app.after_request(self._after_request)
app.teardown_appcontext(self._teardown_request)
def _before_request(self):
"""请求开始前的处理"""
g.apm_start_time = time.time()
g.apm_trace_id = self._generate_trace_id()
def _after_request(self, response):
"""请求结束后的处理"""
if hasattr(g, 'apm_start_time'):
duration = time.time() - g.apm_start_time
# 发送性能数据到APM服务
self._send_performance_data({
'trace_id': getattr(g, 'apm_trace_id', ''),
'duration': duration,
'endpoint': request.endpoint,
'method': request.method,
'status_code': response.status_code,
'path': request.path,
'timestamp': time.time()
})
return response
def _teardown_request(self, exception):
"""请求清理"""
if exception:
# 发送错误信息到APM服务
self._send_error_data({
'trace_id': getattr(g, 'apm_trace_id', ''),
'exception': str(exception),
'exception_type': type(exception).__name__,
'endpoint': request.endpoint,
'path': request.path,
'timestamp': time.time()
})
def _generate_trace_id(self) -> str:
"""生成追踪ID"""
import uuid
return str(uuid.uuid4())
def _send_performance_data(self, data: Dict[str, Any]):
"""发送性能数据"""
if not self.config.get('performance_endpoint'):
return
try:
requests.post(
self.config['performance_endpoint'],
json=data,
headers=self._get_headers(),
timeout=5
)
except Exception as e:
# 静默处理APM错误,避免影响主应用
pass
def _send_error_data(self, data: Dict[str, Any]):
"""发送错误数据"""
if not self.config.get('error_endpoint'):
return
try:
requests.post(
self.config['error_endpoint'],
json=data,
headers=self._get_headers(),
timeout=5
)
except Exception as e:
# 静默处理APM错误
pass
def _get_headers(self) -> Dict[str, str]:
"""获取请求头"""
headers = {'Content-Type': 'application/json'}
if self.config.get('api_key'):
headers['Authorization'] = f"Bearer {self.config['api_key']}"
return headers
class CustomMetrics:
"""自定义指标收集器"""
def __init__(self):
self.counters = {}
self.gauges = {}
self.histograms = {}
def increment_counter(self, name: str, value: int = 1, tags: Dict[str, str] = None):
"""递增计数器"""
key = self._make_key(name, tags)
self.counters[key] = self.counters.get(key, 0) + value
def set_gauge(self, name: str, value: float, tags: Dict[str, str] = None):
"""设置仪表值"""
key = self._make_key(name, tags)
self.gauges[key] = value
def record_histogram(self, name: str, value: float, tags: Dict[str, str] = None):
"""记录直方图值"""
key = self._make_key(name, tags)
if key not in self.histograms:
self.histograms[key] = []
self.histograms[key].append(value)
# 只保留最近1000个值
if len(self.histograms[key]) > 1000:
self.histograms[key] = self.histograms[key][-1000:]
def _make_key(self, name: str, tags: Dict[str, str] = None) -> str:
"""生成指标键"""
if not tags:
return name
tag_str = ','.join([f'{k}={v}' for k, v in sorted(tags.items())])
return f'{name}[{tag_str}]'
def get_metrics(self) -> Dict[str, Any]:
"""获取所有指标"""
histogram_stats = {}
for key, values in self.histograms.items():
if values:
histogram_stats[key] = {
'count': len(values),
'sum': sum(values),
'avg': sum(values) / len(values),
'min': min(values),
'max': max(values),
'p95': sorted(values)[int(len(values) * 0.95)] if len(values) > 0 else 0,
'p99': sorted(values)[int(len(values) * 0.99)] if len(values) > 0 else 0
}
return {
'counters': self.counters,
'gauges': self.gauges,
'histograms': histogram_stats,
'timestamp': time.time()
}
def reset_metrics(self):
"""重置所有指标"""
self.counters.clear()
self.gauges.clear()
self.histograms.clear()
# 全局指标实例
custom_metrics = CustomMetrics()
def track_custom_metric(metric_type: str, name: str, value: Any = 1, tags: Dict[str, str] = None):
"""追踪自定义指标装饰器"""
def decorator(func):
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
try:
result = func(*args, **kwargs)
# 记录成功指标
if metric_type == 'counter':
custom_metrics.increment_counter(f'{name}_success', value, tags)
elif metric_type == 'histogram':
duration = time.time() - start_time
custom_metrics.record_histogram(f'{name}_duration', duration, tags)
return result
except Exception as e:
# 记录错误指标
error_tags = (tags or {}).copy()
error_tags['error_type'] = type(e).__name__
custom_metrics.increment_counter(f'{name}_error', 1, error_tags)
raise
return wrapper
return decorator
”`