当前位置: 首页 > news >正文

NiN模型

NiN模型

import torch
from torch import nn
from d2l import torch as d2l
def NiN_block(in_chanels,out_chanels,kernel_size,padding,stride):#NiN块return nn.Sequential(nn.Conv2d(in_chanels,out_chanels,kernel_size,padding=padding,stride=stride),nn.ReLU(),nn.Conv2d(out_chanels,out_chanels,kernel_size=1),nn.ReLU(),nn.Conv2d(out_chanels,out_chanels,kernel_size=1),nn.ReLU())#NiN网络
net=nn.Sequential(NiN_block(1,96,11,stride=4,padding=0),nn.MaxPool2d(kernel_size=3,stride=2),NiN_block(96,256,kernel_size=5,padding=2,stride=1),nn.MaxPool2d(kernel_size=3,stride=2),NiN_block(256,384,kernel_size=3,padding=1,stride=1),nn.MaxPool2d(kernel_size=3,stride=2),nn.Dropout(p=0.5),NiN_block(384,10,3,1,1),#输出通道最终为10,因为等会要用于数字0~9分类nn.AdaptiveAvgPool2d((1,1)),nn.Flatten()
)
X=torch.rand((1,1,224,224))
for layer in net:X=layer(X)print(layer.__class__.__name__,X.shape)

 解释分析:

nn.AdaptiveAvgPool2d((1, 1)) 是 PyTorch 中的自适应平均池化层,它的作用是将输入的任意尺寸的特征图,通过平均池化操作,固定输出为 (1, 1) 大小的特征图(即高和宽都为 1)。

具体解释:

  1. 自适应(Adaptive)
     
    与普通的 nn.AvgPool2d 不同,它不需要手动指定池化核的大小(kernel_size)和步长(stride),而是直接指定输出特征图的尺寸。PyTorch 会自动计算所需的池化核大小和步长,以确保输出符合指定尺寸。
  2. 参数 (1, 1)
     
    表示输出特征图的高和宽都为 1。例如:
    • 如果输入是形状为 (N, C, H, W) 的特征图(N 是批量大小,C 是通道数,H 是高,W 是宽),
    • 经过 nn.AdaptiveAvgPool2d((1, 1)) 后,输出形状会变为 (N, C, 1, 1)
  3. 在你的代码中的作用
     
    在 NiN 网络中,最后一个 NiN_block 的输出通道数是 10(对应 10 个类别),假设此时特征图形状为 (N, 10, H, W)(例如经过前面的层后,H 和 W 可能是 7 左右)。
     
    通过 nn.AdaptiveAvgPool2d((1, 1)) 后,特征图会被压缩为 (N, 10, 1, 1),再经过 nn.Flatten() 展平为 (N, 10),正好对应 10 个类别的输出,可直接用于分类任务(如计算交叉熵损失)。
 
简单说,这个层的核心作用是 **“压缩空间维度,保留通道信息”**,方便后续将特征图转换为分类所需的向量形式。

二.训练NiN网络

import torch
from torch import nn
from d2l import torch as d2l
def NiN_block(in_chanels,out_chanels,kernel_size,padding,stride):return nn.Sequential(nn.Conv2d(in_chanels,out_chanels,kernel_size,padding=padding,stride=stride),nn.ReLU(),nn.Conv2d(out_chanels,out_chanels,kernel_size=1),nn.ReLU(),nn.Conv2d(out_chanels,out_chanels,kernel_size=1),nn.ReLU())
net=nn.Sequential(NiN_block(1,96,11,stride=4,padding=0),nn.MaxPool2d(kernel_size=3,stride=2),NiN_block(96,256,kernel_size=5,padding=2,stride=1),nn.MaxPool2d(kernel_size=3,stride=2),NiN_block(256,384,kernel_size=3,padding=1,stride=1),nn.MaxPool2d(kernel_size=3,stride=2),nn.Dropout(p=0.5),NiN_block(384,10,3,1,1),#输出通道最终为10,因为等会要用于数字0~9分类nn.AdaptiveAvgPool2d((1,1)),nn.Flatten()
)
batch_size=128#批量数
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size,resize=224)
lr=0.05#学习率
nums_epochs=10#学习10代
d2l.train_ch6(net,train_iter,test_iter,nums_epochs,lr,d2l.try_gpu())#这个函数封装在d2l(本书的一个包)
 
http://www.hskmm.com/?act=detail&tid=32645

相关文章:

  • 2025秋_13
  • 2023 ICPC Hefei
  • 斑马日记2025.10.16
  • 可能是 ICPC2025 西安站游记
  • Active Directory用户账户安全配置与漏洞防范指南
  • 实验一 现代C++编程初体验
  • day013
  • Git SSH 推送完整流程总结
  • 运筹学奖学金项目促进科研多元化发展
  • 非托管内存怎么计算?
  • ubuntu配置镜像源和配置containerd安装源
  • dotnet集合类型性能优化的两个小儿科的知识点
  • ABC420 AtCoder Beginner Contest 420 游记(VP)
  • 【题解】CF2086C Disappearing Permutation
  • Windows 事件ID + 登录类型 + 服务对应表大全
  • 5-互评-OO之接口-DAO模式代码阅读及应用
  • [Paper Reading] VLM2VEC: TRAINING VISION-LANGUAGE MODELS FOR MASSIVE MULTIMODAL EMBEDDING TASKS
  • Index of /ubuntu-cdimage/ubuntukylin/releases/
  • ubuntu安装和设置为图形界面或命令行界面
  • 10.16日学习笔记
  • day 3
  • PWN手的成长之路-18-ciscn_2019_ne_5-rettext
  • 技术人不用当“兼职运营”:2025微信编辑器实用指南,让产品更新日志/API教程产出效率提升3倍
  • 站位1
  • ubuntu2204系统ip地址配置
  • 10.16 —— 2021ccpc桂林D,B
  • day 2
  • 日志分析-windows日志分析base
  • 2025/10/16 模拟赛笔记 - sb
  • 课后作业3