7.1 Flask-Login基础
7.1.1 安装和配置
# 安装Flask-Login
pip install Flask-Login
pip install Flask-Bcrypt # 密码加密
pip install PyJWT # JWT令牌
# extensions.py
from flask_login import LoginManager
from flask_bcrypt import Bcrypt
login_manager = LoginManager()
bcrypt = Bcrypt()
def init_extensions(app):
"""初始化扩展"""
login_manager.init_app(app)
bcrypt.init_app(app)
# 配置登录管理器
login_manager.login_view = 'auth.login'
login_manager.login_message = '请先登录访问此页面'
login_manager.login_message_category = 'info'
login_manager.session_protection = 'strong'
login_manager.remember_cookie_duration = timedelta(days=7)
# 用户加载回调
@login_manager.user_loader
def load_user(user_id):
from models.user import User
return User.query.get(int(user_id))
# 未授权处理
@login_manager.unauthorized_handler
def unauthorized():
from flask import request, jsonify, redirect, url_for
if request.is_json:
return jsonify({'error': '需要登录'}), 401
else:
return redirect(url_for('auth.login', next=request.url))
7.1.2 用户模型扩展
# models/user.py
from flask_login import UserMixin
from werkzeug.security import generate_password_hash, check_password_hash
from extensions import bcrypt
from datetime import datetime, timedelta
import secrets
class User(BaseModel, UserMixin):
"""用户模型"""
__tablename__ = 'users'
# 基本信息
username = db.Column(db.String(80), unique=True, nullable=False, index=True)
email = db.Column(db.String(120), unique=True, nullable=False, index=True)
password_hash = db.Column(db.String(255), nullable=False)
# 个人信息
first_name = db.Column(db.String(50))
last_name = db.Column(db.String(50))
phone = db.Column(db.String(20))
avatar = db.Column(db.String(255))
bio = db.Column(db.Text)
# 状态字段
is_active = db.Column(db.Boolean, default=True, nullable=False)
is_verified = db.Column(db.Boolean, default=False, nullable=False)
is_admin = db.Column(db.Boolean, default=False, nullable=False)
# 安全字段
password_reset_token = db.Column(db.String(255))
password_reset_expires = db.Column(db.DateTime)
email_verification_token = db.Column(db.String(255))
email_verification_expires = db.Column(db.DateTime)
# 登录统计
login_count = db.Column(db.Integer, default=0)
last_login_at = db.Column(db.DateTime)
last_login_ip = db.Column(db.String(45))
failed_login_attempts = db.Column(db.Integer, default=0)
locked_until = db.Column(db.DateTime)
# 关系
roles = db.relationship('Role', secondary='user_roles', backref='users')
def set_password(self, password):
"""设置密码"""
self.password_hash = bcrypt.generate_password_hash(password).decode('utf-8')
def check_password(self, password):
"""验证密码"""
return bcrypt.check_password_hash(self.password_hash, password)
def generate_password_reset_token(self):
"""生成密码重置令牌"""
self.password_reset_token = secrets.token_urlsafe(32)
self.password_reset_expires = datetime.utcnow() + timedelta(hours=1)
db.session.commit()
return self.password_reset_token
def verify_password_reset_token(self, token):
"""验证密码重置令牌"""
if (self.password_reset_token == token and
self.password_reset_expires and
self.password_reset_expires > datetime.utcnow()):
return True
return False
def reset_password(self, new_password):
"""重置密码"""
self.set_password(new_password)
self.password_reset_token = None
self.password_reset_expires = None
self.failed_login_attempts = 0
self.locked_until = None
db.session.commit()
def generate_email_verification_token(self):
"""生成邮箱验证令牌"""
self.email_verification_token = secrets.token_urlsafe(32)
self.email_verification_expires = datetime.utcnow() + timedelta(days=1)
db.session.commit()
return self.email_verification_token
def verify_email(self, token):
"""验证邮箱"""
if (self.email_verification_token == token and
self.email_verification_expires and
self.email_verification_expires > datetime.utcnow()):
self.is_verified = True
self.email_verification_token = None
self.email_verification_expires = None
db.session.commit()
return True
return False
def record_login_attempt(self, success=True, ip_address=None):
"""记录登录尝试"""
if success:
self.login_count += 1
self.last_login_at = datetime.utcnow()
self.last_login_ip = ip_address
self.failed_login_attempts = 0
self.locked_until = None
else:
self.failed_login_attempts += 1
# 连续失败5次锁定账户30分钟
if self.failed_login_attempts >= 5:
self.locked_until = datetime.utcnow() + timedelta(minutes=30)
db.session.commit()
def is_locked(self):
"""检查账户是否被锁定"""
if self.locked_until and self.locked_until > datetime.utcnow():
return True
return False
def unlock_account(self):
"""解锁账户"""
self.failed_login_attempts = 0
self.locked_until = None
db.session.commit()
def has_role(self, role_name):
"""检查是否有指定角色"""
return any(role.name == role_name for role in self.roles)
def has_permission(self, permission_name):
"""检查是否有指定权限"""
for role in self.roles:
if role.has_permission(permission_name):
return True
return False
@property
def is_authenticated(self):
"""是否已认证"""
return True
@property
def is_anonymous(self):
"""是否匿名"""
return False
def get_id(self):
"""获取用户ID"""
return str(self.id)
def __repr__(self):
return f'<User {self.username}>'
7.1.3 认证视图
# views/auth.py
from flask import Blueprint, render_template, request, redirect, url_for, flash, session, current_app
from flask_login import login_user, logout_user, login_required, current_user
from forms.auth import LoginForm, RegisterForm, ForgotPasswordForm, ResetPasswordForm
from models.user import User
from utils.email import send_email
from utils.decorators import anonymous_required
auth_bp = Blueprint('auth', __name__, url_prefix='/auth')
@auth_bp.route('/login', methods=['GET', 'POST'])
@anonymous_required
def login():
"""用户登录"""
form = LoginForm()
if form.validate_on_submit():
user = User.query.filter(
(User.username == form.username.data) |
(User.email == form.username.data)
).first()
if user:
# 检查账户是否被锁定
if user.is_locked():
flash('账户已被锁定,请稍后再试', 'error')
return render_template('auth/login.html', form=form)
# 检查账户是否激活
if not user.is_active:
flash('账户已被禁用', 'error')
return render_template('auth/login.html', form=form)
# 验证密码
if user.check_password(form.password.data):
# 记录成功登录
user.record_login_attempt(success=True, ip_address=request.remote_addr)
# 登录用户
login_user(user, remember=form.remember_me.data)
flash(f'欢迎回来,{user.username}!', 'success')
# 重定向到原来要访问的页面
next_page = request.args.get('next')
if next_page and is_safe_url(next_page):
return redirect(next_page)
return redirect(url_for('main.index'))
else:
# 记录失败登录
user.record_login_attempt(success=False)
flash('用户名或密码错误', 'error')
else:
flash('用户名或密码错误', 'error')
return render_template('auth/login.html', form=form)
@auth_bp.route('/register', methods=['GET', 'POST'])
@anonymous_required
def register():
"""用户注册"""
form = RegisterForm()
if form.validate_on_submit():
# 创建新用户
user = User(
username=form.username.data,
email=form.email.data,
first_name=form.first_name.data,
last_name=form.last_name.data
)
user.set_password(form.password.data)
# 生成邮箱验证令牌
token = user.generate_email_verification_token()
db.session.add(user)
db.session.commit()
# 发送验证邮件
send_verification_email(user, token)
flash('注册成功!请检查邮箱并验证您的账户', 'success')
return redirect(url_for('auth.login'))
return render_template('auth/register.html', form=form)
@auth_bp.route('/logout')
@login_required
def logout():
"""用户登出"""
username = current_user.username
logout_user()
flash(f'再见,{username}!', 'info')
return redirect(url_for('main.index'))
@auth_bp.route('/verify-email/<token>')
def verify_email(token):
"""验证邮箱"""
user = User.query.filter_by(email_verification_token=token).first()
if user and user.verify_email(token):
flash('邮箱验证成功!', 'success')
return redirect(url_for('auth.login'))
else:
flash('验证链接无效或已过期', 'error')
return redirect(url_for('main.index'))
@auth_bp.route('/forgot-password', methods=['GET', 'POST'])
@anonymous_required
def forgot_password():
"""忘记密码"""
form = ForgotPasswordForm()
if form.validate_on_submit():
user = User.query.filter_by(email=form.email.data).first()
if user:
# 生成重置令牌
token = user.generate_password_reset_token()
# 发送重置邮件
send_password_reset_email(user, token)
# 无论用户是否存在都显示相同消息(安全考虑)
flash('如果该邮箱存在,您将收到密码重置链接', 'info')
return redirect(url_for('auth.login'))
return render_template('auth/forgot_password.html', form=form)
@auth_bp.route('/reset-password/<token>', methods=['GET', 'POST'])
@anonymous_required
def reset_password(token):
"""重置密码"""
user = User.query.filter_by(password_reset_token=token).first()
if not user or not user.verify_password_reset_token(token):
flash('重置链接无效或已过期', 'error')
return redirect(url_for('auth.forgot_password'))
form = ResetPasswordForm()
if form.validate_on_submit():
user.reset_password(form.password.data)
flash('密码重置成功!请使用新密码登录', 'success')
return redirect(url_for('auth.login'))
return render_template('auth/reset_password.html', form=form)
@auth_bp.route('/resend-verification')
@login_required
def resend_verification():
"""重新发送验证邮件"""
if current_user.is_verified:
flash('您的邮箱已经验证过了', 'info')
return redirect(url_for('main.index'))
token = current_user.generate_email_verification_token()
send_verification_email(current_user, token)
flash('验证邮件已重新发送', 'info')
return redirect(url_for('main.index'))
def is_safe_url(target):
"""检查URL是否安全"""
from urllib.parse import urlparse, urljoin
ref_url = urlparse(request.host_url)
test_url = urlparse(urljoin(request.host_url, target))
return test_url.scheme in ('http', 'https') and ref_url.netloc == test_url.netloc
def send_verification_email(user, token):
"""发送验证邮件"""
verification_url = url_for('auth.verify_email', token=token, _external=True)
send_email(
to=user.email,
subject='验证您的邮箱',
template='auth/email/verify_email.html',
user=user,
verification_url=verification_url
)
def send_password_reset_email(user, token):
"""发送密码重置邮件"""
reset_url = url_for('auth.reset_password', token=token, _external=True)
send_email(
to=user.email,
subject='重置您的密码',
template='auth/email/reset_password.html',
user=user,
reset_url=reset_url
)
7.2 角色权限系统
7.2.1 角色权限模型
# models/auth.py
from models.base import BaseModel, db
# 用户角色关联表
user_roles = db.Table('user_roles',
db.Column('user_id', db.Integer, db.ForeignKey('users.id'), primary_key=True),
db.Column('role_id', db.Integer, db.ForeignKey('roles.id'), primary_key=True)
)
# 角色权限关联表
role_permissions = db.Table('role_permissions',
db.Column('role_id', db.Integer, db.ForeignKey('roles.id'), primary_key=True),
db.Column('permission_id', db.Integer, db.ForeignKey('permissions.id'), primary_key=True)
)
class Role(BaseModel):
"""角色模型"""
__tablename__ = 'roles'
name = db.Column(db.String(50), unique=True, nullable=False)
description = db.Column(db.String(255))
is_default = db.Column(db.Boolean, default=False) # 默认角色
# 关系
permissions = db.relationship('Permission', secondary=role_permissions, backref='roles')
def add_permission(self, permission):
"""添加权限"""
if not self.has_permission(permission.name):
self.permissions.append(permission)
db.session.commit()
def remove_permission(self, permission):
"""移除权限"""
if self.has_permission(permission.name):
self.permissions.remove(permission)
db.session.commit()
def has_permission(self, permission_name):
"""检查是否有指定权限"""
return any(perm.name == permission_name for perm in self.permissions)
def get_permissions_list(self):
"""获取权限列表"""
return [perm.name for perm in self.permissions]
@staticmethod
def get_default_role():
"""获取默认角色"""
return Role.query.filter_by(is_default=True).first()
def __repr__(self):
return f'<Role {self.name}>'
class Permission(BaseModel):
"""权限模型"""
__tablename__ = 'permissions'
name = db.Column(db.String(50), unique=True, nullable=False)
description = db.Column(db.String(255))
resource = db.Column(db.String(50)) # 资源类型
action = db.Column(db.String(50)) # 操作类型
def __repr__(self):
return f'<Permission {self.name}>'
class UserSession(BaseModel):
"""用户会话模型"""
__tablename__ = 'user_sessions'
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
session_id = db.Column(db.String(255), unique=True, nullable=False)
ip_address = db.Column(db.String(45))
user_agent = db.Column(db.Text)
expires_at = db.Column(db.DateTime, nullable=False)
is_active = db.Column(db.Boolean, default=True)
# 关系
user = db.relationship('User', backref='sessions')
def is_expired(self):
"""检查会话是否过期"""
from datetime import datetime
return datetime.utcnow() > self.expires_at
def revoke(self):
"""撤销会话"""
self.is_active = False
db.session.commit()
@classmethod
def cleanup_expired(cls):
"""清理过期会话"""
from datetime import datetime
expired_sessions = cls.query.filter(
cls.expires_at < datetime.utcnow()
).all()
for session in expired_sessions:
db.session.delete(session)
db.session.commit()
return len(expired_sessions)
def __repr__(self):
return f'<UserSession {self.session_id}>'
7.2.2 权限装饰器
# utils/decorators.py
from functools import wraps
from flask import abort, redirect, url_for, request, jsonify
from flask_login import current_user, login_required
def anonymous_required(f):
"""要求匿名用户(未登录)"""
@wraps(f)
def decorated_function(*args, **kwargs):
if current_user.is_authenticated:
return redirect(url_for('main.index'))
return f(*args, **kwargs)
return decorated_function
def verified_required(f):
"""要求已验证邮箱"""
@wraps(f)
@login_required
def decorated_function(*args, **kwargs):
if not current_user.is_verified:
if request.is_json:
return jsonify({'error': '需要验证邮箱'}), 403
else:
return redirect(url_for('auth.resend_verification'))
return f(*args, **kwargs)
return decorated_function
def admin_required(f):
"""要求管理员权限"""
@wraps(f)
@login_required
def decorated_function(*args, **kwargs):
if not current_user.is_admin:
abort(403)
return f(*args, **kwargs)
return decorated_function
def role_required(*role_names):
"""要求指定角色"""
def decorator(f):
@wraps(f)
@login_required
def decorated_function(*args, **kwargs):
if not any(current_user.has_role(role) for role in role_names):
abort(403)
return f(*args, **kwargs)
return decorated_function
return decorator
def permission_required(permission_name):
"""要求指定权限"""
def decorator(f):
@wraps(f)
@login_required
def decorated_function(*args, **kwargs):
if not current_user.has_permission(permission_name):
abort(403)
return f(*args, **kwargs)
return decorated_function
return decorator
def resource_permission_required(resource, action):
"""要求资源权限"""
def decorator(f):
@wraps(f)
@login_required
def decorated_function(*args, **kwargs):
permission_name = f"{resource}:{action}"
if not current_user.has_permission(permission_name):
abort(403)
return f(*args, **kwargs)
return decorated_function
return decorator
def owner_or_permission_required(permission_name, get_owner_id=None):
"""要求是所有者或有指定权限"""
def decorator(f):
@wraps(f)
@login_required
def decorated_function(*args, **kwargs):
# 检查是否有权限
if current_user.has_permission(permission_name):
return f(*args, **kwargs)
# 检查是否是所有者
if get_owner_id:
owner_id = get_owner_id(*args, **kwargs)
if current_user.id == owner_id:
return f(*args, **kwargs)
abort(403)
return decorated_function
return decorator
# 使用示例
@app.route('/admin/users')
@admin_required
def admin_users():
"""管理员用户列表"""
pass
@app.route('/posts/create')
@permission_required('post:create')
def create_post():
"""创建文章"""
pass
@app.route('/posts/<int:post_id>/edit')
@owner_or_permission_required('post:edit', lambda post_id: Post.query.get_or_404(post_id).author_id)
def edit_post(post_id):
"""编辑文章"""
pass
7.2.3 权限管理
# utils/permission_manager.py
from models.auth import Role, Permission, user_roles
from models.user import User
from flask import current_app
class PermissionManager:
"""权限管理器"""
# 预定义权限
DEFAULT_PERMISSIONS = {
# 用户权限
'user:view': '查看用户',
'user:edit': '编辑用户',
'user:delete': '删除用户',
'user:manage': '管理用户',
# 文章权限
'post:view': '查看文章',
'post:create': '创建文章',
'post:edit': '编辑文章',
'post:delete': '删除文章',
'post:publish': '发布文章',
'post:manage': '管理文章',
# 评论权限
'comment:view': '查看评论',
'comment:create': '创建评论',
'comment:edit': '编辑评论',
'comment:delete': '删除评论',
'comment:moderate': '审核评论',
# 系统权限
'system:admin': '系统管理',
'system:config': '系统配置',
'system:backup': '系统备份',
}
# 预定义角色
DEFAULT_ROLES = {
'user': {
'description': '普通用户',
'permissions': [
'post:view', 'post:create',
'comment:view', 'comment:create'
],
'is_default': True
},
'author': {
'description': '作者',
'permissions': [
'post:view', 'post:create', 'post:edit', 'post:delete',
'comment:view', 'comment:create', 'comment:edit'
]
},
'editor': {
'description': '编辑',
'permissions': [
'post:view', 'post:create', 'post:edit', 'post:delete', 'post:publish', 'post:manage',
'comment:view', 'comment:create', 'comment:edit', 'comment:delete', 'comment:moderate',
'user:view'
]
},
'admin': {
'description': '管理员',
'permissions': list(DEFAULT_PERMISSIONS.keys())
}
}
@classmethod
def init_permissions(cls):
"""初始化权限"""
for name, description in cls.DEFAULT_PERMISSIONS.items():
permission = Permission.query.filter_by(name=name).first()
if not permission:
resource, action = name.split(':', 1)
permission = Permission(
name=name,
description=description,
resource=resource,
action=action
)
db.session.add(permission)
db.session.commit()
current_app.logger.info('权限初始化完成')
@classmethod
def init_roles(cls):
"""初始化角色"""
cls.init_permissions() # 确保权限已初始化
for role_name, role_config in cls.DEFAULT_ROLES.items():
role = Role.query.filter_by(name=role_name).first()
if not role:
role = Role(
name=role_name,
description=role_config['description'],
is_default=role_config.get('is_default', False)
)
db.session.add(role)
db.session.flush() # 获取ID
# 添加权限
for perm_name in role_config['permissions']:
permission = Permission.query.filter_by(name=perm_name).first()
if permission and not role.has_permission(perm_name):
role.permissions.append(permission)
db.session.commit()
current_app.logger.info('角色初始化完成')
@classmethod
def assign_default_role(cls, user):
"""为用户分配默认角色"""
default_role = Role.get_default_role()
if default_role and not user.has_role(default_role.name):
user.roles.append(default_role)
db.session.commit()
@classmethod
def create_permission(cls, name, description, resource=None, action=None):
"""创建权限"""
if not resource and not action:
resource, action = name.split(':', 1)
permission = Permission(
name=name,
description=description,
resource=resource,
action=action
)
db.session.add(permission)
db.session.commit()
return permission
@classmethod
def create_role(cls, name, description, permissions=None):
"""创建角色"""
role = Role(
name=name,
description=description
)
db.session.add(role)
db.session.flush()
if permissions:
for perm_name in permissions:
permission = Permission.query.filter_by(name=perm_name).first()
if permission:
role.permissions.append(permission)
db.session.commit()
return role
@classmethod
def assign_role(cls, user, role_name):
"""为用户分配角色"""
role = Role.query.filter_by(name=role_name).first()
if role and not user.has_role(role_name):
user.roles.append(role)
db.session.commit()
return True
return False
@classmethod
def remove_role(cls, user, role_name):
"""移除用户角色"""
role = Role.query.filter_by(name=role_name).first()
if role and user.has_role(role_name):
user.roles.remove(role)
db.session.commit()
return True
return False
@classmethod
def get_user_permissions(cls, user):
"""获取用户所有权限"""
permissions = set()
for role in user.roles:
permissions.update(role.get_permissions_list())
return list(permissions)
@classmethod
def check_permission(cls, user, permission_name):
"""检查用户权限"""
return user.has_permission(permission_name)
@classmethod
def get_role_hierarchy(cls):
"""获取角色层级"""
return {
'admin': ['editor', 'author', 'user'],
'editor': ['author', 'user'],
'author': ['user'],
'user': []
}
@classmethod
def user_can_manage_role(cls, user, target_role_name):
"""检查用户是否可以管理指定角色"""
if user.is_admin:
return True
hierarchy = cls.get_role_hierarchy()
user_roles = [role.name for role in user.roles]
for role_name in user_roles:
if target_role_name in hierarchy.get(role_name, []):
return True
return False
7.3 JWT令牌认证
7.3.1 JWT配置
# utils/jwt_manager.py
import jwt
from datetime import datetime, timedelta
from flask import current_app, request, jsonify
from functools import wraps
from models.user import User
class JWTManager:
"""JWT管理器"""
@staticmethod
def generate_token(user, expires_delta=None):
"""生成JWT令牌"""
if expires_delta is None:
expires_delta = timedelta(hours=24)
payload = {
'user_id': user.id,
'username': user.username,
'email': user.email,
'roles': [role.name for role in user.roles],
'exp': datetime.utcnow() + expires_delta,
'iat': datetime.utcnow(),
'type': 'access'
}
return jwt.encode(
payload,
current_app.config['SECRET_KEY'],
algorithm='HS256'
)
@staticmethod
def generate_refresh_token(user, expires_delta=None):
"""生成刷新令牌"""
if expires_delta is None:
expires_delta = timedelta(days=30)
payload = {
'user_id': user.id,
'exp': datetime.utcnow() + expires_delta,
'iat': datetime.utcnow(),
'type': 'refresh'
}
return jwt.encode(
payload,
current_app.config['SECRET_KEY'],
algorithm='HS256'
)
@staticmethod
def decode_token(token):
"""解码JWT令牌"""
try:
payload = jwt.decode(
token,
current_app.config['SECRET_KEY'],
algorithms=['HS256']
)
return payload
except jwt.ExpiredSignatureError:
return {'error': 'Token已过期'}
except jwt.InvalidTokenError:
return {'error': 'Token无效'}
@staticmethod
def verify_token(token, token_type='access'):
"""验证令牌"""
payload = JWTManager.decode_token(token)
if 'error' in payload:
return None, payload['error']
if payload.get('type') != token_type:
return None, 'Token类型错误'
user = User.query.get(payload['user_id'])
if not user or not user.is_active:
return None, '用户不存在或已禁用'
return user, None
@staticmethod
def refresh_access_token(refresh_token):
"""刷新访问令牌"""
user, error = JWTManager.verify_token(refresh_token, 'refresh')
if error:
return None, error
# 生成新的访问令牌
new_access_token = JWTManager.generate_token(user)
return new_access_token, None
def jwt_required(f):
"""JWT认证装饰器"""
@wraps(f)
def decorated_function(*args, **kwargs):
token = None
# 从请求头获取令牌
auth_header = request.headers.get('Authorization')
if auth_header:
try:
token = auth_header.split(' ')[1] # Bearer <token>
except IndexError:
return jsonify({'error': 'Token格式错误'}), 401
if not token:
return jsonify({'error': '缺少认证令牌'}), 401
# 验证令牌
user, error = JWTManager.verify_token(token)
if error:
return jsonify({'error': error}), 401
# 将用户信息添加到请求上下文
request.current_user = user
return f(*args, **kwargs)
return decorated_function
def jwt_permission_required(permission_name):
"""JWT权限验证装饰器"""
def decorator(f):
@wraps(f)
@jwt_required
def decorated_function(*args, **kwargs):
if not request.current_user.has_permission(permission_name):
return jsonify({'error': '权限不足'}), 403
return f(*args, **kwargs)
return decorated_function
return decorator
7.3.2 API认证视图
# views/api_auth.py
from flask import Blueprint, request, jsonify
from utils.jwt_manager import JWTManager, jwt_required
from models.user import User
from forms.auth import LoginForm
api_auth_bp = Blueprint('api_auth', __name__, url_prefix='/api/auth')
@api_auth_bp.route('/login', methods=['POST'])
def api_login():
"""API登录"""
data = request.get_json()
if not data or not data.get('username') or not data.get('password'):
return jsonify({'error': '用户名和密码不能为空'}), 400
user = User.query.filter(
(User.username == data['username']) |
(User.email == data['username'])
).first()
if not user:
return jsonify({'error': '用户名或密码错误'}), 401
# 检查账户状态
if user.is_locked():
return jsonify({'error': '账户已被锁定'}), 423
if not user.is_active:
return jsonify({'error': '账户已被禁用'}), 403
# 验证密码
if not user.check_password(data['password']):
user.record_login_attempt(success=False)
return jsonify({'error': '用户名或密码错误'}), 401
# 记录成功登录
user.record_login_attempt(success=True, ip_address=request.remote_addr)
# 生成令牌
access_token = JWTManager.generate_token(user)
refresh_token = JWTManager.generate_refresh_token(user)
return jsonify({
'access_token': access_token,
'refresh_token': refresh_token,
'user': user.to_dict()
})
@api_auth_bp.route('/refresh', methods=['POST'])
def api_refresh():
"""刷新令牌"""
data = request.get_json()
refresh_token = data.get('refresh_token')
if not refresh_token:
return jsonify({'error': '缺少刷新令牌'}), 400
new_access_token, error = JWTManager.refresh_access_token(refresh_token)
if error:
return jsonify({'error': error}), 401
return jsonify({'access_token': new_access_token})
@api_auth_bp.route('/me', methods=['GET'])
@jwt_required
def api_me():
"""获取当前用户信息"""
return jsonify({
'user': request.current_user.to_dict(),
'permissions': request.current_user.get_permissions_list()
})
@api_auth_bp.route('/logout', methods=['POST'])
@jwt_required
def api_logout():
"""API登出(可以实现令牌黑名单)"""
# 这里可以将令牌加入黑名单
# 简单实现就是客户端删除令牌
return jsonify({'message': '登出成功'})
7.4 OAuth第三方登录
7.4.1 OAuth配置
# utils/oauth.py
from authlib.integrations.flask_client import OAuth
from flask import current_app, url_for, session, request
import requests
oauth = OAuth()
def init_oauth(app):
"""初始化OAuth"""
oauth.init_app(app)
# GitHub OAuth
oauth.register(
name='github',
client_id=app.config.get('GITHUB_CLIENT_ID'),
client_secret=app.config.get('GITHUB_CLIENT_SECRET'),
server_metadata_url='https://api.github.com/.well-known/oauth_authorization_server',
client_kwargs={
'scope': 'user:email'
}
)
# Google OAuth
oauth.register(
name='google',
client_id=app.config.get('GOOGLE_CLIENT_ID'),
client_secret=app.config.get('GOOGLE_CLIENT_SECRET'),
server_metadata_url='https://accounts.google.com/.well-known/openid_configuration',
client_kwargs={
'scope': 'openid email profile'
}
)
# 微信OAuth
oauth.register(
name='wechat',
client_id=app.config.get('WECHAT_APP_ID'),
client_secret=app.config.get('WECHAT_APP_SECRET'),
authorize_url='https://open.weixin.qq.com/connect/oauth2/authorize',
access_token_url='https://api.weixin.qq.com/sns/oauth2/access_token',
client_kwargs={
'scope': 'snsapi_userinfo'
}
)
class OAuthProvider:
"""OAuth提供商基类"""
def __init__(self, name):
self.name = name
self.client = oauth.create_client(name)
def get_authorize_url(self, redirect_uri):
"""获取授权URL"""
return self.client.create_authorization_url(redirect_uri)
def get_access_token(self, code, redirect_uri):
"""获取访问令牌"""
return self.client.authorize_access_token(redirect_uri=redirect_uri)
def get_user_info(self, token):
"""获取用户信息"""
raise NotImplementedError
class GitHubOAuth(OAuthProvider):
"""GitHub OAuth"""
def __init__(self):
super().__init__('github')
def get_user_info(self, token):
"""获取GitHub用户信息"""
resp = self.client.get('user', token=token)
user_data = resp.json()
# 获取邮箱信息
emails_resp = self.client.get('user/emails', token=token)
emails = emails_resp.json()
primary_email = next((email['email'] for email in emails if email['primary']), None)
return {
'provider': 'github',
'provider_id': str(user_data['id']),
'username': user_data['login'],
'email': primary_email or user_data.get('email'),
'name': user_data.get('name', ''),
'avatar': user_data.get('avatar_url'),
'profile_url': user_data.get('html_url')
}
class GoogleOAuth(OAuthProvider):
"""Google OAuth"""
def __init__(self):
super().__init__('google')
def get_user_info(self, token):
"""获取Google用户信息"""
resp = self.client.get('userinfo', token=token)
user_data = resp.json()
return {
'provider': 'google',
'provider_id': user_data['sub'],
'username': user_data.get('email', '').split('@')[0],
'email': user_data['email'],
'name': user_data.get('name', ''),
'avatar': user_data.get('picture'),
'verified': user_data.get('email_verified', False)
}
class WeChatOAuth(OAuthProvider):
"""微信OAuth"""
def __init__(self):
super().__init__('wechat')
def get_user_info(self, token):
"""获取微信用户信息"""
# 获取用户基本信息
user_info_url = 'https://api.weixin.qq.com/sns/userinfo'
params = {
'access_token': token['access_token'],
'openid': token['openid'],
'lang': 'zh_CN'
}
resp = requests.get(user_info_url, params=params)
user_data = resp.json()
return {
'provider': 'wechat',
'provider_id': user_data['openid'],
'username': user_data.get('nickname', ''),
'name': user_data.get('nickname', ''),
'avatar': user_data.get('headimgurl'),
'gender': user_data.get('sex'),
'city': user_data.get('city'),
'province': user_data.get('province'),
'country': user_data.get('country')
}
7.4.2 OAuth用户模型
# models/oauth.py
from models.base import BaseModel, db
from datetime import datetime
class OAuthAccount(BaseModel):
"""OAuth账户模型"""
__tablename__ = 'oauth_accounts'
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
provider = db.Column(db.String(50), nullable=False) # github, google, wechat等
provider_id = db.Column(db.String(100), nullable=False) # 第三方平台的用户ID
provider_username = db.Column(db.String(100)) # 第三方平台的用户名
access_token = db.Column(db.Text) # 访问令牌
refresh_token = db.Column(db.Text) # 刷新令牌
expires_at = db.Column(db.DateTime) # 令牌过期时间
# 额外信息
profile_url = db.Column(db.String(255)) # 个人主页
avatar_url = db.Column(db.String(255)) # 头像URL
extra_data = db.Column(db.JSON) # 其他数据
# 关系
user = db.relationship('User', backref='oauth_accounts')
# 唯一约束
__table_args__ = (
db.UniqueConstraint('provider', 'provider_id', name='unique_provider_account'),
)
def is_token_expired(self):
"""检查令牌是否过期"""
if not self.expires_at:
return False
return datetime.utcnow() > self.expires_at
def update_token(self, access_token, refresh_token=None, expires_at=None):
"""更新令牌"""
self.access_token = access_token
if refresh_token:
self.refresh_token = refresh_token
if expires_at:
self.expires_at = expires_at
db.session.commit()
@classmethod
def get_by_provider(cls, provider, provider_id):
"""根据提供商获取账户"""
return cls.query.filter_by(
provider=provider,
provider_id=provider_id
).first()
def __repr__(self):
return f'<OAuthAccount {self.provider}:{self.provider_id}>'
7.4.3 OAuth认证视图
# views/oauth.py
from flask import Blueprint, redirect, url_for, session, flash, request, current_app
from flask_login import login_user, current_user, login_required
from utils.oauth import GitHubOAuth, GoogleOAuth, WeChatOAuth
from models.user import User
from models.oauth import OAuthAccount
from utils.permission_manager import PermissionManager
oauth_bp = Blueprint('oauth', __name__, url_prefix='/oauth')
# OAuth提供商映射
OAUTH_PROVIDERS = {
'github': GitHubOAuth,
'google': GoogleOAuth,
'wechat': WeChatOAuth
}
@oauth_bp.route('/login/<provider>')
def oauth_login(provider):
"""OAuth登录"""
if provider not in OAUTH_PROVIDERS:
flash('不支持的登录方式', 'error')
return redirect(url_for('auth.login'))
oauth_provider = OAUTH_PROVIDERS[provider]()
redirect_uri = url_for('oauth.oauth_callback', provider=provider, _external=True)
authorization_url, state = oauth_provider.get_authorize_url(redirect_uri)
session[f'{provider}_oauth_state'] = state
return redirect(authorization_url)
@oauth_bp.route('/callback/<provider>')
def oauth_callback(provider):
"""OAuth回调"""
if provider not in OAUTH_PROVIDERS:
flash('不支持的登录方式', 'error')
return redirect(url_for('auth.login'))
# 验证state参数
state = request.args.get('state')
if state != session.get(f'{provider}_oauth_state'):
flash('OAuth状态验证失败', 'error')
return redirect(url_for('auth.login'))
oauth_provider = OAUTH_PROVIDERS[provider]()
redirect_uri = url_for('oauth.oauth_callback', provider=provider, _external=True)
try:
# 获取访问令牌
token = oauth_provider.get_access_token(
request.args.get('code'),
redirect_uri
)
# 获取用户信息
user_info = oauth_provider.get_user_info(token)
# 处理OAuth登录
user = handle_oauth_login(user_info, token)
if user:
login_user(user, remember=True)
flash(f'欢迎使用{provider.title()}登录!', 'success')
next_page = session.get('next_page')
if next_page:
session.pop('next_page', None)
return redirect(next_page)
return redirect(url_for('main.index'))
else:
flash('OAuth登录失败', 'error')
return redirect(url_for('auth.login'))
except Exception as e:
current_app.logger.error(f'OAuth登录错误: {e}')
flash('OAuth登录过程中发生错误', 'error')
return redirect(url_for('auth.login'))
@oauth_bp.route('/bind/<provider>')
@login_required
def oauth_bind(provider):
"""绑定OAuth账户"""
if provider not in OAUTH_PROVIDERS:
flash('不支持的绑定方式', 'error')
return redirect(url_for('user.profile'))
oauth_provider = OAUTH_PROVIDERS[provider]()
redirect_uri = url_for('oauth.oauth_bind_callback', provider=provider, _external=True)
authorization_url, state = oauth_provider.get_authorize_url(redirect_uri)
session[f'{provider}_bind_state'] = state
return redirect(authorization_url)
@oauth_bp.route('/bind-callback/<provider>')
@login_required
def oauth_bind_callback(provider):
"""OAuth绑定回调"""
if provider not in OAUTH_PROVIDERS:
flash('不支持的绑定方式', 'error')
return redirect(url_for('user.profile'))
# 验证state参数
state = request.args.get('state')
if state != session.get(f'{provider}_bind_state'):
flash('OAuth状态验证失败', 'error')
return redirect(url_for('user.profile'))
oauth_provider = OAUTH_PROVIDERS[provider]()
redirect_uri = url_for('oauth.oauth_bind_callback', provider=provider, _external=True)
try:
# 获取访问令牌
token = oauth_provider.get_access_token(
request.args.get('code'),
redirect_uri
)
# 获取用户信息
user_info = oauth_provider.get_user_info(token)
# 检查是否已绑定其他账户
existing_oauth = OAuthAccount.get_by_provider(
user_info['provider'],
user_info['provider_id']
)
if existing_oauth and existing_oauth.user_id != current_user.id:
flash(f'该{provider.title()}账户已绑定其他用户', 'error')
return redirect(url_for('user.profile'))
# 绑定账户
if not existing_oauth:
oauth_account = OAuthAccount(
user_id=current_user.id,
provider=user_info['provider'],
provider_id=user_info['provider_id'],
provider_username=user_info.get('username'),
access_token=token.get('access_token'),
refresh_token=token.get('refresh_token'),
profile_url=user_info.get('profile_url'),
avatar_url=user_info.get('avatar'),
extra_data=user_info
)
if 'expires_in' in token:
from datetime import timedelta
oauth_account.expires_at = datetime.utcnow() + timedelta(seconds=token['expires_in'])
db.session.add(oauth_account)
db.session.commit()
flash(f'{provider.title()}账户绑定成功!', 'success')
else:
flash(f'{provider.title()}账户已经绑定', 'info')
except Exception as e:
current_app.logger.error(f'OAuth绑定错误: {e}')
flash('OAuth绑定过程中发生错误', 'error')
return redirect(url_for('user.profile'))
@oauth_bp.route('/unbind/<provider>', methods=['POST'])
@login_required
def oauth_unbind(provider):
"""解绑OAuth账户"""
oauth_account = OAuthAccount.query.filter_by(
user_id=current_user.id,
provider=provider
).first()
if oauth_account:
db.session.delete(oauth_account)
db.session.commit()
flash(f'{provider.title()}账户解绑成功!', 'success')
else:
flash(f'未找到{provider.title()}绑定账户', 'error')
return redirect(url_for('user.profile'))
def handle_oauth_login(user_info, token):
"""处理OAuth登录"""
# 查找已存在的OAuth账户
oauth_account = OAuthAccount.get_by_provider(
user_info['provider'],
user_info['provider_id']
)
if oauth_account:
# 更新令牌信息
oauth_account.update_token(
token.get('access_token'),
token.get('refresh_token')
)
# 更新用户信息
user = oauth_account.user
if user_info.get('avatar') and not user.avatar:
user.avatar = user_info['avatar']
db.session.commit()
return user
else:
# 检查是否已有相同邮箱的用户
existing_user = None
if user_info.get('email'):
existing_user = User.query.filter_by(email=user_info['email']).first()
if existing_user:
# 绑定到现有用户
oauth_account = OAuthAccount(
user_id=existing_user.id,
provider=user_info['provider'],
provider_id=user_info['provider_id'],
provider_username=user_info.get('username'),
access_token=token.get('access_token'),
refresh_token=token.get('refresh_token'),
profile_url=user_info.get('profile_url'),
avatar_url=user_info.get('avatar'),
extra_data=user_info
)
if 'expires_in' in token:
from datetime import timedelta
oauth_account.expires_at = datetime.utcnow() + timedelta(seconds=token['expires_in'])
db.session.add(oauth_account)
db.session.commit()
return existing_user
else:
# 创建新用户
user = User(
username=generate_unique_username(user_info.get('username', user_info['provider_id'])),
email=user_info.get('email', ''),
first_name=user_info.get('name', '').split(' ')[0] if user_info.get('name') else '',
last_name=' '.join(user_info.get('name', '').split(' ')[1:]) if user_info.get('name') else '',
avatar=user_info.get('avatar'),
is_verified=user_info.get('verified', True) # OAuth用户默认已验证
)
# 设置随机密码(OAuth用户不使用密码登录)
import secrets
user.set_password(secrets.token_urlsafe(32))
db.session.add(user)
db.session.flush() # 获取用户ID
# 分配默认角色
PermissionManager.assign_default_role(user)
# 创建OAuth账户记录
oauth_account = OAuthAccount(
user_id=user.id,
provider=user_info['provider'],
provider_id=user_info['provider_id'],
provider_username=user_info.get('username'),
access_token=token.get('access_token'),
refresh_token=token.get('refresh_token'),
profile_url=user_info.get('profile_url'),
avatar_url=user_info.get('avatar'),
extra_data=user_info
)
if 'expires_in' in token:
from datetime import timedelta
oauth_account.expires_at = datetime.utcnow() + timedelta(seconds=token['expires_in'])
db.session.add(oauth_account)
db.session.commit()
return user
def generate_unique_username(base_username):
"""生成唯一用户名"""
username = base_username
counter = 1
while User.query.filter_by(username=username).first():
username = f"{base_username}{counter}"
counter += 1
return username
7.5 认证表单
7.5.1 登录注册表单
# forms/auth.py
from flask_wtf import FlaskForm
from wtforms import StringField, PasswordField, BooleanField, SubmitField, TextAreaField
from wtforms.validators import DataRequired, Email, Length, EqualTo, ValidationError
from models.user import User
import re
class LoginForm(FlaskForm):
"""登录表单"""
username = StringField('用户名/邮箱', validators=[
DataRequired(message='请输入用户名或邮箱')
])
password = PasswordField('密码', validators=[
DataRequired(message='请输入密码')
])
remember_me = BooleanField('记住我')
submit = SubmitField('登录')
class RegisterForm(FlaskForm):
"""注册表单"""
username = StringField('用户名', validators=[
DataRequired(message='请输入用户名'),
Length(min=3, max=20, message='用户名长度必须在3-20个字符之间')
])
email = StringField('邮箱', validators=[
DataRequired(message='请输入邮箱'),
Email(message='请输入有效的邮箱地址')
])
first_name = StringField('名字', validators=[
Length(max=50, message='名字长度不能超过50个字符')
])
last_name = StringField('姓氏', validators=[
Length(max=50, message='姓氏长度不能超过50个字符')
])
password = PasswordField('密码', validators=[
DataRequired(message='请输入密码'),
Length(min=8, message='密码长度至少8个字符')
])
password2 = PasswordField('确认密码', validators=[
DataRequired(message='请确认密码'),
EqualTo('password', message='两次输入的密码不一致')
])
submit = SubmitField('注册')
def validate_username(self, username):
"""验证用户名"""
# 检查用户名格式
if not re.match(r'^[a-zA-Z0-9_]+$', username.data):
raise ValidationError('用户名只能包含字母、数字和下划线')
# 检查用户名是否已存在
user = User.query.filter_by(username=username.data).first()
if user:
raise ValidationError('该用户名已被使用')
def validate_email(self, email):
"""验证邮箱"""
user = User.query.filter_by(email=email.data).first()
if user:
raise ValidationError('该邮箱已被注册')
def validate_password(self, password):
"""验证密码强度"""
password_str = password.data
# 检查是否包含数字
if not re.search(r'\d', password_str):
raise ValidationError('密码必须包含至少一个数字')
# 检查是否包含字母
if not re.search(r'[a-zA-Z]', password_str):
raise ValidationError('密码必须包含至少一个字母')
# 检查是否包含特殊字符
if not re.search(r'[!@#$%^&*(),.?":{}|<>]', password_str):
raise ValidationError('密码必须包含至少一个特殊字符')
class ForgotPasswordForm(FlaskForm):
"""忘记密码表单"""
email = StringField('邮箱', validators=[
DataRequired(message='请输入邮箱'),
Email(message='请输入有效的邮箱地址')
])
submit = SubmitField('发送重置链接')
class ResetPasswordForm(FlaskForm):
"""重置密码表单"""
password = PasswordField('新密码', validators=[
DataRequired(message='请输入新密码'),
Length(min=8, message='密码长度至少8个字符')
])
password2 = PasswordField('确认新密码', validators=[
DataRequired(message='请确认新密码'),
EqualTo('password', message='两次输入的密码不一致')
])
submit = SubmitField('重置密码')
def validate_password(self, password):
"""验证密码强度"""
password_str = password.data
if not re.search(r'\d', password_str):
raise ValidationError('密码必须包含至少一个数字')
if not re.search(r'[a-zA-Z]', password_str):
raise ValidationError('密码必须包含至少一个字母')
if not re.search(r'[!@#$%^&*(),.?":{}|<>]', password_str):
raise ValidationError('密码必须包含至少一个特殊字符')
class ChangePasswordForm(FlaskForm):
"""修改密码表单"""
current_password = PasswordField('当前密码', validators=[
DataRequired(message='请输入当前密码')
])
new_password = PasswordField('新密码', validators=[
DataRequired(message='请输入新密码'),
Length(min=8, message='密码长度至少8个字符')
])
new_password2 = PasswordField('确认新密码', validators=[
DataRequired(message='请确认新密码'),
EqualTo('new_password', message='两次输入的密码不一致')
])
submit = SubmitField('修改密码')
def validate_new_password(self, new_password):
"""验证新密码强度"""
password_str = new_password.data
if not re.search(r'\d', password_str):
raise ValidationError('密码必须包含至少一个数字')
if not re.search(r'[a-zA-Z]', password_str):
raise ValidationError('密码必须包含至少一个字母')
if not re.search(r'[!@#$%^&*(),.?":{}|<>]', password_str):
raise ValidationError('密码必须包含至少一个特殊字符')
class TwoFactorSetupForm(FlaskForm):
"""双因素认证设置表单"""
verification_code = StringField('验证码', validators=[
DataRequired(message='请输入验证码'),
Length(min=6, max=6, message='验证码必须是6位数字')
])
submit = SubmitField('启用双因素认证')
def validate_verification_code(self, verification_code):
"""验证验证码格式"""
if not verification_code.data.isdigit():
raise ValidationError('验证码必须是6位数字')
class TwoFactorForm(FlaskForm):
"""双因素认证表单"""
verification_code = StringField('验证码', validators=[
DataRequired(message='请输入验证码'),
Length(min=6, max=6, message='验证码必须是6位数字')
])
submit = SubmitField('验证')
def validate_verification_code(self, verification_code):
"""验证验证码格式"""
if not verification_code.data.isdigit():
raise ValidationError('验证码必须是6位数字')
7.5.2 认证模板
<!-- templates/auth/login.html -->
{% extends "base.html" %}
{% block title %}用户登录{% endblock %}
{% block content %}
<div class="auth-container">
<div class="auth-card">
<h2>用户登录</h2>
<form method="POST" class="auth-form">
{{ form.hidden_tag() }}
<div class="form-group">
{{ form.username.label(class="form-label") }}
{{ form.username(class="form-control") }}
{% if form.username.errors %}
<div class="error-messages">
{% for error in form.username.errors %}
<span class="error">{{ error }}</span>
{% endfor %}
</div>
{% endif %}
</div>
<div class="form-group">
{{ form.password.label(class="form-label") }}
{{ form.password(class="form-control") }}
{% if form.password.errors %}
<div class="error-messages">
{% for error in form.password.errors %}
<span class="error">{{ error }}</span>
{% endfor %}
</div>
{% endif %}
</div>
<div class="form-group checkbox-group">
{{ form.remember_me() }}
{{ form.remember_me.label() }}
</div>
<div class="form-group">
{{ form.submit(class="btn btn-primary btn-block") }}
</div>
</form>
<div class="auth-links">
<p><a href="{{ url_for('auth.forgot_password') }}">忘记密码?</a></p>
<p>还没有账户?<a href="{{ url_for('auth.register') }}">立即注册</a></p>
</div>
<!-- OAuth登录 -->
<div class="oauth-section">
<div class="divider">
<span>或使用第三方登录</span>
</div>
<div class="oauth-buttons">
<a href="{{ url_for('oauth.oauth_login', provider='github') }}" class="btn btn-github">
<i class="fab fa-github"></i> GitHub登录
</a>
<a href="{{ url_for('oauth.oauth_login', provider='google') }}" class="btn btn-google">
<i class="fab fa-google"></i> Google登录
</a>
</div>
</div>
</div>
</div>
{% endblock %}
{% block styles %}
<style>
.auth-container {
display: flex;
justify-content: center;
align-items: center;
min-height: 80vh;
padding: 20px;
}
.auth-card {
background: white;
border-radius: 8px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
padding: 40px;
width: 100%;
max-width: 400px;
}
.auth-card h2 {
text-align: center;
margin-bottom: 30px;
color: #333;
}
.form-group {
margin-bottom: 20px;
}
.form-label {
display: block;
margin-bottom: 5px;
font-weight: 500;
color: #555;
}
.form-control {
width: 100%;
padding: 12px;
border: 1px solid #ddd;
border-radius: 4px;
font-size: 14px;
transition: border-color 0.3s;
}
.form-control:focus {
outline: none;
border-color: #007bff;
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
}
.checkbox-group {
display: flex;
align-items: center;
gap: 8px;
}
.btn {
padding: 12px 24px;
border: none;
border-radius: 4px;
font-size: 14px;
font-weight: 500;
text-decoration: none;
text-align: center;
cursor: pointer;
transition: all 0.3s;
display: inline-block;
}
.btn-primary {
background-color: #007bff;
color: white;
}
.btn-primary:hover {
background-color: #0056b3;
}
.btn-block {
width: 100%;
}
.auth-links {
text-align: center;
margin-top: 20px;
}
.auth-links p {
margin: 10px 0;
}
.auth-links a {
color: #007bff;
text-decoration: none;
}
.auth-links a:hover {
text-decoration: underline;
}
.oauth-section {
margin-top: 30px;
}
.divider {
text-align: center;
margin: 20px 0;
position: relative;
}
.divider::before {
content: '';
position: absolute;
top: 50%;
left: 0;
right: 0;
height: 1px;
background: #ddd;
}
.divider span {
background: white;
padding: 0 15px;
color: #666;
font-size: 12px;
}
.oauth-buttons {
display: flex;
flex-direction: column;
gap: 10px;
}
.btn-github {
background-color: #333;
color: white;
}
.btn-github:hover {
background-color: #24292e;
}
.btn-google {
background-color: #db4437;
color: white;
}
.btn-google:hover {
background-color: #c23321;
}
.error-messages {
margin-top: 5px;
}
.error {
color: #dc3545;
font-size: 12px;
display: block;
}
</style>
{% endblock %}
<!-- templates/auth/register.html -->
{% extends "base.html" %}
{% block title %}用户注册{% endblock %}
{% block content %}
<div class="auth-container">
<div class="auth-card">
<h2>用户注册</h2>
<form method="POST" class="auth-form">
{{ form.hidden_tag() }}
<div class="form-row">
<div class="form-group half">
{{ form.first_name.label(class="form-label") }}
{{ form.first_name(class="form-control") }}
{% if form.first_name.errors %}
<div class="error-messages">
{% for error in form.first_name.errors %}
<span class="error">{{ error }}</span>
{% endfor %}
</div>
{% endif %}
</div>
<div class="form-group half">
{{ form.last_name.label(class="form-label") }}
{{ form.last_name(class="form-control") }}
{% if form.last_name.errors %}
<div class="error-messages">
{% for error in form.last_name.errors %}
<span class="error">{{ error }}</span>
{% endfor %}
</div>
{% endif %}
</div>
</div>
<div class="form-group">
{{ form.username.label(class="form-label") }}
{{ form.username(class="form-control") }}
{% if form.username.errors %}
<div class="error-messages">
{% for error in form.username.errors %}
<span class="error">{{ error }}</span>
{% endfor %}
</div>
{% endif %}
</div>
<div class="form-group">
{{ form.email.label(class="form-label") }}
{{ form.email(class="form-control") }}
{% if form.email.errors %}
<div class="error-messages">
{% for error in form.email.errors %}
<span class="error">{{ error }}</span>
{% endfor %}
</div>
{% endif %}
</div>
<div class="form-group">
{{ form.password.label(class="form-label") }}
{{ form.password(class="form-control", id="password") }}
<div class="password-strength" id="password-strength"></div>
{% if form.password.errors %}
<div class="error-messages">
{% for error in form.password.errors %}
<span class="error">{{ error }}</span>
{% endfor %}
</div>
{% endif %}
</div>
<div class="form-group">
{{ form.password2.label(class="form-label") }}
{{ form.password2(class="form-control") }}
{% if form.password2.errors %}
<div class="error-messages">
{% for error in form.password2.errors %}
<span class="error">{{ error }}</span>
{% endfor %}
</div>
{% endif %}
</div>
<div class="form-group">
{{ form.submit(class="btn btn-primary btn-block") }}
</div>
</form>
<div class="auth-links">
<p>已有账户?<a href="{{ url_for('auth.login') }}">立即登录</a></p>
</div>
</div>
</div>
{% endblock %}
{% block scripts %}
<script>
// 密码强度检测
document.getElementById('password').addEventListener('input', function(e) {
const password = e.target.value;
const strengthDiv = document.getElementById('password-strength');
let strength = 0;
let feedback = [];
// 长度检查
if (password.length >= 8) {
strength += 1;
} else {
feedback.push('至少8个字符');
}
// 数字检查
if (/\d/.test(password)) {
strength += 1;
} else {
feedback.push('包含数字');
}
// 字母检查
if (/[a-zA-Z]/.test(password)) {
strength += 1;
} else {
feedback.push('包含字母');
}
// 特殊字符检查
if (/[!@#$%^&*(),.?":{}|<>]/.test(password)) {
strength += 1;
} else {
feedback.push('包含特殊字符');
}
// 显示强度
const strengthTexts = ['很弱', '弱', '中等', '强'];
const strengthColors = ['#dc3545', '#fd7e14', '#ffc107', '#28a745'];
if (password.length > 0) {
strengthDiv.innerHTML = `
<div class="strength-bar">
<div class="strength-fill" style="width: ${(strength / 4) * 100}%; background-color: ${strengthColors[strength - 1] || '#dc3545'}"></div>
</div>
<div class="strength-text">密码强度: ${strengthTexts[strength - 1] || '很弱'}</div>
${feedback.length > 0 ? `<div class="strength-feedback">需要: ${feedback.join(', ')}</div>` : ''}
`;
} else {
strengthDiv.innerHTML = '';
}
});
</script>
{% endblock %}
7.6 安全最佳实践
7.6.1 密码安全
# utils/security.py
import secrets
import hashlib
from datetime import datetime, timedelta
from flask import current_app, request
import re
class PasswordSecurity:
"""密码安全工具"""
@staticmethod
def generate_secure_password(length=12):
"""生成安全密码"""
import string
# 确保包含各种字符类型
lowercase = string.ascii_lowercase
uppercase = string.ascii_uppercase
digits = string.digits
special = '!@#$%^&*()'
# 至少包含一个每种类型的字符
password = [
secrets.choice(lowercase),
secrets.choice(uppercase),
secrets.choice(digits),
secrets.choice(special)
]
# 填充剩余长度
all_chars = lowercase + uppercase + digits + special
for _ in range(length - 4):
password.append(secrets.choice(all_chars))
# 随机打乱
secrets.SystemRandom().shuffle(password)
return ''.join(password)
@staticmethod
def check_password_strength(password):
"""检查密码强度"""
score = 0
feedback = []
# 长度检查
if len(password) >= 8:
score += 1
else:
feedback.append('密码长度至少8个字符')
if len(password) >= 12:
score += 1
# 字符类型检查
if re.search(r'[a-z]', password):
score += 1
else:
feedback.append('包含小写字母')
if re.search(r'[A-Z]', password):
score += 1
else:
feedback.append('包含大写字母')
if re.search(r'\d', password):
score += 1
else:
feedback.append('包含数字')
if re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
score += 1
else:
feedback.append('包含特殊字符')
# 复杂性检查
if not re.search(r'(.)\1{2,}', password): # 没有连续重复字符
score += 1
else:
feedback.append('避免连续重复字符')
# 常见密码检查
common_passwords = [
'password', '123456', 'qwerty', 'abc123',
'password123', 'admin', 'letmein'
]
if password.lower() not in common_passwords:
score += 1
else:
feedback.append('避免使用常见密码')
strength_levels = {
0: '极弱',
1: '很弱',
2: '很弱',
3: '弱',
4: '弱',
5: '中等',
6: '强',
7: '很强',
8: '极强'
}
return {
'score': score,
'max_score': 8,
'strength': strength_levels.get(score, '未知'),
'feedback': feedback,
'is_strong': score >= 6
}
@staticmethod
def is_password_compromised(password):
"""检查密码是否在已泄露密码库中(使用Have I Been Pwned API)"""
import requests
# 使用SHA-1哈希的前5位查询
sha1_hash = hashlib.sha1(password.encode('utf-8')).hexdigest().upper()
prefix = sha1_hash[:5]
suffix = sha1_hash[5:]
try:
response = requests.get(
f'https://api.pwnedpasswords.com/range/{prefix}',
timeout=5
)
if response.status_code == 200:
hashes = response.text.split('\n')
for hash_line in hashes:
if hash_line.startswith(suffix):
count = int(hash_line.split(':')[1])
return True, count
return False, 0
else:
return None, 0 # API不可用
except:
return None, 0 # 网络错误
class SecurityHeaders:
"""安全头部管理"""
@staticmethod
def add_security_headers(response):
"""添加安全头部"""
# 防止XSS攻击
response.headers['X-Content-Type-Options'] = 'nosniff'
response.headers['X-Frame-Options'] = 'DENY'
response.headers['X-XSS-Protection'] = '1; mode=block'
# HTTPS相关
if request.is_secure:
response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
# 内容安全策略
csp = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline' https://cdnjs.cloudflare.com; "
"style-src 'self' 'unsafe-inline' https://cdnjs.cloudflare.com; "
"img-src 'self' data: https:; "
"font-src 'self' https://cdnjs.cloudflare.com;"
)
response.headers['Content-Security-Policy'] = csp
# 推荐人策略
response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin'
return response
class RateLimiter:
"""速率限制器"""
def __init__(self, redis_client=None):
self.redis = redis_client
def is_rate_limited(self, key, limit, window):
"""检查是否超过速率限制"""
if not self.redis:
return False
try:
current = self.redis.get(key)
if current is None:
self.redis.setex(key, window, 1)
return False
if int(current) >= limit:
return True
self.redis.incr(key)
return False
except:
return False # Redis不可用时不限制
def get_remaining_attempts(self, key, limit):
"""获取剩余尝试次数"""
if not self.redis:
return limit
try:
current = self.redis.get(key)
if current is None:
return limit
return max(0, limit - int(current))
except:
return limit
class IPSecurity:
"""IP安全检查"""
# 可疑IP模式
SUSPICIOUS_PATTERNS = [
r'^10\.', # 私有网络
r'^172\.(1[6-9]|2[0-9]|3[0-1])\.', # 私有网络
r'^192\.168\.', # 私有网络
r'^127\.', # 本地回环
]
@staticmethod
def get_real_ip():
"""获取真实IP地址"""
# 检查代理头部
forwarded_ips = request.headers.get('X-Forwarded-For')
if forwarded_ips:
# 取第一个IP(客户端真实IP)
return forwarded_ips.split(',')[0].strip()
real_ip = request.headers.get('X-Real-IP')
if real_ip:
return real_ip
return request.remote_addr
@staticmethod
def is_suspicious_ip(ip):
"""检查是否为可疑IP"""
for pattern in IPSecurity.SUSPICIOUS_PATTERNS:
if re.match(pattern, ip):
return True
return False
@staticmethod
def log_security_event(event_type, details, ip=None):
"""记录安全事件"""
if ip is None:
ip = IPSecurity.get_real_ip()
security_log = {
'timestamp': datetime.utcnow().isoformat(),
'event_type': event_type,
'ip_address': ip,
'user_agent': request.headers.get('User-Agent', ''),
'details': details
}
current_app.logger.warning(f'Security Event: {security_log}')
# 这里可以集成到安全监控系统
# 例如发送到SIEM、告警系统等
7.6.2 会话安全
# utils/session_security.py
from flask import session, request, current_app
from datetime import datetime, timedelta
import secrets
import hashlib
class SessionSecurity:
"""会话安全管理"""
@staticmethod
def regenerate_session_id():
"""重新生成会话ID"""
# 保存当前会话数据
session_data = dict(session)
# 清除当前会话
session.clear()
# 恢复会话数据
session.update(session_data)
# 设置新的会话标识
session['_session_id'] = secrets.token_hex(16)
session['_created_at'] = datetime.utcnow().isoformat()
@staticmethod
def validate_session():
"""验证会话有效性"""
# 检查会话是否存在
if '_session_id' not in session:
return False
# 检查会话是否过期
created_at = session.get('_created_at')
if created_at:
created_time = datetime.fromisoformat(created_at)
max_age = timedelta(hours=current_app.config.get('SESSION_TIMEOUT', 24))
if datetime.utcnow() - created_time > max_age:
return False
# 检查IP地址是否变化(可选)
if current_app.config.get('SESSION_IP_CHECK', False):
session_ip = session.get('_ip_address')
current_ip = request.remote_addr
if session_ip and session_ip != current_ip:
return False
# 检查User-Agent是否变化(可选)
if current_app.config.get('SESSION_UA_CHECK', False):
session_ua = session.get('_user_agent')
current_ua = request.headers.get('User-Agent', '')
if session_ua and session_ua != current_ua:
return False
return True
@staticmethod
def secure_session():
"""加固会话安全"""
# 设置会话指纹
session['_ip_address'] = request.remote_addr
session['_user_agent'] = request.headers.get('User-Agent', '')
session['_last_activity'] = datetime.utcnow().isoformat()
# 生成CSRF令牌
if '_csrf_token' not in session:
session['_csrf_token'] = secrets.token_hex(16)
@staticmethod
def cleanup_expired_sessions():
"""清理过期会话(需要自定义会话存储)"""
# 这里需要根据具体的会话存储实现
# 例如Redis、数据库等
pass
class CSRFProtection:
"""CSRF保护"""
@staticmethod
def generate_csrf_token():
"""生成CSRF令牌"""
if '_csrf_token' not in session:
session['_csrf_token'] = secrets.token_hex(16)
return session['_csrf_token']
@staticmethod
def validate_csrf_token(token):
"""验证CSRF令牌"""
session_token = session.get('_csrf_token')
return session_token and secrets.compare_digest(session_token, token)
@staticmethod
def csrf_protect():
"""CSRF保护装饰器"""
def decorator(f):
from functools import wraps
@wraps(f)
def decorated_function(*args, **kwargs):
if request.method == 'POST':
token = request.form.get('csrf_token') or request.headers.get('X-CSRF-Token')
if not token or not CSRFProtection.validate_csrf_token(token):
from flask import abort
abort(403)
return f(*args, **kwargs)
return decorated_function
return decorator
7.7 本章小结
7.7.1 技术要点总结
本章详细介绍了Flask应用中的用户认证与授权系统,主要包括:
认证机制: - Flask-Login基础配置和用户模型扩展 - 传统用户名密码认证流程 - JWT令牌认证机制 - OAuth第三方登录集成
授权系统: - 基于角色的权限控制(RBAC) - 权限装饰器和中间件 - 动态权限管理 - 资源级别的访问控制
安全最佳实践: - 密码安全策略和强度检测 - 会话安全管理 - CSRF保护机制 - 安全头部配置 - 速率限制和IP安全
7.7.2 安全考虑
密码安全: - 使用强密码策略 - 密码哈希存储 - 防止密码泄露检测 - 账户锁定机制
会话安全: - 会话ID重新生成 - 会话超时控制 - 会话指纹验证 - 安全的会话存储
传输安全: - HTTPS强制使用 - 安全头部配置 - CSRF保护 - XSS防护
7.7.3 性能优化
认证优化: - JWT令牌减少数据库查询 - 权限缓存机制 - 会话存储优化 - OAuth令牌管理
授权优化: - 权限预加载 - 角色层级缓存 - 批量权限检查 - 权限继承机制
7.7.4 开发建议
- 安全优先:始终将安全性放在首位,不要为了便利而牺牲安全
- 最小权限原则:用户只应获得完成任务所需的最小权限
- 防御深度:实施多层安全防护,不依赖单一安全机制
- 定期审计:定期审查用户权限和安全日志
- 安全测试:进行渗透测试和安全扫描
7.8 下一章预告
下一章我们将学习Flask缓存与性能优化,内容包括:
- 缓存策略和实现
- Redis集成和使用
- 数据库查询优化
- 静态资源优化
- 应用性能监控
- 负载测试和调优
7.9 练习题
基础练习
用户认证系统:
- 实现完整的用户注册、登录、登出功能
- 添加邮箱验证和密码重置功能
- 实现记住我功能
权限管理:
- 创建角色权限管理界面
- 实现用户角色分配功能
- 添加权限检查装饰器
OAuth集成:
- 集成GitHub OAuth登录
- 实现账户绑定和解绑功能
- 处理OAuth登录错误
进阶练习
双因素认证:
- 实现TOTP双因素认证
- 添加备用恢复码
- 创建安全设置页面
安全加固:
- 实现登录尝试限制
- 添加可疑活动检测
- 创建安全日志系统
API认证:
- 实现JWT API认证
- 添加API密钥管理
- 创建API权限控制
挑战练习
单点登录(SSO):
- 实现SAML SSO
- 创建身份提供商集成
- 处理跨域认证
高级权限系统:
- 实现基于属性的访问控制(ABAC)
- 添加动态权限计算
- 创建权限审计系统
安全监控:
- 实现实时安全监控
- 添加异常行为检测
- 创建安全告警系统