8.1 章节概述

行为型设计模式关注对象之间的通信和职责分配,它们描述了对象和类之间怎样相互协作来完成单个对象无法独立完成的任务。本章将介绍行为型模式的核心思想,并深入学习策略模式和模板方法模式。

8.1.1 行为型模式分类

行为型设计模式可以分为以下几类:

  1. 对象行为模式

    • 策略模式(Strategy)
    • 观察者模式(Observer)
    • 命令模式(Command)
    • 状态模式(State)
    • 责任链模式(Chain of Responsibility)
    • 访问者模式(Visitor)
    • 中介者模式(Mediator)
    • 备忘录模式(Memento)
    • 迭代器模式(Iterator)
  2. 类行为模式

    • 模板方法模式(Template Method)
    • 解释器模式(Interpreter)

8.1.2 行为型模式的核心思想

  1. 封装变化:将变化的部分封装起来,使其独立于不变的部分
  2. 对象协作:定义对象间的交互方式和通信协议
  3. 职责分离:将复杂的行为分解为多个简单的职责
  4. 算法族:将一系列算法封装起来,使它们可以互相替换

8.1.3 学习目标

  1. 理解行为型模式的设计思想和应用场景
  2. 掌握策略模式的实现和最佳实践
  3. 学会模板方法模式的设计和应用
  4. 了解两种模式的优缺点和选择标准
  5. 能够在实际项目中灵活运用这些模式

8.2 策略模式(Strategy Pattern)

8.2.1 模式定义与动机

定义: 策略模式定义了一系列算法,把它们一个个封装起来,并且使它们可相互替换。策略模式让算法的变化独立于使用算法的客户。

动机: - 有多种方式实现同一个功能 - 需要在运行时动态选择算法 - 避免使用多重条件判断 - 算法的实现细节应该对客户端透明

8.2.2 模式结构

策略模式包含以下角色:

  1. Strategy(抽象策略):定义所有具体策略的公共接口
  2. ConcreteStrategy(具体策略):实现具体的算法或行为
  3. Context(环境类):维护一个对策略对象的引用,定义策略对象的接口

8.2.3 Python实现示例:支付系统

from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
from decimal import Decimal
from datetime import datetime
import hashlib
import json

# 抽象策略:支付策略
class PaymentStrategy(ABC):
    """支付策略抽象基类"""
    
    @abstractmethod
    def pay(self, amount: Decimal, order_info: Dict[str, Any]) -> Dict[str, Any]:
        """执行支付"""
        pass
    
    @abstractmethod
    def validate_payment_info(self, payment_info: Dict[str, Any]) -> bool:
        """验证支付信息"""
        pass
    
    @abstractmethod
    def get_payment_method_name(self) -> str:
        """获取支付方式名称"""
        pass

# 具体策略:信用卡支付
class CreditCardPayment(PaymentStrategy):
    """信用卡支付策略"""
    
    def __init__(self, card_number: str, cvv: str, expiry_date: str, holder_name: str):
        self.card_number = card_number
        self.cvv = cvv
        self.expiry_date = expiry_date
        self.holder_name = holder_name
    
    def pay(self, amount: Decimal, order_info: Dict[str, Any]) -> Dict[str, Any]:
        """执行信用卡支付"""
        # 模拟信用卡支付流程
        transaction_id = self._generate_transaction_id()
        
        # 验证卡片信息
        if not self._validate_card():
            return {
                "success": False,
                "message": "信用卡信息验证失败",
                "transaction_id": None
            }
        
        # 检查余额(模拟)
        if not self._check_balance(amount):
            return {
                "success": False,
                "message": "信用卡余额不足",
                "transaction_id": None
            }
        
        # 执行支付
        print(f"正在通过信用卡支付 ${amount}...")
        print(f"卡号: ****-****-****-{self.card_number[-4:]}")
        print(f"持卡人: {self.holder_name}")
        
        return {
            "success": True,
            "message": "信用卡支付成功",
            "transaction_id": transaction_id,
            "payment_method": "Credit Card",
            "amount": float(amount),
            "timestamp": datetime.now().isoformat()
        }
    
    def validate_payment_info(self, payment_info: Dict[str, Any]) -> bool:
        """验证信用卡信息"""
        required_fields = ['card_number', 'cvv', 'expiry_date', 'holder_name']
        return all(field in payment_info for field in required_fields)
    
    def get_payment_method_name(self) -> str:
        return "Credit Card"
    
    def _validate_card(self) -> bool:
        """验证卡片有效性"""
        # 简单的卡号验证(Luhn算法的简化版本)
        return len(self.card_number.replace('-', '')) == 16
    
    def _check_balance(self, amount: Decimal) -> bool:
        """检查余额"""
        # 模拟余额检查
        return amount <= Decimal('10000')  # 假设信用额度为10000
    
    def _generate_transaction_id(self) -> str:
        """生成交易ID"""
        data = f"{self.card_number}{datetime.now().isoformat()}"
        return hashlib.md5(data.encode()).hexdigest()[:12].upper()

# 具体策略:PayPal支付
class PayPalPayment(PaymentStrategy):
    """PayPal支付策略"""
    
    def __init__(self, email: str, password: str):
        self.email = email
        self.password = password
    
    def pay(self, amount: Decimal, order_info: Dict[str, Any]) -> Dict[str, Any]:
        """执行PayPal支付"""
        transaction_id = self._generate_transaction_id()
        
        # 验证账户
        if not self._authenticate():
            return {
                "success": False,
                "message": "PayPal账户验证失败",
                "transaction_id": None
            }
        
        # 检查余额
        if not self._check_balance(amount):
            return {
                "success": False,
                "message": "PayPal账户余额不足",
                "transaction_id": None
            }
        
        print(f"正在通过PayPal支付 ${amount}...")
        print(f"账户: {self.email}")
        
        return {
            "success": True,
            "message": "PayPal支付成功",
            "transaction_id": transaction_id,
            "payment_method": "PayPal",
            "amount": float(amount),
            "timestamp": datetime.now().isoformat()
        }
    
    def validate_payment_info(self, payment_info: Dict[str, Any]) -> bool:
        """验证PayPal信息"""
        required_fields = ['email', 'password']
        return all(field in payment_info for field in required_fields)
    
    def get_payment_method_name(self) -> str:
        return "PayPal"
    
    def _authenticate(self) -> bool:
        """账户认证"""
        # 模拟认证过程
        return '@' in self.email and len(self.password) >= 6
    
    def _check_balance(self, amount: Decimal) -> bool:
        """检查余额"""
        # 模拟余额检查
        return amount <= Decimal('5000')  # 假设PayPal余额为5000
    
    def _generate_transaction_id(self) -> str:
        """生成交易ID"""
        data = f"{self.email}{datetime.now().isoformat()}"
        return f"PP{hashlib.md5(data.encode()).hexdigest()[:10].upper()}"

# 具体策略:银行转账
class BankTransferPayment(PaymentStrategy):
    """银行转账支付策略"""
    
    def __init__(self, account_number: str, routing_number: str, account_holder: str):
        self.account_number = account_number
        self.routing_number = routing_number
        self.account_holder = account_holder
    
    def pay(self, amount: Decimal, order_info: Dict[str, Any]) -> Dict[str, Any]:
        """执行银行转账支付"""
        transaction_id = self._generate_transaction_id()
        
        # 验证账户信息
        if not self._validate_account():
            return {
                "success": False,
                "message": "银行账户信息验证失败",
                "transaction_id": None
            }
        
        # 检查余额
        if not self._check_balance(amount):
            return {
                "success": False,
                "message": "银行账户余额不足",
                "transaction_id": None
            }
        
        print(f"正在通过银行转账支付 ${amount}...")
        print(f"账户: ****{self.account_number[-4:]}")
        print(f"户名: {self.account_holder}")
        print("注意: 银行转账可能需要1-3个工作日到账")
        
        return {
            "success": True,
            "message": "银行转账支付已提交",
            "transaction_id": transaction_id,
            "payment_method": "Bank Transfer",
            "amount": float(amount),
            "timestamp": datetime.now().isoformat(),
            "estimated_completion": "1-3 business days"
        }
    
    def validate_payment_info(self, payment_info: Dict[str, Any]) -> bool:
        """验证银行转账信息"""
        required_fields = ['account_number', 'routing_number', 'account_holder']
        return all(field in payment_info for field in required_fields)
    
    def get_payment_method_name(self) -> str:
        return "Bank Transfer"
    
    def _validate_account(self) -> bool:
        """验证账户"""
        return len(self.account_number) >= 8 and len(self.routing_number) == 9
    
    def _check_balance(self, amount: Decimal) -> bool:
        """检查余额"""
        # 模拟余额检查
        return amount <= Decimal('50000')  # 假设银行账户余额为50000
    
    def _generate_transaction_id(self) -> str:
        """生成交易ID"""
        data = f"{self.account_number}{datetime.now().isoformat()}"
        return f"BT{hashlib.md5(data.encode()).hexdigest()[:10].upper()}"

# 环境类:支付处理器
class PaymentProcessor:
    """支付处理器(环境类)"""
    
    def __init__(self):
        self._strategy: Optional[PaymentStrategy] = None
        self._transaction_history: List[Dict[str, Any]] = []
    
    def set_payment_strategy(self, strategy: PaymentStrategy) -> None:
        """设置支付策略"""
        self._strategy = strategy
        print(f"支付方式已设置为: {strategy.get_payment_method_name()}")
    
    def process_payment(self, amount: Decimal, order_info: Dict[str, Any]) -> Dict[str, Any]:
        """处理支付"""
        if not self._strategy:
            return {
                "success": False,
                "message": "未设置支付策略",
                "transaction_id": None
            }
        
        # 验证金额
        if amount <= 0:
            return {
                "success": False,
                "message": "支付金额必须大于0",
                "transaction_id": None
            }
        
        # 执行支付
        result = self._strategy.pay(amount, order_info)
        
        # 记录交易历史
        self._transaction_history.append({
            "timestamp": datetime.now().isoformat(),
            "payment_method": self._strategy.get_payment_method_name(),
            "amount": float(amount),
            "order_info": order_info,
            "result": result
        })
        
        return result
    
    def get_transaction_history(self) -> List[Dict[str, Any]]:
        """获取交易历史"""
        return self._transaction_history.copy()
    
    def get_current_payment_method(self) -> Optional[str]:
        """获取当前支付方式"""
        return self._strategy.get_payment_method_name() if self._strategy else None

# 支付方式工厂
class PaymentStrategyFactory:
    """支付策略工厂"""
    
    @staticmethod
    def create_credit_card_payment(card_info: Dict[str, str]) -> CreditCardPayment:
        """创建信用卡支付策略"""
        return CreditCardPayment(
            card_number=card_info['card_number'],
            cvv=card_info['cvv'],
            expiry_date=card_info['expiry_date'],
            holder_name=card_info['holder_name']
        )
    
    @staticmethod
    def create_paypal_payment(paypal_info: Dict[str, str]) -> PayPalPayment:
        """创建PayPal支付策略"""
        return PayPalPayment(
            email=paypal_info['email'],
            password=paypal_info['password']
        )
    
    @staticmethod
    def create_bank_transfer_payment(bank_info: Dict[str, str]) -> BankTransferPayment:
        """创建银行转账支付策略"""
        return BankTransferPayment(
            account_number=bank_info['account_number'],
            routing_number=bank_info['routing_number'],
            account_holder=bank_info['account_holder']
        )

# 使用示例
def demonstrate_strategy_pattern():
    """演示策略模式"""
    print("=== 策略模式演示:支付系统 ===")
    
    # 创建支付处理器
    processor = PaymentProcessor()
    
    # 订单信息
    order_info = {
        "order_id": "ORD-2024-001",
        "product": "笔记本电脑",
        "quantity": 1,
        "customer_id": "CUST-001"
    }
    
    # 测试信用卡支付
    print("\n--- 信用卡支付 ---")
    credit_card_info = {
        "card_number": "1234-5678-9012-3456",
        "cvv": "123",
        "expiry_date": "12/25",
        "holder_name": "张三"
    }
    
    credit_card_strategy = PaymentStrategyFactory.create_credit_card_payment(credit_card_info)
    processor.set_payment_strategy(credit_card_strategy)
    result1 = processor.process_payment(Decimal('1299.99'), order_info)
    print(f"支付结果: {result1['message']}")
    
    # 测试PayPal支付
    print("\n--- PayPal支付 ---")
    paypal_info = {
        "email": "user@example.com",
        "password": "password123"
    }
    
    paypal_strategy = PaymentStrategyFactory.create_paypal_payment(paypal_info)
    processor.set_payment_strategy(paypal_strategy)
    result2 = processor.process_payment(Decimal('899.50'), order_info)
    print(f"支付结果: {result2['message']}")
    
    # 测试银行转账支付
    print("\n--- 银行转账支付 ---")
    bank_info = {
        "account_number": "12345678901",
        "routing_number": "123456789",
        "account_holder": "李四"
    }
    
    bank_strategy = PaymentStrategyFactory.create_bank_transfer_payment(bank_info)
    processor.set_payment_strategy(bank_strategy)
    result3 = processor.process_payment(Decimal('2599.00'), order_info)
    print(f"支付结果: {result3['message']}")
    
    # 显示交易历史
    print("\n=== 交易历史 ===")
    history = processor.get_transaction_history()
    for i, transaction in enumerate(history, 1):
        print(f"交易 {i}:")
        print(f"  时间: {transaction['timestamp']}")
        print(f"  支付方式: {transaction['payment_method']}")
        print(f"  金额: ${transaction['amount']}")
        print(f"  状态: {'成功' if transaction['result']['success'] else '失败'}")
        print()

if __name__ == "__main__":
    demonstrate_strategy_pattern()

8.2.4 Java实现示例:排序算法

import java.util.*;
import java.util.function.Consumer;

// 抽象策略:排序策略
interface SortStrategy<T extends Comparable<T>> {
    void sort(List<T> data);
    String getAlgorithmName();
    String getTimeComplexity();
    String getSpaceComplexity();
}

// 具体策略:冒泡排序
class BubbleSortStrategy<T extends Comparable<T>> implements SortStrategy<T> {
    
    @Override
    public void sort(List<T> data) {
        int n = data.size();
        boolean swapped;
        
        for (int i = 0; i < n - 1; i++) {
            swapped = false;
            for (int j = 0; j < n - i - 1; j++) {
                if (data.get(j).compareTo(data.get(j + 1)) > 0) {
                    Collections.swap(data, j, j + 1);
                    swapped = true;
                }
            }
            // 如果没有交换,说明已经排序完成
            if (!swapped) break;
        }
    }
    
    @Override
    public String getAlgorithmName() {
        return "Bubble Sort";
    }
    
    @Override
    public String getTimeComplexity() {
        return "O(n²)";
    }
    
    @Override
    public String getSpaceComplexity() {
        return "O(1)";
    }
}

// 具体策略:快速排序
class QuickSortStrategy<T extends Comparable<T>> implements SortStrategy<T> {
    
    @Override
    public void sort(List<T> data) {
        if (data.size() <= 1) return;
        quickSort(data, 0, data.size() - 1);
    }
    
    private void quickSort(List<T> data, int low, int high) {
        if (low < high) {
            int pivotIndex = partition(data, low, high);
            quickSort(data, low, pivotIndex - 1);
            quickSort(data, pivotIndex + 1, high);
        }
    }
    
    private int partition(List<T> data, int low, int high) {
        T pivot = data.get(high);
        int i = low - 1;
        
        for (int j = low; j < high; j++) {
            if (data.get(j).compareTo(pivot) <= 0) {
                i++;
                Collections.swap(data, i, j);
            }
        }
        
        Collections.swap(data, i + 1, high);
        return i + 1;
    }
    
    @Override
    public String getAlgorithmName() {
        return "Quick Sort";
    }
    
    @Override
    public String getTimeComplexity() {
        return "O(n log n) average, O(n²) worst";
    }
    
    @Override
    public String getSpaceComplexity() {
        return "O(log n)";
    }
}

// 具体策略:归并排序
class MergeSortStrategy<T extends Comparable<T>> implements SortStrategy<T> {
    
    @Override
    public void sort(List<T> data) {
        if (data.size() <= 1) return;
        mergeSort(data, 0, data.size() - 1);
    }
    
    private void mergeSort(List<T> data, int left, int right) {
        if (left < right) {
            int mid = left + (right - left) / 2;
            mergeSort(data, left, mid);
            mergeSort(data, mid + 1, right);
            merge(data, left, mid, right);
        }
    }
    
    private void merge(List<T> data, int left, int mid, int right) {
        List<T> leftArray = new ArrayList<>(data.subList(left, mid + 1));
        List<T> rightArray = new ArrayList<>(data.subList(mid + 1, right + 1));
        
        int i = 0, j = 0, k = left;
        
        while (i < leftArray.size() && j < rightArray.size()) {
            if (leftArray.get(i).compareTo(rightArray.get(j)) <= 0) {
                data.set(k++, leftArray.get(i++));
            } else {
                data.set(k++, rightArray.get(j++));
            }
        }
        
        while (i < leftArray.size()) {
            data.set(k++, leftArray.get(i++));
        }
        
        while (j < rightArray.size()) {
            data.set(k++, rightArray.get(j++));
        }
    }
    
    @Override
    public String getAlgorithmName() {
        return "Merge Sort";
    }
    
    @Override
    public String getTimeComplexity() {
        return "O(n log n)";
    }
    
    @Override
    public String getSpaceComplexity() {
        return "O(n)";
    }
}

// 具体策略:堆排序
class HeapSortStrategy<T extends Comparable<T>> implements SortStrategy<T> {
    
    @Override
    public void sort(List<T> data) {
        int n = data.size();
        
        // 构建最大堆
        for (int i = n / 2 - 1; i >= 0; i--) {
            heapify(data, n, i);
        }
        
        // 逐个提取元素
        for (int i = n - 1; i > 0; i--) {
            Collections.swap(data, 0, i);
            heapify(data, i, 0);
        }
    }
    
    private void heapify(List<T> data, int n, int i) {
        int largest = i;
        int left = 2 * i + 1;
        int right = 2 * i + 2;
        
        if (left < n && data.get(left).compareTo(data.get(largest)) > 0) {
            largest = left;
        }
        
        if (right < n && data.get(right).compareTo(data.get(largest)) > 0) {
            largest = right;
        }
        
        if (largest != i) {
            Collections.swap(data, i, largest);
            heapify(data, n, largest);
        }
    }
    
    @Override
    public String getAlgorithmName() {
        return "Heap Sort";
    }
    
    @Override
    public String getTimeComplexity() {
        return "O(n log n)";
    }
    
    @Override
    public String getSpaceComplexity() {
        return "O(1)";
    }
}

// 环境类:排序器
class Sorter<T extends Comparable<T>> {
    private SortStrategy<T> strategy;
    private List<SortResult> sortHistory;
    
    public Sorter() {
        this.sortHistory = new ArrayList<>();
    }
    
    public void setStrategy(SortStrategy<T> strategy) {
        this.strategy = strategy;
        System.out.println("排序算法已设置为: " + strategy.getAlgorithmName());
    }
    
    public SortResult sort(List<T> data) {
        if (strategy == null) {
            throw new IllegalStateException("未设置排序策略");
        }
        
        List<T> originalData = new ArrayList<>(data);
        long startTime = System.nanoTime();
        
        strategy.sort(data);
        
        long endTime = System.nanoTime();
        long duration = endTime - startTime;
        
        SortResult result = new SortResult(
            strategy.getAlgorithmName(),
            originalData.size(),
            duration,
            strategy.getTimeComplexity(),
            strategy.getSpaceComplexity()
        );
        
        sortHistory.add(result);
        return result;
    }
    
    public List<SortResult> getSortHistory() {
        return new ArrayList<>(sortHistory);
    }
    
    public void clearHistory() {
        sortHistory.clear();
    }
    
    // 性能测试
    public void performanceTest(int dataSize, int iterations) {
        System.out.println("\n=== 性能测试 ===");
        System.out.printf("数据大小: %d, 测试次数: %d\n", dataSize, iterations);
        
        List<SortStrategy<Integer>> strategies = Arrays.asList(
            new BubbleSortStrategy<>(),
            new QuickSortStrategy<>(),
            new MergeSortStrategy<>(),
            new HeapSortStrategy<>()
        );
        
        for (SortStrategy<Integer> strategy : strategies) {
            long totalTime = 0;
            
            for (int i = 0; i < iterations; i++) {
                List<Integer> testData = generateRandomData(dataSize);
                setStrategy((SortStrategy<T>) strategy);
                
                long startTime = System.nanoTime();
                strategy.sort((List<Integer>) testData);
                long endTime = System.nanoTime();
                
                totalTime += (endTime - startTime);
            }
            
            double averageTime = totalTime / (double) iterations / 1_000_000; // 转换为毫秒
            System.out.printf("%s: 平均耗时 %.2f ms\n", 
                strategy.getAlgorithmName(), averageTime);
        }
    }
    
    private List<Integer> generateRandomData(int size) {
        List<Integer> data = new ArrayList<>();
        Random random = new Random();
        for (int i = 0; i < size; i++) {
            data.add(random.nextInt(1000));
        }
        return data;
    }
}

// 排序结果类
class SortResult {
    private final String algorithmName;
    private final int dataSize;
    private final long duration; // 纳秒
    private final String timeComplexity;
    private final String spaceComplexity;
    private final long timestamp;
    
    public SortResult(String algorithmName, int dataSize, long duration, 
                     String timeComplexity, String spaceComplexity) {
        this.algorithmName = algorithmName;
        this.dataSize = dataSize;
        this.duration = duration;
        this.timeComplexity = timeComplexity;
        this.spaceComplexity = spaceComplexity;
        this.timestamp = System.currentTimeMillis();
    }
    
    public double getDurationInMillis() {
        return duration / 1_000_000.0;
    }
    
    @Override
    public String toString() {
        return String.format("%s: %d elements, %.2f ms, Time: %s, Space: %s",
            algorithmName, dataSize, getDurationInMillis(), timeComplexity, spaceComplexity);
    }
    
    // Getters
    public String getAlgorithmName() { return algorithmName; }
    public int getDataSize() { return dataSize; }
    public long getDuration() { return duration; }
    public String getTimeComplexity() { return timeComplexity; }
    public String getSpaceComplexity() { return spaceComplexity; }
    public long getTimestamp() { return timestamp; }
}

// 使用示例
public class StrategyPatternDemo {
    public static void main(String[] args) {
        System.out.println("=== 策略模式演示:排序算法 ===");
        
        // 创建排序器
        Sorter<Integer> sorter = new Sorter<>();
        
        // 测试数据
        List<Integer> testData = Arrays.asList(64, 34, 25, 12, 22, 11, 90, 88, 76, 50, 42);
        System.out.println("原始数据: " + testData);
        
        // 测试不同的排序算法
        List<SortStrategy<Integer>> strategies = Arrays.asList(
            new BubbleSortStrategy<>(),
            new QuickSortStrategy<>(),
            new MergeSortStrategy<>(),
            new HeapSortStrategy<>()
        );
        
        for (SortStrategy<Integer> strategy : strategies) {
            List<Integer> data = new ArrayList<>(testData);
            sorter.setStrategy(strategy);
            
            System.out.println("\n--- " + strategy.getAlgorithmName() + " ---");
            System.out.println("时间复杂度: " + strategy.getTimeComplexity());
            System.out.println("空间复杂度: " + strategy.getSpaceComplexity());
            
            SortResult result = sorter.sort(data);
            System.out.println("排序后: " + data);
            System.out.printf("耗时: %.2f ms\n", result.getDurationInMillis());
        }
        
        // 显示排序历史
        System.out.println("\n=== 排序历史 ===");
        List<SortResult> history = sorter.getSortHistory();
        for (int i = 0; i < history.size(); i++) {
            System.out.printf("%d. %s\n", i + 1, history.get(i));
        }
        
        // 性能测试(小数据集,避免冒泡排序耗时过长)
        sorter.performanceTest(100, 10);
    }
}

8.2.5 策略模式的优缺点

优点: 1. 算法可以自由切换:客户端可以在运行时选择算法 2. 避免使用多重条件判断:消除大量的if-else或switch语句 3. 扩展性良好:增加新算法不需要修改现有代码 4. 符合开闭原则:对扩展开放,对修改封闭

缺点: 1. 客户端必须知道所有策略:客户端需要了解各种策略的区别 2. 策略类数量增多:每个算法都需要一个策略类 3. 通信开销:策略和环境类之间可能需要传递大量数据

8.2.6 适用场景

  1. 多种算法实现:一个系统需要动态地在几种算法中选择一种
  2. 避免条件语句:需要避免多重条件选择语句
  3. 算法独立变化:算法的实现细节应该对客户端透明
  4. 运行时选择:需要在运行时根据不同情况选择不同的算法

8.3 模板方法模式(Template Method Pattern)

8.3.1 模式定义与动机

定义: 模板方法模式定义一个操作中算法的骨架,而将一些步骤延迟到子类中。模板方法使得子类可以不改变一个算法的结构即可重定义该算法的某些特定步骤。

动机: - 多个类有相似的处理流程,但具体实现不同 - 需要控制子类的扩展点 - 提取公共行为到父类,避免代码重复 - 实现”好莱坞原则”:Don’t call us, we’ll call you

8.3.2 模式结构

模板方法模式包含以下角色:

  1. AbstractClass(抽象类):定义模板方法和抽象方法
  2. ConcreteClass(具体类):实现抽象方法,完成算法中特定的步骤

8.3.3 Python实现示例:数据处理流水线

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
import json
import csv
import xml.etree.ElementTree as ET
from datetime import datetime
import logging

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 抽象类:数据处理模板
class DataProcessor(ABC):
    """数据处理模板类"""
    
    def __init__(self, source_path: str, output_path: str):
        self.source_path = source_path
        self.output_path = output_path
        self.raw_data: Any = None
        self.processed_data: List[Dict[str, Any]] = []
        self.statistics: Dict[str, Any] = {}
    
    # 模板方法
    def process_data(self) -> Dict[str, Any]:
        """数据处理的模板方法"""
        start_time = datetime.now()
        
        try:
            # 1. 验证输入
            self.validate_input()
            
            # 2. 读取数据
            self.raw_data = self.read_data()
            
            # 3. 验证数据
            self.validate_data()
            
            # 4. 预处理数据
            self.preprocess_data()
            
            # 5. 处理数据
            self.processed_data = self.transform_data()
            
            # 6. 后处理
            self.postprocess_data()
            
            # 7. 验证结果
            self.validate_result()
            
            # 8. 保存结果
            self.save_result()
            
            # 9. 生成统计信息
            self.generate_statistics()
            
            # 10. 清理资源
            self.cleanup()
            
            end_time = datetime.now()
            processing_time = (end_time - start_time).total_seconds()
            
            return {
                "success": True,
                "message": "数据处理完成",
                "processing_time": processing_time,
                "statistics": self.statistics
            }
            
        except Exception as e:
            logging.error(f"数据处理失败: {str(e)}")
            self.cleanup()
            return {
                "success": False,
                "message": f"数据处理失败: {str(e)}",
                "processing_time": 0,
                "statistics": {}
            }
    
    # 具体方法(可以被子类重写)
    def validate_input(self) -> None:
        """验证输入参数"""
        if not self.source_path:
            raise ValueError("源文件路径不能为空")
        if not self.output_path:
            raise ValueError("输出文件路径不能为空")
        logging.info("输入验证通过")
    
    def validate_data(self) -> None:
        """验证读取的数据"""
        if self.raw_data is None:
            raise ValueError("读取的数据为空")
        logging.info("数据验证通过")
    
    def preprocess_data(self) -> None:
        """预处理数据(钩子方法)"""
        logging.info("执行数据预处理")
        # 默认实现为空,子类可以重写
        pass
    
    def postprocess_data(self) -> None:
        """后处理数据(钩子方法)"""
        logging.info("执行数据后处理")
        # 默认实现为空,子类可以重写
        pass
    
    def validate_result(self) -> None:
        """验证处理结果"""
        if not self.processed_data:
            raise ValueError("处理后的数据为空")
        logging.info("结果验证通过")
    
    def generate_statistics(self) -> None:
        """生成统计信息"""
        self.statistics = {
            "total_records": len(self.processed_data),
            "processing_timestamp": datetime.now().isoformat(),
            "source_file": self.source_path,
            "output_file": self.output_path
        }
        logging.info(f"生成统计信息: {self.statistics}")
    
    def cleanup(self) -> None:
        """清理资源(钩子方法)"""
        logging.info("清理资源")
        # 默认实现为空,子类可以重写
        pass
    
    # 抽象方法(必须由子类实现)
    @abstractmethod
    def read_data(self) -> Any:
        """读取数据"""
        pass
    
    @abstractmethod
    def transform_data(self) -> List[Dict[str, Any]]:
        """转换数据"""
        pass
    
    @abstractmethod
    def save_result(self) -> None:
        """保存结果"""
        pass

# 具体类:JSON数据处理器
class JSONDataProcessor(DataProcessor):
    """JSON数据处理器"""
    
    def read_data(self) -> Dict[str, Any]:
        """读取JSON数据"""
        logging.info(f"读取JSON文件: {self.source_path}")
        try:
            with open(self.source_path, 'r', encoding='utf-8') as file:
                data = json.load(file)
            logging.info(f"成功读取JSON数据,包含 {len(data) if isinstance(data, list) else 1} 条记录")
            return data
        except FileNotFoundError:
            raise FileNotFoundError(f"文件不存在: {self.source_path}")
        except json.JSONDecodeError as e:
            raise ValueError(f"JSON格式错误: {str(e)}")
    
    def transform_data(self) -> List[Dict[str, Any]]:
        """转换JSON数据"""
        logging.info("转换JSON数据")
        
        if isinstance(self.raw_data, list):
            # 如果是数组,直接返回
            return self.raw_data
        elif isinstance(self.raw_data, dict):
            # 如果是对象,包装成数组
            return [self.raw_data]
        else:
            raise ValueError("不支持的JSON数据格式")
    
    def save_result(self) -> None:
        """保存JSON结果"""
        logging.info(f"保存JSON结果到: {self.output_path}")
        with open(self.output_path, 'w', encoding='utf-8') as file:
            json.dump(self.processed_data, file, ensure_ascii=False, indent=2)
        logging.info("JSON结果保存完成")
    
    def preprocess_data(self) -> None:
        """JSON数据预处理"""
        super().preprocess_data()
        # 可以添加JSON特定的预处理逻辑
        logging.info("执行JSON特定的预处理")

# 具体类:CSV数据处理器
class CSVDataProcessor(DataProcessor):
    """CSV数据处理器"""
    
    def __init__(self, source_path: str, output_path: str, delimiter: str = ','):
        super().__init__(source_path, output_path)
        self.delimiter = delimiter
        self.headers: List[str] = []
    
    def read_data(self) -> List[List[str]]:
        """读取CSV数据"""
        logging.info(f"读取CSV文件: {self.source_path}")
        try:
            with open(self.source_path, 'r', encoding='utf-8') as file:
                reader = csv.reader(file, delimiter=self.delimiter)
                data = list(reader)
            
            if not data:
                raise ValueError("CSV文件为空")
            
            self.headers = data[0]  # 第一行作为标题
            logging.info(f"成功读取CSV数据,包含 {len(data)-1} 条记录,{len(self.headers)} 个字段")
            return data[1:]  # 返回数据行(不包括标题)
            
        except FileNotFoundError:
            raise FileNotFoundError(f"文件不存在: {self.source_path}")
        except Exception as e:
            raise ValueError(f"CSV读取错误: {str(e)}")
    
    def transform_data(self) -> List[Dict[str, Any]]:
        """转换CSV数据"""
        logging.info("转换CSV数据")
        
        result = []
        for row in self.raw_data:
            if len(row) != len(self.headers):
                logging.warning(f"跳过不完整的行: {row}")
                continue
            
            record = {}
            for i, value in enumerate(row):
                header = self.headers[i]
                # 尝试转换数据类型
                record[header] = self._convert_value(value)
            
            result.append(record)
        
        return result
    
    def _convert_value(self, value: str) -> Any:
        """转换值的数据类型"""
        # 尝试转换为数字
        try:
            if '.' in value:
                return float(value)
            else:
                return int(value)
        except ValueError:
            # 如果不是数字,返回字符串
            return value.strip()
    
    def save_result(self) -> None:
        """保存CSV结果"""
        logging.info(f"保存CSV结果到: {self.output_path}")
        
        if not self.processed_data:
            return
        
        # 获取所有字段名
        all_fields = set()
        for record in self.processed_data:
            all_fields.update(record.keys())
        
        fieldnames = sorted(all_fields)
        
        with open(self.output_path, 'w', newline='', encoding='utf-8') as file:
            writer = csv.DictWriter(file, fieldnames=fieldnames, delimiter=self.delimiter)
            writer.writeheader()
            writer.writerows(self.processed_data)
        
        logging.info("CSV结果保存完成")
    
    def generate_statistics(self) -> None:
        """生成CSV特定的统计信息"""
        super().generate_statistics()
        self.statistics.update({
            "headers": self.headers,
            "delimiter": self.delimiter,
            "field_count": len(self.headers)
        })

# 具体类:XML数据处理器
class XMLDataProcessor(DataProcessor):
    """XML数据处理器"""
    
    def __init__(self, source_path: str, output_path: str, root_element: str = 'data'):
        super().__init__(source_path, output_path)
        self.root_element = root_element
        self.xml_tree: Optional[ET.ElementTree] = None
    
    def read_data(self) -> ET.Element:
        """读取XML数据"""
        logging.info(f"读取XML文件: {self.source_path}")
        try:
            self.xml_tree = ET.parse(self.source_path)
            root = self.xml_tree.getroot()
            logging.info(f"成功读取XML数据,根元素: {root.tag}")
            return root
        except FileNotFoundError:
            raise FileNotFoundError(f"文件不存在: {self.source_path}")
        except ET.ParseError as e:
            raise ValueError(f"XML格式错误: {str(e)}")
    
    def transform_data(self) -> List[Dict[str, Any]]:
        """转换XML数据"""
        logging.info("转换XML数据")
        
        result = []
        
        # 遍历所有子元素
        for element in self.raw_data:
            record = self._element_to_dict(element)
            result.append(record)
        
        return result
    
    def _element_to_dict(self, element: ET.Element) -> Dict[str, Any]:
        """将XML元素转换为字典"""
        result = {}
        
        # 添加属性
        if element.attrib:
            result.update(element.attrib)
        
        # 添加文本内容
        if element.text and element.text.strip():
            result['text'] = element.text.strip()
        
        # 添加子元素
        for child in element:
            if child.tag in result:
                # 如果已存在,转换为列表
                if not isinstance(result[child.tag], list):
                    result[child.tag] = [result[child.tag]]
                result[child.tag].append(self._element_to_dict(child))
            else:
                result[child.tag] = self._element_to_dict(child)
        
        return result
    
    def save_result(self) -> None:
        """保存XML结果"""
        logging.info(f"保存XML结果到: {self.output_path}")
        
        # 创建根元素
        root = ET.Element(self.root_element)
        
        # 添加数据
        for i, record in enumerate(self.processed_data):
            item_element = ET.SubElement(root, f'item_{i+1}')
            self._dict_to_element(record, item_element)
        
        # 创建树并保存
        tree = ET.ElementTree(root)
        tree.write(self.output_path, encoding='utf-8', xml_declaration=True)
        
        logging.info("XML结果保存完成")
    
    def _dict_to_element(self, data: Dict[str, Any], parent: ET.Element) -> None:
        """将字典转换为XML元素"""
        for key, value in data.items():
            if isinstance(value, dict):
                child = ET.SubElement(parent, key)
                self._dict_to_element(value, child)
            elif isinstance(value, list):
                for item in value:
                    child = ET.SubElement(parent, key)
                    if isinstance(item, dict):
                        self._dict_to_element(item, child)
                    else:
                        child.text = str(item)
            else:
                child = ET.SubElement(parent, key)
                child.text = str(value)
    
    def cleanup(self) -> None:
        """清理XML特定资源"""
        super().cleanup()
        self.xml_tree = None
        logging.info("清理XML资源")

# 数据处理工厂
class DataProcessorFactory:
    """数据处理器工厂"""
    
    @staticmethod
    def create_processor(file_type: str, source_path: str, output_path: str, **kwargs) -> DataProcessor:
        """创建数据处理器"""
        if file_type.lower() == 'json':
            return JSONDataProcessor(source_path, output_path)
        elif file_type.lower() == 'csv':
            delimiter = kwargs.get('delimiter', ',')
            return CSVDataProcessor(source_path, output_path, delimiter)
        elif file_type.lower() == 'xml':
            root_element = kwargs.get('root_element', 'data')
            return XMLDataProcessor(source_path, output_path, root_element)
        else:
            raise ValueError(f"不支持的文件类型: {file_type}")

# 使用示例
def demonstrate_template_method_pattern():
    """演示模板方法模式"""
    print("=== 模板方法模式演示:数据处理流水线 ===")
    
    # 创建测试数据文件
    create_test_files()
    
    # 测试不同类型的数据处理器
    test_cases = [
        {
            "type": "json",
            "source": "test_data.json",
            "output": "output_data.json"
        },
        {
            "type": "csv",
            "source": "test_data.csv",
            "output": "output_data.csv",
            "delimiter": ","
        },
        {
            "type": "xml",
            "source": "test_data.xml",
            "output": "output_data.xml",
            "root_element": "processed_data"
        }
    ]
    
    for test_case in test_cases:
        print(f"\n--- 测试 {test_case['type'].upper()} 数据处理 ---")
        
        try:
            # 创建处理器
            processor = DataProcessorFactory.create_processor(
                test_case['type'],
                test_case['source'],
                test_case['output'],
                **{k: v for k, v in test_case.items() if k not in ['type', 'source', 'output']}
            )
            
            # 执行处理
            result = processor.process_data()
            
            # 显示结果
            if result['success']:
                print(f"✅ 处理成功")
                print(f"处理时间: {result['processing_time']:.2f} 秒")
                print(f"统计信息: {result['statistics']}")
            else:
                print(f"❌ 处理失败: {result['message']}")
                
        except Exception as e:
            print(f"❌ 创建处理器失败: {str(e)}")

def create_test_files():
    """创建测试数据文件"""
    # 创建JSON测试文件
    json_data = [
        {"id": 1, "name": "张三", "age": 25, "city": "北京"},
        {"id": 2, "name": "李四", "age": 30, "city": "上海"},
        {"id": 3, "name": "王五", "age": 28, "city": "广州"}
    ]
    
    with open("test_data.json", 'w', encoding='utf-8') as f:
        json.dump(json_data, f, ensure_ascii=False, indent=2)
    
    # 创建CSV测试文件
    csv_data = [
        ["id", "name", "age", "city"],
        ["1", "张三", "25", "北京"],
        ["2", "李四", "30", "上海"],
        ["3", "王五", "28", "广州"]
    ]
    
    with open("test_data.csv", 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerows(csv_data)
    
    # 创建XML测试文件
    xml_content = '''<?xml version="1.0" encoding="UTF-8"?>
<users>
    <user id="1">
        <name>张三</name>
        <age>25</age>
        <city>北京</city>
    </user>
    <user id="2">
        <name>李四</name>
        <age>30</age>
        <city>上海</city>
    </user>
    <user id="3">
        <name>王五</name>
        <age>28</age>
        <city>广州</city>
    </user>
</users>'''
    
    with open("test_data.xml", 'w', encoding='utf-8') as f:
        f.write(xml_content)

if __name__ == "__main__":
    demonstrate_template_method_pattern()

8.3.4 模板方法模式的优缺点

优点: 1. 代码复用:提取公共行为到父类,避免重复代码 2. 控制扩展点:父类控制算法结构,子类只需实现特定步骤 3. 符合开闭原则:增加新的实现不需要修改现有代码 4. 实现反向控制:父类调用子类方法,而不是相反

缺点: 1. 继承关系复杂:需要为每个不同的实现创建子类 2. 调试困难:算法流程分散在多个类中 3. 违反里氏替换原则:子类可能改变父类的行为

8.3.5 适用场景

  1. 算法骨架固定:多个类有相同的算法结构,但具体实现不同
  2. 控制子类扩展:需要控制子类在哪些点可以扩展
  3. 提取公共行为:多个类有重复的代码需要提取
  4. 实现框架:需要定义一个框架,让用户扩展特定部分

8.4 策略模式与模板方法模式对比

8.4.1 相同点

  1. 都是行为型模式:关注对象间的通信和职责分配
  2. 都支持算法变化:允许算法的实现发生变化
  3. 都符合开闭原则:对扩展开放,对修改封闭
  4. 都避免条件判断:减少复杂的if-else语句

8.4.2 不同点

对比维度 策略模式 模板方法模式
实现方式 组合关系 继承关系
算法选择 运行时动态选择 编译时确定
算法结构 完全独立的算法 相同结构的算法
扩展方式 增加新策略类 增加新子类
客户端职责 需要知道具体策略 只需要知道抽象类
控制反转 客户端控制 父类控制

8.4.3 选择指南

选择策略模式的情况: - 需要在运行时动态切换算法 - 算法之间完全独立,没有公共结构 - 希望避免继承关系 - 客户端需要了解不同算法的特点

选择模板方法模式的情况: - 算法有固定的执行步骤 - 多个算法有相同的结构但实现不同 - 需要控制子类的扩展点 - 希望提取公共代码到父类

8.4.4 组合使用

策略模式和模板方法模式可以组合使用:

from abc import ABC, abstractmethod

# 模板方法模式:定义处理流程
class DataProcessorTemplate(ABC):
    def process(self, data):
        # 固定的处理流程
        validated_data = self.validate(data)
        processed_data = self.transform(validated_data)
        return self.output(processed_data)
    
    @abstractmethod
    def validate(self, data):
        pass
    
    @abstractmethod
    def transform(self, data):
        pass
    
    @abstractmethod
    def output(self, data):
        pass

# 策略模式:定义转换策略
class TransformStrategy(ABC):
    @abstractmethod
    def transform(self, data):
        pass

class UpperCaseStrategy(TransformStrategy):
    def transform(self, data):
        return data.upper()

class LowerCaseStrategy(TransformStrategy):
    def transform(self, data):
        return data.lower()

# 组合使用
class FlexibleDataProcessor(DataProcessorTemplate):
    def __init__(self, transform_strategy: TransformStrategy):
        self.transform_strategy = transform_strategy
    
    def validate(self, data):
        if not isinstance(data, str):
            raise ValueError("数据必须是字符串")
        return data
    
    def transform(self, data):
        # 使用策略模式进行转换
        return self.transform_strategy.transform(data)
    
    def output(self, data):
        return f"处理结果: {data}"

8.5 本章总结

8.5.1 核心概念回顾

  1. 行为型模式:关注对象间的通信和职责分配
  2. 策略模式:定义算法族,使它们可以互相替换
  3. 模板方法模式:定义算法骨架,延迟具体实现到子类

8.5.2 最佳实践

  1. 策略模式最佳实践

    • 使用工厂模式创建策略对象
    • 为策略提供统一的接口
    • 考虑策略的状态管理
    • 提供策略的元数据信息
  2. 模板方法模式最佳实践

    • 明确区分抽象方法和钩子方法
    • 保持模板方法的稳定性
    • 提供清晰的文档说明
    • 考虑使用final关键字保护模板方法

8.5.3 实际应用建议

  1. 框架设计:使用模板方法模式定义框架结构
  2. 算法库:使用策略模式实现可插拔的算法
  3. 数据处理:结合两种模式实现灵活的数据处理流水线
  4. 业务规则:使用策略模式实现可配置的业务规则

8.5.4 注意事项

  1. 避免过度设计:不要为了使用模式而使用模式
  2. 性能考虑:策略切换和继承都有性能开销
  3. 维护成本:考虑模式带来的复杂性
  4. 团队理解:确保团队成员理解模式的使用

8.6 练习题

8.6.1 基础练习

  1. 策略模式练习

    • 实现一个计算器,支持不同的运算策略(加、减、乘、除)
    • 为计算器添加历史记录功能
    • 实现运算结果的不同格式化策略
  2. 模板方法模式练习

    • 设计一个文件下载器,支持HTTP、FTP、SFTP协议
    • 实现统一的下载流程:验证URL → 建立连接 → 下载文件 → 验证完整性
    • 为不同协议实现具体的连接和下载逻辑

8.6.2 进阶练习

  1. 组合模式练习

    • 设计一个数据导入系统,结合策略模式和模板方法模式
    • 支持多种数据源(数据库、API、文件)
    • 支持多种数据格式(JSON、XML、CSV)
    • 实现数据验证、转换、存储的统一流程
  2. 扩展练习

    • 为策略模式添加策略注册机制
    • 实现策略的动态加载和卸载
    • 为模板方法模式添加步骤的条件执行
    • 实现模板方法的并行执行

8.6.3 思考题

  1. 在什么情况下策略模式比简单的if-else语句更有优势?
  2. 模板方法模式如何体现”好莱坞原则”?
  3. 如何在保持灵活性的同时避免策略类数量过多?
  4. 模板方法模式中的钩子方法应该如何设计?
  5. 两种模式在微服务架构中如何应用?

下一章预告:第九章将学习观察者模式和命令模式,探讨对象间的通信机制和请求的封装处理。