本章概述

模型训练是深度学习的核心环节,涉及损失函数选择、优化器配置、训练循环设计等关键技术。本章将深入介绍PyTorch中的模型训练和优化技术,帮助你构建高效、稳定的训练流程。

学习目标

通过本章学习,你将掌握: - 损失函数的原理和选择策略 - 优化器的工作机制和配置方法 - 完整训练循环的设计和实现 - 学习率调度和正则化技术 - 模型评估和验证策略 - 训练过程的监控和调试

5.1 损失函数详解

5.1.1 损失函数基础

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
import time
from collections import defaultdict

# 损失函数基础概念演示
def demonstrate_loss_functions():
    """演示不同损失函数的特性"""
    
    # 生成示例数据
    batch_size = 32
    num_classes = 5
    
    # 模拟预测和真实标签
    predictions = torch.randn(batch_size, num_classes)
    true_labels = torch.randint(0, num_classes, (batch_size,))
    
    # 回归任务的连续目标
    regression_pred = torch.randn(batch_size, 1)
    regression_target = torch.randn(batch_size, 1)
    
    print("=== 损失函数对比 ===")
    
    # 分类损失函数
    ce_loss = nn.CrossEntropyLoss()
    nll_loss = nn.NLLLoss()
    
    # 回归损失函数
    mse_loss = nn.MSELoss()
    mae_loss = nn.L1Loss()
    smooth_l1_loss = nn.SmoothL1Loss()
    
    # 计算分类损失
    ce_value = ce_loss(predictions, true_labels)
    nll_value = nll_loss(F.log_softmax(predictions, dim=1), true_labels)
    
    print(f"交叉熵损失: {ce_value.item():.4f}")
    print(f"负对数似然损失: {nll_value.item():.4f}")
    
    # 计算回归损失
    mse_value = mse_loss(regression_pred, regression_target)
    mae_value = mae_loss(regression_pred, regression_target)
    smooth_l1_value = smooth_l1_loss(regression_pred, regression_target)
    
    print(f"均方误差损失: {mse_value.item():.4f}")
    print(f"平均绝对误差损失: {mae_value.item():.4f}")
    print(f"平滑L1损失: {smooth_l1_value.item():.4f}")

demonstrate_loss_functions()

5.1.2 分类任务损失函数

# 分类损失函数详解
class ClassificationLosses:
    def __init__(self):
        self.losses = {
            'CrossEntropy': nn.CrossEntropyLoss(),
            'NLLLoss': nn.NLLLoss(),
            'BCELoss': nn.BCELoss(),
            'BCEWithLogitsLoss': nn.BCEWithLogitsLoss(),
            'MultiMarginLoss': nn.MultiMarginLoss(),
            'HingeEmbeddingLoss': nn.HingeEmbeddingLoss()
        }
    
    def compare_classification_losses(self):
        """比较不同分类损失函数"""
        batch_size = 100
        num_classes = 10
        
        # 生成测试数据
        logits = torch.randn(batch_size, num_classes)
        labels = torch.randint(0, num_classes, (batch_size,))
        
        # 二分类数据
        binary_logits = torch.randn(batch_size, 1)
        binary_labels = torch.randint(0, 2, (batch_size,)).float()
        
        print("=== 分类损失函数对比 ===")
        
        # 多分类损失
        ce_loss = self.losses['CrossEntropy'](logits, labels)
        nll_loss = self.losses['NLLLoss'](F.log_softmax(logits, dim=1), labels)
        
        print(f"CrossEntropy Loss: {ce_loss.item():.4f}")
        print(f"NLL Loss: {nll_loss.item():.4f}")
        
        # 二分类损失
        bce_loss = self.losses['BCELoss'](torch.sigmoid(binary_logits.squeeze()), binary_labels)
        bce_logits_loss = self.losses['BCEWithLogitsLoss'](binary_logits.squeeze(), binary_labels)
        
        print(f"BCE Loss: {bce_loss.item():.4f}")
        print(f"BCE with Logits Loss: {bce_logits_loss.item():.4f}")
        
        return {
            'ce_loss': ce_loss.item(),
            'nll_loss': nll_loss.item(),
            'bce_loss': bce_loss.item(),
            'bce_logits_loss': bce_logits_loss.item()
        }

# 自定义损失函数
class FocalLoss(nn.Module):
    """Focal Loss for addressing class imbalance"""
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class LabelSmoothingLoss(nn.Module):
    """Label Smoothing Loss"""
    def __init__(self, num_classes, smoothing=0.1):
        super(LabelSmoothingLoss, self).__init__()
        self.num_classes = num_classes
        self.smoothing = smoothing
        self.confidence = 1.0 - smoothing
    
    def forward(self, pred, target):
        pred = F.log_softmax(pred, dim=1)
        true_dist = torch.zeros_like(pred)
        true_dist.fill_(self.smoothing / (self.num_classes - 1))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=1))

# 测试自定义损失函数
def test_custom_losses():
    """测试自定义损失函数"""
    batch_size = 64
    num_classes = 10
    
    # 生成测试数据
    logits = torch.randn(batch_size, num_classes)
    labels = torch.randint(0, num_classes, (batch_size,))
    
    # 标准交叉熵
    ce_loss = nn.CrossEntropyLoss()
    ce_value = ce_loss(logits, labels)
    
    # Focal Loss
    focal_loss = FocalLoss(alpha=1, gamma=2)
    focal_value = focal_loss(logits, labels)
    
    # Label Smoothing
    ls_loss = LabelSmoothingLoss(num_classes, smoothing=0.1)
    ls_value = ls_loss(logits, labels)
    
    print("=== 自定义损失函数测试 ===")
    print(f"Cross Entropy Loss: {ce_value.item():.4f}")
    print(f"Focal Loss: {focal_value.item():.4f}")
    print(f"Label Smoothing Loss: {ls_value.item():.4f}")

# 运行测试
classification_losses = ClassificationLosses()
loss_comparison = classification_losses.compare_classification_losses()
test_custom_losses()

5.1.3 回归任务损失函数

# 回归损失函数详解
class RegressionLosses:
    def __init__(self):
        self.losses = {
            'MSE': nn.MSELoss(),
            'MAE': nn.L1Loss(),
            'SmoothL1': nn.SmoothL1Loss(),
            'Huber': nn.HuberLoss(delta=1.0)
        }
    
    def compare_regression_losses(self):
        """比较不同回归损失函数的特性"""
        # 生成测试数据
        predictions = torch.randn(100, 1)
        targets = torch.randn(100, 1)
        
        # 添加一些异常值
        outlier_indices = torch.randint(0, 100, (10,))
        targets[outlier_indices] += torch.randn(10, 1) * 5  # 异常值
        
        print("=== 回归损失函数对比 ===")
        
        results = {}
        for name, loss_fn in self.losses.items():
            loss_value = loss_fn(predictions, targets)
            results[name] = loss_value.item()
            print(f"{name} Loss: {loss_value.item():.4f}")
        
        return results
    
    def visualize_loss_functions(self):
        """可视化不同损失函数的特性"""
        # 生成误差范围
        errors = torch.linspace(-3, 3, 100)
        
        # 计算不同损失函数的值
        mse_values = errors ** 2
        mae_values = torch.abs(errors)
        smooth_l1_values = torch.where(torch.abs(errors) < 1, 
                                     0.5 * errors ** 2, 
                                     torch.abs(errors) - 0.5)
        huber_values = torch.where(torch.abs(errors) < 1,
                                 0.5 * errors ** 2,
                                 torch.abs(errors) - 0.5)
        
        # 绘制损失函数曲线
        plt.figure(figsize=(12, 8))
        
        plt.subplot(2, 2, 1)
        plt.plot(errors.numpy(), mse_values.numpy(), label='MSE', linewidth=2)
        plt.title('Mean Squared Error (MSE)')
        plt.xlabel('Error')
        plt.ylabel('Loss')
        plt.grid(True)
        plt.legend()
        
        plt.subplot(2, 2, 2)
        plt.plot(errors.numpy(), mae_values.numpy(), label='MAE', linewidth=2, color='orange')
        plt.title('Mean Absolute Error (MAE)')
        plt.xlabel('Error')
        plt.ylabel('Loss')
        plt.grid(True)
        plt.legend()
        
        plt.subplot(2, 2, 3)
        plt.plot(errors.numpy(), smooth_l1_values.numpy(), label='Smooth L1', linewidth=2, color='green')
        plt.title('Smooth L1 Loss')
        plt.xlabel('Error')
        plt.ylabel('Loss')
        plt.grid(True)
        plt.legend()
        
        plt.subplot(2, 2, 4)
        plt.plot(errors.numpy(), mse_values.numpy(), label='MSE', alpha=0.7)
        plt.plot(errors.numpy(), mae_values.numpy(), label='MAE', alpha=0.7)
        plt.plot(errors.numpy(), smooth_l1_values.numpy(), label='Smooth L1', alpha=0.7)
        plt.title('Loss Functions Comparison')
        plt.xlabel('Error')
        plt.ylabel('Loss')
        plt.grid(True)
        plt.legend()
        
        plt.tight_layout()
        plt.show()

# 自定义回归损失函数
class QuantileLoss(nn.Module):
    """Quantile Loss for quantile regression"""
    def __init__(self, quantile=0.5):
        super(QuantileLoss, self).__init__()
        self.quantile = quantile
    
    def forward(self, pred, target):
        errors = target - pred
        loss = torch.max((self.quantile - 1) * errors, self.quantile * errors)
        return torch.mean(loss)

class LogCoshLoss(nn.Module):
    """Log-Cosh Loss"""
    def __init__(self):
        super(LogCoshLoss, self).__init__()
    
    def forward(self, pred, target):
        errors = pred - target
        return torch.mean(torch.log(torch.cosh(errors)))

# 测试回归损失函数
regression_losses = RegressionLosses()
regression_results = regression_losses.compare_regression_losses()
regression_losses.visualize_loss_functions()

# 测试自定义回归损失
def test_custom_regression_losses():
    """测试自定义回归损失函数"""
    predictions = torch.randn(100, 1)
    targets = torch.randn(100, 1)
    
    # Quantile Loss
    quantile_loss = QuantileLoss(quantile=0.9)
    quantile_value = quantile_loss(predictions, targets)
    
    # Log-Cosh Loss
    logcosh_loss = LogCoshLoss()
    logcosh_value = logcosh_loss(predictions, targets)
    
    print("=== 自定义回归损失函数测试 ===")
    print(f"Quantile Loss (0.9): {quantile_value.item():.4f}")
    print(f"Log-Cosh Loss: {logcosh_value.item():.4f}")

test_custom_regression_losses()

5.2 优化器详解

5.2.1 优化器基础

# 优化器基础概念
class OptimizerComparison:
    def __init__(self):
        self.optimizers = {}
    
    def create_optimizers(self, model_params):
        """创建不同类型的优化器"""
        self.optimizers = {
            'SGD': optim.SGD(model_params, lr=0.01, momentum=0.9),
            'Adam': optim.Adam(model_params, lr=0.001, betas=(0.9, 0.999)),
            'AdamW': optim.AdamW(model_params, lr=0.001, weight_decay=0.01),
            'RMSprop': optim.RMSprop(model_params, lr=0.001, alpha=0.99),
            'Adagrad': optim.Adagrad(model_params, lr=0.01),
            'Adadelta': optim.Adadelta(model_params, lr=1.0, rho=0.9),
            'LBFGS': optim.LBFGS(model_params, lr=1, max_iter=20)
        }
        return self.optimizers
    
    def demonstrate_optimizer_behavior(self):
        """演示不同优化器的行为"""
        # 创建简单的二次函数进行优化
        def quadratic_function(x, y):
            return (x - 2) ** 2 + (y - 1) ** 2
        
        # 初始化参数
        x = torch.tensor([0.0], requires_grad=True)
        y = torch.tensor([0.0], requires_grad=True)
        
        # 创建优化器
        optimizer = optim.Adam([x, y], lr=0.1)
        
        # 优化过程
        losses = []
        positions = []
        
        for i in range(100):
            optimizer.zero_grad()
            loss = quadratic_function(x, y)
            loss.backward()
            optimizer.step()
            
            losses.append(loss.item())
            positions.append((x.item(), y.item()))
            
            if i % 20 == 0:
                print(f"Step {i}: Loss = {loss.item():.6f}, x = {x.item():.4f}, y = {y.item():.4f}")
        
        print(f"Final position: x = {x.item():.4f}, y = {y.item():.4f}")
        print(f"Target position: x = 2.0000, y = 1.0000")
        
        return losses, positions

# 测试优化器行为
optimizer_demo = OptimizerComparison()
losses, positions = optimizer_demo.demonstrate_optimizer_behavior()

5.2.2 高级优化器配置

# 高级优化器配置
class AdvancedOptimizerConfig:
    def __init__(self):
        pass
    
    def configure_optimizer_groups(self, model):
        """配置不同参数组的优化器"""
        # 分离不同类型的参数
        conv_params = []
        bn_params = []
        linear_params = []
        
        for name, param in model.named_parameters():
            if 'conv' in name:
                conv_params.append(param)
            elif 'bn' in name or 'norm' in name:
                bn_params.append(param)
            elif 'linear' in name or 'fc' in name:
                linear_params.append(param)
        
        # 为不同参数组设置不同的学习率和权重衰减
        optimizer = optim.Adam([
            {'params': conv_params, 'lr': 0.001, 'weight_decay': 1e-4},
            {'params': bn_params, 'lr': 0.002, 'weight_decay': 0},
            {'params': linear_params, 'lr': 0.01, 'weight_decay': 1e-3}
        ])
        
        return optimizer
    
    def gradient_clipping_example(self, model, optimizer):
        """梯度裁剪示例"""
        # 计算损失和梯度
        # loss.backward()  # 假设已经计算了梯度
        
        # 梯度裁剪
        max_grad_norm = 1.0
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        
        # 或者使用梯度值裁剪
        torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
        
        optimizer.step()
    
    def accumulate_gradients(self, model, optimizer, accumulation_steps=4):
        """梯度累积示例"""
        optimizer.zero_grad()
        
        for step in range(accumulation_steps):
            # 假设这里有数据加载和前向传播
            # loss = model(data)
            # loss = loss / accumulation_steps  # 平均化损失
            # loss.backward()
            pass
        
        # 在累积指定步数后更新参数
        optimizer.step()

# 创建示例模型用于测试
class SimpleModel(nn.Module):
    def __init__(self, input_size=784, hidden_size=128, num_classes=10):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(9216, hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_classes)
        self.dropout = nn.Dropout(0.25)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        x = self.dropout(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# 测试高级优化器配置
model = SimpleModel()
advanced_config = AdvancedOptimizerConfig()
optimizer = advanced_config.configure_optimizer_groups(model)

print("=== 优化器参数组配置 ===")
for i, group in enumerate(optimizer.param_groups):
    print(f"参数组 {i}: lr={group['lr']}, weight_decay={group['weight_decay']}")

5.2.3 自定义优化器

# 自定义优化器实现
class CustomSGD(optim.Optimizer):
    """自定义SGD优化器实现"""
    def __init__(self, params, lr=1e-3, momentum=0, dampening=0, weight_decay=0):
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if momentum < 0.0:
            raise ValueError(f"Invalid momentum value: {momentum}")
        if weight_decay < 0.0:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
        
        defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay)
        super(CustomSGD, self).__init__(params, defaults)
    
    def step(self, closure=None):
        """执行单步优化"""
        loss = None
        if closure is not None:
            loss = closure()
        
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                d_p = p.grad.data
                
                # 添加权重衰减
                if weight_decay != 0:
                    d_p = d_p.add(p.data, alpha=weight_decay)
                
                # 添加动量
                if momentum != 0:
                    param_state = self.state[p]
                    if len(param_state) == 0:
                        param_state['momentum_buffer'] = torch.zeros_like(p.data)
                    
                    buf = param_state['momentum_buffer']
                    buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                    d_p = buf
                
                # 更新参数
                p.data.add_(d_p, alpha=-group['lr'])
        
        return loss

class WarmupScheduler:
    """学习率预热调度器"""
    def __init__(self, optimizer, warmup_steps, base_lr):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.base_lr = base_lr
        self.current_step = 0
    
    def step(self):
        """更新学习率"""
        self.current_step += 1
        
        if self.current_step <= self.warmup_steps:
            # 线性预热
            lr = self.base_lr * (self.current_step / self.warmup_steps)
        else:
            # 使用基础学习率
            lr = self.base_lr
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
    
    def get_lr(self):
        """获取当前学习率"""
        return self.optimizer.param_groups[0]['lr']

# 测试自定义优化器
def test_custom_optimizer():
    """测试自定义优化器"""
    # 创建简单模型
    model = nn.Linear(10, 1)
    
    # 创建自定义优化器
    custom_optimizer = CustomSGD(model.parameters(), lr=0.01, momentum=0.9)
    
    # 创建预热调度器
    warmup_scheduler = WarmupScheduler(custom_optimizer, warmup_steps=100, base_lr=0.01)
    
    # 模拟训练过程
    print("=== 自定义优化器测试 ===")
    for step in range(10):
        # 模拟前向传播和反向传播
        x = torch.randn(32, 10)
        y = torch.randn(32, 1)
        
        pred = model(x)
        loss = F.mse_loss(pred, y)
        
        custom_optimizer.zero_grad()
        loss.backward()
        custom_optimizer.step()
        warmup_scheduler.step()
        
        if step % 2 == 0:
            print(f"Step {step}: Loss = {loss.item():.4f}, LR = {warmup_scheduler.get_lr():.6f}")

test_custom_optimizer()

5.3 完整训练循环

5.3.1 基础训练循环

# 完整的训练循环实现
class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, device='cpu'):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        
        # 训练历史记录
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
    
    def train_epoch(self):
        """训练一个epoch"""
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device)
            
            # 前向传播
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            
            # 反向传播
            loss.backward()
            self.optimizer.step()
            
            # 统计
            running_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
            # 打印进度
            if batch_idx % 100 == 0:
                print(f'Train Batch: {batch_idx}/{len(self.train_loader)} '
                      f'Loss: {loss.item():.6f}')
        
        epoch_loss = running_loss / len(self.train_loader)
        epoch_acc = 100. * correct / total
        
        return epoch_loss, epoch_acc
    
    def validate_epoch(self):
        """验证一个epoch"""
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in self.val_loader:
                data, target = data.to(self.device), target.to(self.device)
                
                output = self.model(data)
                loss = self.criterion(output, target)
                
                running_loss += loss.item()
                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()
        
        epoch_loss = running_loss / len(self.val_loader)
        epoch_acc = 100. * correct / total
        
        return epoch_loss, epoch_acc
    
    def train(self, num_epochs):
        """完整训练过程"""
        print("开始训练...")
        
        for epoch in range(num_epochs):
            start_time = time.time()
            
            # 训练
            train_loss, train_acc = self.train_epoch()
            
            # 验证
            val_loss, val_acc = self.validate_epoch()
            
            # 记录历史
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.train_accuracies.append(train_acc)
            self.val_accuracies.append(val_acc)
            
            epoch_time = time.time() - start_time
            
            print(f'Epoch [{epoch+1}/{num_epochs}] '
                  f'Train Loss: {train_loss:.4f} Train Acc: {train_acc:.2f}% '
                  f'Val Loss: {val_loss:.4f} Val Acc: {val_acc:.2f}% '
                  f'Time: {epoch_time:.2f}s')
        
        print("训练完成!")
    
    def plot_training_history(self):
        """绘制训练历史"""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        # 损失曲线
        ax1.plot(self.train_losses, label='Train Loss')
        ax1.plot(self.val_losses, label='Val Loss')
        ax1.set_title('Training and Validation Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True)
        
        # 准确率曲线
        ax2.plot(self.train_accuracies, label='Train Acc')
        ax2.plot(self.val_accuracies, label='Val Acc')
        ax2.set_title('Training and Validation Accuracy')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy (%)')
        ax2.legend()
        ax2.grid(True)
        
        plt.tight_layout()
        plt.show()

# 创建示例数据和模型进行训练测试
def create_sample_data():
    """创建示例数据"""
    # 生成模拟的MNIST风格数据
    train_data = torch.randn(1000, 1, 28, 28)
    train_labels = torch.randint(0, 10, (1000,))
    
    val_data = torch.randn(200, 1, 28, 28)
    val_labels = torch.randint(0, 10, (200,))
    
    train_dataset = TensorDataset(train_data, train_labels)
    val_dataset = TensorDataset(val_data, val_labels)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    
    return train_loader, val_loader

# 测试训练循环
print("=== 训练循环测试 ===")
train_loader, val_loader = create_sample_data()

model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

trainer = Trainer(model, train_loader, val_loader, criterion, optimizer)
trainer.train(num_epochs=5)
trainer.plot_training_history()

5.3.2 高级训练功能

# 高级训练功能
class AdvancedTrainer(Trainer):
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, 
                 device='cpu', scheduler=None, early_stopping_patience=None):
        super().__init__(model, train_loader, val_loader, criterion, optimizer, device)
        self.scheduler = scheduler
        self.early_stopping_patience = early_stopping_patience
        self.best_val_loss = float('inf')
        self.patience_counter = 0
        self.best_model_state = None
    
    def save_checkpoint(self, epoch, filepath):
        """保存检查点"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'train_accuracies': self.train_accuracies,
            'val_accuracies': self.val_accuracies
        }
        
        if self.scheduler:
            checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
        
        torch.save(checkpoint, filepath)
        print(f"检查点已保存到 {filepath}")
    
    def load_checkpoint(self, filepath):
        """加载检查点"""
        checkpoint = torch.load(filepath, map_location=self.device)
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        if self.scheduler and 'scheduler_state_dict' in checkpoint:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        self.train_losses = checkpoint['train_losses']
        self.val_losses = checkpoint['val_losses']
        self.train_accuracies = checkpoint['train_accuracies']
        self.val_accuracies = checkpoint['val_accuracies']
        
        return checkpoint['epoch']
    
    def early_stopping_check(self, val_loss):
        """早停检查"""
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            self.patience_counter = 0
            self.best_model_state = self.model.state_dict().copy()
            return False
        else:
            self.patience_counter += 1
            if self.patience_counter >= self.early_stopping_patience:
                print(f"早停触发!在第 {self.patience_counter} 个epoch后验证损失未改善")
                # 恢复最佳模型
                self.model.load_state_dict(self.best_model_state)
                return True
        return False
    
    def train_with_mixed_precision(self, num_epochs):
        """混合精度训练"""
        scaler = torch.cuda.amp.GradScaler()
        
        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0
            
            for batch_idx, (data, target) in enumerate(self.train_loader):
                data, target = data.to(self.device), target.to(self.device)
                
                self.optimizer.zero_grad()
                
                # 使用自动混合精度
                with torch.cuda.amp.autocast():
                    output = self.model(data)
                    loss = self.criterion(output, target)
                
                # 缩放损失并反向传播
                scaler.scale(loss).backward()
                scaler.step(self.optimizer)
                scaler.update()
                
                running_loss += loss.item()
            
            # 验证和其他逻辑...
            print(f"Epoch {epoch+1}: Loss = {running_loss/len(self.train_loader):.4f}")
    
    def train_advanced(self, num_epochs, save_every=5, checkpoint_dir="./checkpoints"):
        """高级训练循环"""
        import os
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        print("开始高级训练...")
        
        for epoch in range(num_epochs):
            start_time = time.time()
            
            # 训练
            train_loss, train_acc = self.train_epoch()
            
            # 验证
            val_loss, val_acc = self.validate_epoch()
            
            # 学习率调度
            if self.scheduler:
                if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                    self.scheduler.step(val_loss)
                else:
                    self.scheduler.step()
            
            # 记录历史
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.train_accuracies.append(train_acc)
            self.val_accuracies.append(val_acc)
            
            epoch_time = time.time() - start_time
            current_lr = self.optimizer.param_groups[0]['lr']
            
            print(f'Epoch [{epoch+1}/{num_epochs}] '
                  f'Train Loss: {train_loss:.4f} Train Acc: {train_acc:.2f}% '
                  f'Val Loss: {val_loss:.4f} Val Acc: {val_acc:.2f}% '
                  f'LR: {current_lr:.6f} Time: {epoch_time:.2f}s')
            
            # 早停检查
            if self.early_stopping_patience:
                if self.early_stopping_check(val_loss):
                    break
            
            # 保存检查点
            if (epoch + 1) % save_every == 0:
                checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
                self.save_checkpoint(epoch + 1, checkpoint_path)
        
        print("高级训练完成!")

# 测试高级训练功能
def test_advanced_training():
    """测试高级训练功能"""
    train_loader, val_loader = create_sample_data()
    
    model = SimpleModel()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # 创建学习率调度器
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
    
    # 创建高级训练器
    advanced_trainer = AdvancedTrainer(
        model, train_loader, val_loader, criterion, optimizer,
        scheduler=scheduler, early_stopping_patience=5
    )
    
    # 开始训练
    advanced_trainer.train_advanced(num_epochs=10, save_every=3)
    
    return advanced_trainer

print("=== 高级训练功能测试 ===")
advanced_trainer = test_advanced_training()

5.4 学习率调度

5.4.1 内置学习率调度器

# 学习率调度器详解
class LearningRateSchedulers:
    def __init__(self):
        self.schedulers = {}
    
    def create_schedulers(self, optimizer):
        """创建不同类型的学习率调度器"""
        self.schedulers = {
            'StepLR': optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1),
            'MultiStepLR': optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1),
            'ExponentialLR': optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95),
            'CosineAnnealingLR': optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50),
            'ReduceLROnPlateau': optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=10),
            'CyclicLR': optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.001, max_lr=0.01),
            'OneCycleLR': optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=100, epochs=10)
        }
        return self.schedulers
    
    def visualize_schedulers(self, num_epochs=100):
        """可视化不同调度器的学习率变化"""
        # 创建模型和优化器
        model = nn.Linear(10, 1)
        optimizer = optim.SGD(model.parameters(), lr=0.01)
        
        # 创建调度器
        schedulers = self.create_schedulers(optimizer)
        
        # 记录学习率变化
        lr_histories = {name: [] for name in schedulers.keys()}
        
        for name, scheduler in schedulers.items():
            # 重置优化器
            optimizer = optim.SGD(model.parameters(), lr=0.01)
            if name == 'OneCycleLR':
                scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=1, epochs=num_epochs)
            elif name == 'CyclicLR':
                scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.001, max_lr=0.01, step_size_up=20)
            else:
                scheduler = schedulers[name] = self.create_schedulers(optimizer)[name]
            
            for epoch in range(num_epochs):
                lr_histories[name].append(optimizer.param_groups[0]['lr'])
                
                if name == 'ReduceLROnPlateau':
                    # 模拟验证损失
                    val_loss = 1.0 - epoch * 0.01 + 0.1 * np.sin(epoch * 0.1)
                    scheduler.step(val_loss)
                else:
                    scheduler.step()
        
        # 绘制学习率变化
        plt.figure(figsize=(15, 10))
        
        for i, (name, lr_history) in enumerate(lr_histories.items()):
            plt.subplot(3, 3, i + 1)
            plt.plot(lr_history)
            plt.title(f'{name}')
            plt.xlabel('Epoch')
            plt.ylabel('Learning Rate')
            plt.grid(True)
        
        plt.tight_layout()
        plt.show()

# 自定义学习率调度器
class WarmupCosineScheduler:
    """预热+余弦退火调度器"""
    def __init__(self, optimizer, warmup_epochs, total_epochs, base_lr, max_lr):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.base_lr = base_lr
        self.max_lr = max_lr
        self.current_epoch = 0
    
    def step(self):
        """更新学习率"""
        if self.current_epoch < self.warmup_epochs:
            # 预热阶段:线性增长
            lr = self.base_lr + (self.max_lr - self.base_lr) * (self.current_epoch / self.warmup_epochs)
        else:
            # 余弦退火阶段
            progress = (self.current_epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
            lr = self.base_lr + (self.max_lr - self.base_lr) * 0.5 * (1 + np.cos(np.pi * progress))
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        
        self.current_epoch += 1
    
    def get_lr(self):
        """获取当前学习率"""
        return self.optimizer.param_groups[0]['lr']

class LinearWarmupScheduler:
    """线性预热调度器"""
    def __init__(self, optimizer, warmup_steps, base_lr):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.base_lr = base_lr
        self.current_step = 0
    
    def step(self):
        """更新学习率"""
        self.current_step += 1
        
        if self.current_step <= self.warmup_steps:
            lr = self.base_lr * (self.current_step / self.warmup_steps)
        else:
            lr = self.base_lr
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

# 测试学习率调度器
def test_lr_schedulers():
    """测试学习率调度器"""
    print("=== 学习率调度器测试 ===")
    
    lr_schedulers = LearningRateSchedulers()
    lr_schedulers.visualize_schedulers(num_epochs=100)
    
    # 测试自定义调度器
    model = nn.Linear(10, 1)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    custom_scheduler = WarmupCosineScheduler(
        optimizer, warmup_epochs=10, total_epochs=100, 
        base_lr=0.0001, max_lr=0.01
    )
    
    lr_history = []
    for epoch in range(100):
        lr_history.append(custom_scheduler.get_lr())
        custom_scheduler.step()
    
    plt.figure(figsize=(10, 6))
    plt.plot(lr_history)
    plt.title('Custom Warmup + Cosine Annealing Scheduler')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.grid(True)
    plt.show()

test_lr_schedulers()

本章总结

在本章中,我们深入学习了PyTorch中的模型训练与优化技术:

核心概念

  • 损失函数: 分类和回归任务的各种损失函数及其特性
  • 优化器: SGD、Adam、AdamW等优化算法的原理和配置
  • 训练循环: 完整的训练、验证和测试流程设计
  • 学习率调度: 各种学习率调度策略和自定义实现

重要技术

  • 自定义损失函数: Focal Loss、Label Smoothing等高级损失函数
  • 梯度处理: 梯度裁剪、梯度累积等技术
  • 混合精度训练: 提高训练效率的现代技术
  • 早停和检查点: 防止过拟合和保存训练状态

实践技能

  • 完整训练管道: 从数据加载到模型保存的完整流程
  • 性能监控: 训练过程的可视化和分析
  • 超参数调优: 学习率、批次大小等关键参数的优化
  • 调试技巧: 训练过程中常见问题的诊断和解决

通过本章学习,你现在能够设计和实现高效、稳定的深度学习训练流程,为构建实际应用奠定坚实基础。

下一章预告

在下一章《卷积神经网络》中,我们将学习: - CNN的基本原理和组件 - 经典CNN架构(LeNet、AlexNet、VGG、ResNet等) - 图像分类和目标检测应用 - CNN的可视化和解释技术


第5章完成! 🎉

你已经掌握了深度学习训练的核心技术,这些技能将帮助你训练出高性能的深度学习模型!

下一步: 进入第6章《卷积神经网络》,学习计算机视觉领域的强大工具!