当前位置: 首页 > news >正文

浅层 CNN 的瓶颈:用 LeNet 实测不同数据集

本实验旨在评估 LeNet 卷积神经网络 在不同复杂度数据集上的表现,以验证其泛化能力与局限性。我们选择了三个典型数据集:
• MNIST:28×28 灰度手写数字,任务简单、类内差异小;
• Fashion-MNIST:28×28 灰度服饰图像,较 MNIST 更复杂,类间区分难度提升;
• CIFAR-10:32×32 彩色自然图像,背景多样、类间相似度高,任务难度显著增加。

在相同网络结构与超参数条件下进行训练:
• 在 MNIST 上,LeNet 能快速收敛并在测试集上达到约 98% 的准确率,表明其足以胜任简单灰度图像分类任务。
• 在 Fashion-MNIST 上,准确率下降至 ≈90%,显示出模型对更复杂特征提取能力的不足。
• 在 CIFAR-10 上,测试准确率仅能维持在 ≈55%–65%,难以进一步提升,暴露出 LeNet 过浅、特征提取有限 的弊端。

实验结果表明,虽然 LeNet 是卷积神经网络的开山之作,但其浅层结构在面对复杂任务时表现出明显的局限性。这一现象也印证了深层网络(如 AlexNet、ResNet)的必要性:通过更深层数、更多卷积核以及正则化手段,才能在复杂数据集上获得显著性能提升。

# ======================== 0) Imports & Setup ========================
import os, time, random, math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as pltdef seed_all(seed=42):random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = False
seed_all(42)device = (torch.device("cuda") if torch.cuda.is_available()else torch.device("mps") if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()else torch.device("cpu")
)
print("Device:", device)# ======================== 1) LeNet (自适应不同输入大小/通道) ========================
class LeNet(nn.Module):"""经典 LeNet 结构:Conv(5)->AvgPool->Conv(5)->AvgPool->FC*3为了兼容不同输入尺寸(28/32)与通道数(1/3),在 __init__ 中通过一次前向推断动态计算扁平化后的维度。"""def __init__(self, in_channels=1, num_classes=10, img_size=28):super().__init__()self.conv1 = nn.Conv2d(in_channels, 6, kernel_size=5)   # -> (H-4)x(W-4)self.pool1 = nn.AvgPool2d(2)                            # /2self.conv2 = nn.Conv2d(6, 16, kernel_size=5)            # -> (H-8)x(W-8)self.pool2 = nn.AvgPool2d(2)                            # /2# 动态计算 flatten 维度with torch.no_grad():dummy = torch.zeros(1, in_channels, img_size, img_size)x = F.relu(self.conv1(dummy))x = self.pool1(x)x = F.relu(self.conv2(x))x = self.pool2(x)flat_dim = x.numel()self.fc1 = nn.Linear(flat_dim, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, num_classes)def forward(self, x):x = F.relu(self.conv1(x))x = self.pool1(x)x = F.relu(self.conv2(x))x = self.pool2(x)x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))return self.fc3(x)# ======================== 2) 数据集定义(仅归一化,不做增强) ========================
# 均值/方差
MNIST_MEAN, MNIST_STD = (0.1307,), (0.3081,)
FASHION_MEAN, FASHION_STD = (0.2860,), (0.3530,)
CIFAR_MEAN, CIFAR_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)def build_loader(dataset_name, batch_size=128, root="./data"):if dataset_name == "MNIST":tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize(MNIST_MEAN, MNIST_STD)])train_ds = datasets.MNIST(root=root, train=True,  download=True, transform=tf)test_ds  = datasets.MNIST(root=root, train=False, download=True, transform=tf)in_channels, img_size = 1, 28elif dataset_name == "FashionMNIST":tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize(FASHION_MEAN, FASHION_STD)])train_ds = datasets.FashionMNIST(root=root, train=True,  download=True, transform=tf)test_ds  = datasets.FashionMNIST(root=root, train=False, download=True, transform=tf)in_channels, img_size = 1, 28elif dataset_name == "CIFAR10":tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD)])train_ds = datasets.CIFAR10(root=root, train=True,  download=True, transform=tf)test_ds  = datasets.CIFAR10(root=root, train=False, download=True, transform=tf)in_channels, img_size = 3, 32else:raise ValueError("Unsupported dataset.")train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,num_workers=2, pin_memory=(device.type=="cuda"))test_loader  = DataLoader(test_ds, batch_size=256, shuffle=False,num_workers=2, pin_memory=(device.type=="cuda"))return train_loader, test_loader, in_channels, img_size# ======================== 3) 训练/评估函数 ========================
def train_one_epoch(model, loader, optimizer, criterion):model.train()total, correct, loss_sum = 0, 0, 0.0for xb, yb in loader:xb, yb = xb.to(device), yb.to(device)logits = model(xb)loss = criterion(logits, yb)optimizer.zero_grad(); loss.backward(); optimizer.step()loss_sum += loss.item() * xb.size(0)pred = logits.argmax(1)correct += (pred == yb).sum().item()total += xb.size(0)return loss_sum/total, correct/total@torch.no_grad()
def evaluate(model, loader, criterion):model.eval()total, correct, loss_sum = 0, 0, 0.0for xb, yb in loader:xb, yb = xb.to(device), yb.to(device)logits = model(xb)loss = criterion(logits, yb)loss_sum += loss.item() * xb.size(0)pred = logits.argmax(1)correct += (pred == yb).sum().item()total += xb.size(0)return loss_sum/total, correct/total# ======================== 4) 在三个数据集上跑同一 LeNet ========================
datasets_to_run = ["MNIST", "FashionMNIST", "CIFAR10"]
epochs = 5                     # 演示用;想更稳可设 10~15
lr = 1e-3summary = []                   # 记录最终结果
histories = {}                 # 记录曲线t_all = time.time()
for name in datasets_to_run:print(f"\n===== Training on {name} =====")train_loader, test_loader, C, S = build_loader(name)model = LeNet(in_channels=C, num_classes=10, img_size=S).to(device)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=lr)best_acc = 0.0history = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []}t0 = time.time()for ep in range(1, epochs+1):tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, criterion)te_loss, te_acc = evaluate(model, test_loader, criterion)history["train_loss"].append(tr_loss); history["train_acc"].append(tr_acc)history["test_loss"].append(te_loss);  history["test_acc"].append(te_acc)best_acc = max(best_acc, te_acc)print(f"[{name}][{ep:02d}/{epochs}] train_loss={tr_loss:.4f} acc={tr_acc:.4f} | test_loss={te_loss:.4f} acc={te_acc:.4f} | best={best_acc:.4f}")dt = time.time() - t0print(f"[{name}] Done in {dt:.1f}s | Best Test Acc={best_acc:.4f}")histories[name] = historysummary.append((name, best_acc, dt))print(f"\nAll done in {time.time()-t_all:.1f}s")# ======================== 5) 可视化:不同数据集的学习曲线 ========================
cols = len(datasets_to_run)
plt.figure(figsize=(5*cols, 4))
for i, name in enumerate(datasets_to_run, 1):plt.subplot(1, cols, i)plt.plot(histories[name]["test_acc"], label=f"{name} test_acc")plt.plot(histories[name]["train_acc"], label=f"{name} train_acc", linestyle="--")plt.title(f"LeNet on {name}"); plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.legend()
plt.tight_layout(); plt.show()# ======================== 6) 总结表格 ========================
print("\n=== Summary (LeNet, same hyperparams, no augmentation) ===")
for name, acc, dt in summary:print(f"{name:12s} | Best Test Acc: {acc:.4f} | Time: {dt:.1f}s")

image

🔹 实验结论

通过在 MNIST、Fashion-MNIST、CIFAR-10 三个数据集上的对比训练实验,我们观察到以下现象:
1. MNIST:LeNet 表现优异,能快速收敛并在短时间内达到 ≈98% 的准确率。这说明对于低分辨率、单通道、模式单一的任务,LeNet 的浅层卷积结构足以胜任。
2. Fashion-MNIST:准确率下降到 ≈90%,相比 MNIST 难度提升明显。说明在任务复杂度稍高的情况下,LeNet 的浅层网络开始显现出特征提取能力不足的问题。
3. CIFAR-10:测试准确率仅能维持在 55%–65% 之间,表现远低于现代深层 CNN。原因在于:
• 数据更复杂(彩色图像、类间相似度高),浅层卷积难以学习足够丰富的特征;
• 平均池化导致信息丢失,缺乏更强的特征保持能力;
• 没有 ReLU/Dropout/BatchNorm 等优化手段,容易过拟合或收敛停滞。

结合历史发展与实验表现,下一步我们将探索 CNN 在以下几个关键方向上的优化:
1. 更深的卷积层
• 从 LeNet 的 2 层卷积 → AlexNet 的 5 层卷积开始,逐步增加网络深度以提升特征提取能力。
• 引入更多卷积核,使模型能够捕捉更丰富的空间特征。
2. 更强的激活与正则化
• ReLU 激活函数:替代 Sigmoid/Tanh,解决梯度消失问题,加速收敛。
• Dropout:缓解过拟合,提高泛化性能。
• Batch Normalization(稍后出现):稳定训练,加快收敛。
3. 更合理的池化与卷积组合
• 用 Max Pooling 替代 Average Pooling,保留关键特征。
• 在深层结构中适当减少池化层,引入更小的卷积核(3×3),提高特征提取的细粒度。

http://www.hskmm.com/?act=detail&tid=29083

相关文章:

  • 文本派 - 停服公告 2025
  • lCode题库
  • Arista cEOS 4.35.0F 发布 - 针对云原生环境设计的容器化网络操作系统
  • Arista vEOS 4.35.0F 发布 - 虚拟化的数据中心和云网络可扩展操作系统
  • 因果机器学习的技术发展与挑战
  • CSP-S 考前集训
  • Arista EOS 4.35.0F 发布 - 适用于下一代数据中心和云网络的可扩展操作系统
  • 20251011 总结
  • 上课讲的部分 qoj 题记录
  • var与let
  • CSP-S 第二轮集训资料 **总结 + 专题细分精讲**_from_黄老师
  • AI元人文:迈向正负价值统一的文明架构
  • CSP-S 第二轮集训资料 **总结 + 专题细分精讲**。
  • 对抗训练提升产品搜索技术解析
  • Ubuntu Linux双网口主机实现在校园网环境下的网络共享
  • C# Avalonia 16- Animation- ExpandElement
  • DshanPI-A1 RK3576 armbian远程桌面
  • Docker安装MQTT
  • Ubuntu Linux双网卡实现在校园网环境下的网络共享
  • PVE8.x仅克隆虚拟机配置
  • 常用的sql语句
  • SQL常用语句分类及示例
  • 台式机主板上的电池要更换啦
  • 微信小程序 app.js中onLaunch中方法执行完毕后再执行index首页数据请求
  • 轻量服务器Lighthouse + 1Panel 部署.NET 8 Web应用
  • bash alias 多引号问题
  • 关于近期调研各类游戏开发引擎的一些感想
  • Electron38-Vue3OS客户端OS系统|vite7+electron38+arco桌面os后台管理
  • 终于在vim中用上了molokai的炫酷色彩配置了(゚∀゚)
  • 我是如何在Vim8.1中安装好的NERDTree插件的