当前位置: 首页 > news >正文

多层感知机

Fashion-MNIST 分类任务代码笔记

一、整体概述

本代码基于 PyTorch 实现了一个简单的全连接神经网络,用于解决 Fashion-MNIST 图像分类任务(10个类别)。核心流程包括:网络定义、权重初始化、超参数设置、数据加载、训练循环实现及模型评估。

二、代码分块解析

(一)导入依赖库

import torch
from torch import nn
from d2l import torch as d2l
  • 核心库说明
    • torch:PyTorch 核心库,提供张量操作、自动求导等基础功能。
    • torch.nn:PyTorch 神经网络模块,包含层、损失函数等组件。
    • d2l.torch:《动手学深度学习》工具库,提供数据加载等便捷功能。

(二)定义神经网络结构

net = nn.Sequential(nn.Flatten(),          # 将28x28图像展平为784维向量nn.Linear(784, 256),   # 隐藏层:784→256nn.ReLU(),             # 激活函数nn.Linear(256, 10)     # 输出层:256→10(10个类别)
)
  • 网络组件详解
    1. nn.Flatten():图像预处理层,将输入的 2D 图像张量(28×28)转换为 1D 向量(784 维),适配全连接层输入格式。
    2. nn.Linear(784, 256):全连接隐藏层,接收 784 维输入,输出 256 维特征,通过矩阵乘法实现特征映射。
    3. nn.ReLU():激活函数,引入非线性,公式为 $ReLU(x) = max(0, x)$,解决线性模型表达能力不足的问题。
    4. nn.Linear(256, 10):全连接输出层,将 256 维特征映射到 10 维,对应 10 个类别的原始得分(logits)。

(三)权重初始化函数

# 初始化权重
def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)  # 正态分布初始化权重net.apply(init_weights);
  • 功能说明
    • 定义 init_weights 函数,仅对全连接层(nn.Linear)进行权重初始化。
    • 使用 nn.init.normal_ 按正态分布 $N(0, 0.01^2)$ 初始化权重,避免初始权重过大/过小导致训练不稳定(如梯度消失/爆炸)。
    • net.apply(init_weights):递归遍历网络所有层,对符合条件的层执行初始化操作。

(四)超参数与训练组件设置

# 超参数
batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss()  # 交叉熵损失(含Softmax)
trainer = torch.optim.SGD(net.parameters(), lr=lr)  # 随机梯度下降优化器
  • 关键组件说明
    1. 超参数
      • batch_size=256:每次训练迭代的样本数量,平衡训练速度与稳定性。
      • lr=0.1:学习率,控制参数更新幅度。
      • num_epochs=10:训练轮次,即遍历整个训练集的次数。
    2. 损失函数nn.CrossEntropyLoss() 适用于多分类任务,内部集成了 Softmax 函数,直接接收网络输出的 logits 计算损失。
    3. 优化器torch.optim.SGD 为随机梯度下降优化器,接收网络参数和学习率,负责更新权重以最小化损失。

(五)数据加载

train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
  • 功能说明:通过 d2l 工具库加载 Fashion-MNIST 数据集,返回训练集迭代器 train_iter 和测试集迭代器 test_iter
  • 数据特性:每张图像为 28×28 灰度图,训练集 60000 样本,测试集 10000 样本,共 10 个服装类别。

(六)训练与评估函数实现

1. 单轮训练函数

def train_epoch(net, train_iter, loss, trainer):"""训练一个epoch"""net.train()  # 切换到训练模式(启用Dropout、BatchNorm等训练特性)total_loss = 0.0total_correct = 0total_samples = 0for X, y in train_iter:# 前向传播:计算预测值和损失y_hat = net(X)l = loss(y_hat, y)# 反向传播 + 参数更新trainer.zero_grad()  # 清空上一轮梯度(避免累积)l.backward()         # 自动计算梯度(基于计算图)trainer.step()       # 根据梯度更新参数# 统计训练指标total_loss += l.item() * X.shape[0]  # 累计总损失(乘以批量大小还原真实损失)# 计算正确预测数:argmax(dim=1)取预测概率最大的类别索引total_correct += (y_hat.argmax(dim=1) == y).sum().item()total_samples += X.shape[0]  # 累计处理样本数# 返回平均损失和训练准确率return total_loss / total_samples, total_correct / total_samples

2. 测试集评估函数

def evaluate_accuracy(net, test_iter):"""评估测试集准确率"""net.eval()  # 切换到评估模式(禁用Dropout、固定BatchNorm统计量)total_correct = 0total_samples = 0with torch.no_grad():  # 禁用梯度计算,减少内存占用并加速for X, y in test_iter:y_hat = net(X)total_correct += (y_hat.argmax(dim=1) == y).sum().item()total_samples += X.shape[0]return total_correct / total_samples  # 返回测试准确率

(七)执行训练流程

for epoch in range(num_epochs):train_loss, train_acc = train_epoch(net, train_iter, loss, trainer)test_acc = evaluate_accuracy(net, test_iter)print(f" epoch {epoch+1}:")print(f"  训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.4f}")print(f"  测试准确率: {test_acc:.4f}")
  • 流程说明
    1. 遍历 num_epochs 个训练轮次。
    2. 每轮调用 train_epoch 完成训练集的一次遍历,获取训练损失和训练准确率。
    3. 调用 evaluate_accuracy 评估模型在测试集上的性能。
    4. 打印当前轮次的训练损失、训练准确率和测试准确率,监控模型训练进度。

三、核心知识点总结

  1. 全连接网络结构:通过 nn.Sequential 堆叠层,Flatten 层适配图像输入,Linear 层实现特征映射,ReLU 引入非线性。
  2. 权重初始化:合理的初始化(如小方差正态分布)是训练稳定的关键。
  3. 训练三要素:交叉熵损失(多分类任务)、SGD 优化器(参数更新)、批量训练(效率与稳定性平衡)。
  4. 训练/评估模式net.train()net.eval() 切换网络状态,torch.no_grad() 优化评估过程。
  5. 指标统计:训练损失(平均损失)、准确率(正确预测数/总样本数)用于监控模型性能。
http://www.hskmm.com/?act=detail&tid=35157

相关文章:

  • Sql查询优化方案
  • 实用指南:深入解析HarmonyOS ArkTS:从语法特性到实战应用
  • 2025 防水背衬板厂家最新推荐榜:剖析质量与口碑,优选品牌助您精准采购
  • 如何安装fluentd 和fluentd-mongo的插件?然后收集nginx的 json格式的数据写到mongodb
  • 2025年气柱袋厂家推荐排行榜,防震/防摔/食品级气柱袋,奶瓶/奶粉/电子产品/化妆品气柱袋,缓冲包装与物流运输优选方案
  • 2025 年防火涂料厂家最新推荐排行榜:精选优质企业,涵盖钢结构各类型涂料,助您精准选品
  • PHP码农的微信业务开发利器
  • 词向量:从 One-Hot 到 BERT Embedding,NLP 文本表示的核心技术 - 实践
  • 2025年深圳网站建设/外贸独立站推广/阿里巴巴代运营/1688店铺代运营/短视频运营推广/微信小程序开发服务商权威推荐榜
  • Android studio build报错 - show
  • 2025 年蜂窝大板厂家最新推荐榜单:覆盖云南 / 昆明吊顶、铝门、别墅等场景,优质企业助力选对产品
  • 生产环境RAG系统失效原因与解决方案
  • 2025 彩石瓦厂家最新推荐排行榜:权威解析金属瓦 / 屋顶瓦优质厂商,金属/屋顶/凉亭/昆明/云南彩石瓦厂家推荐
  • 文明元代码:价值原语、共识具身与关系语法
  • 2025年扒胎机厂家推荐排行榜,液压无损扒胎机,全自动扒胎机,汽保扒胎机,轮胎扒胎机,汽车扒胎机,大轮胎扒胎机,无损扒胎机,辽南扒胎机,小车扒胎机,立式扒胎机公司推荐
  • springboot集成echarts显示图表
  • 2025年储罐厂家权威推荐榜:钢衬塑储罐/钢塑复合储罐/化工储罐/防腐储罐/PE储罐/盐酸储罐/硫酸储罐/聚丙烯储罐/不锈钢储罐/次氯酸钠储罐专业选购指南
  • Avalonia使用代码更改滑动条的颜色
  • 【SPIE出版】第四届云计算、性能计算与深度学习国际学术会议 (CCPCDL 2025)
  • 【IC原厂】VKD104CB 内建稳压电路低电流4路触摸检测IC
  • 2025年氧化镁厂家最新权威推荐榜:活性氧化镁,肥料级氧化镁,高纯度氧化镁源头厂家深度解析及选购指南
  • 上班摸鱼新姿势!抖音爆火的线稿涂鸦也太治愈了~
  • n8n错误处理全攻略:构建稳定可靠的自动化工作流
  • 2025年通风天窗厂家最新权威推荐榜:通风天窗,排烟天窗,通风气楼,屋顶通风器,顺坡气楼,10A通风天窗,1型通风天窗,TC5A通风天窗,TC12B通风天窗,屋脊通风天窗专业制造与高效通风解决方案
  • DAO模式代码阅读及应用
  • 2025年智能照明系统/模块厂家推荐排行榜,工厂/车间/改建/高亮/高光效/泛光/免维护/投光/大功率智能照明系统及模块公司精选
  • DxO Nik Collection 8.0:7 款专业摄影插件套装,一站式图像后期解决方案
  • [随笔11] 最近的心情 - 枝-致
  • 三款AI平台部署实战体验:Dify、扣子与BuildingAI深度对比
  • #OO之接口-DAO模式代码阅读及应用