14.1 本章概述

14.1.1 学习目标

在本章中,我们将深入学习两个重要的设计模式:

  1. 模板方法模式(Template Method Pattern):定义算法骨架,让子类实现具体步骤
  2. 工厂方法模式(Factory Method Pattern):创建对象的高级应用和最佳实践

14.1.2 应用场景

  • 模板方法模式:算法框架、数据处理流水线、测试框架
  • 工厂方法模式:对象创建、插件系统、依赖注入

14.2 模板方法模式深入解析

14.2.1 模式定义

模板方法模式定义一个操作中算法的骨架,而将一些步骤延迟到子类中。模板方法使得子类可以不改变一个算法的结构即可重定义该算法的某些特定步骤。

14.2.2 模式动机

在软件开发中,我们经常遇到这样的情况: - 多个类有相似的处理流程 - 流程的整体结构相同,但某些步骤的实现不同 - 需要控制子类的扩展点

14.2.3 模式结构

┌─────────────────────┐
│   AbstractClass     │
├─────────────────────┤
│ + templateMethod()  │  ← 模板方法
│ # primitiveOp1()    │  ← 抽象操作
│ # primitiveOp2()    │  ← 抽象操作
│ # hook()            │  ← 钩子方法
└─────────────────────┘
           △
           │
┌─────────────────────┐
│   ConcreteClass     │
├─────────────────────┤
│ # primitiveOp1()    │  ← 具体实现
│ # primitiveOp2()    │  ← 具体实现
│ # hook()            │  ← 可选重写
└─────────────────────┘

14.2.4 Python实现示例:数据处理框架

from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
import time
import json
import csv
import logging

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DataProcessor(ABC):
    """数据处理器抽象基类(模板方法模式)"""
    
    def process_data(self, input_data: Any) -> Any:
        """模板方法:定义数据处理的完整流程"""
        logger.info(f"开始处理数据,处理器类型: {self.__class__.__name__}")
        start_time = time.time()
        
        try:
            # 1. 数据验证
            if not self.validate_input(input_data):
                raise ValueError("输入数据验证失败")
            
            # 2. 预处理
            preprocessed_data = self.preprocess(input_data)
            logger.info("数据预处理完成")
            
            # 3. 核心处理(抽象方法,子类必须实现)
            processed_data = self.process_core(preprocessed_data)
            logger.info("核心处理完成")
            
            # 4. 后处理
            postprocessed_data = self.postprocess(processed_data)
            logger.info("数据后处理完成")
            
            # 5. 结果验证
            if not self.validate_output(postprocessed_data):
                raise ValueError("输出数据验证失败")
            
            # 6. 清理工作(钩子方法)
            self.cleanup()
            
            processing_time = time.time() - start_time
            logger.info(f"数据处理完成,耗时: {processing_time:.2f}秒")
            
            return postprocessed_data
            
        except Exception as e:
            logger.error(f"数据处理失败: {e}")
            self.handle_error(e)
            raise
    
    def validate_input(self, input_data: Any) -> bool:
        """输入验证(具体方法,可被子类重写)"""
        return input_data is not None
    
    def preprocess(self, input_data: Any) -> Any:
        """预处理(具体方法,可被子类重写)"""
        return input_data
    
    @abstractmethod
    def process_core(self, data: Any) -> Any:
        """核心处理逻辑(抽象方法,子类必须实现)"""
        pass
    
    def postprocess(self, processed_data: Any) -> Any:
        """后处理(具体方法,可被子类重写)"""
        return processed_data
    
    def validate_output(self, output_data: Any) -> bool:
        """输出验证(具体方法,可被子类重写)"""
        return output_data is not None
    
    def cleanup(self):
        """清理工作(钩子方法,子类可选择性重写)"""
        pass
    
    def handle_error(self, error: Exception):
        """错误处理(钩子方法,子类可选择性重写)"""
        logger.error(f"处理错误: {error}")

class TextDataProcessor(DataProcessor):
    """文本数据处理器"""
    
    def __init__(self, min_length: int = 1, max_length: int = 10000):
        self.min_length = min_length
        self.max_length = max_length
        self.processed_count = 0
    
    def validate_input(self, input_data: Any) -> bool:
        """验证输入是否为有效文本"""
        if not isinstance(input_data, str):
            logger.warning("输入数据不是字符串类型")
            return False
        
        if not (self.min_length <= len(input_data) <= self.max_length):
            logger.warning(f"文本长度不在有效范围内: {len(input_data)}")
            return False
        
        return True
    
    def preprocess(self, input_data: str) -> str:
        """文本预处理:清理和标准化"""
        # 去除多余空白
        text = ' '.join(input_data.split())
        
        # 转换为小写
        text = text.lower()
        
        # 移除特殊字符(保留字母、数字、空格和基本标点)
        import re
        text = re.sub(r'[^a-zA-Z0-9\s.,!?;:]', '', text)
        
        logger.info(f"文本预处理:原长度 {len(input_data)} -> 处理后长度 {len(text)}")
        return text
    
    def process_core(self, data: str) -> Dict[str, Any]:
        """核心处理:文本分析"""
        words = data.split()
        
        # 统计信息
        stats = {
            'original_text': data,
            'word_count': len(words),
            'char_count': len(data),
            'sentence_count': len([s for s in data.split('.') if s.strip()]),
            'avg_word_length': sum(len(word) for word in words) / len(words) if words else 0,
            'unique_words': len(set(words)),
            'word_frequency': {}
        }
        
        # 词频统计
        for word in words:
            stats['word_frequency'][word] = stats['word_frequency'].get(word, 0) + 1
        
        # 排序词频
        stats['word_frequency'] = dict(sorted(
            stats['word_frequency'].items(), 
            key=lambda x: x[1], 
            reverse=True
        ))
        
        self.processed_count += 1
        logger.info(f"文本分析完成,词数: {stats['word_count']}, 唯一词数: {stats['unique_words']}")
        
        return stats
    
    def postprocess(self, processed_data: Dict[str, Any]) -> Dict[str, Any]:
        """后处理:添加元数据和格式化"""
        processed_data['processing_metadata'] = {
            'processor_type': 'TextDataProcessor',
            'processing_time': time.time(),
            'processed_count': self.processed_count,
            'version': '1.0.0'
        }
        
        # 添加可读性指标
        word_count = processed_data['word_count']
        sentence_count = processed_data['sentence_count']
        
        if sentence_count > 0:
            processed_data['readability'] = {
                'avg_words_per_sentence': word_count / sentence_count,
                'complexity_score': self._calculate_complexity(processed_data)
            }
        
        return processed_data
    
    def _calculate_complexity(self, stats: Dict[str, Any]) -> float:
        """计算文本复杂度分数"""
        avg_word_length = stats['avg_word_length']
        unique_ratio = stats['unique_words'] / stats['word_count'] if stats['word_count'] > 0 else 0
        
        # 简单的复杂度计算公式
        complexity = (avg_word_length * 0.5) + (unique_ratio * 0.3) + (stats['sentence_count'] * 0.2)
        return round(complexity, 2)
    
    def validate_output(self, output_data: Dict[str, Any]) -> bool:
        """验证输出数据的完整性"""
        required_fields = ['word_count', 'char_count', 'unique_words', 'word_frequency']
        
        for field in required_fields:
            if field not in output_data:
                logger.error(f"输出数据缺少必需字段: {field}")
                return False
        
        if output_data['word_count'] < 0 or output_data['char_count'] < 0:
            logger.error("统计数据不能为负数")
            return False
        
        return True
    
    def cleanup(self):
        """清理工作"""
        logger.info(f"文本处理器清理完成,总共处理了 {self.processed_count} 个文本")

class NumericDataProcessor(DataProcessor):
    """数值数据处理器"""
    
    def __init__(self, outlier_threshold: float = 2.0):
        self.outlier_threshold = outlier_threshold
        self.outliers_removed = 0
    
    def validate_input(self, input_data: Any) -> bool:
        """验证输入是否为有效数值列表"""
        if not isinstance(input_data, (list, tuple)):
            logger.warning("输入数据不是列表或元组类型")
            return False
        
        if len(input_data) == 0:
            logger.warning("输入数据为空")
            return False
        
        # 检查是否所有元素都是数值
        for item in input_data:
            if not isinstance(item, (int, float)):
                logger.warning(f"发现非数值元素: {item}")
                return False
        
        return True
    
    def preprocess(self, input_data: List[float]) -> List[float]:
        """数值预处理:排序和基本清理"""
        # 转换为浮点数列表
        data = [float(x) for x in input_data]
        
        # 移除NaN值
        import math
        data = [x for x in data if not math.isnan(x) and not math.isinf(x)]
        
        logger.info(f"数值预处理:原数量 {len(input_data)} -> 处理后数量 {len(data)}")
        return data
    
    def process_core(self, data: List[float]) -> Dict[str, Any]:
        """核心处理:统计分析"""
        import statistics
        import math
        
        if not data:
            return {'error': '没有有效数据进行处理'}
        
        # 基本统计
        stats = {
            'count': len(data),
            'sum': sum(data),
            'mean': statistics.mean(data),
            'median': statistics.median(data),
            'min': min(data),
            'max': max(data),
            'range': max(data) - min(data)
        }
        
        # 高级统计(如果数据量足够)
        if len(data) > 1:
            stats['std_dev'] = statistics.stdev(data)
            stats['variance'] = statistics.variance(data)
            
            # 四分位数
            sorted_data = sorted(data)
            n = len(sorted_data)
            stats['q1'] = sorted_data[n // 4]
            stats['q3'] = sorted_data[3 * n // 4]
            stats['iqr'] = stats['q3'] - stats['q1']
            
            # 偏度和峰度的简单估算
            mean = stats['mean']
            std_dev = stats['std_dev']
            
            if std_dev > 0:
                # 偏度(skewness)
                skewness = sum((x - mean) ** 3 for x in data) / (len(data) * std_dev ** 3)
                stats['skewness'] = skewness
                
                # 峰度(kurtosis)
                kurtosis = sum((x - mean) ** 4 for x in data) / (len(data) * std_dev ** 4) - 3
                stats['kurtosis'] = kurtosis
        
        # 异常值检测
        outliers = self._detect_outliers(data)
        stats['outliers'] = outliers
        stats['outlier_count'] = len(outliers)
        
        logger.info(f"数值分析完成,均值: {stats['mean']:.2f}, 标准差: {stats.get('std_dev', 0):.2f}")
        
        return stats
    
    def _detect_outliers(self, data: List[float]) -> List[float]:
        """使用IQR方法检测异常值"""
        if len(data) < 4:
            return []
        
        import statistics
        
        q1 = statistics.quantiles(data, n=4)[0]
        q3 = statistics.quantiles(data, n=4)[2]
        iqr = q3 - q1
        
        lower_bound = q1 - self.outlier_threshold * iqr
        upper_bound = q3 + self.outlier_threshold * iqr
        
        outliers = [x for x in data if x < lower_bound or x > upper_bound]
        return outliers
    
    def postprocess(self, processed_data: Dict[str, Any]) -> Dict[str, Any]:
        """后处理:添加解释和建议"""
        if 'error' in processed_data:
            return processed_data
        
        # 添加数据质量评估
        processed_data['data_quality'] = self._assess_data_quality(processed_data)
        
        # 添加处理元数据
        processed_data['processing_metadata'] = {
            'processor_type': 'NumericDataProcessor',
            'processing_time': time.time(),
            'outlier_threshold': self.outlier_threshold,
            'outliers_removed': self.outliers_removed,
            'version': '1.0.0'
        }
        
        return processed_data
    
    def _assess_data_quality(self, stats: Dict[str, Any]) -> Dict[str, Any]:
        """评估数据质量"""
        quality = {
            'sample_size': 'adequate' if stats['count'] >= 30 else 'small',
            'distribution': 'normal' if abs(stats.get('skewness', 0)) < 0.5 else 'skewed',
            'outlier_ratio': stats['outlier_count'] / stats['count'],
            'variability': 'high' if stats.get('std_dev', 0) > stats['mean'] else 'normal'
        }
        
        # 总体质量评分
        score = 100
        if quality['sample_size'] == 'small':
            score -= 20
        if quality['distribution'] == 'skewed':
            score -= 15
        if quality['outlier_ratio'] > 0.1:
            score -= 25
        if quality['variability'] == 'high':
            score -= 10
        
        quality['overall_score'] = max(0, score)
        
        return quality
    
    def cleanup(self):
        """清理工作"""
        logger.info(f"数值处理器清理完成,移除了 {self.outliers_removed} 个异常值")

class JSONDataProcessor(DataProcessor):
    """JSON数据处理器"""
    
    def __init__(self, required_fields: Optional[List[str]] = None):
        self.required_fields = required_fields or []
        self.processed_records = 0
    
    def validate_input(self, input_data: Any) -> bool:
        """验证输入是否为有效JSON数据"""
        if isinstance(input_data, str):
            try:
                json.loads(input_data)
                return True
            except json.JSONDecodeError:
                logger.warning("输入字符串不是有效的JSON格式")
                return False
        elif isinstance(input_data, (dict, list)):
            return True
        else:
            logger.warning("输入数据不是有效的JSON类型")
            return False
    
    def preprocess(self, input_data: Any) -> Dict[str, Any]:
        """JSON预处理:解析和标准化"""
        if isinstance(input_data, str):
            data = json.loads(input_data)
        else:
            data = input_data
        
        # 如果是列表,包装成字典
        if isinstance(data, list):
            data = {'items': data, 'count': len(data)}
        
        logger.info(f"JSON预处理完成,数据类型: {type(data).__name__}")
        return data
    
    def process_core(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """核心处理:JSON结构分析"""
        analysis = {
            'structure': self._analyze_structure(data),
            'field_types': self._analyze_field_types(data),
            'statistics': self._calculate_statistics(data),
            'validation': self._validate_required_fields(data)
        }
        
        self.processed_records += 1
        logger.info(f"JSON分析完成,字段数: {len(analysis['field_types'])}")
        
        return analysis
    
    def _analyze_structure(self, data: Dict[str, Any], prefix: str = '') -> Dict[str, Any]:
        """分析JSON结构"""
        structure = {
            'depth': 0,
            'total_fields': 0,
            'nested_objects': 0,
            'arrays': 0,
            'field_paths': []
        }
        
        def traverse(obj, path, depth):
            structure['depth'] = max(structure['depth'], depth)
            
            if isinstance(obj, dict):
                if depth > 0:
                    structure['nested_objects'] += 1
                
                for key, value in obj.items():
                    current_path = f"{path}.{key}" if path else key
                    structure['field_paths'].append(current_path)
                    structure['total_fields'] += 1
                    traverse(value, current_path, depth + 1)
            
            elif isinstance(obj, list):
                structure['arrays'] += 1
                for i, item in enumerate(obj[:5]):  # 只分析前5个元素
                    traverse(item, f"{path}[{i}]", depth + 1)
        
        traverse(data, prefix, 0)
        return structure
    
    def _analyze_field_types(self, data: Dict[str, Any]) -> Dict[str, str]:
        """分析字段类型"""
        field_types = {}
        
        def get_type_info(obj, path):
            if isinstance(obj, dict):
                for key, value in obj.items():
                    current_path = f"{path}.{key}" if path else key
                    field_types[current_path] = type(value).__name__
                    
                    if isinstance(value, (dict, list)):
                        get_type_info(value, current_path)
            
            elif isinstance(obj, list) and obj:
                # 分析列表中第一个元素的类型
                first_item = obj[0]
                field_types[f"{path}[0]"] = type(first_item).__name__
                
                if isinstance(first_item, (dict, list)):
                    get_type_info(first_item, f"{path}[0]")
        
        get_type_info(data, '')
        return field_types
    
    def _calculate_statistics(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """计算JSON数据统计信息"""
        stats = {
            'total_size': len(json.dumps(data)),
            'null_fields': 0,
            'empty_strings': 0,
            'empty_arrays': 0,
            'empty_objects': 0
        }
        
        def count_values(obj):
            if isinstance(obj, dict):
                if not obj:
                    stats['empty_objects'] += 1
                
                for value in obj.values():
                    if value is None:
                        stats['null_fields'] += 1
                    elif value == '':
                        stats['empty_strings'] += 1
                    else:
                        count_values(value)
            
            elif isinstance(obj, list):
                if not obj:
                    stats['empty_arrays'] += 1
                else:
                    for item in obj:
                        count_values(item)
        
        count_values(data)
        return stats
    
    def _validate_required_fields(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """验证必需字段"""
        validation = {
            'missing_fields': [],
            'present_fields': [],
            'validation_passed': True
        }
        
        for field in self.required_fields:
            if self._field_exists(data, field):
                validation['present_fields'].append(field)
            else:
                validation['missing_fields'].append(field)
                validation['validation_passed'] = False
        
        return validation
    
    def _field_exists(self, data: Dict[str, Any], field_path: str) -> bool:
        """检查字段路径是否存在"""
        parts = field_path.split('.')
        current = data
        
        try:
            for part in parts:
                if isinstance(current, dict) and part in current:
                    current = current[part]
                else:
                    return False
            return True
        except (KeyError, TypeError):
            return False
    
    def postprocess(self, processed_data: Dict[str, Any]) -> Dict[str, Any]:
        """后处理:添加建议和优化提示"""
        # 添加数据质量建议
        processed_data['recommendations'] = self._generate_recommendations(processed_data)
        
        # 添加处理元数据
        processed_data['processing_metadata'] = {
            'processor_type': 'JSONDataProcessor',
            'processing_time': time.time(),
            'required_fields': self.required_fields,
            'processed_records': self.processed_records,
            'version': '1.0.0'
        }
        
        return processed_data
    
    def _generate_recommendations(self, analysis: Dict[str, Any]) -> List[str]:
        """生成数据质量改进建议"""
        recommendations = []
        
        structure = analysis['structure']
        statistics = analysis['statistics']
        validation = analysis['validation']
        
        # 结构建议
        if structure['depth'] > 5:
            recommendations.append("JSON结构层次过深,建议简化嵌套结构")
        
        if structure['total_fields'] > 50:
            recommendations.append("字段数量较多,建议考虑拆分为多个对象")
        
        # 数据质量建议
        if statistics['null_fields'] > 0:
            recommendations.append(f"发现 {statistics['null_fields']} 个空值字段,建议检查数据完整性")
        
        if statistics['empty_strings'] > 0:
            recommendations.append(f"发现 {statistics['empty_strings']} 个空字符串,建议使用null或提供默认值")
        
        # 验证建议
        if not validation['validation_passed']:
            missing = ', '.join(validation['missing_fields'])
            recommendations.append(f"缺少必需字段: {missing}")
        
        if not recommendations:
            recommendations.append("数据结构良好,无需特别优化")
        
        return recommendations
    
    def cleanup(self):
        """清理工作"""
        logger.info(f"JSON处理器清理完成,处理了 {self.processed_records} 条记录")

# 演示模板方法模式
def demonstrate_template_method_pattern():
    print("=== 模板方法模式演示 ===")
    
    # 文本数据处理
    print("\n1. 文本数据处理")
    text_processor = TextDataProcessor(min_length=10, max_length=1000)
    
    sample_text = """
    The Template Method Pattern defines the skeleton of an algorithm in a method, 
    deferring some steps to subclasses. Template Method lets subclasses redefine 
    certain steps of an algorithm without changing the algorithm's structure.
    This is a very useful pattern for creating frameworks and libraries!!!
    """
    
    try:
        result = text_processor.process_data(sample_text)
        print(f"处理结果: 词数={result['word_count']}, 唯一词数={result['unique_words']}")
        print(f"复杂度分数: {result.get('readability', {}).get('complexity_score', 'N/A')}")
    except Exception as e:
        print(f"处理失败: {e}")
    
    # 数值数据处理
    print("\n2. 数值数据处理")
    numeric_processor = NumericDataProcessor(outlier_threshold=1.5)
    
    sample_numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 100, 2, 3, 4, 5]  # 100是异常值
    
    try:
        result = numeric_processor.process_data(sample_numbers)
        print(f"处理结果: 均值={result['mean']:.2f}, 标准差={result.get('std_dev', 0):.2f}")
        print(f"异常值数量: {result['outlier_count']}, 数据质量评分: {result['data_quality']['overall_score']}")
    except Exception as e:
        print(f"处理失败: {e}")
    
    # JSON数据处理
    print("\n3. JSON数据处理")
    json_processor = JSONDataProcessor(required_fields=['id', 'name', 'email'])
    
    sample_json = {
        "id": 1,
        "name": "John Doe",
        "email": "john@example.com",
        "profile": {
            "age": 30,
            "city": "New York",
            "hobbies": ["reading", "swimming"]
        },
        "settings": {
            "notifications": True,
            "theme": "dark"
        }
    }
    
    try:
        result = json_processor.process_data(sample_json)
        print(f"处理结果: 字段数={result['structure']['total_fields']}, 深度={result['structure']['depth']}")
        print(f"验证通过: {result['validation']['validation_passed']}")
        print(f"建议数量: {len(result['recommendations'])}")
    except Exception as e:
        print(f"处理失败: {e}")

if __name__ == "__main__":
    demonstrate_template_method_pattern()

14.2.6 模板方法模式的优缺点

优点: 1. 代码复用:公共算法逻辑在父类中实现,避免重复代码 2. 控制扩展点:通过抽象方法和钩子方法控制子类的扩展 3. 符合开闭原则:对扩展开放,对修改封闭 4. 算法结构稳定:算法骨架不变,只改变特定步骤 5. 易于维护:算法逻辑集中管理

缺点: 1. 继承依赖:必须通过继承实现,限制了灵活性 2. 类数量增加:每个变体都需要一个子类 3. 调试困难:算法流程分散在多个类中 4. 违反里氏替换原则:子类可能改变父类行为

14.2.7 适用场景

  • 算法框架:定义算法骨架,让子类实现具体步骤
  • 数据处理流水线:ETL过程、数据转换等
  • 测试框架:测试用例的执行流程
  • 生命周期管理:组件初始化、销毁等
  • 工作流引擎:业务流程的标准化处理

14.3 工厂方法模式高级应用

14.3.1 模式定义

工厂方法模式定义一个用于创建对象的接口,让子类决定实例化哪一个类。工厂方法使一个类的实例化延迟到其子类。

14.3.2 模式动机

在软件开发中,我们经常遇到这样的情况: - 需要创建不同类型的对象,但创建逻辑复杂 - 对象创建过程需要根据条件动态决定 - 希望将对象创建与使用分离 - 需要支持插件化的对象创建

14.3.3 模式结构

┌─────────────────────┐
│     Creator         │
├─────────────────────┤
│ + factoryMethod()   │  ← 工厂方法
│ + someOperation()   │
└─────────────────────┘
           △
           │
┌─────────────────────┐
│  ConcreteCreator    │
├─────────────────────┤
│ + factoryMethod()   │  ← 具体实现
└─────────────────────┘
           │ creates
           ▼
┌─────────────────────┐
│  ConcreteProduct    │
├─────────────────────┤
│ + operation()       │
└─────────────────────┘
           △
           │
┌─────────────────────┐
│     Product         │
├─────────────────────┤
│ + operation()       │
└─────────────────────┘

14.3.4 Python实现示例:插件化日志系统

from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, Union
from enum import Enum
import json
import time
import threading
from datetime import datetime
import os
import gzip
import sqlite3

# 日志级别枚举
class LogLevel(Enum):
    DEBUG = 10
    INFO = 20
    WARNING = 30
    ERROR = 40
    CRITICAL = 50

# 日志记录接口
class LogRecord:
    def __init__(self, level: LogLevel, message: str, 
                 timestamp: Optional[datetime] = None, 
                 metadata: Optional[Dict[str, Any]] = None):
        self.level = level
        self.message = message
        self.timestamp = timestamp or datetime.now()
        self.metadata = metadata or {}
        self.thread_id = threading.get_ident()
        self.process_id = os.getpid()
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            'level': self.level.name,
            'message': self.message,
            'timestamp': self.timestamp.isoformat(),
            'metadata': self.metadata,
            'thread_id': self.thread_id,
            'process_id': self.process_id
        }
    
    def __str__(self) -> str:
        return f"[{self.timestamp}] {self.level.name}: {self.message}"

# 抽象日志处理器
class LogHandler(ABC):
    def __init__(self, level: LogLevel = LogLevel.INFO):
        self.level = level
        self.filters = []
        self.formatters = []
    
    def add_filter(self, filter_func):
        """添加过滤器"""
        self.filters.append(filter_func)
    
    def add_formatter(self, formatter_func):
        """添加格式化器"""
        self.formatters.append(formatter_func)
    
    def should_handle(self, record: LogRecord) -> bool:
        """判断是否应该处理该日志记录"""
        if record.level.value < self.level.value:
            return False
        
        for filter_func in self.filters:
            if not filter_func(record):
                return False
        
        return True
    
    def format_record(self, record: LogRecord) -> str:
        """格式化日志记录"""
        formatted = str(record)
        
        for formatter in self.formatters:
            formatted = formatter(formatted, record)
        
        return formatted
    
    def handle(self, record: LogRecord):
        """处理日志记录"""
        if self.should_handle(record):
            formatted_message = self.format_record(record)
            self.emit(formatted_message, record)
    
    @abstractmethod
    def emit(self, formatted_message: str, record: LogRecord):
        """输出日志记录(抽象方法)"""
        pass
    
    @abstractmethod
    def close(self):
        """关闭处理器(抽象方法)"""
        pass

# 控制台日志处理器
class ConsoleLogHandler(LogHandler):
    def __init__(self, level: LogLevel = LogLevel.INFO, use_colors: bool = True):
        super().__init__(level)
        self.use_colors = use_colors
        self.color_map = {
            LogLevel.DEBUG: '\033[36m',    # 青色
            LogLevel.INFO: '\033[32m',     # 绿色
            LogLevel.WARNING: '\033[33m',  # 黄色
            LogLevel.ERROR: '\033[31m',    # 红色
            LogLevel.CRITICAL: '\033[35m'  # 紫色
        }
        self.reset_color = '\033[0m'
    
    def emit(self, formatted_message: str, record: LogRecord):
        if self.use_colors and record.level in self.color_map:
            color = self.color_map[record.level]
            message = f"{color}{formatted_message}{self.reset_color}"
        else:
            message = formatted_message
        
        print(message)
    
    def close(self):
        pass  # 控制台不需要关闭

# 文件日志处理器
class FileLogHandler(LogHandler):
    def __init__(self, filename: str, level: LogLevel = LogLevel.INFO, 
                 max_size: int = 10 * 1024 * 1024, backup_count: int = 5,
                 encoding: str = 'utf-8'):
        super().__init__(level)
        self.filename = filename
        self.max_size = max_size
        self.backup_count = backup_count
        self.encoding = encoding
        self.current_size = 0
        self.file_handle = None
        self.lock = threading.Lock()
        self._open_file()
    
    def _open_file(self):
        """打开日志文件"""
        try:
            self.file_handle = open(self.filename, 'a', encoding=self.encoding)
            if os.path.exists(self.filename):
                self.current_size = os.path.getsize(self.filename)
        except IOError as e:
            print(f"无法打开日志文件 {self.filename}: {e}")
    
    def _rotate_file(self):
        """轮转日志文件"""
        if self.file_handle:
            self.file_handle.close()
        
        # 轮转备份文件
        for i in range(self.backup_count - 1, 0, -1):
            old_name = f"{self.filename}.{i}"
            new_name = f"{self.filename}.{i + 1}"
            
            if os.path.exists(old_name):
                if os.path.exists(new_name):
                    os.remove(new_name)
                os.rename(old_name, new_name)
        
        # 压缩当前文件并重命名
        backup_name = f"{self.filename}.1.gz"
        with open(self.filename, 'rb') as f_in:
            with gzip.open(backup_name, 'wb') as f_out:
                f_out.writelines(f_in)
        
        # 重新创建日志文件
        self.current_size = 0
        self._open_file()
    
    def emit(self, formatted_message: str, record: LogRecord):
        with self.lock:
            if not self.file_handle:
                return
            
            message = formatted_message + '\n'
            message_size = len(message.encode(self.encoding))
            
            # 检查是否需要轮转
            if self.current_size + message_size > self.max_size:
                self._rotate_file()
            
            try:
                self.file_handle.write(message)
                self.file_handle.flush()
                self.current_size += message_size
            except IOError as e:
                print(f"写入日志文件失败: {e}")
    
    def close(self):
        with self.lock:
            if self.file_handle:
                self.file_handle.close()
                self.file_handle = None

# 数据库日志处理器
class DatabaseLogHandler(LogHandler):
    def __init__(self, db_path: str, level: LogLevel = LogLevel.INFO,
                 table_name: str = 'logs', batch_size: int = 100):
        super().__init__(level)
        self.db_path = db_path
        self.table_name = table_name
        self.batch_size = batch_size
        self.batch_records = []
        self.lock = threading.Lock()
        self._init_database()
    
    def _init_database(self):
        """初始化数据库表"""
        try:
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()
            
            cursor.execute(f'''
                CREATE TABLE IF NOT EXISTS {self.table_name} (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    level TEXT NOT NULL,
                    message TEXT NOT NULL,
                    timestamp TEXT NOT NULL,
                    thread_id INTEGER,
                    process_id INTEGER,
                    metadata TEXT
                )
            ''')
            
            # 创建索引
            cursor.execute(f'''
                CREATE INDEX IF NOT EXISTS idx_{self.table_name}_timestamp 
                ON {self.table_name}(timestamp)
            ''')
            
            cursor.execute(f'''
                CREATE INDEX IF NOT EXISTS idx_{self.table_name}_level 
                ON {self.table_name}(level)
            ''')
            
            conn.commit()
            conn.close()
            
        except sqlite3.Error as e:
            print(f"初始化数据库失败: {e}")
    
    def emit(self, formatted_message: str, record: LogRecord):
        with self.lock:
            self.batch_records.append(record)
            
            if len(self.batch_records) >= self.batch_size:
                self._flush_batch()
    
    def _flush_batch(self):
        """批量写入数据库"""
        if not self.batch_records:
            return
        
        try:
            conn = sqlite3.connect(self.db_path)
            cursor = conn.cursor()
            
            records_data = []
            for record in self.batch_records:
                records_data.append((
                    record.level.name,
                    record.message,
                    record.timestamp.isoformat(),
                    record.thread_id,
                    record.process_id,
                    json.dumps(record.metadata)
                ))
            
            cursor.executemany(f'''
                INSERT INTO {self.table_name} 
                (level, message, timestamp, thread_id, process_id, metadata)
                VALUES (?, ?, ?, ?, ?, ?)
            ''', records_data)
            
            conn.commit()
            conn.close()
            
            self.batch_records.clear()
            
        except sqlite3.Error as e:
            print(f"批量写入数据库失败: {e}")
    
    def close(self):
        with self.lock:
            self._flush_batch()

# 网络日志处理器
class NetworkLogHandler(LogHandler):
    def __init__(self, host: str, port: int, level: LogLevel = LogLevel.INFO,
                 protocol: str = 'tcp', timeout: float = 5.0):
        super().__init__(level)
        self.host = host
        self.port = port
        self.protocol = protocol.lower()
        self.timeout = timeout
        self.connection = None
        self.lock = threading.Lock()
    
    def _connect(self):
        """建立网络连接"""
        try:
            if self.protocol == 'tcp':
                import socket
                self.connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                self.connection.settimeout(self.timeout)
                self.connection.connect((self.host, self.port))
            elif self.protocol == 'udp':
                import socket
                self.connection = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                self.connection.settimeout(self.timeout)
            else:
                raise ValueError(f"不支持的协议: {self.protocol}")
                
        except Exception as e:
            print(f"网络连接失败: {e}")
            self.connection = None
    
    def emit(self, formatted_message: str, record: LogRecord):
        with self.lock:
            if not self.connection:
                self._connect()
            
            if not self.connection:
                return
            
            try:
                # 构造网络消息
                network_message = {
                    'timestamp': record.timestamp.isoformat(),
                    'level': record.level.name,
                    'message': record.message,
                    'metadata': record.metadata,
                    'source': {
                        'thread_id': record.thread_id,
                        'process_id': record.process_id
                    }
                }
                
                message_json = json.dumps(network_message) + '\n'
                message_bytes = message_json.encode('utf-8')
                
                if self.protocol == 'tcp':
                    self.connection.sendall(message_bytes)
                else:  # UDP
                    self.connection.sendto(message_bytes, (self.host, self.port))
                    
            except Exception as e:
                print(f"发送网络日志失败: {e}")
                self.connection = None
    
    def close(self):
        with self.lock:
            if self.connection:
                self.connection.close()
                self.connection = None

# 抽象日志工厂
class LogHandlerFactory(ABC):
    """日志处理器工厂抽象基类"""
    
    @abstractmethod
    def create_handler(self, config: Dict[str, Any]) -> LogHandler:
        """创建日志处理器(抽象方法)"""
        pass
    
    @abstractmethod
    def get_handler_type(self) -> str:
        """获取处理器类型(抽象方法)"""
        pass
    
    def validate_config(self, config: Dict[str, Any]) -> bool:
        """验证配置(可被子类重写)"""
        return True

# 控制台日志工厂
class ConsoleLogHandlerFactory(LogHandlerFactory):
    def create_handler(self, config: Dict[str, Any]) -> LogHandler:
        if not self.validate_config(config):
            raise ValueError("控制台日志处理器配置无效")
        
        level = LogLevel[config.get('level', 'INFO').upper()]
        use_colors = config.get('use_colors', True)
        
        handler = ConsoleLogHandler(level=level, use_colors=use_colors)
        
        # 添加过滤器
        filters = config.get('filters', [])
        for filter_config in filters:
            filter_func = self._create_filter(filter_config)
            if filter_func:
                handler.add_filter(filter_func)
        
        return handler
    
    def get_handler_type(self) -> str:
        return 'console'
    
    def validate_config(self, config: Dict[str, Any]) -> bool:
        level = config.get('level', 'INFO').upper()
        if level not in [l.name for l in LogLevel]:
            return False
        return True
    
    def _create_filter(self, filter_config: Dict[str, Any]):
        """创建过滤器函数"""
        filter_type = filter_config.get('type')
        
        if filter_type == 'keyword':
            keywords = filter_config.get('keywords', [])
            exclude = filter_config.get('exclude', False)
            
            def keyword_filter(record: LogRecord) -> bool:
                has_keyword = any(keyword in record.message for keyword in keywords)
                return not has_keyword if exclude else has_keyword
            
            return keyword_filter
        
        elif filter_type == 'thread':
            thread_ids = filter_config.get('thread_ids', [])
            
            def thread_filter(record: LogRecord) -> bool:
                return record.thread_id in thread_ids
            
            return thread_filter
        
        return None

# 文件日志工厂
class FileLogHandlerFactory(LogHandlerFactory):
    def create_handler(self, config: Dict[str, Any]) -> LogHandler:
        if not self.validate_config(config):
            raise ValueError("文件日志处理器配置无效")
        
        filename = config['filename']
        level = LogLevel[config.get('level', 'INFO').upper()]
        max_size = config.get('max_size', 10 * 1024 * 1024)
        backup_count = config.get('backup_count', 5)
        encoding = config.get('encoding', 'utf-8')
        
        # 确保目录存在
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        
        return FileLogHandler(
            filename=filename,
            level=level,
            max_size=max_size,
            backup_count=backup_count,
            encoding=encoding
        )
    
    def get_handler_type(self) -> str:
        return 'file'
    
    def validate_config(self, config: Dict[str, Any]) -> bool:
        if 'filename' not in config:
            return False
        
        level = config.get('level', 'INFO').upper()
        if level not in [l.name for l in LogLevel]:
            return False
        
        return True

# 数据库日志工厂
class DatabaseLogHandlerFactory(LogHandlerFactory):
    def create_handler(self, config: Dict[str, Any]) -> LogHandler:
        if not self.validate_config(config):
            raise ValueError("数据库日志处理器配置无效")
        
        db_path = config['db_path']
        level = LogLevel[config.get('level', 'INFO').upper()]
        table_name = config.get('table_name', 'logs')
        batch_size = config.get('batch_size', 100)
        
        # 确保数据库目录存在
        os.makedirs(os.path.dirname(db_path), exist_ok=True)
        
        return DatabaseLogHandler(
            db_path=db_path,
            level=level,
            table_name=table_name,
            batch_size=batch_size
        )
    
    def get_handler_type(self) -> str:
        return 'database'
    
    def validate_config(self, config: Dict[str, Any]) -> bool:
        if 'db_path' not in config:
            return False
        
        level = config.get('level', 'INFO').upper()
        if level not in [l.name for l in LogLevel]:
            return False
        
        return True

# 网络日志工厂
class NetworkLogHandlerFactory(LogHandlerFactory):
    def create_handler(self, config: Dict[str, Any]) -> LogHandler:
        if not self.validate_config(config):
            raise ValueError("网络日志处理器配置无效")
        
        host = config['host']
        port = config['port']
        level = LogLevel[config.get('level', 'INFO').upper()]
        protocol = config.get('protocol', 'tcp')
        timeout = config.get('timeout', 5.0)
        
        return NetworkLogHandler(
            host=host,
            port=port,
            level=level,
            protocol=protocol,
            timeout=timeout
        )
    
    def get_handler_type(self) -> str:
        return 'network'
    
    def validate_config(self, config: Dict[str, Any]) -> bool:
        required_fields = ['host', 'port']
        for field in required_fields:
            if field not in config:
                return False
        
        if not isinstance(config['port'], int) or not (1 <= config['port'] <= 65535):
            return False
        
        protocol = config.get('protocol', 'tcp').lower()
        if protocol not in ['tcp', 'udp']:
            return False
        
        level = config.get('level', 'INFO').upper()
        if level not in [l.name for l in LogLevel]:
            return False
        
        return True

# 日志工厂注册器
class LogHandlerFactoryRegistry:
    """日志处理器工厂注册器"""
    
    def __init__(self):
        self._factories: Dict[str, LogHandlerFactory] = {}
        self._register_default_factories()
    
    def _register_default_factories(self):
        """注册默认工厂"""
        self.register_factory(ConsoleLogHandlerFactory())
        self.register_factory(FileLogHandlerFactory())
        self.register_factory(DatabaseLogHandlerFactory())
        self.register_factory(NetworkLogHandlerFactory())
    
    def register_factory(self, factory: LogHandlerFactory):
        """注册工厂"""
        handler_type = factory.get_handler_type()
        self._factories[handler_type] = factory
        print(f"注册日志处理器工厂: {handler_type}")
    
    def unregister_factory(self, handler_type: str):
        """注销工厂"""
        if handler_type in self._factories:
            del self._factories[handler_type]
            print(f"注销日志处理器工厂: {handler_type}")
    
    def create_handler(self, handler_type: str, config: Dict[str, Any]) -> LogHandler:
        """创建日志处理器"""
        if handler_type not in self._factories:
            raise ValueError(f"未知的日志处理器类型: {handler_type}")
        
        factory = self._factories[handler_type]
        return factory.create_handler(config)
    
    def get_available_types(self) -> List[str]:
        """获取可用的处理器类型"""
        return list(self._factories.keys())

# 日志管理器
class Logger:
    """日志管理器"""
    
    def __init__(self, name: str):
        self.name = name
        self.handlers: List[LogHandler] = []
        self.level = LogLevel.INFO
        self.factory_registry = LogHandlerFactoryRegistry()
    
    def add_handler(self, handler: LogHandler):
        """添加日志处理器"""
        self.handlers.append(handler)
    
    def add_handler_from_config(self, handler_type: str, config: Dict[str, Any]):
        """从配置创建并添加日志处理器"""
        handler = self.factory_registry.create_handler(handler_type, config)
        self.add_handler(handler)
    
    def remove_handler(self, handler: LogHandler):
        """移除日志处理器"""
        if handler in self.handlers:
            handler.close()
            self.handlers.remove(handler)
    
    def set_level(self, level: LogLevel):
        """设置日志级别"""
        self.level = level
    
    def log(self, level: LogLevel, message: str, **metadata):
        """记录日志"""
        if level.value < self.level.value:
            return
        
        record = LogRecord(level, message, metadata=metadata)
        
        for handler in self.handlers:
            try:
                handler.handle(record)
            except Exception as e:
                print(f"日志处理器错误: {e}")
    
    def debug(self, message: str, **metadata):
        self.log(LogLevel.DEBUG, message, **metadata)
    
    def info(self, message: str, **metadata):
        self.log(LogLevel.INFO, message, **metadata)
    
    def warning(self, message: str, **metadata):
        self.log(LogLevel.WARNING, message, **metadata)
    
    def error(self, message: str, **metadata):
        self.log(LogLevel.ERROR, message, **metadata)
    
    def critical(self, message: str, **metadata):
        self.log(LogLevel.CRITICAL, message, **metadata)
    
    def close(self):
        """关闭所有处理器"""
        for handler in self.handlers:
            handler.close()
        self.handlers.clear()

# 演示工厂方法模式
def demonstrate_factory_method_pattern():
    print("=== 工厂方法模式演示 - 插件化日志系统 ===")
    
    # 创建日志管理器
    logger = Logger("demo_logger")
    logger.set_level(LogLevel.DEBUG)
    
    # 使用工厂方法创建不同类型的日志处理器
    print("\n1. 创建控制台日志处理器")
    console_config = {
        'level': 'INFO',
        'use_colors': True,
        'filters': [
            {
                'type': 'keyword',
                'keywords': ['error', 'critical'],
                'exclude': False
            }
        ]
    }
    logger.add_handler_from_config('console', console_config)
    
    print("\n2. 创建文件日志处理器")
    file_config = {
        'filename': './logs/application.log',
        'level': 'DEBUG',
        'max_size': 1024 * 1024,  # 1MB
        'backup_count': 3,
        'encoding': 'utf-8'
    }
    logger.add_handler_from_config('file', file_config)
    
    print("\n3. 创建数据库日志处理器")
    db_config = {
        'db_path': './logs/application.db',
        'level': 'WARNING',
        'table_name': 'application_logs',
        'batch_size': 10
    }
    logger.add_handler_from_config('database', db_config)
    
    # 测试日志记录
    print("\n4. 测试日志记录")
    logger.debug("这是一个调试消息", module="demo", function="test")
    logger.info("应用程序启动成功", version="1.0.0", startup_time=time.time())
    logger.warning("配置文件使用默认值", config_file="default.conf")
    logger.error("数据库连接失败", database="mysql", error_code=1045)
    logger.critical("系统内存不足", memory_usage="95%", available_memory="256MB")
    
    # 显示可用的处理器类型
    print("\n5. 可用的日志处理器类型:")
    available_types = logger.factory_registry.get_available_types()
    for handler_type in available_types:
        print(f"  - {handler_type}")
    
    # 清理资源
    print("\n6. 清理资源")
    logger.close()
    
    print("\n演示完成!")

if __name__ == "__main__":
    demonstrate_factory_method_pattern()

14.3.5 Java实现示例:文档处理系统

import java.util.*;
import java.io.*;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;

// 抽象文档接口
interface Document {
    void process();
    void save(String filename);
    String getContent();
    String getFormat();
}

// PDF文档实现
class PdfDocument implements Document {
    private String content;
    private Map<String, Object> metadata;
    
    public PdfDocument(String content) {
        this.content = content;
        this.metadata = new HashMap<>();
        this.metadata.put("format", "PDF");
        this.metadata.put("created", LocalDateTime.now());
    }
    
    @Override
    public void process() {
        System.out.println("Processing PDF document...");
        // 模拟PDF处理逻辑
        content = "[PDF] " + content;
        System.out.println("PDF processing completed.");
    }
    
    @Override
    public void save(String filename) {
        try (PrintWriter writer = new PrintWriter(filename + ".pdf")) {
            writer.println("PDF Document");
            writer.println("Created: " + metadata.get("created"));
            writer.println("Content: " + content);
            System.out.println("PDF saved to: " + filename + ".pdf");
        } catch (IOException e) {
            System.err.println("Error saving PDF: " + e.getMessage());
        }
    }
    
    @Override
    public String getContent() {
        return content;
    }
    
    @Override
    public String getFormat() {
        return "PDF";
    }
}

// Word文档实现
class WordDocument implements Document {
    private String content;
    private Map<String, Object> properties;
    
    public WordDocument(String content) {
        this.content = content;
        this.properties = new HashMap<>();
        this.properties.put("format", "DOCX");
        this.properties.put("author", "System");
        this.properties.put("created", LocalDateTime.now());
    }
    
    @Override
    public void process() {
        System.out.println("Processing Word document...");
        // 模拟Word处理逻辑
        content = "[WORD] " + content;
        System.out.println("Word processing completed.");
    }
    
    @Override
    public void save(String filename) {
        try (PrintWriter writer = new PrintWriter(filename + ".docx")) {
            writer.println("Word Document");
            writer.println("Author: " + properties.get("author"));
            writer.println("Created: " + properties.get("created"));
            writer.println("Content: " + content);
            System.out.println("Word document saved to: " + filename + ".docx");
        } catch (IOException e) {
            System.err.println("Error saving Word document: " + e.getMessage());
        }
    }
    
    @Override
    public String getContent() {
        return content;
    }
    
    @Override
    public String getFormat() {
        return "DOCX";
    }
}

// HTML文档实现
class HtmlDocument implements Document {
    private String content;
    private String title;
    private List<String> stylesheets;
    
    public HtmlDocument(String content) {
        this.content = content;
        this.title = "Generated Document";
        this.stylesheets = new ArrayList<>();
        this.stylesheets.add("default.css");
    }
    
    @Override
    public void process() {
        System.out.println("Processing HTML document...");
        // 模拟HTML处理逻辑
        content = String.format(
            "<html><head><title>%s</title></head><body>%s</body></html>",
            title, content
        );
        System.out.println("HTML processing completed.");
    }
    
    @Override
    public void save(String filename) {
        try (PrintWriter writer = new PrintWriter(filename + ".html")) {
            writer.println(content);
            System.out.println("HTML document saved to: " + filename + ".html");
        } catch (IOException e) {
            System.err.println("Error saving HTML document: " + e.getMessage());
        }
    }
    
    @Override
    public String getContent() {
        return content;
    }
    
    @Override
    public String getFormat() {
        return "HTML";
    }
    
    public void setTitle(String title) {
        this.title = title;
    }
    
    public void addStylesheet(String stylesheet) {
        this.stylesheets.add(stylesheet);
    }
}

// 抽象文档工厂
abstract class DocumentFactory {
    // 工厂方法 - 由子类实现
    public abstract Document createDocument(String content);
    
    // 模板方法 - 定义文档处理流程
    public final Document processDocument(String content, String filename) {
        System.out.println("Starting document processing workflow...");
        
        // 1. 创建文档
        Document document = createDocument(content);
        System.out.println("Document created: " + document.getFormat());
        
        // 2. 处理文档
        document.process();
        
        // 3. 验证文档
        if (validateDocument(document)) {
            System.out.println("Document validation passed.");
        } else {
            System.out.println("Document validation failed.");
            return null;
        }
        
        // 4. 保存文档
        if (filename != null && !filename.isEmpty()) {
            document.save(filename);
        }
        
        System.out.println("Document processing workflow completed.");
        return document;
    }
    
    // 钩子方法 - 子类可以重写
    protected boolean validateDocument(Document document) {
        // 默认验证逻辑
        return document.getContent() != null && !document.getContent().isEmpty();
    }
    
    // 获取支持的格式
    public abstract String getSupportedFormat();
    
    // 获取工厂描述
    public abstract String getDescription();
}

// PDF文档工厂
class PdfDocumentFactory extends DocumentFactory {
    @Override
    public Document createDocument(String content) {
        System.out.println("Creating PDF document with advanced features...");
        return new PdfDocument(content);
    }
    
    @Override
    protected boolean validateDocument(Document document) {
        // PDF特定的验证逻辑
        boolean basicValidation = super.validateDocument(document);
        if (!basicValidation) {
            return false;
        }
        
        // 检查PDF特定要求
        String content = document.getContent();
        if (content.length() > 10000) {
            System.out.println("Warning: PDF content is very long, consider splitting.");
        }
        
        return true;
    }
    
    @Override
    public String getSupportedFormat() {
        return "PDF";
    }
    
    @Override
    public String getDescription() {
        return "Factory for creating PDF documents with advanced formatting";
    }
}

// Word文档工厂
class WordDocumentFactory extends DocumentFactory {
    @Override
    public Document createDocument(String content) {
        System.out.println("Creating Word document with rich formatting...");
        return new WordDocument(content);
    }
    
    @Override
    protected boolean validateDocument(Document document) {
        // Word特定的验证逻辑
        boolean basicValidation = super.validateDocument(document);
        if (!basicValidation) {
            return false;
        }
        
        // 检查Word特定要求
        String content = document.getContent();
        if (!content.contains("[WORD]")) {
            System.out.println("Warning: Document may not be properly formatted for Word.");
        }
        
        return true;
    }
    
    @Override
    public String getSupportedFormat() {
        return "DOCX";
    }
    
    @Override
    public String getDescription() {
        return "Factory for creating Word documents with rich text formatting";
    }
}

// HTML文档工厂
class HtmlDocumentFactory extends DocumentFactory {
    @Override
    public Document createDocument(String content) {
        System.out.println("Creating HTML document with web features...");
        HtmlDocument htmlDoc = new HtmlDocument(content);
        htmlDoc.setTitle("Auto-generated HTML Document");
        htmlDoc.addStylesheet("bootstrap.css");
        return htmlDoc;
    }
    
    @Override
    protected boolean validateDocument(Document document) {
        // HTML特定的验证逻辑
        boolean basicValidation = super.validateDocument(document);
        if (!basicValidation) {
            return false;
        }
        
        // 检查HTML特定要求
        String content = document.getContent();
        if (!content.contains("<html>") || !content.contains("</html>")) {
            System.out.println("Error: Invalid HTML structure.");
            return false;
        }
        
        return true;
    }
    
    @Override
    public String getSupportedFormat() {
        return "HTML";
    }
    
    @Override
    public String getDescription() {
        return "Factory for creating HTML documents with web-ready formatting";
    }
}

// 文档工厂注册器
class DocumentFactoryRegistry {
    private Map<String, DocumentFactory> factories;
    
    public DocumentFactoryRegistry() {
        this.factories = new HashMap<>();
        registerDefaultFactories();
    }
    
    private void registerDefaultFactories() {
        registerFactory("pdf", new PdfDocumentFactory());
        registerFactory("word", new WordDocumentFactory());
        registerFactory("html", new HtmlDocumentFactory());
    }
    
    public void registerFactory(String type, DocumentFactory factory) {
        factories.put(type.toLowerCase(), factory);
        System.out.println("Registered factory for: " + type + " - " + factory.getDescription());
    }
    
    public void unregisterFactory(String type) {
        DocumentFactory removed = factories.remove(type.toLowerCase());
        if (removed != null) {
            System.out.println("Unregistered factory for: " + type);
        }
    }
    
    public Document createDocument(String type, String content, String filename) {
        DocumentFactory factory = factories.get(type.toLowerCase());
        if (factory == null) {
            throw new IllegalArgumentException("Unsupported document type: " + type);
        }
        
        return factory.processDocument(content, filename);
    }
    
    public Set<String> getSupportedTypes() {
        return new HashSet<>(factories.keySet());
    }
    
    public void listFactories() {
        System.out.println("Available document factories:");
        for (Map.Entry<String, DocumentFactory> entry : factories.entrySet()) {
            DocumentFactory factory = entry.getValue();
            System.out.printf("  %s (%s): %s%n", 
                entry.getKey().toUpperCase(), 
                factory.getSupportedFormat(),
                factory.getDescription());
        }
    }
}

// 文档处理服务
class DocumentProcessingService {
    private DocumentFactoryRegistry registry;
    
    public DocumentProcessingService() {
        this.registry = new DocumentFactoryRegistry();
    }
    
    public void processDocuments(Map<String, String> documents) {
        System.out.println("\n=== Batch Document Processing ===");
        
        for (Map.Entry<String, String> entry : documents.entrySet()) {
            String filename = entry.getKey();
            String content = entry.getValue();
            
            // 从文件名推断文档类型
            String type = inferDocumentType(filename);
            
            try {
                System.out.println("\nProcessing: " + filename + " (Type: " + type + ")");
                Document document = registry.createDocument(type, content, filename);
                
                if (document != null) {
                    System.out.println("Successfully processed: " + filename);
                } else {
                    System.out.println("Failed to process: " + filename);
                }
            } catch (Exception e) {
                System.err.println("Error processing " + filename + ": " + e.getMessage());
            }
        }
    }
    
    private String inferDocumentType(String filename) {
        if (filename.toLowerCase().contains("pdf")) {
            return "pdf";
        } else if (filename.toLowerCase().contains("word") || filename.toLowerCase().contains("doc")) {
            return "word";
        } else if (filename.toLowerCase().contains("html") || filename.toLowerCase().contains("web")) {
            return "html";
        } else {
            return "pdf"; // 默认类型
        }
    }
    
    public DocumentFactoryRegistry getRegistry() {
        return registry;
    }
}

// 演示类
public class FactoryMethodPatternDemo {
    public static void main(String[] args) {
        System.out.println("=== 工厂方法模式演示 - 文档处理系统 ===");
        
        // 创建文档处理服务
        DocumentProcessingService service = new DocumentProcessingService();
        
        // 显示可用的工厂
        System.out.println("\n1. 显示可用的文档工厂:");
        service.getRegistry().listFactories();
        
        // 准备测试文档
        Map<String, String> testDocuments = new HashMap<>();
        testDocuments.put("report_pdf", "这是一个重要的业务报告,包含详细的数据分析和图表。");
        testDocuments.put("letter_word", "尊敬的客户,感谢您对我们产品的关注和支持。");
        testDocuments.put("webpage_html", "欢迎访问我们的网站!这里有最新的产品信息和技术文档。");
        
        // 批量处理文档
        System.out.println("\n2. 批量处理文档:");
        service.processDocuments(testDocuments);
        
        // 演示单独创建文档
        System.out.println("\n3. 演示单独创建文档:");
        demonstrateIndividualFactories(service.getRegistry());
        
        // 演示自定义工厂注册
        System.out.println("\n4. 演示自定义工厂注册:");
        demonstrateCustomFactory(service.getRegistry());
        
        System.out.println("\n=== 演示完成 ===");
    }
    
    private static void demonstrateIndividualFactories(DocumentFactoryRegistry registry) {
        String[] types = {"pdf", "word", "html"};
        String content = "这是一个测试文档的内容。";
        
        for (String type : types) {
            try {
                System.out.println("\nCreating " + type.toUpperCase() + " document:");
                Document doc = registry.createDocument(type, content, "test_" + type);
                System.out.println("Document format: " + doc.getFormat());
                System.out.println("Content preview: " + 
                    (doc.getContent().length() > 50 ? 
                     doc.getContent().substring(0, 50) + "..." : 
                     doc.getContent()));
            } catch (Exception e) {
                System.err.println("Error creating " + type + " document: " + e.getMessage());
            }
        }
    }
    
    private static void demonstrateCustomFactory(DocumentFactoryRegistry registry) {
        // 创建自定义的Markdown工厂
        DocumentFactory markdownFactory = new DocumentFactory() {
            @Override
            public Document createDocument(String content) {
                return new Document() {
                    private String mdContent = "# Document\n\n" + content;
                    
                    @Override
                    public void process() {
                        System.out.println("Processing Markdown document...");
                        mdContent = mdContent.replace("\n", "\n\n");
                    }
                    
                    @Override
                    public void save(String filename) {
                        try (PrintWriter writer = new PrintWriter(filename + ".md")) {
                            writer.println(mdContent);
                            System.out.println("Markdown saved to: " + filename + ".md");
                        } catch (IOException e) {
                            System.err.println("Error saving Markdown: " + e.getMessage());
                        }
                    }
                    
                    @Override
                    public String getContent() {
                        return mdContent;
                    }
                    
                    @Override
                    public String getFormat() {
                        return "MARKDOWN";
                    }
                };
            }
            
            @Override
            public String getSupportedFormat() {
                return "MARKDOWN";
            }
            
            @Override
            public String getDescription() {
                return "Factory for creating Markdown documents";
            }
        };
        
        // 注册自定义工厂
        registry.registerFactory("markdown", markdownFactory);
        
        // 使用自定义工厂
        try {
            System.out.println("\nUsing custom Markdown factory:");
            Document mdDoc = registry.createDocument("markdown", "这是一个Markdown文档示例。", "custom_doc");
            System.out.println("Custom document format: " + mdDoc.getFormat());
        } catch (Exception e) {
            System.err.println("Error with custom factory: " + e.getMessage());
        }
        
        // 显示更新后的工厂列表
        System.out.println("\nUpdated factory list:");
        registry.listFactories();
    }
}

14.3.6 工厂方法模式的优缺点

优点: 1. 解耦创建和使用:客户端代码不依赖具体产品类,只依赖抽象接口 2. 符合开闭原则:添加新产品类型时无需修改现有代码 3. 符合单一职责原则:每个工厂只负责创建一种产品 4. 灵活性高:可以通过配置或参数动态选择工厂 5. 易于扩展:支持插件化架构,便于第三方扩展 6. 封装复杂创建逻辑:将复杂的对象创建过程封装在工厂中

缺点: 1. 类数量增加:每个产品都需要对应的工厂类 2. 增加系统复杂度:引入了额外的抽象层 3. 理解成本:新开发者需要理解工厂层次结构 4. 过度设计风险:简单场景下可能造成过度设计

14.3.7 适用场景

  • 插件系统:需要动态加载和创建不同类型的插件
  • 数据库驱动:根据配置创建不同数据库的连接
  • 文件处理:根据文件类型创建相应的处理器
  • UI组件:根据平台创建不同风格的UI组件
  • 序列化框架:根据格式创建不同的序列化器
  • 消息处理:根据消息类型创建相应的处理器
  • 测试框架:根据测试类型创建不同的测试执行器

14.4 模式对比与选择

14.4.1 相同点

  1. 都使用继承:两种模式都依赖继承机制
  2. 都有抽象基类:定义了算法骨架或创建接口
  3. 都延迟到子类:具体实现由子类决定
  4. 都符合开闭原则:对扩展开放,对修改封闭
  5. 都支持多态:通过基类引用调用子类实现

14.4.2 不同点

对比维度 模板方法模式 工厂方法模式
主要目的 定义算法骨架,延迟具体步骤 延迟对象创建到子类
关注点 算法流程控制 对象创建过程
模式类型 行为型模式 创建型模式
抽象方法作用 实现算法的具体步骤 创建具体产品对象
客户端使用 直接调用模板方法 通过工厂创建对象
扩展方式 重写抽象方法和钩子方法 实现工厂方法
控制反转 算法控制权在父类 创建控制权在子类
复用粒度 算法骨架复用 创建逻辑复用

14.4.3 选择指南

选择模板方法模式的场景: - 有固定的算法流程,但某些步骤需要变化 - 需要控制算法的执行顺序 - 希望避免代码重复,提取公共算法逻辑 - 需要在特定点插入自定义行为(钩子方法)

选择工厂方法模式的场景: - 需要创建不同类型的对象 - 对象创建逻辑复杂或经常变化 - 希望将对象创建与使用分离 - 需要支持插件化或可扩展的架构

14.4.4 组合使用

两种模式可以组合使用,形成更强大的设计:

# 组合使用示例:可配置的数据处理流水线
from abc import ABC, abstractmethod
from typing import Any, Dict, List

class DataProcessor(ABC):
    """模板方法模式:定义数据处理流程"""
    
    def process_data(self, data: Any) -> Any:
        """模板方法:定义处理流程"""
        print("开始数据处理流程...")
        
        # 1. 验证数据
        if not self.validate_data(data):
            raise ValueError("数据验证失败")
        
        # 2. 预处理
        preprocessed_data = self.preprocess(data)
        
        # 3. 核心处理(抽象方法)
        processed_data = self.core_process(preprocessed_data)
        
        # 4. 后处理
        result = self.postprocess(processed_data)
        
        # 5. 钩子方法:可选的清理工作
        self.cleanup()
        
        print("数据处理流程完成")
        return result
    
    def validate_data(self, data: Any) -> bool:
        """钩子方法:数据验证"""
        return data is not None
    
    def preprocess(self, data: Any) -> Any:
        """钩子方法:预处理"""
        return data
    
    @abstractmethod
    def core_process(self, data: Any) -> Any:
        """抽象方法:核心处理逻辑"""
        pass
    
    def postprocess(self, data: Any) -> Any:
        """钩子方法:后处理"""
        return data
    
    def cleanup(self):
        """钩子方法:清理工作"""
        pass

class ProcessorFactory(ABC):
    """工厂方法模式:创建数据处理器"""
    
    @abstractmethod
    def create_processor(self, config: Dict[str, Any]) -> DataProcessor:
        """工厂方法:创建处理器"""
        pass
    
    def get_processor_type(self) -> str:
        """获取处理器类型"""
        return self.__class__.__name__.replace('Factory', '')

# 具体的数据处理器实现
class TextProcessor(DataProcessor):
    def __init__(self, encoding: str = 'utf-8'):
        self.encoding = encoding
    
    def validate_data(self, data: Any) -> bool:
        return isinstance(data, str)
    
    def preprocess(self, data: str) -> str:
        return data.strip().lower()
    
    def core_process(self, data: str) -> str:
        # 文本处理逻辑
        words = data.split()
        return ' '.join(sorted(set(words)))
    
    def postprocess(self, data: str) -> str:
        return data.title()

class NumberProcessor(DataProcessor):
    def __init__(self, precision: int = 2):
        self.precision = precision
    
    def validate_data(self, data: Any) -> bool:
        return isinstance(data, (int, float, list))
    
    def preprocess(self, data: Any) -> List[float]:
        if isinstance(data, (int, float)):
            return [float(data)]
        return [float(x) for x in data if isinstance(x, (int, float))]
    
    def core_process(self, data: List[float]) -> Dict[str, float]:
        # 数值处理逻辑
        return {
            'sum': sum(data),
            'avg': sum(data) / len(data) if data else 0,
            'min': min(data) if data else 0,
            'max': max(data) if data else 0
        }
    
    def postprocess(self, data: Dict[str, float]) -> Dict[str, float]:
        return {k: round(v, self.precision) for k, v in data.items()}

# 具体的工厂实现
class TextProcessorFactory(ProcessorFactory):
    def create_processor(self, config: Dict[str, Any]) -> DataProcessor:
        encoding = config.get('encoding', 'utf-8')
        return TextProcessor(encoding)

class NumberProcessorFactory(ProcessorFactory):
    def create_processor(self, config: Dict[str, Any]) -> DataProcessor:
        precision = config.get('precision', 2)
        return NumberProcessor(precision)

# 处理器管理器
class ProcessorManager:
    def __init__(self):
        self.factories = {}
        self._register_default_factories()
    
    def _register_default_factories(self):
        self.register_factory('text', TextProcessorFactory())
        self.register_factory('number', NumberProcessorFactory())
    
    def register_factory(self, processor_type: str, factory: ProcessorFactory):
        self.factories[processor_type] = factory
    
    def create_processor(self, processor_type: str, config: Dict[str, Any] = None) -> DataProcessor:
        if processor_type not in self.factories:
            raise ValueError(f"未知的处理器类型: {processor_type}")
        
        factory = self.factories[processor_type]
        return factory.create_processor(config or {})
    
    def process_data(self, processor_type: str, data: Any, config: Dict[str, Any] = None) -> Any:
        processor = self.create_processor(processor_type, config)
        return processor.process_data(data)

# 演示组合使用
def demonstrate_combined_patterns():
    print("=== 模板方法模式与工厂方法模式组合使用 ===")
    
    manager = ProcessorManager()
    
    # 处理文本数据
    print("\n1. 处理文本数据:")
    text_data = "Hello World Python Programming Hello"
    result = manager.process_data('text', text_data)
    print(f"原始数据: {text_data}")
    print(f"处理结果: {result}")
    
    # 处理数值数据
    print("\n2. 处理数值数据:")
    number_data = [1, 2, 3, 4, 5, 2, 3]
    result = manager.process_data('number', number_data, {'precision': 3})
    print(f"原始数据: {number_data}")
    print(f"处理结果: {result}")
    
    print("\n组合模式演示完成!")

if __name__ == "__main__":
    demonstrate_combined_patterns()

14.5 总结

14.5.1 核心要点

  1. 模板方法模式专注于算法流程的标准化,通过继承实现算法骨架的复用
  2. 工厂方法模式专注于对象创建的灵活性,通过继承实现创建逻辑的扩展
  3. 两种模式都体现了”依赖倒置原则”,高层模块不依赖低层模块的具体实现
  4. 合理使用这两种模式可以提高代码的可维护性、可扩展性和可复用性

14.5.2 最佳实践

  1. 明确职责边界:模板方法控制流程,工厂方法控制创建
  2. 合理设计抽象:抽象类应该稳定,具体实现应该灵活
  3. 文档化设计意图:清楚地说明每个抽象方法和钩子方法的作用
  4. 考虑组合使用:在复杂系统中,两种模式往往需要配合使用
  5. 避免过度设计:简单场景下不要强行使用设计模式

14.5.3 实际应用建议

  • 框架设计:使用模板方法定义框架流程,使用工厂方法支持插件扩展
  • 业务系统:使用模板方法标准化业务流程,使用工厂方法创建业务对象
  • 测试系统:使用模板方法定义测试流程,使用工厂方法创建测试数据
  • 数据处理:使用模板方法定义处理流水线,使用工厂方法创建处理器

通过深入理解和灵活运用这两种设计模式,我们可以构建出更加优雅、可维护和可扩展的软件系统。

14.2.5 Java实现示例:测试框架

import java.util.*;
import java.time.LocalDateTime;
import java.time.Duration;
import java.time.Instant;

// 抽象测试类(模板方法模式)
abstract class TestCase {
    protected String testName;
    protected Map<String, Object> testData;
    protected List<String> logs;
    protected Instant startTime;
    protected Duration executionTime;
    
    public TestCase(String testName) {
        this.testName = testName;
        this.testData = new HashMap<>();
        this.logs = new ArrayList<>();
    }
    
    // 模板方法:定义测试执行的完整流程
    public final TestResult runTest() {
        log("开始执行测试: " + testName);
        startTime = Instant.now();
        
        try {
            // 1. 测试前置条件检查
            if (!checkPreconditions()) {
                return TestResult.failed(testName, "前置条件检查失败", logs);
            }
            
            // 2. 设置测试环境
            setUp();
            log("测试环境设置完成");
            
            // 3. 执行测试(抽象方法)
            executeTest();
            log("测试执行完成");
            
            // 4. 验证结果(抽象方法)
            boolean passed = verifyResult();
            
            // 5. 清理测试环境
            tearDown();
            log("测试环境清理完成");
            
            executionTime = Duration.between(startTime, Instant.now());
            
            if (passed) {
                log("测试通过");
                return TestResult.passed(testName, executionTime, logs);
            } else {
                log("测试失败:验证未通过");
                return TestResult.failed(testName, "验证未通过", logs);
            }
            
        } catch (Exception e) {
            log("测试异常: " + e.getMessage());
            tearDown(); // 确保清理
            executionTime = Duration.between(startTime, Instant.now());
            return TestResult.error(testName, e.getMessage(), logs);
        }
    }
    
    // 具体方法:前置条件检查(可被子类重写)
    protected boolean checkPreconditions() {
        return true;
    }
    
    // 具体方法:设置测试环境(可被子类重写)
    protected void setUp() {
        // 默认实现为空
    }
    
    // 抽象方法:执行测试逻辑(子类必须实现)
    protected abstract void executeTest() throws Exception;
    
    // 抽象方法:验证测试结果(子类必须实现)
    protected abstract boolean verifyResult();
    
    // 钩子方法:清理测试环境(子类可选择性重写)
    protected void tearDown() {
        // 默认实现为空
    }
    
    // 辅助方法
    protected void log(String message) {
        String timestamp = LocalDateTime.now().toString();
        logs.add(String.format("[%s] %s", timestamp, message));
        System.out.println(String.format("[%s] %s: %s", timestamp, testName, message));
    }
    
    protected void setTestData(String key, Object value) {
        testData.put(key, value);
    }
    
    protected Object getTestData(String key) {
        return testData.get(key);
    }
}

// 数据库连接测试
class DatabaseConnectionTest extends TestCase {
    private String connectionString;
    private Object connection;
    
    public DatabaseConnectionTest(String connectionString) {
        super("数据库连接测试");
        this.connectionString = connectionString;
    }
    
    @Override
    protected boolean checkPreconditions() {
        if (connectionString == null || connectionString.trim().isEmpty()) {
            log("连接字符串为空");
            return false;
        }
        
        if (!connectionString.startsWith("jdbc:")) {
            log("无效的JDBC连接字符串");
            return false;
        }
        
        return true;
    }
    
    @Override
    protected void setUp() {
        log("初始化数据库驱动");
        setTestData("driver_loaded", true);
        setTestData("connection_timeout", 30);
    }
    
    @Override
    protected void executeTest() throws Exception {
        log("尝试建立数据库连接: " + connectionString);
        
        // 模拟数据库连接
        Thread.sleep(100); // 模拟连接延迟
        
        if (connectionString.contains("invalid")) {
            throw new Exception("无法连接到数据库");
        }
        
        connection = new Object(); // 模拟连接对象
        setTestData("connection", connection);
        setTestData("connected_at", Instant.now());
        
        log("数据库连接建立成功");
    }
    
    @Override
    protected boolean verifyResult() {
        if (connection == null) {
            log("验证失败:连接对象为空");
            return false;
        }
        
        if (getTestData("connected_at") == null) {
            log("验证失败:连接时间未记录");
            return false;
        }
        
        log("验证通过:连接建立成功且数据完整");
        return true;
    }
    
    @Override
    protected void tearDown() {
        if (connection != null) {
            log("关闭数据库连接");
            connection = null;
        }
        
        log("清理连接资源");
    }
}

// API接口测试
class ApiEndpointTest extends TestCase {
    private String endpoint;
    private String method;
    private Map<String, String> headers;
    private String requestBody;
    private int expectedStatusCode;
    
    public ApiEndpointTest(String endpoint, String method, int expectedStatusCode) {
        super("API接口测试");
        this.endpoint = endpoint;
        this.method = method;
        this.expectedStatusCode = expectedStatusCode;
        this.headers = new HashMap<>();
    }
    
    public void addHeader(String key, String value) {
        headers.put(key, value);
    }
    
    public void setRequestBody(String body) {
        this.requestBody = body;
    }
    
    @Override
    protected boolean checkPreconditions() {
        if (endpoint == null || !endpoint.startsWith("http")) {
            log("无效的API端点URL");
            return false;
        }
        
        if (method == null || method.trim().isEmpty()) {
            log("HTTP方法不能为空");
            return false;
        }
        
        return true;
    }
    
    @Override
    protected void setUp() {
        log("设置HTTP客户端");
        
        // 添加默认头部
        if (!headers.containsKey("Content-Type")) {
            headers.put("Content-Type", "application/json");
        }
        
        if (!headers.containsKey("User-Agent")) {
            headers.put("User-Agent", "TestFramework/1.0");
        }
        
        setTestData("request_headers", new HashMap<>(headers));
        setTestData("request_method", method);
        setTestData("request_body", requestBody);
    }
    
    @Override
    protected void executeTest() throws Exception {
        log(String.format("发送 %s 请求到: %s", method, endpoint));
        
        // 模拟HTTP请求
        Thread.sleep(200); // 模拟网络延迟
        
        // 模拟不同的响应情况
        int statusCode;
        String responseBody;
        
        if (endpoint.contains("error")) {
            statusCode = 500;
            responseBody = "{\"error\": \"Internal Server Error\"}";
        } else if (endpoint.contains("notfound")) {
            statusCode = 404;
            responseBody = "{\"error\": \"Not Found\"}";
        } else {
            statusCode = 200;
            responseBody = "{\"status\": \"success\", \"data\": {\"id\": 1}}";
        }
        
        setTestData("response_status", statusCode);
        setTestData("response_body", responseBody);
        setTestData("response_time", Duration.between(startTime, Instant.now()).toMillis());
        
        log(String.format("收到响应: 状态码=%d, 响应时间=%dms", 
            statusCode, (Long)getTestData("response_time")));
    }
    
    @Override
    protected boolean verifyResult() {
        Integer actualStatus = (Integer) getTestData("response_status");
        
        if (actualStatus == null) {
            log("验证失败:未收到响应状态码");
            return false;
        }
        
        if (!actualStatus.equals(expectedStatusCode)) {
            log(String.format("验证失败:期望状态码 %d,实际状态码 %d", 
                expectedStatusCode, actualStatus));
            return false;
        }
        
        Long responseTime = (Long) getTestData("response_time");
        if (responseTime != null && responseTime > 5000) {
            log(String.format("警告:响应时间过长 %dms", responseTime));
        }
        
        log("验证通过:状态码匹配期望值");
        return true;
    }
    
    @Override
    protected void tearDown() {
        log("清理HTTP客户端资源");
        headers.clear();
    }
}

// 性能测试
class PerformanceTest extends TestCase {
    private Runnable testOperation;
    private int iterations;
    private long maxExecutionTimeMs;
    private List<Long> executionTimes;
    
    public PerformanceTest(String testName, Runnable testOperation, 
                          int iterations, long maxExecutionTimeMs) {
        super(testName);
        this.testOperation = testOperation;
        this.iterations = iterations;
        this.maxExecutionTimeMs = maxExecutionTimeMs;
        this.executionTimes = new ArrayList<>();
    }
    
    @Override
    protected boolean checkPreconditions() {
        if (testOperation == null) {
            log("测试操作不能为空");
            return false;
        }
        
        if (iterations <= 0) {
            log("迭代次数必须大于0");
            return false;
        }
        
        return true;
    }
    
    @Override
    protected void setUp() {
        log(String.format("准备执行性能测试,迭代次数: %d", iterations));
        executionTimes.clear();
        
        // 预热
        log("执行预热操作");
        for (int i = 0; i < Math.min(10, iterations / 10); i++) {
            testOperation.run();
        }
    }
    
    @Override
    protected void executeTest() throws Exception {
        log("开始性能测试");
        
        for (int i = 0; i < iterations; i++) {
            Instant start = Instant.now();
            
            try {
                testOperation.run();
            } catch (Exception e) {
                log(String.format("第 %d 次迭代执行失败: %s", i + 1, e.getMessage()));
                throw e;
            }
            
            long executionTime = Duration.between(start, Instant.now()).toMillis();
            executionTimes.add(executionTime);
            
            if ((i + 1) % (iterations / 10) == 0) {
                log(String.format("完成 %d/%d 次迭代", i + 1, iterations));
            }
        }
        
        // 计算统计信息
        calculateStatistics();
    }
    
    private void calculateStatistics() {
        if (executionTimes.isEmpty()) {
            return;
        }
        
        Collections.sort(executionTimes);
        
        long sum = executionTimes.stream().mapToLong(Long::longValue).sum();
        double average = (double) sum / executionTimes.size();
        long min = executionTimes.get(0);
        long max = executionTimes.get(executionTimes.size() - 1);
        long median = executionTimes.get(executionTimes.size() / 2);
        long p95 = executionTimes.get((int) (executionTimes.size() * 0.95));
        long p99 = executionTimes.get((int) (executionTimes.size() * 0.99));
        
        setTestData("average_time", average);
        setTestData("min_time", min);
        setTestData("max_time", max);
        setTestData("median_time", median);
        setTestData("p95_time", p95);
        setTestData("p99_time", p99);
        setTestData("total_iterations", iterations);
        
        log(String.format("性能统计 - 平均: %.2fms, 最小: %dms, 最大: %dms, 中位数: %dms, P95: %dms, P99: %dms",
            average, min, max, median, p95, p99));
    }
    
    @Override
    protected boolean verifyResult() {
        Double averageTime = (Double) getTestData("average_time");
        
        if (averageTime == null) {
            log("验证失败:未计算平均执行时间");
            return false;
        }
        
        if (averageTime > maxExecutionTimeMs) {
            log(String.format("验证失败:平均执行时间 %.2fms 超过限制 %dms", 
                averageTime, maxExecutionTimeMs));
            return false;
        }
        
        Long p99Time = (Long) getTestData("p99_time");
        if (p99Time != null && p99Time > maxExecutionTimeMs * 2) {
            log(String.format("警告:P99执行时间 %dms 过高", p99Time));
        }
        
        log(String.format("验证通过:平均执行时间 %.2fms 在可接受范围内", averageTime));
        return true;
    }
}

// 测试结果类
class TestResult {
    public enum Status {
        PASSED, FAILED, ERROR
    }
    
    private final String testName;
    private final Status status;
    private final String message;
    private final Duration executionTime;
    private final List<String> logs;
    
    private TestResult(String testName, Status status, String message, 
                      Duration executionTime, List<String> logs) {
        this.testName = testName;
        this.status = status;
        this.message = message;
        this.executionTime = executionTime;
        this.logs = new ArrayList<>(logs);
    }
    
    public static TestResult passed(String testName, Duration executionTime, List<String> logs) {
        return new TestResult(testName, Status.PASSED, "测试通过", executionTime, logs);
    }
    
    public static TestResult failed(String testName, String reason, List<String> logs) {
        return new TestResult(testName, Status.FAILED, reason, Duration.ZERO, logs);
    }
    
    public static TestResult error(String testName, String error, List<String> logs) {
        return new TestResult(testName, Status.ERROR, error, Duration.ZERO, logs);
    }
    
    // Getters
    public String getTestName() { return testName; }
    public Status getStatus() { return status; }
    public String getMessage() { return message; }
    public Duration getExecutionTime() { return executionTime; }
    public List<String> getLogs() { return logs; }
    
    @Override
    public String toString() {
        return String.format("TestResult{name='%s', status=%s, message='%s', time=%dms}",
            testName, status, message, executionTime.toMillis());
    }
}

// 测试套件
class TestSuite {
    private List<TestCase> testCases;
    private List<TestResult> results;
    
    public TestSuite() {
        this.testCases = new ArrayList<>();
        this.results = new ArrayList<>();
    }
    
    public void addTest(TestCase testCase) {
        testCases.add(testCase);
    }
    
    public void runAllTests() {
        System.out.println("=== 开始执行测试套件 ===");
        System.out.println(String.format("总共 %d 个测试用例\n", testCases.size()));
        
        results.clear();
        
        for (int i = 0; i < testCases.size(); i++) {
            TestCase testCase = testCases.get(i);
            System.out.println(String.format("[%d/%d] 执行测试: %s", 
                i + 1, testCases.size(), testCase.testName));
            
            TestResult result = testCase.runTest();
            results.add(result);
            
            System.out.println(String.format("结果: %s\n", result.getStatus()));
        }
        
        printSummary();
    }
    
    private void printSummary() {
        System.out.println("=== 测试执行总结 ===");
        
        long passed = results.stream().filter(r -> r.getStatus() == TestResult.Status.PASSED).count();
        long failed = results.stream().filter(r -> r.getStatus() == TestResult.Status.FAILED).count();
        long errors = results.stream().filter(r -> r.getStatus() == TestResult.Status.ERROR).count();
        
        System.out.println(String.format("总计: %d, 通过: %d, 失败: %d, 错误: %d", 
            results.size(), passed, failed, errors));
        
        if (failed > 0 || errors > 0) {
            System.out.println("\n失败的测试:");
            results.stream()
                .filter(r -> r.getStatus() != TestResult.Status.PASSED)
                .forEach(r -> System.out.println(String.format("  - %s: %s", 
                    r.getTestName(), r.getMessage())));
        }
        
        double successRate = (double) passed / results.size() * 100;
        System.out.println(String.format("\n成功率: %.1f%%", successRate));
    }
}

// 演示类
public class TemplateMethodDemo {
    public static void main(String[] args) {
        System.out.println("=== 模板方法模式 - 测试框架演示 ===");
        
        TestSuite suite = new TestSuite();
        
        // 添加数据库连接测试
        suite.addTest(new DatabaseConnectionTest("jdbc:mysql://localhost:3306/testdb"));
        suite.addTest(new DatabaseConnectionTest("jdbc:invalid://localhost:3306/testdb"));
        
        // 添加API测试
        ApiEndpointTest apiTest1 = new ApiEndpointTest("https://api.example.com/users", "GET", 200);
        apiTest1.addHeader("Authorization", "Bearer token123");
        suite.addTest(apiTest1);
        
        ApiEndpointTest apiTest2 = new ApiEndpointTest("https://api.example.com/error", "POST", 500);
        apiTest2.setRequestBody("{\"test\": \"data\"}");
        suite.addTest(apiTest2);
        
        // 添加性能测试
        Runnable simpleOperation = () -> {
            try {
                Thread.sleep(10); // 模拟10ms的操作
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        };
        
        suite.addTest(new PerformanceTest("简单操作性能测试", simpleOperation, 100, 50));
        
        // 运行所有测试
        suite.runAllTests();
    }
}