本章将通过一个完整的项目实战,综合运用前面学到的所有知识。我们将开发一个个人财务管理系统,包含数据管理、统计分析、图表展示等功能。

11.1 项目规划和设计

项目需求分析

# project_requirements.py
"""
个人财务管理系统需求分析

功能需求:
1. 收支记录管理(增删改查)
2. 分类管理(收入/支出分类)
3. 统计分析(月度/年度统计)
4. 数据可视化(图表展示)
5. 数据导入导出(CSV/JSON)
6. 预算管理和提醒
7. 报表生成

技术需求:
1. 面向对象设计
2. 数据持久化(SQLite)
3. 用户界面(命令行/Web)
4. 数据验证和异常处理
5. 单元测试
6. 日志记录
7. 配置管理
"""

class ProjectRequirements:
    """项目需求类"""
    
    def __init__(self):
        self.functional_requirements = {
            "核心功能": [
                "收支记录的增删改查",
                "收支分类管理",
                "数据统计和分析",
                "数据可视化展示"
            ],
            "扩展功能": [
                "预算管理",
                "定期提醒",
                "数据导入导出",
                "报表生成"
            ],
            "用户体验": [
                "简洁的用户界面",
                "快速的数据检索",
                "直观的数据展示",
                "友好的错误提示"
            ]
        }
        
        self.technical_requirements = {
            "架构设计": [
                "模块化设计",
                "面向对象编程",
                "设计模式应用",
                "可扩展架构"
            ],
            "数据管理": [
                "SQLite数据库",
                "数据模型设计",
                "数据验证",
                "事务处理"
            ],
            "质量保证": [
                "单元测试",
                "异常处理",
                "日志记录",
                "代码规范"
            ]
        }
    
    def display_requirements(self):
        """显示项目需求"""
        print("=== 个人财务管理系统需求 ===")
        
        print("\n功能需求:")
        for category, items in self.functional_requirements.items():
            print(f"\n{category}:")
            for item in items:
                print(f"  - {item}")
        
        print("\n技术需求:")
        for category, items in self.technical_requirements.items():
            print(f"\n{category}:")
            for item in items:
                print(f"  - {item}")

if __name__ == "__main__":
    requirements = ProjectRequirements()
    requirements.display_requirements()

项目结构设计

# project_structure.py
"""
项目结构设计

finance_manager/
├── src/
│   ├── __init__.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── transaction.py      # 交易模型
│   │   ├── category.py         # 分类模型
│   │   └── budget.py           # 预算模型
│   ├── database/
│   │   ├── __init__.py
│   │   ├── connection.py       # 数据库连接
│   │   ├── migrations.py       # 数据库迁移
│   │   └── repository.py       # 数据访问层
│   ├── services/
│   │   ├── __init__.py
│   │   ├── transaction_service.py  # 交易服务
│   │   ├── statistics_service.py   # 统计服务
│   │   └── export_service.py       # 导出服务
│   ├── ui/
│   │   ├── __init__.py
│   │   ├── cli.py              # 命令行界面
│   │   └── web.py              # Web界面
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── validators.py       # 数据验证
│   │   ├── formatters.py       # 格式化工具
│   │   └── logger.py           # 日志工具
│   └── config/
│       ├── __init__.py
│       └── settings.py         # 配置文件
├── tests/
│   ├── __init__.py
│   ├── test_models.py
│   ├── test_services.py
│   └── test_utils.py
├── data/
│   └── finance.db              # SQLite数据库
├── docs/
│   ├── README.md
│   ├── API.md
│   └── CHANGELOG.md
├── requirements.txt
├── setup.py
└── main.py                     # 程序入口
"""

from pathlib import Path

def create_project_structure():
    """创建项目目录结构"""
    
    # 定义项目结构
    project_structure = {
        "src": {
            "models": ["__init__.py", "transaction.py", "category.py", "budget.py"],
            "database": ["__init__.py", "connection.py", "migrations.py", "repository.py"],
            "services": ["__init__.py", "transaction_service.py", "statistics_service.py", "export_service.py"],
            "ui": ["__init__.py", "cli.py", "web.py"],
            "utils": ["__init__.py", "validators.py", "formatters.py", "logger.py"],
            "config": ["__init__.py", "settings.py"]
        },
        "tests": ["__init__.py", "test_models.py", "test_services.py", "test_utils.py"],
        "data": [],
        "docs": ["README.md", "API.md", "CHANGELOG.md"]
    }
    
    # 根目录文件
    root_files = ["requirements.txt", "setup.py", "main.py", "__init__.py"]
    
    # 创建项目根目录
    project_root = Path("finance_manager")
    project_root.mkdir(exist_ok=True)
    
    print(f"创建项目目录: {project_root}")
    
    # 创建目录结构
    for dir_name, content in project_structure.items():
        dir_path = project_root / dir_name
        dir_path.mkdir(exist_ok=True)
        print(f"  创建目录: {dir_path}")
        
        if isinstance(content, dict):
            # 嵌套目录
            for subdir_name, files in content.items():
                subdir_path = dir_path / subdir_name
                subdir_path.mkdir(exist_ok=True)
                print(f"    创建子目录: {subdir_path}")
                
                # 创建文件
                for file_name in files:
                    file_path = subdir_path / file_name
                    if not file_path.exists():
                        file_path.touch()
                        print(f"      创建文件: {file_path}")
        else:
            # 文件列表
            for file_name in content:
                file_path = dir_path / file_name
                if not file_path.exists():
                    file_path.touch()
                    print(f"    创建文件: {file_path}")
    
    # 创建根目录文件
    for file_name in root_files:
        file_path = project_root / file_name
        if not file_path.exists():
            file_path.touch()
            print(f"  创建文件: {file_path}")
    
    return project_root

if __name__ == "__main__":
    project_path = create_project_structure()
    print(f"\n项目结构创建完成: {project_path.absolute()}")

运行项目规划:

python project_requirements.py
python project_structure.py

11.2 核心模块实现

数据模型设计

# models/transaction.py
"""
交易模型
"""

from datetime import datetime
from decimal import Decimal
from enum import Enum
from dataclasses import dataclass
from typing import Optional

class TransactionType(Enum):
    """交易类型枚举"""
    INCOME = "income"      # 收入
    EXPENSE = "expense"    # 支出
    TRANSFER = "transfer"  # 转账

@dataclass
class Transaction:
    """交易模型"""
    
    id: Optional[int] = None
    amount: Decimal = Decimal('0.00')
    transaction_type: TransactionType = TransactionType.EXPENSE
    category_id: Optional[int] = None
    description: str = ""
    date: datetime = None
    created_at: datetime = None
    updated_at: datetime = None
    
    def __post_init__(self):
        """初始化后处理"""
        if self.date is None:
            self.date = datetime.now()
        
        if self.created_at is None:
            self.created_at = datetime.now()
        
        if self.updated_at is None:
            self.updated_at = datetime.now()
    
    def validate(self) -> bool:
        """验证交易数据"""
        errors = []
        
        # 验证金额
        if self.amount <= 0:
            errors.append("金额必须大于0")
        
        # 验证描述
        if not self.description.strip():
            errors.append("描述不能为空")
        
        # 验证日期
        if self.date > datetime.now():
            errors.append("日期不能是未来时间")
        
        if errors:
            raise ValueError("数据验证失败: " + ", ".join(errors))
        
        return True
    
    def to_dict(self) -> dict:
        """转换为字典"""
        return {
            'id': self.id,
            'amount': float(self.amount),
            'transaction_type': self.transaction_type.value,
            'category_id': self.category_id,
            'description': self.description,
            'date': self.date.isoformat() if self.date else None,
            'created_at': self.created_at.isoformat() if self.created_at else None,
            'updated_at': self.updated_at.isoformat() if self.updated_at else None
        }
    
    @classmethod
    def from_dict(cls, data: dict) -> 'Transaction':
        """从字典创建交易对象"""
        transaction = cls()
        
        transaction.id = data.get('id')
        transaction.amount = Decimal(str(data.get('amount', 0)))
        transaction.transaction_type = TransactionType(data.get('transaction_type', 'expense'))
        transaction.category_id = data.get('category_id')
        transaction.description = data.get('description', '')
        
        # 处理日期字段
        if data.get('date'):
            transaction.date = datetime.fromisoformat(data['date'])
        
        if data.get('created_at'):
            transaction.created_at = datetime.fromisoformat(data['created_at'])
        
        if data.get('updated_at'):
            transaction.updated_at = datetime.fromisoformat(data['updated_at'])
        
        return transaction
    
    def __str__(self) -> str:
        """字符串表示"""
        type_symbol = "+" if self.transaction_type == TransactionType.INCOME else "-"
        return f"{self.date.strftime('%Y-%m-%d')} {type_symbol}¥{self.amount} {self.description}"

# models/category.py
"""
分类模型
"""

from dataclasses import dataclass
from typing import Optional, List
from datetime import datetime

@dataclass
class Category:
    """分类模型"""
    
    id: Optional[int] = None
    name: str = ""
    transaction_type: TransactionType = TransactionType.EXPENSE
    parent_id: Optional[int] = None  # 父分类ID,支持分类层级
    color: str = "#007bff"  # 分类颜色
    icon: str = "💰"  # 分类图标
    description: str = ""
    is_active: bool = True
    created_at: datetime = None
    updated_at: datetime = None
    
    def __post_init__(self):
        """初始化后处理"""
        if self.created_at is None:
            self.created_at = datetime.now()
        
        if self.updated_at is None:
            self.updated_at = datetime.now()
    
    def validate(self) -> bool:
        """验证分类数据"""
        errors = []
        
        # 验证名称
        if not self.name.strip():
            errors.append("分类名称不能为空")
        
        if len(self.name) > 50:
            errors.append("分类名称不能超过50个字符")
        
        # 验证颜色格式
        if not self.color.startswith('#') or len(self.color) != 7:
            errors.append("颜色格式必须是#RRGGBB")
        
        if errors:
            raise ValueError("数据验证失败: " + ", ".join(errors))
        
        return True
    
    def to_dict(self) -> dict:
        """转换为字典"""
        return {
            'id': self.id,
            'name': self.name,
            'transaction_type': self.transaction_type.value,
            'parent_id': self.parent_id,
            'color': self.color,
            'icon': self.icon,
            'description': self.description,
            'is_active': self.is_active,
            'created_at': self.created_at.isoformat() if self.created_at else None,
            'updated_at': self.updated_at.isoformat() if self.updated_at else None
        }
    
    @classmethod
    def from_dict(cls, data: dict) -> 'Category':
        """从字典创建分类对象"""
        category = cls()
        
        category.id = data.get('id')
        category.name = data.get('name', '')
        category.transaction_type = TransactionType(data.get('transaction_type', 'expense'))
        category.parent_id = data.get('parent_id')
        category.color = data.get('color', '#007bff')
        category.icon = data.get('icon', '💰')
        category.description = data.get('description', '')
        category.is_active = data.get('is_active', True)
        
        if data.get('created_at'):
            category.created_at = datetime.fromisoformat(data['created_at'])
        
        if data.get('updated_at'):
            category.updated_at = datetime.fromisoformat(data['updated_at'])
        
        return category
    
    def __str__(self) -> str:
        """字符串表示"""
        return f"{self.icon} {self.name}"

# models/budget.py
"""
预算模型
"""

from dataclasses import dataclass
from decimal import Decimal
from datetime import datetime, date
from typing import Optional
from enum import Enum

class BudgetPeriod(Enum):
    """预算周期枚举"""
    MONTHLY = "monthly"    # 月度
    QUARTERLY = "quarterly" # 季度
    YEARLY = "yearly"      # 年度

@dataclass
class Budget:
    """预算模型"""
    
    id: Optional[int] = None
    category_id: int = None
    amount: Decimal = Decimal('0.00')
    period: BudgetPeriod = BudgetPeriod.MONTHLY
    start_date: date = None
    end_date: date = None
    alert_threshold: float = 0.8  # 预警阈值(80%)
    is_active: bool = True
    created_at: datetime = None
    updated_at: datetime = None
    
    def __post_init__(self):
        """初始化后处理"""
        if self.created_at is None:
            self.created_at = datetime.now()
        
        if self.updated_at is None:
            self.updated_at = datetime.now()
        
        # 如果没有设置日期,根据周期自动设置
        if self.start_date is None:
            self.start_date = date.today().replace(day=1)
        
        if self.end_date is None:
            if self.period == BudgetPeriod.MONTHLY:
                # 月末
                next_month = self.start_date.replace(day=28) + datetime.timedelta(days=4)
                self.end_date = next_month - datetime.timedelta(days=next_month.day)
    
    def validate(self) -> bool:
        """验证预算数据"""
        errors = []
        
        # 验证金额
        if self.amount <= 0:
            errors.append("预算金额必须大于0")
        
        # 验证日期
        if self.start_date and self.end_date and self.start_date >= self.end_date:
            errors.append("开始日期必须早于结束日期")
        
        # 验证预警阈值
        if not 0 < self.alert_threshold <= 1:
            errors.append("预警阈值必须在0-1之间")
        
        if errors:
            raise ValueError("数据验证失败: " + ", ".join(errors))
        
        return True
    
    def is_current_period(self) -> bool:
        """判断是否是当前周期"""
        today = date.today()
        return self.start_date <= today <= self.end_date
    
    def get_usage_percentage(self, spent_amount: Decimal) -> float:
        """获取预算使用百分比"""
        if self.amount == 0:
            return 0.0
        return float(spent_amount / self.amount)
    
    def is_over_threshold(self, spent_amount: Decimal) -> bool:
        """判断是否超过预警阈值"""
        return self.get_usage_percentage(spent_amount) >= self.alert_threshold
    
    def is_over_budget(self, spent_amount: Decimal) -> bool:
        """判断是否超预算"""
        return spent_amount > self.amount
    
    def to_dict(self) -> dict:
        """转换为字典"""
        return {
            'id': self.id,
            'category_id': self.category_id,
            'amount': float(self.amount),
            'period': self.period.value,
            'start_date': self.start_date.isoformat() if self.start_date else None,
            'end_date': self.end_date.isoformat() if self.end_date else None,
            'alert_threshold': self.alert_threshold,
            'is_active': self.is_active,
            'created_at': self.created_at.isoformat() if self.created_at else None,
            'updated_at': self.updated_at.isoformat() if self.updated_at else None
        }
    
    @classmethod
    def from_dict(cls, data: dict) -> 'Budget':
        """从字典创建预算对象"""
        budget = cls()
        
        budget.id = data.get('id')
        budget.category_id = data.get('category_id')
        budget.amount = Decimal(str(data.get('amount', 0)))
        budget.period = BudgetPeriod(data.get('period', 'monthly'))
        budget.alert_threshold = data.get('alert_threshold', 0.8)
        budget.is_active = data.get('is_active', True)
        
        if data.get('start_date'):
            budget.start_date = date.fromisoformat(data['start_date'])
        
        if data.get('end_date'):
            budget.end_date = date.fromisoformat(data['end_date'])
        
        if data.get('created_at'):
            budget.created_at = datetime.fromisoformat(data['created_at'])
        
        if data.get('updated_at'):
            budget.updated_at = datetime.fromisoformat(data['updated_at'])
        
        return budget
    
    def __str__(self) -> str:
        """字符串表示"""
        return f"预算: ¥{self.amount} ({self.period.value}) {self.start_date} - {self.end_date}"

# 测试模型
if __name__ == "__main__":
    # 测试交易模型
    print("=== 测试交易模型 ===")
    
    transaction = Transaction(
        amount=Decimal('100.50'),
        transaction_type=TransactionType.EXPENSE,
        description="午餐",
        category_id=1
    )
    
    print(f"交易对象: {transaction}")
    print(f"验证结果: {transaction.validate()}")
    print(f"字典格式: {transaction.to_dict()}")
    
    # 测试分类模型
    print("\n=== 测试分类模型 ===")
    
    category = Category(
        name="餐饮",
        transaction_type=TransactionType.EXPENSE,
        icon="🍽️",
        color="#ff6b6b"
    )
    
    print(f"分类对象: {category}")
    print(f"验证结果: {category.validate()}")
    
    # 测试预算模型
    print("\n=== 测试预算模型 ===")
    
    budget = Budget(
        category_id=1,
        amount=Decimal('1000.00'),
        period=BudgetPeriod.MONTHLY
    )
    
    print(f"预算对象: {budget}")
    print(f"验证结果: {budget.validate()}")
    print(f"是否当前周期: {budget.is_current_period()}")
    
    # 测试预算使用情况
    spent = Decimal('800.00')
    print(f"已花费: ¥{spent}")
    print(f"使用百分比: {budget.get_usage_percentage(spent):.1%}")
    print(f"是否超过预警: {budget.is_over_threshold(spent)}")
    print(f"是否超预算: {budget.is_over_budget(spent)}")

数据库层实现

# database/connection.py
"""
数据库连接管理
"""

import sqlite3
import threading
from pathlib import Path
from contextlib import contextmanager
from typing import Optional

class DatabaseConnection:
    """数据库连接管理器"""
    
    _instance = None
    _lock = threading.Lock()
    
    def __new__(cls, db_path: str = "data/finance.db"):
        """单例模式"""
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = super().__new__(cls)
                    cls._instance._initialized = False
        return cls._instance
    
    def __init__(self, db_path: str = "data/finance.db"):
        """初始化数据库连接"""
        if self._initialized:
            return
        
        self.db_path = Path(db_path)
        self.db_path.parent.mkdir(parents=True, exist_ok=True)
        self._local = threading.local()
        self._initialized = True
    
    def get_connection(self) -> sqlite3.Connection:
        """获取数据库连接(线程安全)"""
        if not hasattr(self._local, 'connection'):
            self._local.connection = sqlite3.connect(
                str(self.db_path),
                check_same_thread=False
            )
            self._local.connection.row_factory = sqlite3.Row
            # 启用外键约束
            self._local.connection.execute("PRAGMA foreign_keys = ON")
        
        return self._local.connection
    
    @contextmanager
    def get_cursor(self):
        """获取数据库游标(上下文管理器)"""
        conn = self.get_connection()
        cursor = conn.cursor()
        try:
            yield cursor
            conn.commit()
        except Exception:
            conn.rollback()
            raise
        finally:
            cursor.close()
    
    def close_connection(self):
        """关闭当前线程的数据库连接"""
        if hasattr(self._local, 'connection'):
            self._local.connection.close()
            delattr(self._local, 'connection')
    
    def execute_script(self, script: str):
        """执行SQL脚本"""
        with self.get_cursor() as cursor:
            cursor.executescript(script)

# database/migrations.py
"""
数据库迁移
"""

from .connection import DatabaseConnection

class DatabaseMigrations:
    """数据库迁移管理"""
    
    def __init__(self):
        self.db = DatabaseConnection()
    
    def create_tables(self):
        """创建数据库表"""
        
        # 创建分类表
        categories_sql = """
        CREATE TABLE IF NOT EXISTS categories (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            name VARCHAR(50) NOT NULL,
            transaction_type VARCHAR(20) NOT NULL CHECK (transaction_type IN ('income', 'expense', 'transfer')),
            parent_id INTEGER,
            color VARCHAR(7) DEFAULT '#007bff',
            icon VARCHAR(10) DEFAULT '💰',
            description TEXT,
            is_active BOOLEAN DEFAULT 1,
            created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
            updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
            FOREIGN KEY (parent_id) REFERENCES categories (id)
        );
        """
        
        # 创建交易表
        transactions_sql = """
        CREATE TABLE IF NOT EXISTS transactions (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            amount DECIMAL(10, 2) NOT NULL CHECK (amount > 0),
            transaction_type VARCHAR(20) NOT NULL CHECK (transaction_type IN ('income', 'expense', 'transfer')),
            category_id INTEGER,
            description TEXT NOT NULL,
            date DATE NOT NULL,
            created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
            updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
            FOREIGN KEY (category_id) REFERENCES categories (id)
        );
        """
        
        # 创建预算表
        budgets_sql = """
        CREATE TABLE IF NOT EXISTS budgets (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            category_id INTEGER NOT NULL,
            amount DECIMAL(10, 2) NOT NULL CHECK (amount > 0),
            period VARCHAR(20) NOT NULL CHECK (period IN ('monthly', 'quarterly', 'yearly')),
            start_date DATE NOT NULL,
            end_date DATE NOT NULL,
            alert_threshold REAL DEFAULT 0.8 CHECK (alert_threshold > 0 AND alert_threshold <= 1),
            is_active BOOLEAN DEFAULT 1,
            created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
            updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
            FOREIGN KEY (category_id) REFERENCES categories (id),
            CHECK (start_date < end_date)
        );
        """
        
        # 创建索引
        indexes_sql = """
        CREATE INDEX IF NOT EXISTS idx_transactions_date ON transactions (date);
        CREATE INDEX IF NOT EXISTS idx_transactions_category ON transactions (category_id);
        CREATE INDEX IF NOT EXISTS idx_transactions_type ON transactions (transaction_type);
        CREATE INDEX IF NOT EXISTS idx_budgets_period ON budgets (start_date, end_date);
        CREATE INDEX IF NOT EXISTS idx_categories_type ON categories (transaction_type);
        """
        
        # 执行SQL
        with self.db.get_cursor() as cursor:
            cursor.executescript(categories_sql)
            cursor.executescript(transactions_sql)
            cursor.executescript(budgets_sql)
            cursor.executescript(indexes_sql)
        
        print("数据库表创建完成")
    
    def insert_default_data(self):
        """插入默认数据"""
        
        # 默认分类数据
        default_categories = [
            # 支出分类
            ('餐饮', 'expense', None, '#ff6b6b', '🍽️', '日常餐饮支出'),
            ('交通', 'expense', None, '#4ecdc4', '🚗', '交通出行费用'),
            ('购物', 'expense', None, '#45b7d1', '🛍️', '日常购物消费'),
            ('娱乐', 'expense', None, '#96ceb4', '🎮', '娱乐休闲支出'),
            ('医疗', 'expense', None, '#ffeaa7', '🏥', '医疗健康费用'),
            ('教育', 'expense', None, '#dda0dd', '📚', '教育学习支出'),
            ('住房', 'expense', None, '#98d8c8', '🏠', '房租房贷等'),
            ('其他支出', 'expense', None, '#f7dc6f', '💸', '其他支出'),
            
            # 收入分类
            ('工资', 'income', None, '#52c41a', '💼', '工资收入'),
            ('奖金', 'income', None, '#1890ff', '🎁', '奖金收入'),
            ('投资', 'income', None, '#722ed1', '📈', '投资收益'),
            ('其他收入', 'income', None, '#fa8c16', '💰', '其他收入')
        ]
        
        # 检查是否已有数据
        with self.db.get_cursor() as cursor:
            cursor.execute("SELECT COUNT(*) FROM categories")
            count = cursor.fetchone()[0]
            
            if count == 0:
                # 插入默认分类
                cursor.executemany(
                    """
                    INSERT INTO categories (name, transaction_type, parent_id, color, icon, description)
                    VALUES (?, ?, ?, ?, ?, ?)
                    """,
                    default_categories
                )
                print(f"插入了 {len(default_categories)} 个默认分类")
            else:
                print("数据库已有数据,跳过默认数据插入")
    
    def migrate(self):
        """执行数据库迁移"""
        print("开始数据库迁移...")
        self.create_tables()
        self.insert_default_data()
        print("数据库迁移完成")

# 测试数据库
if __name__ == "__main__":
    # 测试数据库连接
    print("=== 测试数据库连接 ===")
    
    db = DatabaseConnection("test_finance.db")
    
    # 测试连接
    with db.get_cursor() as cursor:
        cursor.execute("SELECT sqlite_version()")
        version = cursor.fetchone()[0]
        print(f"SQLite版本: {version}")
    
    # 执行迁移
    migrations = DatabaseMigrations()
    migrations.migrate()
    
    # 测试查询
    with db.get_cursor() as cursor:
        cursor.execute("SELECT COUNT(*) FROM categories")
        count = cursor.fetchone()[0]
        print(f"分类数量: {count}")
        
        cursor.execute("SELECT name, icon FROM categories LIMIT 5")
        categories = cursor.fetchall()
        print("前5个分类:")
        for cat in categories:
            print(f"  {cat['icon']} {cat['name']}")
    
    # 清理测试数据库
    import os
    if os.path.exists("test_finance.db"):
        os.remove("test_finance.db")
        print("清理测试数据库")

运行核心模块:

python models/transaction.py
python database/connection.py

11.3 业务逻辑层

数据访问层

# database/repository.py
"""
数据访问层
"""

from typing import List, Optional, Dict, Any
from datetime import date, datetime
from decimal import Decimal

from .connection import DatabaseConnection
from ..models.transaction import Transaction, TransactionType
from ..models.category import Category
from ..models.budget import Budget, BudgetPeriod

class BaseRepository:
    """基础仓储类"""
    
    def __init__(self):
        self.db = DatabaseConnection()

class CategoryRepository(BaseRepository):
    """分类仓储"""
    
    def create(self, category: Category) -> int:
        """创建分类"""
        category.validate()
        
        with self.db.get_cursor() as cursor:
            cursor.execute(
                """
                INSERT INTO categories (name, transaction_type, parent_id, color, icon, description, is_active)
                VALUES (?, ?, ?, ?, ?, ?, ?)
                """,
                (
                    category.name,
                    category.transaction_type.value,
                    category.parent_id,
                    category.color,
                    category.icon,
                    category.description,
                    category.is_active
                )
            )
            return cursor.lastrowid
    
    def get_by_id(self, category_id: int) -> Optional[Category]:
        """根据ID获取分类"""
        with self.db.get_cursor() as cursor:
            cursor.execute(
                "SELECT * FROM categories WHERE id = ?",
                (category_id,)
            )
            row = cursor.fetchone()
            
            if row:
                return self._row_to_category(row)
            return None
    
    def get_all(self, transaction_type: Optional[TransactionType] = None, active_only: bool = True) -> List[Category]:
        """获取所有分类"""
        sql = "SELECT * FROM categories WHERE 1=1"
        params = []
        
        if transaction_type:
            sql += " AND transaction_type = ?"
            params.append(transaction_type.value)
        
        if active_only:
            sql += " AND is_active = 1"
        
        sql += " ORDER BY name"
        
        with self.db.get_cursor() as cursor:
            cursor.execute(sql, params)
            rows = cursor.fetchall()
            
            return [self._row_to_category(row) for row in rows]
    
    def update(self, category: Category) -> bool:
        """更新分类"""
        category.validate()
        category.updated_at = datetime.now()
        
        with self.db.get_cursor() as cursor:
            cursor.execute(
                """
                UPDATE categories 
                SET name = ?, transaction_type = ?, parent_id = ?, color = ?, 
                    icon = ?, description = ?, is_active = ?, updated_at = ?
                WHERE id = ?
                """,
                (
                    category.name,
                    category.transaction_type.value,
                    category.parent_id,
                    category.color,
                    category.icon,
                    category.description,
                    category.is_active,
                    category.updated_at,
                    category.id
                )
            )
            return cursor.rowcount > 0
    
    def delete(self, category_id: int) -> bool:
        """删除分类(软删除)"""
        with self.db.get_cursor() as cursor:
            cursor.execute(
                "UPDATE categories SET is_active = 0, updated_at = ? WHERE id = ?",
                (datetime.now(), category_id)
            )
            return cursor.rowcount > 0
    
    def _row_to_category(self, row) -> Category:
        """将数据库行转换为分类对象"""
        return Category(
            id=row['id'],
            name=row['name'],
            transaction_type=TransactionType(row['transaction_type']),
            parent_id=row['parent_id'],
            color=row['color'],
            icon=row['icon'],
            description=row['description'],
            is_active=bool(row['is_active']),
            created_at=datetime.fromisoformat(row['created_at']),
            updated_at=datetime.fromisoformat(row['updated_at'])
        )

class TransactionRepository(BaseRepository):
    """交易仓储"""
    
    def create(self, transaction: Transaction) -> int:
        """创建交易"""
        transaction.validate()
        
        with self.db.get_cursor() as cursor:
            cursor.execute(
                """
                INSERT INTO transactions (amount, transaction_type, category_id, description, date)
                VALUES (?, ?, ?, ?, ?)
                """,
                (
                    float(transaction.amount),
                    transaction.transaction_type.value,
                    transaction.category_id,
                    transaction.description,
                    transaction.date.date() if isinstance(transaction.date, datetime) else transaction.date
                )
            )
            return cursor.lastrowid
    
    def get_by_id(self, transaction_id: int) -> Optional[Transaction]:
        """根据ID获取交易"""
        with self.db.get_cursor() as cursor:
            cursor.execute(
                "SELECT * FROM transactions WHERE id = ?",
                (transaction_id,)
            )
            row = cursor.fetchone()
            
            if row:
                return self._row_to_transaction(row)
            return None
    
    def get_by_date_range(self, start_date: date, end_date: date, 
                         transaction_type: Optional[TransactionType] = None,
                         category_id: Optional[int] = None) -> List[Transaction]:
        """根据日期范围获取交易"""
        sql = "SELECT * FROM transactions WHERE date BETWEEN ? AND ?"
        params = [start_date, end_date]
        
        if transaction_type:
            sql += " AND transaction_type = ?"
            params.append(transaction_type.value)
        
        if category_id:
            sql += " AND category_id = ?"
            params.append(category_id)
        
        sql += " ORDER BY date DESC, created_at DESC"
        
        with self.db.get_cursor() as cursor:
            cursor.execute(sql, params)
            rows = cursor.fetchall()
            
            return [self._row_to_transaction(row) for row in rows]
    
    def get_recent(self, limit: int = 10) -> List[Transaction]:
        """获取最近的交易"""
        with self.db.get_cursor() as cursor:
            cursor.execute(
                "SELECT * FROM transactions ORDER BY date DESC, created_at DESC LIMIT ?",
                (limit,)
            )
            rows = cursor.fetchall()
            
            return [self._row_to_transaction(row) for row in rows]
    
    def update(self, transaction: Transaction) -> bool:
        """更新交易"""
        transaction.validate()
        transaction.updated_at = datetime.now()
        
        with self.db.get_cursor() as cursor:
            cursor.execute(
                """
                UPDATE transactions 
                SET amount = ?, transaction_type = ?, category_id = ?, 
                    description = ?, date = ?, updated_at = ?
                WHERE id = ?
                """,
                (
                    float(transaction.amount),
                    transaction.transaction_type.value,
                    transaction.category_id,
                    transaction.description,
                    transaction.date.date() if isinstance(transaction.date, datetime) else transaction.date,
                    transaction.updated_at,
                    transaction.id
                )
            )
            return cursor.rowcount > 0
    
    def delete(self, transaction_id: int) -> bool:
        """删除交易"""
        with self.db.get_cursor() as cursor:
            cursor.execute(
                "DELETE FROM transactions WHERE id = ?",
                (transaction_id,)
            )
            return cursor.rowcount > 0
    
    def get_statistics(self, start_date: date, end_date: date) -> Dict[str, Any]:
        """获取统计信息"""
        with self.db.get_cursor() as cursor:
            # 总收入和支出
            cursor.execute(
                """
                SELECT 
                    transaction_type,
                    SUM(amount) as total,
                    COUNT(*) as count
                FROM transactions 
                WHERE date BETWEEN ? AND ?
                GROUP BY transaction_type
                """,
                (start_date, end_date)
            )
            
            stats = {'income': 0, 'expense': 0, 'income_count': 0, 'expense_count': 0}
            
            for row in cursor.fetchall():
                if row['transaction_type'] == 'income':
                    stats['income'] = float(row['total'])
                    stats['income_count'] = row['count']
                elif row['transaction_type'] == 'expense':
                    stats['expense'] = float(row['total'])
                    stats['expense_count'] = row['count']
            
            stats['balance'] = stats['income'] - stats['expense']
            stats['total_count'] = stats['income_count'] + stats['expense_count']
            
            return stats
    
    def get_category_statistics(self, start_date: date, end_date: date, 
                              transaction_type: TransactionType) -> List[Dict[str, Any]]:
        """获取分类统计"""
        with self.db.get_cursor() as cursor:
            cursor.execute(
                """
                SELECT 
                    c.name as category_name,
                    c.icon as category_icon,
                    c.color as category_color,
                    SUM(t.amount) as total,
                    COUNT(t.id) as count
                FROM transactions t
                LEFT JOIN categories c ON t.category_id = c.id
                WHERE t.date BETWEEN ? AND ? AND t.transaction_type = ?
                GROUP BY t.category_id, c.name, c.icon, c.color
                ORDER BY total DESC
                """,
                (start_date, end_date, transaction_type.value)
            )
            
            return [
                {
                    'category_name': row['category_name'] or '未分类',
                    'category_icon': row['category_icon'] or '❓',
                    'category_color': row['category_color'] or '#666666',
                    'total': float(row['total']),
                    'count': row['count']
                }
                for row in cursor.fetchall()
            ]
    
    def _row_to_transaction(self, row) -> Transaction:
        """将数据库行转换为交易对象"""
        return Transaction(
            id=row['id'],
            amount=Decimal(str(row['amount'])),
            transaction_type=TransactionType(row['transaction_type']),
            category_id=row['category_id'],
            description=row['description'],
            date=datetime.fromisoformat(row['date']),
            created_at=datetime.fromisoformat(row['created_at']),
            updated_at=datetime.fromisoformat(row['updated_at'])
        )

class BudgetRepository(BaseRepository):
    """预算仓储"""
    
    def create(self, budget: Budget) -> int:
        """创建预算"""
        budget.validate()
        
        with self.db.get_cursor() as cursor:
            cursor.execute(
                """
                INSERT INTO budgets (category_id, amount, period, start_date, end_date, alert_threshold, is_active)
                VALUES (?, ?, ?, ?, ?, ?, ?)
                """,
                (
                    budget.category_id,
                    float(budget.amount),
                    budget.period.value,
                    budget.start_date,
                    budget.end_date,
                    budget.alert_threshold,
                    budget.is_active
                )
            )
            return cursor.lastrowid
    
    def get_by_id(self, budget_id: int) -> Optional[Budget]:
        """根据ID获取预算"""
        with self.db.get_cursor() as cursor:
            cursor.execute(
                "SELECT * FROM budgets WHERE id = ?",
                (budget_id,)
            )
            row = cursor.fetchone()
            
            if row:
                return self._row_to_budget(row)
            return None
    
    def get_active_budgets(self, current_date: date = None) -> List[Budget]:
        """获取当前有效的预算"""
        if current_date is None:
            current_date = date.today()
        
        with self.db.get_cursor() as cursor:
            cursor.execute(
                """
                SELECT * FROM budgets 
                WHERE is_active = 1 AND start_date <= ? AND end_date >= ?
                ORDER BY start_date DESC
                """,
                (current_date, current_date)
            )
            rows = cursor.fetchall()
            
            return [self._row_to_budget(row) for row in rows]
    
    def get_by_category(self, category_id: int) -> List[Budget]:
        """根据分类获取预算"""
        with self.db.get_cursor() as cursor:
            cursor.execute(
                "SELECT * FROM budgets WHERE category_id = ? ORDER BY start_date DESC",
                (category_id,)
            )
            rows = cursor.fetchall()
            
            return [self._row_to_budget(row) for row in rows]
    
    def update(self, budget: Budget) -> bool:
        """更新预算"""
        budget.validate()
        budget.updated_at = datetime.now()
        
        with self.db.get_cursor() as cursor:
            cursor.execute(
                """
                UPDATE budgets 
                SET category_id = ?, amount = ?, period = ?, start_date = ?, 
                    end_date = ?, alert_threshold = ?, is_active = ?, updated_at = ?
                WHERE id = ?
                """,
                (
                    budget.category_id,
                    float(budget.amount),
                    budget.period.value,
                    budget.start_date,
                    budget.end_date,
                    budget.alert_threshold,
                    budget.is_active,
                    budget.updated_at,
                    budget.id
                )
            )
            return cursor.rowcount > 0
    
    def delete(self, budget_id: int) -> bool:
        """删除预算"""
        with self.db.get_cursor() as cursor:
            cursor.execute(
                "DELETE FROM budgets WHERE id = ?",
                (budget_id,)
            )
            return cursor.rowcount > 0
    
    def _row_to_budget(self, row) -> Budget:
        """将数据库行转换为预算对象"""
        return Budget(
            id=row['id'],
            category_id=row['category_id'],
            amount=Decimal(str(row['amount'])),
            period=BudgetPeriod(row['period']),
            start_date=date.fromisoformat(row['start_date']),
            end_date=date.fromisoformat(row['end_date']),
            alert_threshold=row['alert_threshold'],
            is_active=bool(row['is_active']),
            created_at=datetime.fromisoformat(row['created_at']),
            updated_at=datetime.fromisoformat(row['updated_at'])
        )

# 测试仓储
if __name__ == "__main__":
    from ..database.migrations import DatabaseMigrations
    
    # 初始化数据库
    migrations = DatabaseMigrations()
    migrations.migrate()
    
    # 测试分类仓储
    print("=== 测试分类仓储 ===")
    
    category_repo = CategoryRepository()
    
    # 获取所有分类
    categories = category_repo.get_all()
    print(f"总分类数: {len(categories)}")
    
    for cat in categories[:5]:
        print(f"  {cat.icon} {cat.name} ({cat.transaction_type.value})")
    
    # 测试交易仓储
    print("\n=== 测试交易仓储 ===")
    
    transaction_repo = TransactionRepository()
    
    # 创建测试交易
    if categories:
        test_transaction = Transaction(
            amount=Decimal('100.00'),
            transaction_type=TransactionType.EXPENSE,
            category_id=categories[0].id,
            description="测试交易"
        )
        
        transaction_id = transaction_repo.create(test_transaction)
        print(f"创建交易ID: {transaction_id}")
        
        # 获取交易
        retrieved_transaction = transaction_repo.get_by_id(transaction_id)
        if retrieved_transaction:
            print(f"获取交易: {retrieved_transaction}")
        
        # 获取统计信息
        from datetime import date, timedelta
        today = date.today()
        start_date = today - timedelta(days=30)
        
        stats = transaction_repo.get_statistics(start_date, today)
        print(f"统计信息: {stats}")

运行业务逻辑层:

python database/repository.py

本章小结

本章我们通过一个完整的个人财务管理系统项目,学习了:

  1. 项目规划:需求分析、架构设计、目录结构
  2. 数据模型:使用dataclass和枚举设计业务模型
  3. 数据库设计:SQLite数据库、表结构、索引优化
  4. 数据访问层:仓储模式、CRUD操作、统计查询
  5. 业务逻辑:服务层设计、数据验证、异常处理

下一步计划

项目还需要完成: - 服务层实现(业务逻辑封装) - 用户界面(CLI和Web) - 数据可视化(图表展示) - 单元测试 - 部署和打包

练习题

基础练习

  1. 扩展功能

    • 添加账户管理功能
    • 实现标签系统
    • 添加附件上传功能
  2. 数据分析

    • 实现趋势分析
    • 添加同比环比功能
    • 创建财务健康评分

进阶练习

  1. 性能优化

    • 实现数据分页
    • 添加缓存机制
    • 优化查询性能
  2. 系统集成

    • 集成第三方支付API
    • 实现数据同步
    • 添加通知系统

提示:项目实战是学习编程的最佳方式。通过完整项目的开发,你能够综合运用所学知识,理解软件开发的完整流程。建议按照本章的结构继续完善这个财务管理系统。