当前位置: 首页 > news >正文

66页作业

点击查看代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np# 设置设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=False, num_workers=2)# 定义类别
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 定义网络模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 64, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(64 * 4 * 4, 64)self.fc2 = nn.Linear(64, 10)self.relu = nn.ReLU()def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = self.relu(self.conv3(x))x = x.view(-1, 64 * 4 * 4)x = self.relu(self.fc1(x))x = self.fc2(x)return xnet = Net().to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)# 训练网络
train_losses = []
train_accs = []
test_accs = []epochs = 10
for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()# 计算训练集准确率train_acc = 100. * correct / totaltrain_loss = running_loss / len(trainloader)train_losses.append(train_loss)train_accs.append(train_acc)# 在测试集上评估correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = net(images)_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()test_acc = 100. * correct / totaltest_accs.append(test_acc)print(f'Epoch {epoch + 1}, Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%')print('训练完成')# 绘制训练和测试准确率曲线
plt.figure(figsize=(10, 5))
plt.plot(train_accs, label='Train Accuracy')
plt.plot(test_accs, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Training and Test Accuracy')
plt.show()# 绘制训练损失曲线
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training Loss')
plt.show()# 查看每个类别的准确率
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(len(labels)):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print(f'类别 {classes[i]} 的准确率: {100 * class_correct[i] / class_total[i]:.2f}%')

微信图片_20251023222201_2647_35

http://www.hskmm.com/?act=detail&tid=37681

相关文章:

  • 写电商详情页不用挠头了:一个还算实用的AI指令模板
  • CF2153D
  • 20232417 2025-2026-1 《网络与系统攻防技术》实验二实验报告
  • iPhone口袋状态检测技术揭秘
  • 搜维尔科技:IROS 2025现场,触觉力反馈、数据手套遥操作机器人灵巧手平台系统解决方案
  • 一些题解
  • Node.js JSON import attributes All In One
  • DeepSeek的“认知提纯”能力解析
  • 梦熊知更鸟赛水题题解合集 (两个人的演唱会 使一颗心免于哀伤 空气蛹)
  • CF2154D
  • Plya 定理学习笔记 | ABC428G 题解
  • 第十八天
  • 第十七天
  • vue3+elementPlus el-date-picker 自定义禁用状态hook 建立结束时间不能小于开始时间
  • [优先队列] P3611 [USACO17JAN] Cow Dance Show S 题解
  • 搜维尔科技将携手Xsens|Haption|Tesollo|Manus亮相IROS 2025国际智能机器人与系统会议
  • 【做题记录】贪心--提高组
  • 如何炫酷地使用集合划分容斥
  • 简单云计算算法--20251023
  • 处理空输入踩的坑
  • latex输入公式
  • 【为美好CTF献上祝福】 New Star 2025 逆向笔记
  • XXL-JOB(5)
  • 蛋白表达原理与关键要素解析
  • Ramanujan Master Theorem
  • 顾雅南的声音美化课堂
  • C++案例 自定义数组
  • 【周记】2025.10.13~2025.10.19
  • 背包
  • 10.23《程序员修炼之道 从小工到专家》第二章 注重实效的途径 - GENGAR