本章概述
数据是深度学习的基础,高效的数据处理和加载对模型训练至关重要。本章将详细介绍PyTorch中的数据处理机制,包括Dataset和DataLoader的使用、数据预处理技术、批量处理策略等。
学习目标
通过本章学习,你将掌握: - Dataset和DataLoader的核心概念和使用方法 - 数据预处理和增强技术 - 自定义数据集的实现 - 批量处理和并行加载策略 - 数据管道的优化技巧 - 常见数据格式的处理方法
4.1 数据处理基础
4.1.1 PyTorch数据处理架构
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import os
from pathlib import Path
# PyTorch数据处理流程示意
print("PyTorch数据处理流程:")
print("原始数据 -> Dataset -> DataLoader -> 模型训练")
print(" ↓")
print(" 数据预处理/增强")
4.1.2 数据类型和格式
# 常见数据类型示例
def demonstrate_data_types():
# 图像数据
image_tensor = torch.randn(3, 224, 224) # RGB图像
print(f"图像张量形状: {image_tensor.shape}")
# 文本数据(词汇索引)
text_indices = torch.tensor([1, 15, 23, 8, 45, 2]) # 词汇ID序列
print(f"文本索引形状: {text_indices.shape}")
# 表格数据
tabular_data = torch.randn(100, 10) # 100个样本,10个特征
print(f"表格数据形状: {tabular_data.shape}")
# 时间序列数据
time_series = torch.randn(50, 1) # 50个时间步,1个特征
print(f"时间序列形状: {time_series.shape}")
# 标签数据
labels = torch.randint(0, 10, (100,)) # 100个样本的分类标签
print(f"标签形状: {labels.shape}")
demonstrate_data_types()
4.2 Dataset类详解
4.2.1 Dataset基础概念
# 最简单的Dataset实现
class SimpleDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
# 创建简单数据集
data = torch.randn(1000, 10) # 1000个样本,每个10维
labels = torch.randint(0, 3, (1000,)) # 3分类问题
simple_dataset = SimpleDataset(data, labels)
print(f"数据集大小: {len(simple_dataset)}")
print(f"第一个样本: {simple_dataset[0]}")
print(f"样本形状: {simple_dataset[0][0].shape}")
print(f"标签: {simple_dataset[0][1]}")
4.2.2 自定义图像数据集
# 图像数据集实现
class CustomImageDataset(Dataset):
def __init__(self, image_dir, annotations_file, transform=None):
self.image_dir = Path(image_dir)
self.annotations = pd.read_csv(annotations_file) if annotations_file else None
self.transform = transform
# 如果没有标注文件,则扫描目录
if self.annotations is None:
self.image_paths = list(self.image_dir.glob("*.jpg")) + \
list(self.image_dir.glob("*.png"))
else:
self.image_paths = [self.image_dir / fname for fname in self.annotations['filename']]
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# 加载图像
image_path = self.image_paths[idx]
try:
image = Image.open(image_path).convert('RGB')
except Exception as e:
print(f"Error loading image {image_path}: {e}")
# 返回一个默认图像
image = Image.new('RGB', (224, 224), color='black')
# 获取标签
if self.annotations is not None:
label = self.annotations.iloc[idx]['label']
else:
# 从文件名或目录结构推断标签
label = 0 # 默认标签
# 应用变换
if self.transform:
image = self.transform(image)
return image, label
# 创建示例数据集(模拟)
def create_sample_images():
"""创建一些示例图像用于演示"""
sample_dir = Path("sample_images")
sample_dir.mkdir(exist_ok=True)
# 创建一些随机图像
for i in range(10):
# 创建随机RGB图像
img_array = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
img = Image.fromarray(img_array)
img.save(sample_dir / f"sample_{i}.jpg")
# 创建标注文件
annotations = pd.DataFrame({
'filename': [f"sample_{i}.jpg" for i in range(10)],
'label': np.random.randint(0, 3, 10)
})
annotations.to_csv(sample_dir / "annotations.csv", index=False)
return sample_dir
# 创建示例并测试
try:
sample_dir = create_sample_images()
# 定义图像变换
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 创建数据集
image_dataset = CustomImageDataset(
image_dir=sample_dir,
annotations_file=sample_dir / "annotations.csv",
transform=transform
)
print(f"图像数据集大小: {len(image_dataset)}")
sample_image, sample_label = image_dataset[0]
print(f"图像形状: {sample_image.shape}")
print(f"标签: {sample_label}")
except Exception as e:
print(f"创建示例图像时出错: {e}")
print("将使用模拟数据继续演示")
4.2.3 文本数据集
# 文本数据集实现
class TextDataset(Dataset):
def __init__(self, texts, labels, tokenizer=None, max_length=128):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
# 简单的词汇表构建(如果没有提供tokenizer)
if self.tokenizer is None:
self.vocab = self._build_vocab()
self.tokenizer = self._simple_tokenize
def _build_vocab(self):
"""构建简单词汇表"""
vocab = {'<PAD>': 0, '<UNK>': 1}
for text in self.texts:
for word in text.lower().split():
if word not in vocab:
vocab[word] = len(vocab)
return vocab
def _simple_tokenize(self, text):
"""简单分词和编码"""
words = text.lower().split()
indices = [self.vocab.get(word, self.vocab['<UNK>']) for word in words]
# 截断或填充到固定长度
if len(indices) > self.max_length:
indices = indices[:self.max_length]
else:
indices.extend([self.vocab['<PAD>']] * (self.max_length - len(indices)))
return torch.tensor(indices, dtype=torch.long)
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
# 文本编码
encoded_text = self.tokenizer(text)
return encoded_text, torch.tensor(label, dtype=torch.long)
# 创建文本数据集示例
sample_texts = [
"This is a positive example",
"This is a negative example",
"Another positive case here",
"Yet another negative case",
"Neutral statement without sentiment"
]
sample_labels = [1, 0, 1, 0, 2] # 正面、负面、中性
text_dataset = TextDataset(sample_texts, sample_labels, max_length=10)
print(f"文本数据集大小: {len(text_dataset)}")
print(f"词汇表大小: {len(text_dataset.vocab)}")
sample_text, sample_label = text_dataset[0]
print(f"编码后的文本: {sample_text}")
print(f"文本标签: {sample_label}")
4.2.4 时间序列数据集
# 时间序列数据集
class TimeSeriesDataset(Dataset):
def __init__(self, data, sequence_length, prediction_length=1):
self.data = torch.FloatTensor(data)
self.sequence_length = sequence_length
self.prediction_length = prediction_length
def __len__(self):
return len(self.data) - self.sequence_length - self.prediction_length + 1
def __getitem__(self, idx):
# 输入序列
x = self.data[idx:idx + self.sequence_length]
# 目标序列
y = self.data[idx + self.sequence_length:idx + self.sequence_length + self.prediction_length]
return x, y
# 生成示例时间序列数据
def generate_sine_wave(length=1000, frequency=0.1, noise=0.1):
t = np.linspace(0, length * frequency, length)
signal = np.sin(2 * np.pi * t) + noise * np.random.randn(length)
return signal
# 创建时间序列数据集
sine_data = generate_sine_wave(1000)
ts_dataset = TimeSeriesDataset(sine_data, sequence_length=50, prediction_length=10)
print(f"时间序列数据集大小: {len(ts_dataset)}")
sample_x, sample_y = ts_dataset[0]
print(f"输入序列形状: {sample_x.shape}")
print(f"目标序列形状: {sample_y.shape}")
# 可视化时间序列数据
plt.figure(figsize=(12, 4))
plt.plot(sine_data[:200])
plt.title("示例时间序列数据")
plt.xlabel("时间步")
plt.ylabel("值")
plt.grid(True)
plt.show()
4.3 DataLoader详解
4.3.1 DataLoader基础使用
# DataLoader基础配置
def demonstrate_dataloader():
# 使用之前创建的简单数据集
data = torch.randn(1000, 10)
labels = torch.randint(0, 3, (1000,))
dataset = SimpleDataset(data, labels)
# 创建DataLoader
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=0, # Windows上建议设为0
drop_last=False
)
print(f"数据集大小: {len(dataset)}")
print(f"批次数量: {len(dataloader)}")
print(f"批次大小: {dataloader.batch_size}")
# 遍历数据
for batch_idx, (batch_data, batch_labels) in enumerate(dataloader):
print(f"批次 {batch_idx}: 数据形状 {batch_data.shape}, 标签形状 {batch_labels.shape}")
if batch_idx >= 2: # 只显示前3个批次
break
demonstrate_dataloader()
4.3.2 自定义collate函数
# 自定义批次整理函数
def custom_collate_fn(batch):
"""
自定义collate函数,用于处理变长序列
"""
# 分离数据和标签
data, labels = zip(*batch)
# 对于变长序列,需要填充到相同长度
max_length = max(len(seq) for seq in data)
# 填充序列
padded_data = []
for seq in data:
if len(seq) < max_length:
# 填充0到最大长度
padded_seq = torch.cat([seq, torch.zeros(max_length - len(seq))])
else:
padded_seq = seq
padded_data.append(padded_seq)
# 堆叠成批次
batch_data = torch.stack(padded_data)
batch_labels = torch.tensor(labels)
return batch_data, batch_labels
# 创建变长序列数据集
class VariableLengthDataset(Dataset):
def __init__(self, num_samples=100):
self.data = []
self.labels = []
for i in range(num_samples):
# 创建随机长度的序列
length = torch.randint(5, 20, (1,)).item()
sequence = torch.randn(length)
label = torch.randint(0, 2, (1,)).item()
self.data.append(sequence)
self.labels.append(label)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
# 测试自定义collate函数
var_dataset = VariableLengthDataset(50)
var_dataloader = DataLoader(
var_dataset,
batch_size=8,
shuffle=True,
collate_fn=custom_collate_fn
)
print("变长序列数据集测试:")
for batch_data, batch_labels in var_dataloader:
print(f"批次数据形状: {batch_data.shape}")
print(f"批次标签形状: {batch_labels.shape}")
break
4.3.3 数据加载性能优化
# 性能测试函数
import time
def benchmark_dataloader(dataset, configs):
"""测试不同DataLoader配置的性能"""
results = {}
for name, config in configs.items():
dataloader = DataLoader(dataset, **config)
start_time = time.time()
for batch_idx, (data, labels) in enumerate(dataloader):
# 模拟一些计算
_ = data.mean()
if batch_idx >= 50: # 只测试前50个批次
break
end_time = time.time()
results[name] = end_time - start_time
print(f"{name}: {results[name]:.4f}秒")
return results
# 创建大一点的数据集用于性能测试
large_data = torch.randn(10000, 100)
large_labels = torch.randint(0, 10, (10000,))
large_dataset = SimpleDataset(large_data, large_labels)
# 不同的DataLoader配置
configs = {
"基础配置": {"batch_size": 64, "shuffle": False, "num_workers": 0},
"启用shuffle": {"batch_size": 64, "shuffle": True, "num_workers": 0},
"大批次": {"batch_size": 256, "shuffle": True, "num_workers": 0},
"小批次": {"batch_size": 16, "shuffle": True, "num_workers": 0},
}
print("DataLoader性能测试:")
benchmark_results = benchmark_dataloader(large_dataset, configs)
4.4 数据预处理和增强
4.4.1 图像预处理
# 图像预处理管道
class ImagePreprocessing:
def __init__(self):
# 基础预处理
self.basic_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 训练时的数据增强
self.train_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 验证时的预处理
self.val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def get_transform(self, mode='train'):
if mode == 'train':
return self.train_transform
elif mode == 'val':
return self.val_transform
else:
return self.basic_transform
# 自定义图像变换
class CustomImageTransform:
def __init__(self, noise_factor=0.1):
self.noise_factor = noise_factor
def __call__(self, tensor):
# 添加高斯噪声
noise = torch.randn_like(tensor) * self.noise_factor
return tensor + noise
# 组合自定义变换
custom_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
CustomImageTransform(noise_factor=0.05),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
print("图像预处理管道创建完成")
4.4.2 文本预处理
# 文本预处理工具
class TextPreprocessor:
def __init__(self, vocab_size=10000, max_length=128):
self.vocab_size = vocab_size
self.max_length = max_length
self.vocab = {'<PAD>': 0, '<UNK>': 1, '<SOS>': 2, '<EOS>': 3}
self.word_count = {}
def build_vocab(self, texts):
"""构建词汇表"""
# 统计词频
for text in texts:
for word in self.tokenize(text):
self.word_count[word] = self.word_count.get(word, 0) + 1
# 按频率排序,取前vocab_size个词
sorted_words = sorted(self.word_count.items(), key=lambda x: x[1], reverse=True)
for word, count in sorted_words[:self.vocab_size - len(self.vocab)]:
self.vocab[word] = len(self.vocab)
def tokenize(self, text):
"""简单分词"""
import re
# 简单的分词:小写化,去除标点,按空格分割
text = re.sub(r'[^\w\s]', '', text.lower())
return text.split()
def encode(self, text):
"""编码文本为数字序列"""
tokens = self.tokenize(text)
indices = [self.vocab.get(token, self.vocab['<UNK>']) for token in tokens]
# 截断或填充
if len(indices) > self.max_length:
indices = indices[:self.max_length]
else:
indices.extend([self.vocab['<PAD>']] * (self.max_length - len(indices)))
return torch.tensor(indices, dtype=torch.long)
def decode(self, indices):
"""解码数字序列为文本"""
# 创建反向词汇表
reverse_vocab = {v: k for k, v in self.vocab.items()}
tokens = [reverse_vocab.get(idx.item(), '<UNK>') for idx in indices]
# 移除填充符号
tokens = [token for token in tokens if token != '<PAD>']
return ' '.join(tokens)
# 测试文本预处理
sample_texts = [
"Hello world, this is a test sentence.",
"Another example with different words.",
"Machine learning is fascinating!",
"PyTorch makes deep learning accessible."
]
text_processor = TextPreprocessor(vocab_size=50, max_length=15)
text_processor.build_vocab(sample_texts)
print(f"词汇表大小: {len(text_processor.vocab)}")
print("词汇表示例:", list(text_processor.vocab.items())[:10])
# 编码和解码示例
test_text = "Hello world, this is a test."
encoded = text_processor.encode(test_text)
decoded = text_processor.decode(encoded)
print(f"原文: {test_text}")
print(f"编码: {encoded}")
print(f"解码: {decoded}")
4.4.3 数值数据预处理
# 数值数据预处理
class NumericalPreprocessor:
def __init__(self):
self.scalers = {}
self.fitted = False
def fit(self, data):
"""拟合预处理参数"""
self.mean = data.mean(dim=0)
self.std = data.std(dim=0)
self.min = data.min(dim=0)[0]
self.max = data.max(dim=0)[0]
self.fitted = True
def standardize(self, data):
"""标准化(Z-score)"""
if not self.fitted:
raise ValueError("必须先调用fit方法")
return (data - self.mean) / (self.std + 1e-8)
def normalize(self, data):
"""归一化到[0,1]"""
if not self.fitted:
raise ValueError("必须先调用fit方法")
return (data - self.min) / (self.max - self.min + 1e-8)
def robust_scale(self, data):
"""鲁棒缩放(使用中位数和四分位距)"""
median = data.median(dim=0)[0]
q75 = data.quantile(0.75, dim=0)
q25 = data.quantile(0.25, dim=0)
iqr = q75 - q25
return (data - median) / (iqr + 1e-8)
# 测试数值预处理
# 生成示例数据(不同特征有不同的尺度)
feature1 = torch.randn(1000) * 100 + 500 # 大尺度特征
feature2 = torch.randn(1000) * 0.1 + 0.5 # 小尺度特征
feature3 = torch.randint(0, 10, (1000,)).float() # 离散特征
numerical_data = torch.stack([feature1, feature2, feature3], dim=1)
print("原始数据统计:")
print(f"均值: {numerical_data.mean(dim=0)}")
print(f"标准差: {numerical_data.std(dim=0)}")
print(f"最小值: {numerical_data.min(dim=0)[0]}")
print(f"最大值: {numerical_data.max(dim=0)[0]}")
# 应用预处理
preprocessor = NumericalPreprocessor()
preprocessor.fit(numerical_data)
standardized_data = preprocessor.standardize(numerical_data)
normalized_data = preprocessor.normalize(numerical_data)
print("\n标准化后数据统计:")
print(f"均值: {standardized_data.mean(dim=0)}")
print(f"标准差: {standardized_data.std(dim=0)}")
print("\n归一化后数据统计:")
print(f"最小值: {normalized_data.min(dim=0)[0]}")
print(f"最大值: {normalized_data.max(dim=0)[0]}")
4.5 数据增强技术
4.5.1 图像数据增强
# 高级图像数据增强
class AdvancedImageAugmentation:
def __init__(self):
# 几何变换
self.geometric_transforms = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.2),
transforms.RandomRotation(degrees=30),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1))
])
# 颜色变换
self.color_transforms = transforms.Compose([
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2),
transforms.RandomGrayscale(p=0.1),
transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.2)
])
# 噪声和模糊
self.noise_transforms = transforms.Compose([
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
])
def get_strong_augmentation(self):
"""强数据增强"""
return transforms.Compose([
self.geometric_transforms,
self.color_transforms,
self.noise_transforms,
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def get_weak_augmentation(self):
"""弱数据增强"""
return transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomCrop(224, padding=4),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 自定义图像增强
class CutOut:
"""随机遮挡增强"""
def __init__(self, n_holes=1, length=16):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = torch.ones((h, w), dtype=torch.float32)
for _ in range(self.n_holes):
y = torch.randint(0, h, (1,)).item()
x = torch.randint(0, w, (1,)).item()
y1 = max(0, y - self.length // 2)
y2 = min(h, y + self.length // 2)
x1 = max(0, x - self.length // 2)
x2 = min(w, x + self.length // 2)
mask[y1:y2, x1:x2] = 0
mask = mask.expand_as(img)
img = img * mask
return img
class MixUp:
"""MixUp数据增强"""
def __init__(self, alpha=1.0):
self.alpha = alpha
def __call__(self, batch_data, batch_labels):
if self.alpha > 0:
lam = np.random.beta(self.alpha, self.alpha)
else:
lam = 1
batch_size = batch_data.size(0)
index = torch.randperm(batch_size)
mixed_data = lam * batch_data + (1 - lam) * batch_data[index, :]
y_a, y_b = batch_labels, batch_labels[index]
return mixed_data, y_a, y_b, lam
print("高级数据增强技术实现完成")
4.5.2 文本数据增强
# 文本数据增强
class TextAugmentation:
def __init__(self):
self.synonym_dict = {
'good': ['great', 'excellent', 'wonderful', 'fantastic'],
'bad': ['terrible', 'awful', 'horrible', 'poor'],
'big': ['large', 'huge', 'enormous', 'massive'],
'small': ['tiny', 'little', 'mini', 'compact']
}
def synonym_replacement(self, text, n=1):
"""同义词替换"""
words = text.split()
new_words = words.copy()
for _ in range(n):
# 随机选择一个词进行替换
if len(words) > 0:
idx = torch.randint(0, len(words), (1,)).item()
word = words[idx].lower()
if word in self.synonym_dict:
synonym = torch.randint(0, len(self.synonym_dict[word]), (1,)).item()
new_words[idx] = self.synonym_dict[word][synonym]
return ' '.join(new_words)
def random_insertion(self, text, n=1):
"""随机插入"""
words = text.split()
for _ in range(n):
if len(words) > 0:
# 随机选择一个同义词插入
word_to_insert = torch.randint(0, len(words), (1,)).item()
original_word = words[word_to_insert].lower()
if original_word in self.synonym_dict:
synonym_idx = torch.randint(0, len(self.synonym_dict[original_word]), (1,)).item()
synonym = self.synonym_dict[original_word][synonym_idx]
# 随机位置插入
insert_pos = torch.randint(0, len(words) + 1, (1,)).item()
words.insert(insert_pos, synonym)
return ' '.join(words)
def random_deletion(self, text, p=0.1):
"""随机删除"""
words = text.split()
if len(words) == 1:
return text
new_words = []
for word in words:
if torch.rand(1).item() > p:
new_words.append(word)
if len(new_words) == 0:
# 如果所有词都被删除,随机保留一个
idx = torch.randint(0, len(words), (1,)).item()
return words[idx]
return ' '.join(new_words)
def random_swap(self, text, n=1):
"""随机交换"""
words = text.split()
for _ in range(n):
if len(words) >= 2:
idx1 = torch.randint(0, len(words), (1,)).item()
idx2 = torch.randint(0, len(words), (1,)).item()
words[idx1], words[idx2] = words[idx2], words[idx1]
return ' '.join(words)
# 测试文本增强
text_aug = TextAugmentation()
original_text = "This is a good example of text augmentation"
print("原始文本:", original_text)
print("同义词替换:", text_aug.synonym_replacement(original_text, n=2))
print("随机插入:", text_aug.random_insertion(original_text, n=1))
print("随机删除:", text_aug.random_deletion(original_text, p=0.2))
print("随机交换:", text_aug.random_swap(original_text, n=2))
4.6 数据管道优化
4.6.1 内存优化
# 内存高效的数据集
class MemoryEfficientDataset(Dataset):
def __init__(self, data_paths, labels, transform=None, cache_size=1000):
self.data_paths = data_paths
self.labels = labels
self.transform = transform
self.cache_size = cache_size
self.cache = {}
self.cache_order = []
def __len__(self):
return len(self.data_paths)
def __getitem__(self, idx):
# 检查缓存
if idx in self.cache:
data = self.cache[idx]
else:
# 从磁盘加载数据
data = self._load_data(self.data_paths[idx])
# 更新缓存
if len(self.cache) >= self.cache_size:
# 移除最旧的缓存项
oldest_idx = self.cache_order.pop(0)
del self.cache[oldest_idx]
self.cache[idx] = data
self.cache_order.append(idx)
# 应用变换
if self.transform:
data = self.transform(data)
return data, self.labels[idx]
def _load_data(self, path):
# 模拟数据加载
return torch.randn(100) # 实际应用中这里会加载真实数据
# 预取数据加载器
class PrefetchDataLoader:
def __init__(self, dataloader, device):
self.dataloader = dataloader
self.device = device
self.stream = torch.cuda.Stream() if device.type == 'cuda' else None
def __iter__(self):
if self.stream is not None:
# GPU预取
first = True
for next_data, next_target in self.dataloader:
with torch.cuda.stream(self.stream):
next_data = next_data.to(self.device, non_blocking=True)
next_target = next_target.to(self.device, non_blocking=True)
if not first:
yield data, target
else:
first = False
torch.cuda.current_stream().wait_stream(self.stream)
data, target = next_data, next_target
yield data, target
else:
# CPU情况
for data, target in self.dataloader:
yield data.to(self.device), target.to(self.device)
print("内存优化数据管道实现完成")
4.6.2 并行处理
# 并行数据处理
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor
class ParallelDataProcessor:
def __init__(self, num_workers=None):
self.num_workers = num_workers or mp.cpu_count()
def process_batch_parallel(self, batch_data, process_func):
"""并行处理批次数据"""
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
futures = [executor.submit(process_func, data) for data in batch_data]
results = [future.result() for future in futures]
return results
def preprocess_dataset_parallel(self, dataset, process_func, batch_size=32):
"""并行预处理整个数据集"""
processed_data = []
for i in range(0, len(dataset), batch_size):
batch = [dataset[j] for j in range(i, min(i + batch_size, len(dataset)))]
batch_results = self.process_batch_parallel(batch, process_func)
processed_data.extend(batch_results)
return processed_data
# 示例处理函数
def expensive_preprocessing(data):
"""模拟耗时的预处理操作"""
import time
time.sleep(0.01) # 模拟计算时间
return data * 2 + 1
# 测试并行处理
if __name__ == "__main__":
# 创建测试数据
test_data = [torch.randn(10) for _ in range(100)]
# 串行处理
start_time = time.time()
serial_results = [expensive_preprocessing(data) for data in test_data]
serial_time = time.time() - start_time
# 并行处理
processor = ParallelDataProcessor(num_workers=4)
start_time = time.time()
parallel_results = processor.process_batch_parallel(test_data, expensive_preprocessing)
parallel_time = time.time() - start_time
print(f"串行处理时间: {serial_time:.4f}秒")
print(f"并行处理时间: {parallel_time:.4f}秒")
print(f"加速比: {serial_time/parallel_time:.2f}x")
4.7 实际应用案例
4.7.1 图像分类数据管道
# 完整的图像分类数据管道
class ImageClassificationPipeline:
def __init__(self, data_dir, batch_size=32, num_workers=0, image_size=224):
self.data_dir = Path(data_dir)
self.batch_size = batch_size
self.num_workers = num_workers
self.image_size = image_size
# 定义变换
self.train_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop((image_size, image_size)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
self.val_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def create_datasets(self):
"""创建训练和验证数据集"""
# 假设数据目录结构为:
# data_dir/
# train/
# class1/
# class2/
# val/
# class1/
# class2/
try:
train_dataset = torchvision.datasets.ImageFolder(
root=self.data_dir / 'train',
transform=self.train_transform
)
val_dataset = torchvision.datasets.ImageFolder(
root=self.data_dir / 'val',
transform=self.val_transform
)
return train_dataset, val_dataset
except FileNotFoundError:
print("数据目录不存在,创建模拟数据集")
return self._create_mock_datasets()
def _create_mock_datasets(self):
"""创建模拟数据集用于演示"""
# 创建模拟的图像数据
train_data = torch.randn(1000, 3, self.image_size, self.image_size)
train_labels = torch.randint(0, 10, (1000,))
val_data = torch.randn(200, 3, self.image_size, self.image_size)
val_labels = torch.randint(0, 10, (200,))
train_dataset = SimpleDataset(train_data, train_labels)
val_dataset = SimpleDataset(val_data, val_labels)
return train_dataset, val_dataset
def create_dataloaders(self):
"""创建数据加载器"""
train_dataset, val_dataset = self.create_datasets()
train_loader = DataLoader(
train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True
)
val_loader = DataLoader(
val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True
)
return train_loader, val_loader
# 测试图像分类管道
pipeline = ImageClassificationPipeline("./data", batch_size=16)
train_loader, val_loader = pipeline.create_dataloaders()
print(f"训练集批次数: {len(train_loader)}")
print(f"验证集批次数: {len(val_loader)}")
# 查看一个批次的数据
for batch_data, batch_labels in train_loader:
print(f"训练批次形状: {batch_data.shape}, {batch_labels.shape}")
break
4.7.2 文本分类数据管道
# 文本分类数据管道
class TextClassificationPipeline:
def __init__(self, vocab_size=10000, max_length=128, batch_size=32):
self.vocab_size = vocab_size
self.max_length = max_length
self.batch_size = batch_size
self.tokenizer = None
def prepare_data(self, texts, labels):
"""准备文本数据"""
# 构建词汇表
self.tokenizer = TextPreprocessor(self.vocab_size, self.max_length)
self.tokenizer.build_vocab(texts)
# 创建数据集
dataset = TextDataset(texts, labels, self.tokenizer.encode, self.max_length)
return dataset
def create_dataloaders(self, train_texts, train_labels, val_texts, val_labels):
"""创建数据加载器"""
# 准备数据集
train_dataset = self.prepare_data(train_texts, train_labels)
# 验证集使用相同的tokenizer
val_dataset = TextDataset(val_texts, val_labels, self.tokenizer.encode, self.max_length)
# 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=self.batch_size,
shuffle=True,
collate_fn=self._collate_fn
)
val_loader = DataLoader(
val_dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=self._collate_fn
)
return train_loader, val_loader
def _collate_fn(self, batch):
"""批次整理函数"""
texts, labels = zip(*batch)
texts = torch.stack(texts)
labels = torch.tensor(labels, dtype=torch.long)
return texts, labels
# 测试文本分类管道
train_texts = [
"This movie is really good and entertaining",
"I hate this film, it's terrible",
"Amazing storyline and great acting",
"Boring and poorly made movie",
"Excellent cinematography and direction"
] * 20 # 重复以增加数据量
train_labels = [1, 0, 1, 0, 1] * 20 # 1=正面,0=负面
val_texts = [
"Great movie with excellent plot",
"Terrible acting and bad story",
"Wonderful film experience"
]
val_labels = [1, 0, 1]
text_pipeline = TextClassificationPipeline(vocab_size=1000, max_length=20, batch_size=8)
train_loader, val_loader = text_pipeline.create_dataloaders(
train_texts, train_labels, val_texts, val_labels
)
print(f"文本训练集批次数: {len(train_loader)}")
print(f"文本验证集批次数: {len(val_loader)}")
# 查看一个批次
for batch_texts, batch_labels in train_loader:
print(f"文本批次形状: {batch_texts.shape}, {batch_labels.shape}")
print(f"第一个样本: {batch_texts[0]}")
break
4.8 数据质量和监控
4.8.1 数据质量检查
# 数据质量检查工具
class DataQualityChecker:
def __init__(self):
self.stats = {}
def check_dataset(self, dataset, sample_size=100):
"""检查数据集质量"""
print("=== 数据质量检查报告 ===")
# 基本统计
dataset_size = len(dataset)
print(f"数据集大小: {dataset_size}")
# 采样检查
sample_indices = torch.randperm(dataset_size)[:min(sample_size, dataset_size)]
data_shapes = []
label_values = []
for idx in sample_indices:
try:
data, label = dataset[idx]
data_shapes.append(data.shape if hasattr(data, 'shape') else len(data))
label_values.append(label.item() if hasattr(label, 'item') else label)
except Exception as e:
print(f"样本 {idx} 加载失败: {e}")
# 形状一致性检查
unique_shapes = list(set(data_shapes))
print(f"数据形状: {unique_shapes}")
if len(unique_shapes) > 1:
print("⚠️ 警告: 数据形状不一致")
# 标签分布检查
unique_labels = list(set(label_values))
print(f"标签类别: {unique_labels}")
label_counts = {}
for label in label_values:
label_counts[label] = label_counts.get(label, 0) + 1
print("标签分布:")
for label, count in sorted(label_counts.items()):
percentage = count / len(label_values) * 100
print(f" 类别 {label}: {count} ({percentage:.1f}%)")
# 检查类别不平衡
max_count = max(label_counts.values())
min_count = min(label_counts.values())
imbalance_ratio = max_count / min_count
if imbalance_ratio > 3:
print(f"⚠️ 警告: 类别不平衡,比例为 {imbalance_ratio:.1f}:1")
return {
'dataset_size': dataset_size,
'unique_shapes': unique_shapes,
'unique_labels': unique_labels,
'label_distribution': label_counts,
'imbalance_ratio': imbalance_ratio
}
def check_dataloader(self, dataloader, num_batches=5):
"""检查数据加载器"""
print("\n=== DataLoader检查报告 ===")
batch_times = []
batch_shapes = []
for batch_idx, (data, labels) in enumerate(dataloader):
start_time = time.time()
# 记录批次信息
batch_shapes.append((data.shape, labels.shape))
# 模拟一些计算
_ = data.mean()
batch_time = time.time() - start_time
batch_times.append(batch_time)
if batch_idx >= num_batches - 1:
break
avg_batch_time = np.mean(batch_times)
print(f"平均批次加载时间: {avg_batch_time:.4f}秒")
print(f"批次形状: {batch_shapes[0]}")
if avg_batch_time > 0.1:
print("⚠️ 警告: 批次加载时间较长,考虑优化数据管道")
return {
'avg_batch_time': avg_batch_time,
'batch_shapes': batch_shapes
}
# 测试数据质量检查
quality_checker = DataQualityChecker()
# 创建测试数据集
test_data = torch.randn(500, 32)
test_labels = torch.randint(0, 5, (500,))
test_dataset = SimpleDataset(test_data, test_labels)
# 检查数据集质量
dataset_stats = quality_checker.check_dataset(test_dataset)
# 检查数据加载器
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
loader_stats = quality_checker.check_dataloader(test_loader)
4.8.2 数据监控和可视化
# 数据监控工具
class DataMonitor:
def __init__(self):
self.batch_stats = []
def monitor_batch(self, data, labels):
"""监控单个批次"""
stats = {
'batch_size': data.size(0),
'data_mean': data.mean().item(),
'data_std': data.std().item(),
'data_min': data.min().item(),
'data_max': data.max().item(),
'label_distribution': self._get_label_distribution(labels)
}
self.batch_stats.append(stats)
return stats
def _get_label_distribution(self, labels):
"""获取标签分布"""
unique_labels, counts = torch.unique(labels, return_counts=True)
return dict(zip(unique_labels.tolist(), counts.tolist()))
def plot_monitoring_results(self):
"""可视化监控结果"""
if not self.batch_stats:
print("没有监控数据")
return
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# 数据均值变化
means = [stats['data_mean'] for stats in self.batch_stats]
axes[0, 0].plot(means)
axes[0, 0].set_title('批次数据均值')
axes[0, 0].set_xlabel('批次')
axes[0, 0].set_ylabel('均值')
# 数据标准差变化
stds = [stats['data_std'] for stats in self.batch_stats]
axes[0, 1].plot(stds)
axes[0, 1].set_title('批次数据标准差')
axes[0, 1].set_xlabel('批次')
axes[0, 1].set_ylabel('标准差')
# 批次大小
batch_sizes = [stats['batch_size'] for stats in self.batch_stats]
axes[1, 0].plot(batch_sizes)
axes[1, 0].set_title('批次大小')
axes[1, 0].set_xlabel('批次')
axes[1, 0].set_ylabel('样本数')
# 数据范围
mins = [stats['data_min'] for stats in self.batch_stats]
maxs = [stats['data_max'] for stats in self.batch_stats]
axes[1, 1].plot(mins, label='最小值')
axes[1, 1].plot(maxs, label='最大值')
axes[1, 1].set_title('数据范围')
axes[1, 1].set_xlabel('批次')
axes[1, 1].set_ylabel('值')
axes[1, 1].legend()
plt.tight_layout()
plt.show()
# 测试数据监控
monitor = DataMonitor()
# 模拟监控过程
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
print("开始数据监控...")
for batch_idx, (data, labels) in enumerate(test_loader):
stats = monitor.monitor_batch(data, labels)
if batch_idx == 0:
print(f"第一个批次统计: {stats}")
if batch_idx >= 10: # 只监控前10个批次
break
# 可视化监控结果
monitor.plot_monitoring_results()
本章总结
在本章中,我们全面学习了PyTorch中的数据处理和加载技术:
核心概念
- Dataset类: 数据集的抽象基类,定义数据访问接口
- DataLoader类: 批量数据加载器,支持并行加载和批次处理
- 数据变换: transforms模块提供的数据预处理和增强功能
重要技术
- 自定义数据集: 针对不同数据类型实现专用数据集类
- 数据预处理: 标准化、归一化、编码等预处理技术
- 数据增强: 图像和文本的各种增强方法
- 性能优化: 内存管理、并行处理、预取等优化策略
实践技能
- 多模态数据处理: 图像、文本、数值、时间序列数据的处理
- 数据管道设计: 完整的数据处理流水线构建
- 质量监控: 数据质量检查和监控工具的使用
- 性能调优: 数据加载性能的分析和优化
学习成果
通过本章学习,你现在能够: 1. 实现各种类型的自定义数据集 2. 设计高效的数据预处理管道 3. 应用适当的数据增强技术 4. 优化数据加载性能 5. 监控和保证数据质量
下一章预告
在下一章《模型训练与优化》中,我们将学习: - 损失函数的选择和使用 - 优化器的原理和配置 - 完整的训练循环实现 - 学习率调度和正则化技术 - 模型评估和验证策略
练习题
基础练习
- 实现一个CSV文件数据集类
- 创建图像数据增强管道
- 实现变长序列的collate函数
进阶练习
- 设计内存高效的大规模数据集
- 实现多进程数据预处理
- 创建数据质量自动检查工具
项目练习
- 构建完整的图像分类数据管道
- 实现文本情感分析数据处理系统
- 设计时间序列预测数据加载器
代码示例总结
本章包含了以下重要代码示例: - 基础Dataset和DataLoader使用 - 自定义数据集实现(图像、文本、时间序列) - 数据预处理和增强技术 - 性能优化策略(内存管理、并行处理) - 数据质量监控工具 - 完整的数据管道实现
学习资源
第4章完成! 🎉
你已经掌握了PyTorch中数据处理与加载的核心技术。这些技能是深度学习项目成功的基础,将为后续的模型训练和优化提供坚实的数据支撑。
下一步: 进入第5章《模型训练与优化》,学习如何使用这些数据来训练强大的深度学习模型!