10.1 测试基础
10.1.1 测试概念和重要性
测试是软件开发过程中的重要环节,它确保代码的质量、可靠性和可维护性。在Flask应用开发中,良好的测试策略包括:
测试类型: - 单元测试:测试单个函数或方法 - 集成测试:测试组件之间的交互 - 功能测试:测试完整的用户场景 - 性能测试:测试应用的性能表现
测试原则: - 自动化:测试应该能够自动运行 - 独立性:测试之间不应该相互依赖 - 可重复:测试结果应该是一致的 - 快速:测试应该快速执行
10.1.2 Flask测试环境配置
# tests/conftest.py
import pytest
import tempfile
import os
from app import create_app, db
from app.models import User, Role
from flask import current_app
@pytest.fixture(scope='session')
def app():
"""创建测试应用实例"""
# 创建临时数据库文件
db_fd, db_path = tempfile.mkstemp()
# 测试配置
test_config = {
'TESTING': True,
'SQLALCHEMY_DATABASE_URI': f'sqlite:///{db_path}',
'SQLALCHEMY_TRACK_MODIFICATIONS': False,
'WTF_CSRF_ENABLED': False,
'SECRET_KEY': 'test-secret-key',
'CACHE_TYPE': 'simple',
'REDIS_URL': 'redis://localhost:6379/15' # 使用测试数据库
}
app = create_app(test_config)
with app.app_context():
db.create_all()
# 创建测试数据
create_test_data()
yield app
# 清理
db.drop_all()
os.close(db_fd)
os.unlink(db_path)
@pytest.fixture
def client(app):
"""创建测试客户端"""
return app.test_client()
@pytest.fixture
def runner(app):
"""创建CLI测试运行器"""
return app.test_cli_runner()
@pytest.fixture
def auth_headers(client):
"""获取认证头"""
# 登录获取token
response = client.post('/api/auth/login', json={
'email': 'test@example.com',
'password': 'password123'
})
token = response.get_json()['access_token']
return {'Authorization': f'Bearer {token}'}
def create_test_data():
"""创建测试数据"""
# 创建角色
admin_role = Role(name='admin', description='管理员')
user_role = Role(name='user', description='普通用户')
db.session.add(admin_role)
db.session.add(user_role)
db.session.commit()
# 创建用户
admin_user = User(
username='admin',
email='admin@example.com',
role_id=admin_role.id
)
admin_user.set_password('admin123')
test_user = User(
username='testuser',
email='test@example.com',
role_id=user_role.id
)
test_user.set_password('password123')
db.session.add(admin_user)
db.session.add(test_user)
db.session.commit()
@pytest.fixture
def sample_user():
"""创建示例用户数据"""
return {
'username': 'newuser',
'email': 'newuser@example.com',
'password': 'newpassword123'
}
@pytest.fixture(autouse=True)
def enable_db_access_for_all_tests(db_session):
"""为所有测试启用数据库访问"""
pass
@pytest.fixture(scope='function')
def db_session(app):
"""创建数据库会话"""
with app.app_context():
connection = db.engine.connect()
transaction = connection.begin()
# 配置会话使用事务
session = db.create_scoped_session(
options={'bind': connection, 'binds': {}}
)
# 替换全局会话
db.session = session
yield session
# 回滚事务
transaction.rollback()
connection.close()
session.remove()
10.1.3 测试工具和库
# requirements-test.txt
pytest>=6.2.0
pytest-flask>=1.2.0
pytest-cov>=2.12.0
pytest-mock>=3.6.0
factory-boy>=3.2.0
faker>=8.10.0
responses>=0.13.0
freezegun>=1.1.0
10.2 单元测试
10.2.1 模型测试
# tests/test_models.py
import pytest
from datetime import datetime, timedelta
from app.models import User, Role, Article, Comment
from app import db
class TestUserModel:
"""用户模型测试"""
def test_user_creation(self, app):
"""测试用户创建"""
with app.app_context():
user = User(
username='testuser',
email='test@example.com'
)
user.set_password('password123')
db.session.add(user)
db.session.commit()
assert user.id is not None
assert user.username == 'testuser'
assert user.email == 'test@example.com'
assert user.check_password('password123')
assert not user.check_password('wrongpassword')
def test_password_hashing(self, app):
"""测试密码哈希"""
with app.app_context():
user = User(username='test', email='test@example.com')
user.set_password('password123')
# 密码应该被哈希
assert user.password_hash != 'password123'
assert user.check_password('password123')
assert not user.check_password('wrongpassword')
def test_user_repr(self, app):
"""测试用户字符串表示"""
with app.app_context():
user = User(username='testuser', email='test@example.com')
assert repr(user) == '<User testuser>'
def test_user_role_relationship(self, app):
"""测试用户角色关系"""
with app.app_context():
role = Role(name='admin', description='管理员')
user = User(
username='admin',
email='admin@example.com',
role=role
)
db.session.add(role)
db.session.add(user)
db.session.commit()
assert user.role.name == 'admin'
assert role.users.count() == 1
assert role.users.first() == user
def test_user_validation(self, app):
"""测试用户数据验证"""
with app.app_context():
# 测试邮箱唯一性
user1 = User(username='user1', email='test@example.com')
user2 = User(username='user2', email='test@example.com')
db.session.add(user1)
db.session.commit()
db.session.add(user2)
with pytest.raises(Exception): # 应该抛出完整性错误
db.session.commit()
class TestArticleModel:
"""文章模型测试"""
def test_article_creation(self, app):
"""测试文章创建"""
with app.app_context():
user = User(username='author', email='author@example.com')
article = Article(
title='测试文章',
content='这是一篇测试文章的内容',
author=user
)
db.session.add(user)
db.session.add(article)
db.session.commit()
assert article.id is not None
assert article.title == '测试文章'
assert article.author.username == 'author'
assert article.created_at is not None
def test_article_slug_generation(self, app):
"""测试文章slug生成"""
with app.app_context():
user = User(username='author', email='author@example.com')
article = Article(
title='测试文章标题',
content='内容',
author=user
)
db.session.add(user)
db.session.add(article)
db.session.commit()
# 检查slug是否正确生成
assert article.slug is not None
assert len(article.slug) > 0
def test_article_comments_relationship(self, app):
"""测试文章评论关系"""
with app.app_context():
user = User(username='user', email='user@example.com')
article = Article(
title='文章',
content='内容',
author=user
)
comment = Comment(
content='评论内容',
author=user,
article=article
)
db.session.add_all([user, article, comment])
db.session.commit()
assert article.comments.count() == 1
assert article.comments.first() == comment
assert comment.article == article
class TestRoleModel:
"""角色模型测试"""
def test_role_creation(self, app):
"""测试角色创建"""
with app.app_context():
role = Role(name='admin', description='管理员角色')
db.session.add(role)
db.session.commit()
assert role.id is not None
assert role.name == 'admin'
assert role.description == '管理员角色'
def test_role_permissions(self, app):
"""测试角色权限"""
with app.app_context():
role = Role(name='admin', description='管理员')
role.add_permission('read')
role.add_permission('write')
role.add_permission('delete')
db.session.add(role)
db.session.commit()
assert role.has_permission('read')
assert role.has_permission('write')
assert role.has_permission('delete')
assert not role.has_permission('nonexistent')
10.2.2 视图函数测试
# tests/test_views.py
import pytest
import json
from flask import url_for
from app.models import User, Article
from app import db
class TestAuthViews:
"""认证视图测试"""
def test_register_success(self, client, sample_user):
"""测试用户注册成功"""
response = client.post('/auth/register', data=sample_user)
assert response.status_code == 302 # 重定向
# 检查用户是否创建
user = User.query.filter_by(email=sample_user['email']).first()
assert user is not None
assert user.username == sample_user['username']
def test_register_duplicate_email(self, client, sample_user):
"""测试重复邮箱注册"""
# 先注册一次
client.post('/auth/register', data=sample_user)
# 再次注册相同邮箱
response = client.post('/auth/register', data=sample_user)
assert response.status_code == 200 # 返回注册页面
assert b'Email already registered' in response.data
def test_login_success(self, client):
"""测试登录成功"""
response = client.post('/auth/login', data={
'email': 'test@example.com',
'password': 'password123'
})
assert response.status_code == 302 # 重定向到首页
def test_login_invalid_credentials(self, client):
"""测试无效凭据登录"""
response = client.post('/auth/login', data={
'email': 'test@example.com',
'password': 'wrongpassword'
})
assert response.status_code == 200 # 返回登录页面
assert b'Invalid email or password' in response.data
def test_logout(self, client):
"""测试登出"""
# 先登录
client.post('/auth/login', data={
'email': 'test@example.com',
'password': 'password123'
})
# 登出
response = client.get('/auth/logout')
assert response.status_code == 302 # 重定向
class TestArticleViews:
"""文章视图测试"""
def test_article_list(self, client):
"""测试文章列表"""
response = client.get('/articles')
assert response.status_code == 200
assert b'Articles' in response.data
def test_article_detail(self, client, app):
"""测试文章详情"""
with app.app_context():
# 创建测试文章
user = User.query.first()
article = Article(
title='测试文章',
content='测试内容',
author=user
)
db.session.add(article)
db.session.commit()
response = client.get(f'/articles/{article.id}')
assert response.status_code == 200
assert b'测试文章' in response.data
def test_article_create_authenticated(self, client):
"""测试已认证用户创建文章"""
# 先登录
client.post('/auth/login', data={
'email': 'test@example.com',
'password': 'password123'
})
response = client.post('/articles/create', data={
'title': '新文章',
'content': '新文章内容'
})
assert response.status_code == 302 # 重定向到文章详情
# 检查文章是否创建
article = Article.query.filter_by(title='新文章').first()
assert article is not None
def test_article_create_unauthenticated(self, client):
"""测试未认证用户创建文章"""
response = client.post('/articles/create', data={
'title': '新文章',
'content': '新文章内容'
})
assert response.status_code == 302 # 重定向到登录页面
def test_article_edit_owner(self, client, app):
"""测试文章作者编辑文章"""
with app.app_context():
# 登录
client.post('/auth/login', data={
'email': 'test@example.com',
'password': 'password123'
})
user = User.query.filter_by(email='test@example.com').first()
article = Article(
title='原标题',
content='原内容',
author=user
)
db.session.add(article)
db.session.commit()
response = client.post(f'/articles/{article.id}/edit', data={
'title': '新标题',
'content': '新内容'
})
assert response.status_code == 302
# 检查文章是否更新
updated_article = Article.query.get(article.id)
assert updated_article.title == '新标题'
assert updated_article.content == '新内容'
def test_article_delete_owner(self, client, app):
"""测试文章作者删除文章"""
with app.app_context():
# 登录
client.post('/auth/login', data={
'email': 'test@example.com',
'password': 'password123'
})
user = User.query.filter_by(email='test@example.com').first()
article = Article(
title='待删除文章',
content='内容',
author=user
)
db.session.add(article)
db.session.commit()
article_id = article.id
response = client.post(f'/articles/{article_id}/delete')
assert response.status_code == 302
# 检查文章是否删除
deleted_article = Article.query.get(article_id)
assert deleted_article is None
class TestAPIViews:
"""API视图测试"""
def test_api_article_list(self, client):
"""测试API文章列表"""
response = client.get('/api/articles')
assert response.status_code == 200
data = json.loads(response.data)
assert 'articles' in data
assert isinstance(data['articles'], list)
def test_api_article_create_with_auth(self, client, auth_headers):
"""测试带认证的API文章创建"""
article_data = {
'title': 'API文章',
'content': 'API文章内容'
}
response = client.post(
'/api/articles',
json=article_data,
headers=auth_headers
)
assert response.status_code == 201
data = json.loads(response.data)
assert data['title'] == 'API文章'
def test_api_article_create_without_auth(self, client):
"""测试无认证的API文章创建"""
article_data = {
'title': 'API文章',
'content': 'API文章内容'
}
response = client.post('/api/articles', json=article_data)
assert response.status_code == 401
data = json.loads(response.data)
assert 'error' in data
## 10.4 功能测试
### 10.4.1 端到端测试
```python
# tests/test_e2e.py
import pytest
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.chrome.options import Options
import time
class TestE2E:
"""端到端测试"""
@pytest.fixture(scope='class')
def driver(self):
"""创建WebDriver实例"""
options = Options()
options.add_argument('--headless') # 无头模式
options.add_argument('--no-sandbox')
options.add_argument('--disable-dev-shm-usage')
driver = webdriver.Chrome(options=options)
driver.implicitly_wait(10)
yield driver
driver.quit()
def test_user_registration_flow(self, driver, live_server):
"""测试用户注册流程"""
# 访问注册页面
driver.get(f'{live_server.url}/auth/register')
# 填写注册表单
username_input = driver.find_element(By.NAME, 'username')
email_input = driver.find_element(By.NAME, 'email')
password_input = driver.find_element(By.NAME, 'password')
confirm_password_input = driver.find_element(By.NAME, 'confirm_password')
username_input.send_keys('e2euser')
email_input.send_keys('e2euser@example.com')
password_input.send_keys('password123')
confirm_password_input.send_keys('password123')
# 提交表单
submit_button = driver.find_element(By.CSS_SELECTOR, 'input[type="submit"]')
submit_button.click()
# 等待重定向到首页
WebDriverWait(driver, 10).until(
EC.url_contains('/dashboard')
)
# 验证注册成功
assert '/dashboard' in driver.current_url
def test_user_login_flow(self, driver, live_server):
"""测试用户登录流程"""
# 访问登录页面
driver.get(f'{live_server.url}/auth/login')
# 填写登录表单
email_input = driver.find_element(By.NAME, 'email')
password_input = driver.find_element(By.NAME, 'password')
email_input.send_keys('test@example.com')
password_input.send_keys('password123')
# 提交表单
submit_button = driver.find_element(By.CSS_SELECTOR, 'input[type="submit"]')
submit_button.click()
# 等待重定向
WebDriverWait(driver, 10).until(
EC.url_contains('/dashboard')
)
# 验证登录成功
assert '/dashboard' in driver.current_url
# 检查用户菜单
user_menu = driver.find_element(By.CLASS_NAME, 'user-menu')
assert user_menu.is_displayed()
def test_article_creation_flow(self, driver, live_server):
"""测试文章创建流程"""
# 先登录
self.test_user_login_flow(driver, live_server)
# 访问文章创建页面
driver.get(f'{live_server.url}/articles/create')
# 填写文章表单
title_input = driver.find_element(By.NAME, 'title')
content_textarea = driver.find_element(By.NAME, 'content')
title_input.send_keys('E2E测试文章')
content_textarea.send_keys('这是一篇通过E2E测试创建的文章内容。')
# 提交表单
submit_button = driver.find_element(By.CSS_SELECTOR, 'input[type="submit"]')
submit_button.click()
# 等待重定向到文章详情页
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.CLASS_NAME, 'article-title'))
)
# 验证文章创建成功
article_title = driver.find_element(By.CLASS_NAME, 'article-title')
assert article_title.text == 'E2E测试文章'
article_content = driver.find_element(By.CLASS_NAME, 'article-content')
assert 'E2E测试创建的文章' in article_content.text
def test_comment_system_flow(self, driver, live_server):
"""测试评论系统流程"""
# 先创建文章
self.test_article_creation_flow(driver, live_server)
# 添加评论
comment_textarea = driver.find_element(By.NAME, 'comment_content')
comment_textarea.send_keys('这是一条E2E测试评论。')
comment_submit = driver.find_element(By.CSS_SELECTOR, '.comment-form input[type="submit"]')
comment_submit.click()
# 等待评论出现
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.CLASS_NAME, 'comment-item'))
)
# 验证评论添加成功
comment_items = driver.find_elements(By.CLASS_NAME, 'comment-item')
assert len(comment_items) > 0
latest_comment = comment_items[-1]
assert 'E2E测试评论' in latest_comment.text
def test_search_functionality(self, driver, live_server):
"""测试搜索功能"""
# 访问首页
driver.get(f'{live_server.url}/')
# 使用搜索功能
search_input = driver.find_element(By.NAME, 'search')
search_input.send_keys('测试')
search_button = driver.find_element(By.CSS_SELECTOR, '.search-form button')
search_button.click()
# 等待搜索结果
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.CLASS_NAME, 'search-results'))
)
# 验证搜索结果
search_results = driver.find_element(By.CLASS_NAME, 'search-results')
assert search_results.is_displayed()
result_items = driver.find_elements(By.CLASS_NAME, 'result-item')
assert len(result_items) > 0
@pytest.fixture(scope='session')
def live_server(app):
"""启动测试服务器"""
import threading
from werkzeug.serving import make_server
server = make_server('127.0.0.1', 5000, app)
thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
thread.start()
yield server
server.shutdown()
10.4.2 表单测试
# tests/test_forms.py
import pytest
from app.forms import (
LoginForm, RegisterForm, ArticleForm,
CommentForm, ProfileForm, PasswordChangeForm
)
from wtforms.validators import ValidationError
class TestLoginForm:
"""登录表单测试"""
def test_valid_login_form(self, app):
"""测试有效登录表单"""
with app.app_context():
form_data = {
'email': 'test@example.com',
'password': 'password123'
}
form = LoginForm(data=form_data)
assert form.validate()
def test_invalid_email_format(self, app):
"""测试无效邮箱格式"""
with app.app_context():
form_data = {
'email': 'invalid-email',
'password': 'password123'
}
form = LoginForm(data=form_data)
assert not form.validate()
assert 'Invalid email address' in form.email.errors
def test_missing_password(self, app):
"""测试缺少密码"""
with app.app_context():
form_data = {
'email': 'test@example.com',
'password': ''
}
form = LoginForm(data=form_data)
assert not form.validate()
assert 'This field is required' in form.password.errors
class TestRegisterForm:
"""注册表单测试"""
def test_valid_register_form(self, app):
"""测试有效注册表单"""
with app.app_context():
form_data = {
'username': 'newuser',
'email': 'newuser@example.com',
'password': 'password123',
'confirm_password': 'password123'
}
form = RegisterForm(data=form_data)
assert form.validate()
def test_password_mismatch(self, app):
"""测试密码不匹配"""
with app.app_context():
form_data = {
'username': 'newuser',
'email': 'newuser@example.com',
'password': 'password123',
'confirm_password': 'different_password'
}
form = RegisterForm(data=form_data)
assert not form.validate()
assert 'Passwords must match' in form.confirm_password.errors
def test_weak_password(self, app):
"""测试弱密码"""
with app.app_context():
form_data = {
'username': 'newuser',
'email': 'newuser@example.com',
'password': '123', # 太短
'confirm_password': '123'
}
form = RegisterForm(data=form_data)
assert not form.validate()
assert 'Field must be at least 8 characters long' in form.password.errors
def test_duplicate_username(self, app):
"""测试重复用户名"""
with app.app_context():
form_data = {
'username': 'testuser', # 已存在的用户名
'email': 'newuser@example.com',
'password': 'password123',
'confirm_password': 'password123'
}
form = RegisterForm(data=form_data)
assert not form.validate()
assert 'Username already exists' in form.username.errors
class TestArticleForm:
"""文章表单测试"""
def test_valid_article_form(self, app):
"""测试有效文章表单"""
with app.app_context():
form_data = {
'title': '测试文章标题',
'content': '这是文章的内容部分,应该足够长以满足验证要求。',
'tags': 'python, flask, 测试'
}
form = ArticleForm(data=form_data)
assert form.validate()
def test_empty_title(self, app):
"""测试空标题"""
with app.app_context():
form_data = {
'title': '',
'content': '文章内容',
'tags': 'tag1, tag2'
}
form = ArticleForm(data=form_data)
assert not form.validate()
assert 'This field is required' in form.title.errors
def test_content_too_short(self, app):
"""测试内容过短"""
with app.app_context():
form_data = {
'title': '标题',
'content': '短', # 内容太短
'tags': 'tag1'
}
form = ArticleForm(data=form_data)
assert not form.validate()
assert 'Content must be at least 10 characters long' in form.content.errors
class TestCommentForm:
"""评论表单测试"""
def test_valid_comment_form(self, app):
"""测试有效评论表单"""
with app.app_context():
form_data = {
'content': '这是一条有效的评论内容。'
}
form = CommentForm(data=form_data)
assert form.validate()
def test_empty_comment(self, app):
"""测试空评论"""
with app.app_context():
form_data = {
'content': ''
}
form = CommentForm(data=form_data)
assert not form.validate()
assert 'This field is required' in form.content.errors
def test_comment_too_long(self, app):
"""测试评论过长"""
with app.app_context():
form_data = {
'content': 'x' * 1001 # 超过1000字符限制
}
form = CommentForm(data=form_data)
assert not form.validate()
assert 'Field cannot be longer than 1000 characters' in form.content.errors
10.5 性能测试
10.5.1 负载测试
# tests/test_performance.py
import pytest
import time
import threading
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed
from app import create_app, db
from app.models import User, Article
class TestPerformance:
"""性能测试"""
def test_database_query_performance(self, app):
"""测试数据库查询性能"""
with app.app_context():
# 创建大量测试数据
users = []
for i in range(1000):
user = User(
username=f'perfuser{i}',
email=f'perfuser{i}@example.com'
)
users.append(user)
db.session.add_all(users)
db.session.commit()
# 测试查询性能
start_time = time.time()
# 执行复杂查询
result = User.query.filter(
User.username.like('perfuser%')
).order_by(User.created_at.desc()).limit(50).all()
end_time = time.time()
query_time = end_time - start_time
assert len(result) == 50
assert query_time < 0.5 # 查询应该在0.5秒内完成
def test_api_response_time(self, client):
"""测试API响应时间"""
endpoints = [
'/api/articles',
'/api/users',
'/api/auth/status'
]
for endpoint in endpoints:
start_time = time.time()
response = client.get(endpoint)
end_time = time.time()
response_time = end_time - start_time
assert response.status_code in [200, 401] # 可能需要认证
assert response_time < 1.0 # 响应时间应该在1秒内
def test_concurrent_requests(self, live_server):
"""测试并发请求"""
def make_request(url):
"""发送单个请求"""
try:
start_time = time.time()
response = requests.get(url, timeout=10)
end_time = time.time()
return {
'status_code': response.status_code,
'response_time': end_time - start_time,
'success': response.status_code == 200
}
except Exception as e:
return {
'status_code': 0,
'response_time': 0,
'success': False,
'error': str(e)
}
# 并发测试参数
num_threads = 10
num_requests_per_thread = 5
url = f'{live_server.url}/api/articles'
results = []
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# 提交所有请求
futures = []
for _ in range(num_threads * num_requests_per_thread):
future = executor.submit(make_request, url)
futures.append(future)
# 收集结果
for future in as_completed(futures):
result = future.result()
results.append(result)
# 分析结果
successful_requests = [r for r in results if r['success']]
failed_requests = [r for r in results if not r['success']]
success_rate = len(successful_requests) / len(results)
avg_response_time = sum(r['response_time'] for r in successful_requests) / len(successful_requests)
# 断言
assert success_rate >= 0.95 # 成功率应该至少95%
assert avg_response_time < 2.0 # 平均响应时间应该在2秒内
assert len(failed_requests) < 3 # 失败请求应该少于3个
def test_memory_usage(self, app):
"""测试内存使用"""
import psutil
import os
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
with app.app_context():
# 执行内存密集型操作
large_data = []
for i in range(10000):
user = User(
username=f'memuser{i}',
email=f'memuser{i}@example.com'
)
large_data.append(user)
# 批量插入
db.session.add_all(large_data)
db.session.commit()
# 查询大量数据
all_users = User.query.all()
final_memory = process.memory_info().rss / 1024 / 1024 # MB
memory_increase = final_memory - initial_memory
# 内存增长应该在合理范围内
assert memory_increase < 100 # 内存增长不应超过100MB
assert len(all_users) >= 10000
class TestLoadTesting:
"""负载测试"""
def test_stress_test(self, live_server):
"""压力测试"""
def worker(worker_id, num_requests, results):
"""工作线程"""
worker_results = []
for i in range(num_requests):
try:
start_time = time.time()
response = requests.get(f'{live_server.url}/api/articles')
end_time = time.time()
worker_results.append({
'worker_id': worker_id,
'request_id': i,
'status_code': response.status_code,
'response_time': end_time - start_time,
'timestamp': start_time
})
# 短暂休息
time.sleep(0.1)
except Exception as e:
worker_results.append({
'worker_id': worker_id,
'request_id': i,
'status_code': 0,
'response_time': 0,
'error': str(e),
'timestamp': time.time()
})
results.extend(worker_results)
# 压力测试参数
num_workers = 5
requests_per_worker = 20
results = []
# 启动工作线程
threads = []
for worker_id in range(num_workers):
thread = threading.Thread(
target=worker,
args=(worker_id, requests_per_worker, results)
)
threads.append(thread)
thread.start()
# 等待所有线程完成
for thread in threads:
thread.join()
# 分析结果
total_requests = len(results)
successful_requests = [r for r in results if r.get('status_code') == 200]
failed_requests = [r for r in results if r.get('status_code') != 200]
success_rate = len(successful_requests) / total_requests
avg_response_time = sum(r['response_time'] for r in successful_requests) / len(successful_requests)
max_response_time = max(r['response_time'] for r in successful_requests)
# 生成报告
print(f"\n压力测试报告:")
print(f"总请求数: {total_requests}")
print(f"成功请求数: {len(successful_requests)}")
print(f"失败请求数: {len(failed_requests)}")
print(f"成功率: {success_rate:.2%}")
print(f"平均响应时间: {avg_response_time:.3f}秒")
print(f"最大响应时间: {max_response_time:.3f}秒")
# 断言
assert success_rate >= 0.90 # 成功率应该至少90%
assert avg_response_time < 3.0 # 平均响应时间应该在3秒内
assert max_response_time < 10.0 # 最大响应时间应该在10秒内
class TestBenchmark:
"""基准测试"""
def test_database_operations_benchmark(self, app):
"""数据库操作基准测试"""
with app.app_context():
# 测试插入性能
start_time = time.time()
users = []
for i in range(1000):
user = User(
username=f'benchuser{i}',
email=f'benchuser{i}@example.com'
)
users.append(user)
db.session.add_all(users)
db.session.commit()
insert_time = time.time() - start_time
# 测试查询性能
start_time = time.time()
all_users = User.query.filter(
User.username.like('benchuser%')
).all()
query_time = time.time() - start_time
# 测试更新性能
start_time = time.time()
User.query.filter(
User.username.like('benchuser%')
).update({'email': 'updated@example.com'})
db.session.commit()
update_time = time.time() - start_time
# 测试删除性能
start_time = time.time()
User.query.filter(
User.username.like('benchuser%')
).delete()
db.session.commit()
delete_time = time.time() - start_time
# 生成基准报告
print(f"\n数据库操作基准测试报告:")
print(f"插入1000条记录: {insert_time:.3f}秒")
print(f"查询1000条记录: {query_time:.3f}秒")
print(f"更新1000条记录: {update_time:.3f}秒")
print(f"删除1000条记录: {delete_time:.3f}秒")
# 性能断言
assert insert_time < 5.0 # 插入应该在5秒内完成
assert query_time < 1.0 # 查询应该在1秒内完成
assert update_time < 3.0 # 更新应该在3秒内完成
assert delete_time < 2.0 # 删除应该在2秒内完成
assert len(all_users) == 1000
10.6 调试技巧
10.6.1 Flask调试模式
# app/__init__.py
from flask import Flask
import logging
from logging.handlers import RotatingFileHandler
import os
def create_app(config_name=None):
app = Flask(__name__)
# 配置调试模式
if config_name == 'development':
app.config['DEBUG'] = True
app.config['TESTING'] = False
elif config_name == 'testing':
app.config['DEBUG'] = False
app.config['TESTING'] = True
else:
app.config['DEBUG'] = False
app.config['TESTING'] = False
# 配置日志
configure_logging(app)
return app
def configure_logging(app):
"""配置应用日志"""
if not app.debug and not app.testing:
# 生产环境日志配置
if not os.path.exists('logs'):
os.mkdir('logs')
file_handler = RotatingFileHandler(
'logs/flask_app.log',
maxBytes=10240000, # 10MB
backupCount=10
)
file_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s: %(message)s [in %(pathname)s:%(lineno)d]'
))
file_handler.setLevel(logging.INFO)
app.logger.addHandler(file_handler)
app.logger.setLevel(logging.INFO)
app.logger.info('Flask application startup')
else:
# 开发环境日志配置
app.logger.setLevel(logging.DEBUG)
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
console_handler.setFormatter(formatter)
app.logger.addHandler(console_handler)
10.6.2 调试工具和技巧
# app/utils/debug.py
import functools
import time
import traceback
from flask import current_app, request, g
import logging
def debug_timer(func):
"""函数执行时间调试装饰器"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
try:
result = func(*args, **kwargs)
end_time = time.time()
execution_time = end_time - start_time
current_app.logger.debug(
f'Function {func.__name__} executed in {execution_time:.4f} seconds'
)
return result
except Exception as e:
end_time = time.time()
execution_time = end_time - start_time
current_app.logger.error(
f'Function {func.__name__} failed after {execution_time:.4f} seconds: {str(e)}'
)
raise
return wrapper
def debug_request_info():
"""记录请求信息"""
if current_app.debug:
current_app.logger.debug(f'Request URL: {request.url}')
current_app.logger.debug(f'Request Method: {request.method}')
current_app.logger.debug(f'Request Headers: {dict(request.headers)}')
if request.is_json:
current_app.logger.debug(f'Request JSON: {request.get_json()}')
elif request.form:
current_app.logger.debug(f'Request Form: {dict(request.form)}')
def debug_database_queries():
"""调试数据库查询"""
if current_app.debug:
from sqlalchemy import event
from sqlalchemy.engine import Engine
@event.listens_for(Engine, "before_cursor_execute")
def receive_before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
g.query_start_time = time.time()
current_app.logger.debug(f'SQL Query: {statement}')
current_app.logger.debug(f'Parameters: {parameters}')
@event.listens_for(Engine, "after_cursor_execute")
def receive_after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
total = time.time() - g.query_start_time
current_app.logger.debug(f'Query executed in {total:.4f} seconds')
class DebugMiddleware:
"""调试中间件"""
def __init__(self, app):
self.app = app
self.init_app(app)
def init_app(self, app):
"""初始化调试中间件"""
if app.debug:
app.before_request(self.before_request)
app.after_request(self.after_request)
app.teardown_request(self.teardown_request)
def before_request(self):
"""请求前处理"""
g.start_time = time.time()
debug_request_info()
def after_request(self, response):
"""请求后处理"""
total_time = time.time() - g.start_time
current_app.logger.debug(
f'Request completed in {total_time:.4f} seconds '
f'with status {response.status_code}'
)
# 添加调试头
response.headers['X-Response-Time'] = f'{total_time:.4f}s'
return response
def teardown_request(self, exception):
"""请求清理"""
if exception:
current_app.logger.error(
f'Request failed with exception: {str(exception)}'
)
current_app.logger.error(traceback.format_exc())
def setup_debug_toolbar(app):
"""设置Flask调试工具栏"""
if app.debug:
try:
from flask_debugtoolbar import DebugToolbarExtension
app.config['DEBUG_TB_INTERCEPT_REDIRECTS'] = False
app.config['DEBUG_TB_PROFILER_ENABLED'] = True
toolbar = DebugToolbarExtension(app)
current_app.logger.info('Debug toolbar enabled')
except ImportError:
current_app.logger.warning('Flask-DebugToolbar not installed')
# 使用示例
@debug_timer
def slow_function():
"""模拟慢函数"""
time.sleep(1)
return "完成"
10.6.3 错误处理和异常调试
# app/utils/error_handler.py
from flask import current_app, request, jsonify
import traceback
import sys
from datetime import datetime
class ErrorHandler:
"""错误处理器"""
def __init__(self, app=None):
if app:
self.init_app(app)
def init_app(self, app):
"""初始化错误处理"""
app.register_error_handler(404, self.handle_404)
app.register_error_handler(500, self.handle_500)
app.register_error_handler(Exception, self.handle_exception)
def handle_404(self, error):
"""处理404错误"""
self.log_error(error, 404)
if request.is_json:
return jsonify({
'error': 'Resource not found',
'status_code': 404
}), 404
return render_template('errors/404.html'), 404
def handle_500(self, error):
"""处理500错误"""
self.log_error(error, 500)
if request.is_json:
return jsonify({
'error': 'Internal server error',
'status_code': 500
}), 500
return render_template('errors/500.html'), 500
def handle_exception(self, error):
"""处理通用异常"""
self.log_error(error, 500)
if current_app.debug:
# 开发环境显示详细错误信息
if request.is_json:
return jsonify({
'error': str(error),
'traceback': traceback.format_exc(),
'status_code': 500
}), 500
# 重新抛出异常以显示调试信息
raise error
# 生产环境返回通用错误信息
if request.is_json:
return jsonify({
'error': 'An unexpected error occurred',
'status_code': 500
}), 500
return render_template('errors/500.html'), 500
def log_error(self, error, status_code):
"""记录错误日志"""
error_info = {
'timestamp': datetime.utcnow().isoformat(),
'url': request.url,
'method': request.method,
'ip': request.remote_addr,
'user_agent': request.headers.get('User-Agent'),
'error': str(error),
'status_code': status_code,
'traceback': traceback.format_exc()
}
current_app.logger.error(f'Error {status_code}: {error_info}')
# 可以发送到外部错误监控服务
self.send_to_monitoring_service(error_info)
def send_to_monitoring_service(self, error_info):
"""发送错误信息到监控服务"""
# 这里可以集成Sentry、Rollbar等错误监控服务
pass
def setup_error_monitoring(app):
"""设置错误监控"""
try:
import sentry_sdk
from sentry_sdk.integrations.flask import FlaskIntegration
from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration
sentry_sdk.init(
dsn=app.config.get('SENTRY_DSN'),
integrations=[
FlaskIntegration(),
SqlalchemyIntegration()
],
traces_sample_rate=0.1,
environment=app.config.get('ENVIRONMENT', 'development')
)
current_app.logger.info('Sentry error monitoring enabled')
except ImportError:
current_app.logger.warning('Sentry SDK not installed')
10.7 测试覆盖率
10.7.1 代码覆盖率配置
# .coveragerc
[run]
source = app
omit =
app/__init__.py
app/config.py
*/venv/*
*/virtualenv/*
*/tests/*
*/migrations/*
[report]
exclude_lines =
pragma: no cover
def __repr__
raise AssertionError
raise NotImplementedError
if __name__ == .__main__.:
class .*\(Protocol\):
@(abc\.)?abstractmethod
[html]
directory = htmlcov
[xml]
output = coverage.xml
# tests/test_coverage.py
import pytest
import coverage
import os
class TestCoverage:
"""测试覆盖率检查"""
def test_minimum_coverage(self):
"""检查最低覆盖率要求"""
cov = coverage.Coverage()
cov.load()
# 获取覆盖率报告
total_coverage = cov.report(show_missing=False)
# 要求最低80%覆盖率
assert total_coverage >= 80.0, f"代码覆盖率 {total_coverage:.1f}% 低于要求的80%"
def test_critical_modules_coverage(self):
"""检查关键模块覆盖率"""
cov = coverage.Coverage()
cov.load()
critical_modules = [
'app/models.py',
'app/views.py',
'app/api.py',
'app/auth.py'
]
for module in critical_modules:
if os.path.exists(module):
analysis = cov.analysis2(module)
executed_lines = len(analysis[1])
missing_lines = len(analysis[3])
total_lines = executed_lines + missing_lines
if total_lines > 0:
coverage_percent = (executed_lines / total_lines) * 100
assert coverage_percent >= 90.0, (
f"关键模块 {module} 覆盖率 {coverage_percent:.1f}% "
f"低于要求的90%"
)
10.7.2 覆盖率报告生成
# scripts/run_tests_with_coverage.sh
#!/bin/bash
# 清理之前的覆盖率数据
coverage erase
# 运行测试并收集覆盖率数据
coverage run -m pytest tests/ -v
# 生成控制台报告
echo "\n=== 覆盖率报告 ==="
coverage report -m
# 生成HTML报告
coverage html
echo "\nHTML报告已生成到 htmlcov/ 目录"
# 生成XML报告(用于CI/CD)
coverage xml
echo "XML报告已生成: coverage.xml"
# 检查覆盖率阈值
coverage report --fail-under=80
# scripts/coverage_analysis.py
import coverage
import json
import os
from datetime import datetime
def generate_coverage_report():
"""生成详细的覆盖率分析报告"""
cov = coverage.Coverage()
cov.load()
report_data = {
'timestamp': datetime.utcnow().isoformat(),
'total_coverage': 0,
'modules': [],
'summary': {
'total_statements': 0,
'covered_statements': 0,
'missing_statements': 0
}
}
# 获取所有Python文件
for root, dirs, files in os.walk('app'):
for file in files:
if file.endswith('.py') and not file.startswith('__'):
module_path = os.path.join(root, file)
try:
analysis = cov.analysis2(module_path)
statements = analysis[1]
missing = analysis[3]
total_statements = len(statements) + len(missing)
covered_statements = len(statements)
if total_statements > 0:
coverage_percent = (covered_statements / total_statements) * 100
module_info = {
'module': module_path,
'coverage_percent': round(coverage_percent, 2),
'total_statements': total_statements,
'covered_statements': covered_statements,
'missing_statements': len(missing),
'missing_lines': list(missing)
}
report_data['modules'].append(module_info)
# 更新总计
report_data['summary']['total_statements'] += total_statements
report_data['summary']['covered_statements'] += covered_statements
report_data['summary']['missing_statements'] += len(missing)
except Exception as e:
print(f"分析模块 {module_path} 时出错: {e}")
# 计算总覆盖率
if report_data['summary']['total_statements'] > 0:
report_data['total_coverage'] = round(
(report_data['summary']['covered_statements'] /
report_data['summary']['total_statements']) * 100, 2
)
# 按覆盖率排序
report_data['modules'].sort(key=lambda x: x['coverage_percent'])
# 保存报告
with open('coverage_report.json', 'w', encoding='utf-8') as f:
json.dump(report_data, f, indent=2, ensure_ascii=False)
# 打印摘要
print(f"\n=== 覆盖率分析报告 ===")
print(f"总覆盖率: {report_data['total_coverage']}%")
print(f"总语句数: {report_data['summary']['total_statements']}")
print(f"已覆盖语句: {report_data['summary']['covered_statements']}")
print(f"未覆盖语句: {report_data['summary']['missing_statements']}")
print(f"\n=== 覆盖率最低的模块 ===")
for module in report_data['modules'][:5]:
print(f"{module['module']}: {module['coverage_percent']}%")
return report_data
if __name__ == '__main__':
generate_coverage_report()
10.8 持续集成
10.8.1 GitHub Actions配置
# .github/workflows/test.yml
name: Tests
on:
push:
branches: [ main, develop ]
pull_request:
branches: [ main ]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9, '3.10', 3.11]
services:
postgres:
image: postgres:13
env:
POSTGRES_PASSWORD: postgres
POSTGRES_DB: test_db
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
redis:
image: redis:6
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 6379:6379
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip dependencies
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-test.txt
- name: Set up environment variables
run: |
echo "DATABASE_URL=postgresql://postgres:postgres@localhost:5432/test_db" >> $GITHUB_ENV
echo "REDIS_URL=redis://localhost:6379/0" >> $GITHUB_ENV
echo "FLASK_ENV=testing" >> $GITHUB_ENV
- name: Run linting
run: |
flake8 app tests
black --check app tests
isort --check-only app tests
- name: Run type checking
run: |
mypy app
- name: Run security checks
run: |
bandit -r app
safety check
- name: Run tests with coverage
run: |
coverage run -m pytest tests/ -v
coverage report
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
name: codecov-umbrella
fail_ci_if_error: true
- name: Run integration tests
run: |
pytest tests/integration/ -v
- name: Run performance tests
run: |
pytest tests/performance/ -v --benchmark-only
- name: Generate test report
if: always()
run: |
pytest tests/ --junitxml=test-results.xml
- name: Publish test results
uses: EnricoMi/publish-unit-test-result-action@v2
if: always()
with:
files: test-results.xml
deploy:
needs: test
runs-on: ubuntu-latest
if: github.ref == 'refs/heads/main' && github.event_name == 'push'
steps:
- uses: actions/checkout@v3
- name: Deploy to staging
run: |
echo "部署到测试环境"
# 这里添加部署脚本
10.8.2 测试配置管理
# tests/conftest.py (完整版)
import pytest
import tempfile
import os
from app import create_app, db
from app.models import User, Role, Article
from flask import current_app
import redis
from unittest.mock import Mock
@pytest.fixture(scope='session')
def app():
"""创建测试应用实例"""
# 创建临时数据库文件
db_fd, db_path = tempfile.mkstemp()
# 测试配置
test_config = {
'TESTING': True,
'SQLALCHEMY_DATABASE_URI': f'sqlite:///{db_path}',
'SQLALCHEMY_TRACK_MODIFICATIONS': False,
'WTF_CSRF_ENABLED': False,
'SECRET_KEY': 'test-secret-key',
'CACHE_TYPE': 'simple',
'REDIS_URL': 'redis://localhost:6379/15',
'MAIL_SUPPRESS_SEND': True,
'CELERY_ALWAYS_EAGER': True,
'CELERY_EAGER_PROPAGATES_EXCEPTIONS': True
}
app = create_app(test_config)
with app.app_context():
db.create_all()
create_test_data()
yield app
db.drop_all()
os.close(db_fd)
os.unlink(db_path)
@pytest.fixture
def client(app):
"""创建测试客户端"""
return app.test_client()
@pytest.fixture
def runner(app):
"""创建CLI测试运行器"""
return app.test_cli_runner()
@pytest.fixture
def mock_redis():
"""模拟Redis连接"""
mock_redis = Mock()
mock_redis.get.return_value = None
mock_redis.set.return_value = True
mock_redis.delete.return_value = True
mock_redis.exists.return_value = False
return mock_redis
@pytest.fixture
def mock_celery():
"""模拟Celery任务"""
mock_celery = Mock()
mock_celery.delay.return_value = Mock(id='test-task-id')
return mock_celery
@pytest.fixture(autouse=True)
def enable_db_access_for_all_tests(db_session):
"""为所有测试启用数据库访问"""
pass
@pytest.fixture(scope='function')
def db_session(app):
"""创建数据库会话"""
with app.app_context():
connection = db.engine.connect()
transaction = connection.begin()
# 配置会话使用事务
session = db.create_scoped_session(
options={'bind': connection, 'binds': {}}
)
# 替换全局会话
db.session = session
yield session
# 回滚事务
transaction.rollback()
connection.close()
session.remove()
def create_test_data():
"""创建测试数据"""
# 创建角色
admin_role = Role(name='admin', description='管理员')
user_role = Role(name='user', description='普通用户')
db.session.add(admin_role)
db.session.add(user_role)
db.session.commit()
# 创建用户
admin_user = User(
username='admin',
email='admin@example.com',
role_id=admin_role.id
)
admin_user.set_password('admin123')
test_user = User(
username='testuser',
email='test@example.com',
role_id=user_role.id
)
test_user.set_password('password123')
db.session.add(admin_user)
db.session.add(test_user)
db.session.commit()
# pytest配置
def pytest_configure(config):
"""pytest配置"""
config.addinivalue_line(
"markers", "slow: marks tests as slow (deselect with '-m "not slow"')"
)
config.addinivalue_line(
"markers", "integration: marks tests as integration tests"
)
config.addinivalue_line(
"markers", "e2e: marks tests as end-to-end tests"
)
def pytest_collection_modifyitems(config, items):
"""修改测试收集"""
# 为慢测试添加标记
slow_marker = pytest.mark.slow
for item in items:
if "slow" in item.nodeid:
item.add_marker(slow_marker)
本章小结
技术要点总结
测试基础
- 测试环境配置和fixture管理
- 测试数据创建和清理
- 测试工具和库的使用
单元测试
- 模型测试:数据验证、关系测试
- 视图测试:路由、权限、响应验证
- 工具函数测试:边界条件、异常处理
集成测试
- 数据库集成:事务、约束、性能
- API集成:完整流程、错误处理
- 组件交互测试
功能测试
- 端到端测试:用户场景模拟
- 表单验证测试
- 浏览器自动化测试
性能测试
- 负载测试:并发、压力测试
- 基准测试:数据库操作性能
- 内存和响应时间监控
调试技巧
- 调试模式配置
- 日志记录和分析
- 错误处理和监控
测试覆盖率
- 覆盖率配置和报告
- 关键模块覆盖率要求
- 覆盖率分析和改进
持续集成
- CI/CD流水线配置
- 自动化测试执行
- 代码质量检查
测试最佳实践
测试策略
- 测试金字塔:单元测试为主,集成测试为辅
- 测试驱动开发(TDD)
- 行为驱动开发(BDD)
测试设计
- 测试独立性和可重复性
- 边界条件和异常情况测试
- 正向和负向测试用例
测试维护
- 定期更新测试用例
- 重构测试代码
- 测试文档维护
性能考虑
- 测试执行速度优化
- 并行测试执行
- 测试数据管理
调试和监控
开发环境调试
- 断点调试
- 日志分析
- 性能分析
生产环境监控
- 错误监控和报警
- 性能指标监控
- 用户行为分析
问题诊断
- 日志聚合和分析
- 分布式追踪
- 根因分析
下一章预告
第11章将介绍部署与运维,包括: - 部署环境配置 - Docker容器化 - 云平台部署 - 监控和日志管理 - 性能优化 - 安全加固 - 备份和恢复 - 运维自动化
练习题
基础练习
- 为用户注册功能编写完整的测试用例
- 实现一个简单的性能测试脚本
- 配置测试覆盖率检查
进阶练习
- 设计并实现API的集成测试
- 编写端到端测试脚本
- 实现自定义的调试中间件
项目练习
- 为整个Flask应用建立完整的测试体系
- 配置CI/CD流水线
- 实现错误监控和报警系统
思考题
- 如何平衡测试覆盖率和开发效率?
- 在微服务架构中如何进行集成测试?
- 如何设计有效的性能测试策略?
10.2.3 工具函数测试
# tests/test_utils.py
import pytest
from datetime import datetime, timedelta
from app.utils import (
generate_token, verify_token,
send_email, format_datetime,
slugify, validate_email,
calculate_reading_time
)
from unittest.mock import patch, MagicMock
class TestTokenUtils:
"""令牌工具测试"""
def test_generate_and_verify_token(self, app):
"""测试令牌生成和验证"""
with app.app_context():
user_id = 123
token = generate_token(user_id)
assert token is not None
assert isinstance(token, str)
# 验证令牌
decoded_user_id = verify_token(token)
assert decoded_user_id == user_id
def test_verify_invalid_token(self, app):
"""测试验证无效令牌"""
with app.app_context():
invalid_token = 'invalid.token.here'
result = verify_token(invalid_token)
assert result is None
def test_verify_expired_token(self, app):
"""测试验证过期令牌"""
with app.app_context():
# 生成一个立即过期的令牌
token = generate_token(123, expires_in=-1)
result = verify_token(token)
assert result is None
class TestEmailUtils:
"""邮件工具测试"""
@patch('app.utils.mail.send')
def test_send_email_success(self, mock_send, app):
"""测试发送邮件成功"""
with app.app_context():
result = send_email(
to='test@example.com',
subject='测试邮件',
template='email/test.html',
user='测试用户'
)
assert result is True
mock_send.assert_called_once()
@patch('app.utils.mail.send')
def test_send_email_failure(self, mock_send, app):
"""测试发送邮件失败"""
mock_send.side_effect = Exception('SMTP错误')
with app.app_context():
result = send_email(
to='test@example.com',
subject='测试邮件',
template='email/test.html'
)
assert result is False
class TestStringUtils:
"""字符串工具测试"""
def test_slugify_english(self):
"""测试英文slug化"""
title = 'This is a Test Title'
slug = slugify(title)
assert slug == 'this-is-a-test-title'
def test_slugify_chinese(self):
"""测试中文slug化"""
title = '这是一个测试标题'
slug = slugify(title)
assert slug is not None
assert len(slug) > 0
def test_slugify_special_characters(self):
"""测试特殊字符slug化"""
title = 'Title with @#$% special chars!'
slug = slugify(title)
assert '@' not in slug
assert '#' not in slug
assert '$' not in slug
assert '%' not in slug
assert '!' not in slug
def test_validate_email_valid(self):
"""测试有效邮箱验证"""
valid_emails = [
'test@example.com',
'user.name@domain.co.uk',
'user+tag@example.org'
]
for email in valid_emails:
assert validate_email(email) is True
def test_validate_email_invalid(self):
"""测试无效邮箱验证"""
invalid_emails = [
'invalid-email',
'@example.com',
'user@',
'user..name@example.com'
]
for email in invalid_emails:
assert validate_email(email) is False
class TestDateTimeUtils:
"""日期时间工具测试"""
def test_format_datetime_default(self):
"""测试默认日期时间格式化"""
dt = datetime(2023, 12, 25, 15, 30, 45)
formatted = format_datetime(dt)
assert '2023' in formatted
assert '12' in formatted
assert '25' in formatted
def test_format_datetime_custom_format(self):
"""测试自定义日期时间格式化"""
dt = datetime(2023, 12, 25, 15, 30, 45)
formatted = format_datetime(dt, '%Y-%m-%d')
assert formatted == '2023-12-25'
def test_format_datetime_none(self):
"""测试None值日期时间格式化"""
formatted = format_datetime(None)
assert formatted == ''
class TestContentUtils:
"""内容工具测试"""
def test_calculate_reading_time_short(self):
"""测试短文本阅读时间计算"""
content = '这是一个简短的测试文本。'
reading_time = calculate_reading_time(content)
assert reading_time == 1 # 最少1分钟
def test_calculate_reading_time_long(self):
"""测试长文本阅读时间计算"""
# 生成约500个单词的文本
words = ['word'] * 500
content = ' '.join(words)
reading_time = calculate_reading_time(content)
assert reading_time > 1
assert reading_time <= 3 # 500词大约2-3分钟
def test_calculate_reading_time_empty(self):
"""测试空文本阅读时间计算"""
reading_time = calculate_reading_time('')
assert reading_time == 1
def test_calculate_reading_time_chinese(self):
"""测试中文文本阅读时间计算"""
content = '这是一段中文测试文本。' * 100 # 重复100次
reading_time = calculate_reading_time(content)
assert reading_time >= 1
10.3 集成测试
10.3.1 数据库集成测试
# tests/test_database_integration.py
import pytest
from app import db
from app.models import User, Article, Comment, Role
from sqlalchemy.exc import IntegrityError
class TestDatabaseIntegration:
"""数据库集成测试"""
def test_user_article_relationship(self, app):
"""测试用户文章关系"""
with app.app_context():
user = User(username='author', email='author@example.com')
user.set_password('password')
article1 = Article(title='文章1', content='内容1', author=user)
article2 = Article(title='文章2', content='内容2', author=user)
db.session.add_all([user, article1, article2])
db.session.commit()
# 测试关系
assert user.articles.count() == 2
assert article1.author == user
assert article2.author == user
# 测试查询
user_articles = Article.query.filter_by(author_id=user.id).all()
assert len(user_articles) == 2
def test_article_comment_relationship(self, app):
"""测试文章评论关系"""
with app.app_context():
user = User(username='user', email='user@example.com')
article = Article(title='文章', content='内容', author=user)
comment1 = Comment(content='评论1', author=user, article=article)
comment2 = Comment(content='评论2', author=user, article=article)
db.session.add_all([user, article, comment1, comment2])
db.session.commit()
# 测试关系
assert article.comments.count() == 2
assert comment1.article == article
assert comment2.article == article
# 测试级联删除
db.session.delete(article)
db.session.commit()
# 评论应该被级联删除
assert Comment.query.count() == 0
def test_role_permission_system(self, app):
"""测试角色权限系统"""
with app.app_context():
# 创建角色
admin_role = Role(name='admin', description='管理员')
user_role = Role(name='user', description='普通用户')
# 设置权限
admin_role.add_permission('read')
admin_role.add_permission('write')
admin_role.add_permission('delete')
user_role.add_permission('read')
# 创建用户
admin = User(username='admin', email='admin@example.com', role=admin_role)
user = User(username='user', email='user@example.com', role=user_role)
db.session.add_all([admin_role, user_role, admin, user])
db.session.commit()
# 测试权限
assert admin.role.has_permission('delete')
assert not user.role.has_permission('delete')
assert user.role.has_permission('read')
def test_database_constraints(self, app):
"""测试数据库约束"""
with app.app_context():
# 测试邮箱唯一性约束
user1 = User(username='user1', email='same@example.com')
user2 = User(username='user2', email='same@example.com')
db.session.add(user1)
db.session.commit()
db.session.add(user2)
with pytest.raises(IntegrityError):
db.session.commit()
db.session.rollback()
def test_database_transactions(self, app):
"""测试数据库事务"""
with app.app_context():
user = User(username='transuser', email='trans@example.com')
try:
db.session.add(user)
# 模拟错误
raise Exception('模拟错误')
db.session.commit()
except Exception:
db.session.rollback()
# 用户不应该被保存
saved_user = User.query.filter_by(username='transuser').first()
assert saved_user is None
def test_query_performance(self, app):
"""测试查询性能"""
with app.app_context():
# 创建大量测试数据
users = []
for i in range(100):
user = User(
username=f'user{i}',
email=f'user{i}@example.com'
)
users.append(user)
db.session.add_all(users)
db.session.commit()
# 测试批量查询
import time
start_time = time.time()
all_users = User.query.all()
end_time = time.time()
query_time = end_time - start_time
assert len(all_users) == 100
assert query_time < 1.0 # 查询应该在1秒内完成
10.3.2 API集成测试
# tests/test_api_integration.py
import pytest
import json
from app.models import User, Article
from app import db
class TestAPIIntegration:
"""API集成测试"""
def test_user_registration_and_login_flow(self, client):
"""测试用户注册和登录流程"""
# 1. 注册新用户
register_data = {
'username': 'newuser',
'email': 'newuser@example.com',
'password': 'password123'
}
response = client.post('/api/auth/register', json=register_data)
assert response.status_code == 201
data = json.loads(response.data)
assert 'user' in data
assert data['user']['username'] == 'newuser'
# 2. 登录
login_data = {
'email': 'newuser@example.com',
'password': 'password123'
}
response = client.post('/api/auth/login', json=login_data)
assert response.status_code == 200
data = json.loads(response.data)
assert 'access_token' in data
assert 'refresh_token' in data
return data['access_token']
def test_article_crud_flow(self, client):
"""测试文章CRUD流程"""
# 先获取认证令牌
token = self.test_user_registration_and_login_flow(client)
headers = {'Authorization': f'Bearer {token}'}
# 1. 创建文章
article_data = {
'title': '集成测试文章',
'content': '这是一篇集成测试文章的内容'
}
response = client.post('/api/articles', json=article_data, headers=headers)
assert response.status_code == 201
data = json.loads(response.data)
article_id = data['id']
assert data['title'] == '集成测试文章'
# 2. 获取文章
response = client.get(f'/api/articles/{article_id}')
assert response.status_code == 200
data = json.loads(response.data)
assert data['title'] == '集成测试文章'
# 3. 更新文章
update_data = {
'title': '更新后的文章标题',
'content': '更新后的文章内容'
}
response = client.put(f'/api/articles/{article_id}', json=update_data, headers=headers)
assert response.status_code == 200
data = json.loads(response.data)
assert data['title'] == '更新后的文章标题'
# 4. 删除文章
response = client.delete(f'/api/articles/{article_id}', headers=headers)
assert response.status_code == 204
# 5. 验证文章已删除
response = client.get(f'/api/articles/{article_id}')
assert response.status_code == 404
def test_comment_system_flow(self, client):
"""测试评论系统流程"""
# 获取认证令牌
token = self.test_user_registration_and_login_flow(client)
headers = {'Authorization': f'Bearer {token}'}
# 创建文章
article_data = {
'title': '评论测试文章',
'content': '文章内容'
}
response = client.post('/api/articles', json=article_data, headers=headers)
article_id = json.loads(response.data)['id']
# 1. 添加评论
comment_data = {
'content': '这是一条测试评论'
}
response = client.post(
f'/api/articles/{article_id}/comments',
json=comment_data,
headers=headers
)
assert response.status_code == 201
data = json.loads(response.data)
comment_id = data['id']
assert data['content'] == '这是一条测试评论'
# 2. 获取文章评论
response = client.get(f'/api/articles/{article_id}/comments')
assert response.status_code == 200
data = json.loads(response.data)
assert len(data['comments']) == 1
assert data['comments'][0]['content'] == '这是一条测试评论'
# 3. 删除评论
response = client.delete(f'/api/comments/{comment_id}', headers=headers)
assert response.status_code == 204
# 4. 验证评论已删除
response = client.get(f'/api/articles/{article_id}/comments')
data = json.loads(response.data)
assert len(data['comments']) == 0
def test_pagination_flow(self, client, app):
"""测试分页流程"""
with app.app_context():
# 创建测试用户
user = User(username='testuser', email='test@example.com')
user.set_password('password')
db.session.add(user)
# 创建多篇文章
articles = []
for i in range(25):
article = Article(
title=f'文章 {i+1}',
content=f'文章 {i+1} 的内容',
author=user
)
articles.append(article)
db.session.add_all(articles)
db.session.commit()
# 测试第一页
response = client.get('/api/articles?page=1&per_page=10')
assert response.status_code == 200
data = json.loads(response.data)
assert len(data['articles']) == 10
assert data['pagination']['page'] == 1
assert data['pagination']['pages'] == 3 # 25篇文章,每页10篇,共3页
assert data['pagination']['total'] == 25
# 测试第二页
response = client.get('/api/articles?page=2&per_page=10')
assert response.status_code == 200
data = json.loads(response.data)
assert len(data['articles']) == 10
assert data['pagination']['page'] == 2
# 测试最后一页
response = client.get('/api/articles?page=3&per_page=10')
assert response.status_code == 200
data = json.loads(response.data)
assert len(data['articles']) == 5 # 最后一页只有5篇文章
assert data['pagination']['page'] == 3
def test_error_handling_flow(self, client):
"""测试错误处理流程"""
# 1. 测试404错误
response = client.get('/api/articles/999999')
assert response.status_code == 404
data = json.loads(response.data)
assert 'error' in data
# 2. 测试401错误(未认证)
response = client.post('/api/articles', json={
'title': '测试文章',
'content': '内容'
})
assert response.status_code == 401
# 3. 测试400错误(无效数据)
token = self.test_user_registration_and_login_flow(client)
headers = {'Authorization': f'Bearer {token}'}
response = client.post('/api/articles', json={
'title': '', # 空标题
'content': '内容'
}, headers=headers)
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data