当前位置: 首页 > news >正文

从零实现 VGG-16

博客地址:https://www.cnblogs.com/zylyehuo/

参考视频:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

需要用到的库

  • torch

安装有问题可参考网上教程

pip install torch
  • protobuf
pip install protobuf

model.py

import torch
from torch import nnclass Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.model1 = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model1(x)return xif __name__ == '__main__':tudui = Tudui()print(tudui)input = torch.ones((64, 3, 32, 32))output = tudui(input)print(output.shape)

image

train.py

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import timefrom model import *device = torch.device("cuda" if torch.cuda.is_available() else "cpu")train_data = torchvision.datasets.CIFAR10("./dataset", train=True, download=True,transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, download=True,transform=torchvision.transforms.ToTensor())train_data_size = len(train_data)
test_data_size = len(test_data)print(f"训练数据集的长度为:{train_data_size}")
print(f"测试数据集的长度为:{test_data_size}")train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)tudui = Tudui()
tudui.to(device)loss_fn = nn.CrossEntropyLoss()
loss_fn.to(device)learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)total_train_step = 0
total_test_step = 0
epoch = 20writer = SummaryWriter("./logs_train")for i in range(epoch):start_time = time.time()tudui.train()print(f"--------第 {i + 1} 轮训练开始--------")for data in train_dataloader:optimizer.zero_grad()imgs, targets = dataimgs = imgs.to(device)targets = targets.to(device)outputs = tudui(imgs)loss = loss_fn(outputs, targets)loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:end_time = time.time()print(f"训练次数:{total_train_step}, Loss:{loss.item()}, 训练所花时间:{end_time - start_time}")writer.add_scalar("train_loss", loss.item(), total_train_step)tudui.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataimgs = imgs.to(device)targets = targets.to(device)outputs = tudui(imgs)loss = loss_fn(outputs, targets)total_test_loss += lossaccuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracytotal_test_step += 1print(f"整体测试集上的 Loss:{total_test_loss}")print(f"整体测试集上的正确率:{total_accuracy / test_data_size}")writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step)if epoch % 10 == 0:# torch.save(tudui, f"tudui_{i}.pth")torch.save(tudui.state_dict(), f"tudui_{i}.pth")print("模型已保存")writer.close()

CPU 效果

74ea633712986de3764c5599f00f3474

GPU 效果

7d4f24f1077c80dff3eaae4c68711b34
5c30a95b2fdfeb41742a434b145b8bbe

test.py

dog

import torch
import torchvision
from model import *
from PIL import Imagetest_data = torchvision.datasets.CIFAR10("./dataset", train=False, download=True,transform=torchvision.transforms.ToTensor())
print(test_data.classes)image_path = "./imgs/dog.png"
image = Image.open(image_path)
image = image.convert("RGB")
# print(image)
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)), torchvision.transforms.ToTensor()])
image = transform(image)
image = torch.reshape(image, (1, 3, 32, 32))model = Tudui()
model.load_state_dict(torch.load("./tudui_19.pth"))
# print(model)model.eval()
with torch.no_grad():output = model(image)print(output)
print(output.argmax(1))

db45ce06e1fe1806d24f6c71a9cda3c1

补充

如果电脑没有 GPU,可以借助第三方平台使用 GPU

Google-Colab

e0c6102a7bc511026b60713257110227

http://www.hskmm.com/?act=detail&tid=29234

相关文章:

  • WPF上位机入门教程
  • 潘院士高瞻远瞩:三大趋势勾勒中国AI发展路径,元人文构想恰逢其时
  • 2025家居MES厂家最新权威推荐榜:智能制造与高效管理深度
  • 开源 C# 快速构建(七)通讯--串口
  • 2025新能源冲压件厂家权威推荐榜:技术革新与品质保障深度解
  • 浮点数的相等性判断
  • ubuntu18
  • 2025国庆dp
  • 2025数控锯床厂家权威推荐榜:精密加工与高效生产口碑之选
  • FFmpeg开发笔记(八十二)使用国产直播服务器smart_rtmpd执行推流操作
  • 实验室装修厂家最新权威推荐榜:专业设计与施工品质深度解析
  • 生成式AI在红队测试中的应用:构建自动化工具
  • 杂题 10月份
  • 2025年UV LED点光源厂家权威推荐榜:精准固化与高效能
  • NVR软件快速对比表
  • 20232410 2025-2026-1 《网络与系统攻防技术》 实验一实验报告
  • 在Windows系统打造基于ConEmu的命令行工具环境
  • 2025工矿灯厂家最新权威推荐榜:工业照明技术革新与品质保障
  • ZR 2025 十一集训 Day 1
  • 2025广东粉末厂家最新权威推荐榜:技术实力与市场口碑深度解
  • [KaibaMath]1007 关于数列极限存在的唯一性证明
  • 20232418-郭俊廷-实验一-逆向及Bof基础实践
  • 十月模拟赛
  • 2025年成都软件开发机构最新推荐排行榜,涵CRM,物联网,运维,仓储,人力多系统,技术实力与市场口碑深度解析
  • 2025硅藻土定制厂家权威推荐榜:专业生产与深度定制实力解析
  • 变量、函数命名方式
  • 汉文博士 0.7 版:支持统一码 17.0,新增字体分析器,优化词典编译器
  • 2025燃气采暖锅炉厂家权威推荐榜:高效节能与品质保障口碑之
  • 【python】python进阶——Redis模块 - 教程
  • 2025 年 10 月桥架厂家最新推荐:专业制造与品牌保障口碑之选!