transfroms
在 PyTorch 中,torchvision.transforms 是用于数据预处理和数据增强的工具集,主要作用是将原始数据(如图像、文本)转换为适合模型输入的格式,并通过随机变换增加数据多样性,从而提升模型的泛化能力。
transforms的作用
1.数据标准化
将原始数据转换为模型要求的格式,例如:
- 将图像从 PIL 格式转为 Tensor(ToTensor)
- 对像素值进行归一化(Normalize),使数据分布更稳定
- 调整图像尺寸(Resize),确保输入尺寸一致
2.数据增强
通过随机变换生成更多样化的训练样本,减少过拟合,例如:
- 随机裁剪(RandomCrop)、翻转(RandomHorizontalFlip)
- 随机调整亮度、对比度(ColorJitter)
- 随机旋转(RandomRotation)
还是上一下示例代码:
from torchvision import transforms
from PIL import Image# 定义变换流水线
transform = transforms.Compose([transforms.Resize((224, 224)), # 调整图像大小为 224x224transforms.RandomHorizontalFlip(p=0.5), # 50% 概率水平翻转(数据增强)transforms.ToTensor(), # 转为 Tensortransforms.Normalize( # 归一化(使用 ImageNet 数据集的均值和标准差)mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])# 加载图像并应用变换
img = Image.open("test.jpg") # 原始 PIL 图像
transformed_img = transform(img) # 经过变换后的 Tensorprint(transformed_img.shape) # 输出: torch.Size([3, 224, 224])(通道数×高×宽)
训练和测试中transforms的使用区别
- 训练集:通常加入数据增强(如随机翻转、裁剪),增加数据多样性
- 测试集:仅进行标准化处理(如 Resize、ToTensor、Normalize),不使用随机变换,确保结果可复现
例如:
# 训练集变换(含数据增强)
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 测试集变换(无随机操作)
test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224), # 中心裁剪,而非随机transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
总结:
transforms 是连接原始数据与模型输入的关键环节,其核心价值在于:
1.统一数据格式,使原始数据符合模型输入要求
2.通过数据增强扩展训练样本多样性,提升模型泛化能力
3.简化数据预处理流程,与 Dataset、DataLoader 无缝配合
在实际使用中,需根据数据集特点和模型需求选择合适的变换组合,平衡数据增强效果与计算开销。