重生之从零开始的神经网络算法学习之路——第七篇 重拾PyTorch(超分辨率重建和脚本的使用)
引言
在前一篇中,我们初步探索了PyTorch框架的使用并体验了GPU加速计算的优势。本篇将聚焦于一个更具视觉冲击力的任务——图像超分辨率重建,通过实现经典的SRCNN模型,深入学习PyTorch在图像处理任务中的应用,并掌握使用脚本进行后台训练的实用技巧。
超分辨率重建技术旨在将低分辨率图像恢复为高分辨率图像,在监控视频增强、医学影像分析、卫星图像处理等领域有着广泛应用。与图像分类任务不同,超分辨率是典型的生成式任务,其输入和输出均为图像,这为我们提供了学习PyTorch中图像处理流水线的绝佳机会。
超分辨率重建原理与SRCNN模型
超分辨率任务概述
超分辨率(Super Resolution, SR)是指从低分辨率(Low Resolution, LR)图像中恢复出高分辨率(High Resolution, HR)图像的技术。其核心挑战在于如何在提升图像尺寸的同时,保持并增强图像细节,避免产生模糊或伪影。
常见的超分辨率方法可分为:
- 插值方法(如双三次插值):简单但效果有限
- 基于重建的方法:利用先验知识约束重建过程
- 基于学习的方法:通过神经网络学习LR到HR的映射关系(当前主流)
SRCNN模型结构
我们将实现2014年提出的SRCNN(Super-Resolution Convolutional Neural Network),这是首个将卷积神经网络应用于超分辨率任务的模型,其结构简洁却效果显著:
- 特征提取:使用9x9卷积核从低分辨率图像中提取基础特征
- 非线性映射:通过1x1卷积核进行特征转换和降维
- 重建:使用5x5卷积核生成最终的高分辨率图像
与传统方法相比,SRCNN通过端到端的训练,能够自动学习从低分辨率到高分辨率的映射关系,无需人工设计特征。
环境准备与项目结构
我们继续使用第六篇中搭建的PyTorch GPU环境,项目结构如下:
workspace/
├── data/ # 数据集目录
│ └── DIV2K/ # 超分辨率专用数据集
│ ├── train/ # 训练集
│ └── valid/ # 验证集
├── super_resolution_output/ # 输出目录
│ ├── checkpoints/ # 模型检查点
│ └── training_log.json # 训练日志
├── PyTorch_SuperResolution_GPU.py # 主程序
└── run_super_resolution.sh # 运行脚本
超分辨率重建代码实现
核心代码解析
完整代码可参考PyTorch_SuperResolution_GPU.py
,以下为关键部分解析:
1. 模型定义
class SRCNN(nn.Module):def __init__(self, scale_factor=4):super(SRCNN, self).__init__()self.scale_factor = scale_factor# 特征提取层self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)# 非线性映射层self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)# 重建层self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)self.relu = nn.ReLU(inplace=True)def forward(self, x):# 首先对输入进行上采样(双三次插值)x = F.interpolate(x, scale_factor=self.scale_factor, mode='bicubic', align_corners=False)x = self.relu(self.conv1(x))x = self.relu(self.conv2(x))x = self.conv3(x)return x
SRCNN的特点是先通过插值将低分辨率图像放大到目标尺寸,再通过卷积网络优化细节,这种设计既利用了传统插值的基础结构,又通过神经网络修复了细节损失。
2. 自定义数据集
class SuperResolutionDataset(Dataset):def __init__(self, dataset_path, transform=None, train=True, scale_factor=4, patch_size=128):self.dataset_path = dataset_pathself.train = trainself.scale_factor = scale_factorself.patch_size = patch_size# 收集图像路径if train:self.image_paths = glob.glob(os.path.join(dataset_path, 'train', '**', '*.png'), recursive=True)self.image_paths += glob.glob(os.path.join(dataset_path, 'train', '**', '*.jpg'), recursive=True)else:self.image_paths = glob.glob(os.path.join(dataset_path, 'valid', '**', '*.png'), recursive=True)# ... 处理图像路径和备用数据集def __getitem__(self, idx):# 加载高分辨率图像hr_image = Image.open(img_path).convert('RGB')# 数据增强 - 随机裁剪和翻转if self.train:i = random.randint(0, hr_image.height - self.patch_size)j = random.randint(0, hr_image.width - self.patch_size)hr_image = hr_image.crop((j, i, j + self.patch_size, i + self.patch_size))if random.random() > 0.5:hr_image = hr_image.transpose(Image.FLIP_LEFT_RIGHT)# 转换为张量hr_image = transforms.ToTensor()(hr_image)# 生成对应的低分辨率图像lr_size = self.patch_size // self.scale_factorlr_image = F.interpolate(hr_image.unsqueeze(0), size=(lr_size, lr_size), mode='bicubic', align_corners=False).squeeze(0)return lr_image, hr_image
超分辨率数据集的核心是为每张高分辨率图像生成对应的低分辨率版本,通过下采样操作模拟真实场景中的低清图像。训练时使用图像块(patch)而非完整图像,既能减少内存占用,又能增加训练样本多样性。
3. 评估指标PSNR
def psnr(original, compressed):mse = torch.mean((original - compressed) **2)if mse == 0: # MSE为0表示完美重建return 100max_pixel = 1.0 # 图像像素值已归一化到[0,1]psnr = 20 * log10(max_pixel / torch.sqrt(mse))return psnr
峰值信噪比(PSNR)是图像重建任务中常用的评估指标,数值越高表示重建质量越好(通常30dB以上为可接受质量)。其计算公式基于均方误差(MSE),反映了重建图像与真实图像的像素差异。
4. 训练与验证循环
def train(model, train_loader, criterion, optimizer, epoch):model.train()train_loss = 0total_psnr = 0for batch_idx, (lr_imgs, hr_imgs) in enumerate(train_loader):lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)optimizer.zero_grad()outputs = model(lr_imgs)loss = criterion(outputs, hr_imgs)loss.backward()optimizer.step()train_loss += loss.item()batch_psnr = psnr(hr_imgs, outputs)total_psnr += batch_psnr# 日志输出if batch_idx % args.log_interval == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(lr_imgs)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}\t'f'PSNR: {batch_psnr:.2f} dB')
超分辨率训练使用MSE损失函数(像素级损失),通过最小化重建图像与真实高分辨率图像的像素差异来优化模型参数。训练过程中同时监控损失和PSNR指标,便于分析模型收敛情况。
使用脚本进行后台训练
对于超分辨率这类需要长时间训练的任务,直接在终端运行程序存在风险(如断开连接导致训练中断)。我们可以使用Shell脚本实现后台训练和日志记录。
运行脚本解析
run_super_resolution.sh
脚本内容如下:
#!/bin/bash
# 设置工作目录
cd /home/vscode/workspace# 运行超分辨率训练脚本
# 使用nohup和&实现后台运行,输出重定向到日志文件
nohup python3 PyTorch_SuperResolution_GPU.py \--epochs 1000 \--batch_size 32 \--lr 0.001 \--checkpoint_interval 10 \--log_interval 50 \> training_log_$(date +%Y%m%d_%H%M%S).txt 2>&1 &# 显示进程信息
echo "训练任务已在后台启动,PID: $!"
echo "日志文件: training_log_$(date +%Y%m%d_%H%M%S).txt"
脚本关键技术点:
nohup
:忽略挂起信号,确保程序在终端关闭后继续运行> training_log...txt
:将标准输出重定向到日志文件2>&1
:将错误输出合并到标准输出&
:将程序放入后台运行$(date +%Y%m%d_%H%M%S)
:生成带时间戳的唯一日志文件名
脚本使用方法
-
赋予脚本执行权限:
chmod +x run_super_resolution.sh
-
运行脚本:
./run_super_resolution.sh
-
查看训练日志:
tail -f training_log_20240520_153045.txt # 替换为实际日志文件名
-
查看后台进程:
ps aux | grep PyTorch_SuperResolution_GPU.py
-
终止训练(如需):
kill -9 <进程PID> # 替换为实际进程ID
检查点与训练恢复
长时间训练中,定期保存检查点(checkpoint)至关重要,代码中实现了完善的检查点机制:
def save_checkpoint(model, optimizer, epoch, loss, psnr, is_best=False):state = {'epoch': epoch,'state_dict': model.state_dict(),'optimizer': optimizer.state_dict(),'loss': loss,'psnr': psnr}filename = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')torch.save(state, filename)# 保存最佳模型if is_best:best_filename = os.path.join(checkpoint_dir, 'model_best.pth')torch.save(state, best_filename)# 清理旧检查点,只保留最近5个checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint_epoch_')])if len(checkpoints) > 5:for old_checkpoint in checkpoints[:-5]:os.remove(os.path.join(checkpoint_dir, old_checkpoint))
从检查点恢复训练:
nohup python3 PyTorch_SuperResolution_GPU.py \--resume ./super_resolution_output/checkpoints/checkpoint_epoch_200.pth \--epochs 1000 \> training_log_resume.txt 2>&1 &
实验结果与分析
经过1000轮训练后,我们得到以下结果:
- 训练集PSNR从初始的24.35dB提升至32.68dB
- 验证集PSNR从初始的23.87dB提升至31.24dB
- 每轮训练时间约为45秒(使用NVIDIA Tesla T4 GPU)
从视觉效果看,SRCNN重建结果相比单纯插值:
- 边缘更清晰(如建筑物轮廓、文本边缘)
- 细节更丰富(如纹理、小尺度特征)
- 减少了模糊和锯齿现象
总结与进阶方向
通过本篇实验,我们掌握了:
- 1.超分辨率核心技术:理解SRCNN工作原理和图像重建流程
- 2.使用训练技巧:使用Shell脚本进行后台训练、日志管理和进程监控
- 3.检查点机制:实现训练中断后的恢复功能,保障长时间实验的稳定性
- 4.评估指标:掌握PSNR计算方法及在图像重建任务中的应用
进阶改进方向:
- 尝试更先进的模型(如ESRGAN、RCAN)
- 引入感知损失(Perceptual Loss)提升视觉质量
- 增加更多数据增强策略(旋转、缩放、噪声添加)
- 实现模型量化和部署,探索实际应用场景
超分辨率技术正朝着更高效、更高质量的方向发展,结合注意力机制和生成对抗网络的方法已能产生接近真实的重建效果。下一篇我们将探索更复杂的网络结构和训练策略,进一步提升模型性能。