当前位置: 首页 > news >正文

pytorch第66页

点击查看代码
import torch
from torch import optim, nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import classification_report
from PIL import Image
import time
from matplotlib import pyplot as pltdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#数据加载
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])# 加载CIFAR-10数据集
def load_data():train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)return train_loader, test_loader, test_dataset#定义MYVGG模型
class MYVGG(nn.Module):def __init__(self, num_classes=10):super(MYVGG, self).__init__()self.features = nn.Sequential(# Block 1nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),# Block 2nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(128, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),# Block 3nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),# Block 4nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),# Block 5nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),)self.classifier = nn.Sequential(nn.Dropout(0.5),nn.Linear(512, num_classes))def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)x = self.classifier(x)return x#训练函数
model = MYVGG().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)def train(model, train_loader, criterion, optimizer, epoch_num=50):model.train()train_loss = []train_acc = []for epoch in range(epoch_num):start_time = time.time()running_loss = 0.0current = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(output.data, 1)total += target.size(0)current += (predicted == target).sum().item()epoch_loss = running_loss / len(train_loader)epoch_acc = 100.0 * current / totaltrain_loss.append(epoch_loss)train_acc.append(epoch_acc)end_time = time.time()print(f"Epoch [{epoch+1}/{epoch_num}], Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.2f}%, Time: {end_time-start_time:.2f}s")return train_loss, train_acc#测试函数
def test(model, test_loader):model.eval()all_pred = []all_label = []with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)outputs = model(data)_, predicted = torch.max(outputs.data, 1)all_pred.extend(predicted.cpu().numpy())all_label.extend(target.cpu().numpy())all_pred = np.array(all_pred)all_label = np.array(all_label)accuracy = (all_pred == all_label).mean()accuracy = 100.0 * accuracyprint(f'测试准确率: {accuracy:.4f}%')print("分类效果评估:")target_names = [str(i) for i in range(10)]report = classification_report(all_label, all_pred, target_names=target_names)print(report)if __name__ == '__main__':print(f"24信计2班 佘婷婷 2024310143102")print(f"device:{device}")epoch_num = 20train_loader, test_loader, test_dataset = load_data()train_loss, train_acc = train(model, train_loader, criterion, optimizer, epoch_num)test(model, test_loader)#绘制结果plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.plot(range(1, epoch_num+1), train_loss)plt.title("Training Loss")plt.xlabel("Epoch")plt.ylabel("Loss")plt.subplot(1, 2, 2)plt.plot(range(1, epoch_num+1), train_acc)plt.title("Training Accuracy")plt.xlabel("Epoch")plt.ylabel("Accuracy (%)")plt.tight_layout()plt.show()

image

http://www.hskmm.com/?act=detail&tid=36603

相关文章:

  • Navicat Premium 17 官方版下载安装教程|支持MySQL、PostgreSQL、MongoDB等数据库
  • 从埋点到用户行为分析:ClkLog 如何帮助企业读懂用户
  • 函数的高级
  • C#实现OPC客户端
  • Gitee:数字化转型浪潮中的项目管理利器
  • 有什么指标可以判断手机是否降频
  • 实用指南:Linux内核kallsyms符号压缩与解压机制
  • 5G企业应用的七大场景与商业机遇
  • 2025 水泥墩源头厂家最新推荐排行榜:光伏 / 围挡 / 交通 / 防撞水泥墩多品类优选,实力品牌权威榜单
  • 高效数据结构 - 循环队列
  • 2025 年国内活塞杆厂家最新推荐排行榜:聚焦精密 / 不锈钢 / 油缸 / 气缸 / 45# 镀铬类产品,助力企业精准挑选可靠合作方
  • Day16
  • 数据类型,二元运算符,自动类型提升规则,关系运算,取余模运算
  • 股票技术面分析平台QuantMatrix深度解析 - 实践
  • 迷宫问题
  • WPF使用MediaCapture开发相机应用(四、相机录视频)
  • 链队
  • Gitee本土化战略深度解析:中国开发者生态的合规与效率革命
  • 2025年10月上海装修公司口碑榜:十强对比评测
  • 02-GPIO-铁头山羊STM32标准库新版笔记
  • 【多校支持、EI检索】第六届大数据与社会科学国际学术会议(ICBDSS 2025)
  • IDC iPaaS市场报告解读:独立厂商与云巨头的“双轨竞速”
  • 2025年10月仓储管理系统推荐:鸿链云仓领衔五大方案对比评测榜
  • 2025年10月电动叉车销售公司排行榜:五家主流服务商对比评测
  • 2025年口罩机厂家权威推荐榜:全自动口罩机器,全自动KN95口罩机源头企业综合评测与采购指南
  • 2025年包装机厂家权威推荐榜单:全自动包装机/包装生产线/非标定制机器与生产线专业选购指南
  • Timing Signoff 技术精要
  • Oracle故障处理:10G RAC srvctl注册实例正常,但是crs切不能管理实例
  • 杂题选做-2
  • 读书笔记:白话解读Oracle范围分区