Fashion-MNIST 分类任务代码笔记
一、整体概述
本代码基于 PyTorch 实现了一个简单的全连接神经网络,用于解决 Fashion-MNIST 图像分类任务(10个类别)。核心流程包括:网络定义、权重初始化、超参数设置、数据加载、训练循环实现及模型评估。
二、代码分块解析
(一)导入依赖库
import torch
from torch import nn
from d2l import torch as d2l
- 核心库说明:
torch
:PyTorch 核心库,提供张量操作、自动求导等基础功能。torch.nn
:PyTorch 神经网络模块,包含层、损失函数等组件。d2l.torch
:《动手学深度学习》工具库,提供数据加载等便捷功能。
(二)定义神经网络结构
net = nn.Sequential(nn.Flatten(), # 将28x28图像展平为784维向量nn.Linear(784, 256), # 隐藏层:784→256nn.ReLU(), # 激活函数nn.Linear(256, 10) # 输出层:256→10(10个类别)
)
- 网络组件详解:
nn.Flatten()
:图像预处理层,将输入的 2D 图像张量(28×28)转换为 1D 向量(784 维),适配全连接层输入格式。nn.Linear(784, 256)
:全连接隐藏层,接收 784 维输入,输出 256 维特征,通过矩阵乘法实现特征映射。nn.ReLU()
:激活函数,引入非线性,公式为 ReLU(x) = max(0, x),解决线性模型表达能力不足的问题。nn.Linear(256, 10)
:全连接输出层,将 256 维特征映射到 10 维,对应 10 个类别的原始得分(logits)。
(三)权重初始化函数
# 初始化权重
def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01) # 正态分布初始化权重net.apply(init_weights);
- 功能说明:
- 定义
init_weights
函数,仅对全连接层(nn.Linear
)进行权重初始化。 - 使用
nn.init.normal_
按正态分布 N(0, 0.01²) 初始化权重,避免初始权重过大/过小导致训练不稳定(如梯度消失/爆炸)。 net.apply(init_weights)
:递归遍历网络所有层,对符合条件的层执行初始化操作。
- 定义
(四)超参数与训练组件设置
# 超参数
batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss() # 交叉熵损失(含Softmax)
trainer = torch.optim.SGD(net.parameters(), lr=lr) # 随机梯度下降优化器
- 关键组件说明:
- 超参数:
batch_size=256
:每次训练迭代的样本数量,平衡训练速度与稳定性。lr=0.1
:学习率,控制参数更新幅度。num_epochs=10
:训练轮次,即遍历整个训练集的次数。
- 损失函数:
nn.CrossEntropyLoss()
适用于多分类任务,内部集成了 Softmax 函数,直接接收网络输出的 logits 计算损失。 - 优化器:
torch.optim.SGD
为随机梯度下降优化器,接收网络参数和学习率,负责更新权重以最小化损失。
- 超参数:
(五)数据加载
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
- 功能说明:通过
d2l
工具库加载 Fashion-MNIST 数据集,返回训练集迭代器train_iter
和测试集迭代器test_iter
。 - 数据特性:每张图像为 28×28 灰度图,训练集 60000 样本,测试集 10000 样本,共 10 个服装类别。
(六)训练与评估函数实现
1. 单轮训练函数
def train_epoch(net, train_iter, loss, trainer):"""训练一个epoch"""net.train() # 切换到训练模式(启用Dropout、BatchNorm等训练特性)total_loss = 0.0total_correct = 0total_samples = 0for X, y in train_iter:# 前向传播:计算预测值和损失y_hat = net(X)l = loss(y_hat, y)# 反向传播 + 参数更新trainer.zero_grad() # 清空上一轮梯度(避免累积)l.backward() # 自动计算梯度(基于计算图)trainer.step() # 根据梯度更新参数# 统计训练指标total_loss += l.item() * X.shape[0] # 累计总损失(乘以批量大小还原真实损失)# 计算正确预测数:argmax(dim=1)取预测概率最大的类别索引total_correct += (y_hat.argmax(dim=1) == y).sum().item()total_samples += X.shape[0] # 累计处理样本数# 返回平均损失和训练准确率return total_loss / total_samples, total_correct / total_samples
2. 测试集评估函数
def evaluate_accuracy(net, test_iter):"""评估测试集准确率"""net.eval() # 切换到评估模式(禁用Dropout、固定BatchNorm统计量)total_correct = 0total_samples = 0with torch.no_grad(): # 禁用梯度计算,减少内存占用并加速for X, y in test_iter:y_hat = net(X)total_correct += (y_hat.argmax(dim=1) == y).sum().item()total_samples += X.shape[0]return total_correct / total_samples # 返回测试准确率
(七)执行训练流程
for epoch in range(num_epochs):train_loss, train_acc = train_epoch(net, train_iter, loss, trainer)test_acc = evaluate_accuracy(net, test_iter)print(f" epoch {epoch+1}:")print(f" 训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.4f}")print(f" 测试准确率: {test_acc:.4f}")
- 流程说明:
- 遍历
num_epochs
个训练轮次。 - 每轮调用
train_epoch
完成训练集的一次遍历,获取训练损失和训练准确率。 - 调用
evaluate_accuracy
评估模型在测试集上的性能。 - 打印当前轮次的训练损失、训练准确率和测试准确率,监控模型训练进度。
- 遍历
三、核心知识点总结
- 全连接网络结构:通过
nn.Sequential
堆叠层,Flatten 层适配图像输入,Linear 层实现特征映射,ReLU 引入非线性。 - 权重初始化:合理的初始化(如小方差正态分布)是训练稳定的关键。
- 训练三要素:交叉熵损失(多分类任务)、SGD 优化器(参数更新)、批量训练(效率与稳定性平衡)。
- 训练/评估模式:
net.train()
和net.eval()
切换网络状态,torch.no_grad()
优化评估过程。 - 指标统计:训练损失(平均损失)、准确率(正确预测数/总样本数)用于监控模型性能。