NiN模型
import torch
from torch import nn
from d2l import torch as d2l
def NiN_block(in_chanels,out_chanels,kernel_size,padding,stride):#NiN块return nn.Sequential(nn.Conv2d(in_chanels,out_chanels,kernel_size,padding=padding,stride=stride),nn.ReLU(),nn.Conv2d(out_chanels,out_chanels,kernel_size=1),nn.ReLU(),nn.Conv2d(out_chanels,out_chanels,kernel_size=1),nn.ReLU())#NiN网络
net=nn.Sequential(NiN_block(1,96,11,stride=4,padding=0),nn.MaxPool2d(kernel_size=3,stride=2),NiN_block(96,256,kernel_size=5,padding=2,stride=1),nn.MaxPool2d(kernel_size=3,stride=2),NiN_block(256,384,kernel_size=3,padding=1,stride=1),nn.MaxPool2d(kernel_size=3,stride=2),nn.Dropout(p=0.5),NiN_block(384,10,3,1,1),#输出通道最终为10,因为等会要用于数字0~9分类nn.AdaptiveAvgPool2d((1,1)),nn.Flatten()
)
X=torch.rand((1,1,224,224))
for layer in net:X=layer(X)print(layer.__class__.__name__,X.shape)
解释分析:
具体解释:
二.训练NiN网络
import torch
from torch import nn
from d2l import torch as d2l
def NiN_block(in_chanels,out_chanels,kernel_size,padding,stride):return nn.Sequential(nn.Conv2d(in_chanels,out_chanels,kernel_size,padding=padding,stride=stride),nn.ReLU(),nn.Conv2d(out_chanels,out_chanels,kernel_size=1),nn.ReLU(),nn.Conv2d(out_chanels,out_chanels,kernel_size=1),nn.ReLU())
net=nn.Sequential(NiN_block(1,96,11,stride=4,padding=0),nn.MaxPool2d(kernel_size=3,stride=2),NiN_block(96,256,kernel_size=5,padding=2,stride=1),nn.MaxPool2d(kernel_size=3,stride=2),NiN_block(256,384,kernel_size=3,padding=1,stride=1),nn.MaxPool2d(kernel_size=3,stride=2),nn.Dropout(p=0.5),NiN_block(384,10,3,1,1),#输出通道最终为10,因为等会要用于数字0~9分类nn.AdaptiveAvgPool2d((1,1)),nn.Flatten()
)
batch_size=128#批量数
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size,resize=224)
lr=0.05#学习率
nums_epochs=10#学习10代
d2l.train_ch6(net,train_iter,test_iter,nums_epochs,lr,d2l.try_gpu())#这个函数封装在d2l(本书的一个包)