当前位置: 首页 > news >正文

pytorch实训题

代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time

1. 数据预处理与加载

transform = transforms.Compose([
transforms.RandomCrop(32, padding=4), # 数据增强:随机裁剪
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # CIFAR-10标准归一化参数
])

加载数据集

trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=128, shuffle=True, num_workers=2
)

testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=100, shuffle=False, num_workers=2
)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')

2. 定义卷积神经网络模型

class CIFAR10CNN(nn.Module):
def init(self):
super(CIFAR10CNN, self).init()
# 卷积层部分
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1), # 3输入通道,64输出通道,3x3卷积核
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # 2x2池化,步长2

        nn.Conv2d(64, 128, 3, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.Conv2d(128, 128, 3, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),nn.Conv2d(128, 256, 3, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Conv2d(256, 256, 3, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))# 全连接层部分self.fc_layers = nn.Sequential(nn.Dropout(0.5),nn.Linear(256 * 4 * 4, 512),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(512, 10)  # 10个分类)def forward(self, x):x = self.conv_layers(x)x = x.view(-1, 256 * 4 * 4)  # 展平特征图x = self.fc_layers(x)return x

3. 初始化模型、损失函数和优化器

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

net = CIFAR10CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

4. 训练模型

def train(epochs=20):
train_losses = []
test_losses = []
best_acc = 0.0

print("开始训练...")
start_time = time.time()for epoch in range(epochs):net.train()  # 训练模式running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)# 清零梯度optimizer.zero_grad()# 前向传播、计算损失、反向传播、参数更新outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 统计损失running_loss += loss.item()if i % 100 == 99:  # 每100个batch打印一次print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}')running_loss = 0.0# 每个epoch结束后测试test_loss, acc = test()train_losses.append(running_loss / len(trainloader))test_losses.append(test_loss)# 学习率调整scheduler.step(test_loss)# 保存最佳模型if acc > best_acc:best_acc = acctorch.save(net.state_dict(), 'best_model.pth')print(f'Epoch {epoch+1} 测试准确率: {acc:.2f}%')print(f'训练完成,耗时: {time.time() - start_time:.2f}秒')
print(f'最佳测试准确率: {best_acc:.2f}%')# 绘制损失曲线
plt.plot(train_losses, label='训练损失')
plt.plot(test_losses, label='测试损失')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('loss_curve.png')
plt.close()return train_losses, test_losses

5. 测试模型

def test():
net.eval() # 评估模式
correct = 0
total = 0
test_loss = 0.0

with torch.no_grad():  # 不计算梯度for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = net(images)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()acc = 100 * correct / total
avg_loss = test_loss / len(testloader)
return avg_loss, acc

6. 测试每个类别的准确率

def test_class_accuracy():
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print(f'类别 {classes[i]} 的准确率: {100 * class_correct[i] / class_total[i]:.2f}%')

7. 显示一些测试图像和预测结果

def show_predictions(num_images=5):
dataiter = iter(testloader)
images, labels = next(dataiter)

# 打印原始图像
imshow(torchvision.utils.make_grid(images))
print('真实标签: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))# 预测
outputs = net(images.to(device))
_, predicted = torch.max(outputs, 1)
print('预测标签: ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(4)))

def imshow(img):
img = img / 2 + 0.5 # 反归一化
npimg = img.numpy()
plt.figure(figsize=(10, 4))
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.savefig('predictions.png')
plt.close()

主程序

if name == 'main':
# 训练模型(20个epochs)
train_losses, test_losses = train(epochs=20)

# 加载最佳模型
net.load_state_dict(torch.load('best_model.pth'))# 评估模型
print("\n每个类别的准确率:")
test_class_accuracy()# 显示预测结果
show_predictions()

运行结果
使用设备: cuda:0
开始训练...
[1, 100] loss: 1.762
[1, 200] loss: 1.421
[1, 300] loss: 1.285
Epoch 1 测试准确率: 57.32%
[2, 100] loss: 1.135
[2, 200] loss: 1.052
...
训练完成,耗时: 456.23秒
最佳测试准确率: 85.67%

每个类别的准确率:
类别 plane 的准确率: 89.20%
类别 car 的准确率: 92.50%
类别 bird 的准确率: 78.30%
类别 cat 的准确率: 72.10%
类别 deer 的准确率: 84.50%
类别 dog 的准确率: 79.80%
类别 frog 的准确率: 88.70%
类别 horse 的准确率: 87.60%
类别 ship 的准确率: 91.20%
类别 truck 的准确率: 89.40%

http://www.hskmm.com/?act=detail&tid=31803

相关文章:

  • 数据库基础知识1
  • 近期模拟赛汇总
  • 实用指南:部署Tomcat11.0.11(Kylinv10sp3、Ubuntu2204、Rocky9.3)
  • Hbase的安装与配置
  • 【Azure App Service】App Service是否支持PHP的版本选择呢?
  • OAuth/OpenID Connect 渗透测试完全指南
  • Problem K. 置换环(The ICPC online 2025)思路解析 - tsunchi
  • Go 语言和 Tesseract OCR 识别英文数字验证码
  • Markdown转换为Word:Pandoc模板使用指南 - 实践
  • 2025年10月小程序开发公司最新推荐排行榜,小程序定制开发,电商小程序开发,预订服务小程序开发,活动报名小程序开发!
  • 复习CSharp
  • Rust 和 Tesseract OCR 实现英文数字验证码识别
  • 数据结构-循环队列
  • C语言学习——键盘录入
  • 2025年10月软件开发公司最新推荐,软件定制开发,crm系统定制软件开发,管理系统软件开发,物联网软件开发公司推荐!
  • C语言学习——运算符的学习
  • 第十五篇
  • 数据结构-双向循环链表
  • 数据结构-顺序栈
  • Erlang 的英文数字验证码识别系统设计与实现
  • 使用Django从零开始构建一个个人博客系统 - 实践
  • 2025年磨床厂家TOP企业品牌推荐排行榜,平面磨床,外圆磨床,数控平面磨床,数控外圆磨床,7163平面磨床推荐这十家公司!
  • cifar10
  • [LangChain] 02. 模型接口
  • 摄像头调试
  • C语言学习——字符串数据类型
  • 感知节点@4@ ESP32+arduino+ 第二个程序 LED灯显示
  • WebGL学习及项目实战(第02期:绘制一个点)
  • 2025 年 10 月国内加工中心制造商最新推荐排行榜:涵盖立式、卧式、龙门及多规格型号!
  • display ip routing-table protocol ospf 概念及题目 - 详解