当前位置: 首页 > news >正文

cifar10

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from multiprocessing import freeze_support
import sys

1. 加载和预处理数据

def load_data():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True,transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True,num_workers=2
)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True,transform=transform
)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False,num_workers=2
)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')return trainloader, testloader, classes

2. 构建网络

class Net(nn.Module):
def init(self):
super().init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = torch.flatten(x, 1)  # 展平x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x

3. 编译网络(定义损失函数和优化器)

def compile_model(net):
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
return criterion, optimizer

4. 训练网络(已同步设备)

def train(net, trainloader, criterion, optimizer, device, epochs=2):
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
# 核心:数据与模型设备同步
inputs, labels = inputs.to(device), labels.to(device)

        # 梯度清零optimizer.zero_grad()# 前向计算 + 反向传播 + 优化参数outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 打印训练日志running_loss += loss.item()if i % 2000 == 1999:  # 每2000个batch打印一次print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')running_loss = 0.0print('训练完成')

5. 测试网络(已同步设备)

def test(net, testloader, classes, device):
correct = 0
total = 0
# 测试时不计算梯度,加快速度
with torch.no_grad():
for data in testloader:
images, labels = data
# 数据与模型设备同步
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print(f'测试集整体准确率: {100 * correct // total} %')# 按类别统计准确率
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predictions = torch.max(outputs, 1)# 统计每个类别的预测结果for label, prediction in zip(labels, predictions):if label == prediction:correct_pred[classes[label]] += 1total_pred[classes[label]] += 1# 打印各类别准确率
for classname, correct_count in correct_pred.items():accuracy = 100 * float(correct_count) / total_pred[classname]print(f'类别: {classname:5s} 准确率: {accuracy:.1f} %')

if name == 'main':
freeze_support() # 解决Windows多进程问题
# 自动选择设备(有GPU用GPU,无则用CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"当前使用设备: {device}")

# 加载数据、初始化模型和优化器
trainloader, testloader, classes = load_data()
net = Net().to(device)  # 模型放到指定设备
criterion, optimizer = compile_model(net)# 重定向输出到文件,同时保留控制台打印
original_stdout = sys.stdout
with open('cifar10_result.txt', 'w') as f:sys.stdout = fprint(f"当前使用设备: {device}")train(net, trainloader, criterion, optimizer, device)test(net, testloader, classes, device)sys.stdout = original_stdout  # 恢复控制台输出print("训练完成!结果已保存到 cifar10_result.txt ")

image

http://www.hskmm.com/?act=detail&tid=31778

相关文章:

  • [LangChain] 02. 模型接口
  • 摄像头调试
  • C语言学习——字符串数据类型
  • 感知节点@4@ ESP32+arduino+ 第二个程序 LED灯显示
  • WebGL学习及项目实战(第02期:绘制一个点)
  • 2025 年 10 月国内加工中心制造商最新推荐排行榜:涵盖立式、卧式、龙门及多规格型号!
  • display ip routing-table protocol ospf 概念及题目 - 详解
  • C语言学习——小数数据类型
  • 高敏感人应对焦虑
  • Palantir本体论以及对智能体建设的价值与意义
  • 2025 年执业兽医资格证备考服务机构推荐榜,执业兽医资格证培训机构/执兽考试机构/考试辅导机构获得行业推荐
  • [LangChain] 基本介绍
  • 题解:P6755 [BalticOI 2013] Pipes (Day1)
  • Palantir 的“本体工程”的核心思路、技术架构与实践示例
  • P14164 [ICPC 2022 Nanjing R] 命题作文
  • C语言学习——整数变量
  • 语音合成技术从1秒样本学习表达风格
  • 我的高敏感和家人
  • 对称多项式
  • usb储存之BOT/UAS内核驱动
  • 简述flux思想?
  • 风控评分卡
  • 20232428 2025-2026-1 《网络与系统攻防技术》实验一实验报告
  • JAVA对象内存布局
  • 20232409 2025-2026-1 《网络与系统攻防技术》实验二实验报告
  • 10月15号
  • 记录一次客户现场环境,银河麒麟V10操作系统重启后,进入登录页面后卡死,鼠标键盘无响应的解决过程
  • 图 生成树
  • DolphinScheduler 3.1.9 单机版重启后,项目、流程定义等数据全部丢失
  • ManySpeech.AliParaformerAsr 使用指南