重生之从零开始的神经网络算法学习之路——第八篇 大型数据集与复杂模型的GPU训练实践
引言
在前一篇中,我们实现了基础的SRCNN超分辨率模型并掌握了后台训练技巧。本篇将进一步拓展实验规模:引入更大规模的数据集、实现更复杂的网络结构,并优化GPU训练策略,以应对更具挑战性的图像重建任务。通过这些实践,我们将深入理解大规模深度学习实验的关键技术和工程细节。
项目目录结构
一个规范的项目结构有助于代码管理和团队协作,以下是我们超分辨率项目的完整目录结构:
esrgan-super-resolution/
│
├── src/ # 源代码目录
│ ├── __init__.py
│ ├── models/ # 模型定义
│ │ ├── __init__.py
│ │ ├── esrgan.py # ESRGAN生成器实现
│ │ └── discriminator.py # 判别器实现
│ │
│ ├── data/ # 数据处理相关
│ │ ├── __init__.py
│ │ ├── datasets.py # 数据集类定义
│ │ ├── downloader.py # 数据集下载工具
│ │ └── transforms.py # 数据增强与转换
│ │
│ ├── losses/ # 损失函数
│ │ ├── __init__.py
│ │ ├── content_loss.py # 内容损失
│ │ └── gan_loss.py # GAN损失
│ │
│ ├── utils/ # 工具函数
│ │ ├── __init__.py
│ │ ├── metrics.py # 评估指标(PSNR等)
│ │ ├── logger.py # 日志工具
│ │ └── helpers.py # 辅助函数
│ │
│ └── training/ # 训练相关
│ ├── __init__.py
│ ├── trainer.py # 训练器类
│ └── validator.py # 验证器类
│
├── configs/ # 配置文件目录
│ ├── base_config.yaml # 基础配置
│ └── esrgan_config.yaml # ESRGAN专用配置
│
├── scripts/ # 脚本目录
│ ├── train_esrgan.py # 训练脚本
│ ├── evaluate.py # 评估脚本
│ └── predict.py # 预测脚本
│
├── data/ # 数据目录
│ ├── raw/ # 原始数据
│ │ ├── DIV2K/
│ │ └── Flickr2K/
│ └── processed/ # 处理后的数据
│
├── checkpoints/ # 模型检查点
│ ├── generator/
│ └── discriminator/
│
├── logs/ # 日志文件
│ └── tensorboard/ # TensorBoard日志
│
├── results/ # 结果输出
│ ├── comparisons/ # 图像对比结果
│ └── samples/ # 生成样本
│
├── docs/ # 文档
│ ├── setup.md # 环境搭建说明
│ └── usage.md # 使用说明
│
├── main.py # 主程序入口
├── requirements.txt # 依赖项
└── README.md # 项目说明
大型数据集的获取与处理
自动下载与解压实现
为了提升模型性能,我们使用DIV2K和Flickr2K两个大型数据集进行训练。以下是优化后的数据集自动下载与处理流程(对应src/data/downloader.py
):
import os
import wget
import zipfile
import tarfile
from tqdm import tqdm# 数据集配置
DATASETS = {"DIV2K": {"train": "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip","valid": "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip"},"Flickr2K": {"url": "https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar"}
}def progress_bar(current, total, width=80):"""自定义进度条"""progress = int(width * current / total)bar = '=' * progress + '-' * (width - progress)print(f'[{bar}] {current/total*100:.1f}%', end='\r')def download_dataset(url, save_dir):"""下载数据集并显示进度条"""os.makedirs(save_dir, exist_ok=True)filename = url.split('/')[-1]file_path = os.path.join(save_dir, filename)if not os.path.exists(file_path):print(f"下载 {filename}...")wget.download(url, file_path, bar=progress_bar)print("\n下载完成")return file_pathdef extract_archive(file_path, extract_dir):"""解压数据集"""print(f"解压 {file_path} 到 {extract_dir}...")os.makedirs(extract_dir, exist_ok=True)if file_path.endswith('.zip'):with zipfile.ZipFile(file_path, 'r') as zip_ref:# 获取所有文件列表files = zip_ref.namelist()# 使用tqdm显示解压进度for file in tqdm(files, desc="解压中"):zip_ref.extract(file, extract_dir)elif file_path.endswith('.tar') or file_path.endswith('.tar.gz'):with tarfile.open(file_path, 'r') as tar_ref:members = tar_ref.getmembers()for member in tqdm(members, desc="解压中"):tar_ref.extract(member, extract_dir)def prepare_datasets(base_dir):"""准备所有数据集"""# 下载DIV2Kdiv2k_dir = os.path.join(base_dir, "DIV2K")for split, url in DATASETS["DIV2K"].items():file_path = download_dataset(url, div2k_dir)extract_archive(file_path, os.path.join(div2k_dir, split))# 下载Flickr2Kflickr_dir = os.path.join(base_dir, "Flickr2K")flickr_url = DATASETS["Flickr2K"]["url"]file_path = download_dataset(flickr_url, flickr_dir)extract_archive(file_path, flickr_dir)print("所有数据集准备完成")if __name__ == "__main__":# 可直接运行此脚本下载数据prepare_datasets(os.path.join(os.path.dirname(__file__), '../../data/raw'))
优化的数据加载器
对应src/data/datasets.py
文件,实现高效处理大型数据集的加载器:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torchvision.transforms as transformsclass SuperResolutionDataset(Dataset):"""超分辨率数据集基础类"""def __init__(self, root_dir, scale_factor=4, patch_size=128, train=True):self.root_dir = root_dirself.scale_factor = scale_factorself.patch_size = patch_sizeself.train = train# 收集所有图像路径self.image_paths = []for dirpath, _, filenames in os.walk(root_dir):for fname in filenames:if fname.lower().endswith(('.png', '.jpg', '.jpeg')):self.image_paths.append(os.path.join(dirpath, fname))# 数据转换self.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])def __len__(self):return len(self.image_paths)def __getitem__(self, idx):# 读取图像img_path = self.image_paths[idx]hr_img = cv2.imread(img_path)hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)# 生成低分辨率图像h, w = hr_img.shape[:2]lr_size = (w // self.scale_factor, h // self.scale_factor)lr_img = cv2.resize(hr_img, lr_size, interpolation=cv2.INTER_CUBIC)# 训练时随机裁剪patchif self.train:# 随机裁剪高分辨率图像h, w = hr_img.shape[:2]x = np.random.randint(0, w - self.patch_size)y = np.random.randint(0, h - self.patch_size)hr_patch = hr_img[y:y+self.patch_size, x:x+self.patch_size]# 对应裁剪低分辨率图像lr_patch_size = self.patch_size // self.scale_factorlr_patch = lr_img[y//self.scale_factor : y//self.scale_factor + lr_patch_size,x//self.scale_factor : x//self.scale_factor + lr_patch_size]# 应用数据增强if np.random.random() > 0.5:hr_patch = cv2.flip(hr_patch, 1)lr_patch = cv2.flip(lr_patch, 1)return self.transform(lr_patch), self.transform(hr_patch)else:# 验证时使用完整图像return self.transform(lr_img), self.transform(hr_img)class CombinedDataset(ConcatDataset):"""组合多个数据集的包装类"""def __init__(self, dataset_paths, scale_factor=4, patch_size=128, train=True):datasets = []for path in dataset_paths:datasets.append(SuperResolutionDataset(path, train=train,scale_factor=scale_factor,patch_size=patch_size))super().__init__(datasets)def create_optimized_dataloaders(batch_size, dataset_paths, scale_factor=4, patch_size=128,num_workers=8, pin_memory=True):"""创建优化的数据加载器"""# 训练数据集train_dataset = CombinedDataset(dataset_paths,scale_factor=scale_factor,patch_size=patch_size,train=True)# 验证数据集(使用DIV2K验证集)val_dataset = SuperResolutionDataset([p for p in dataset_paths if 'DIV2K' in p][0],train=False,scale_factor=scale_factor,patch_size=patch_size)# 使用预加载和多进程加速train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=pin_memory,prefetch_factor=2, # 预加载下一批数据persistent_workers=True # 保持工作进程存活)val_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=pin_memory)return train_loader, val_loader
复杂模型实现:ESRGAN
对应src/models/esrgan.py
文件,实现ESRGAN生成器:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass ResidualDenseBlock(nn.Module):"""残差密集块,ESRGAN的核心组件"""def __init__(self, nf=64, gc=32, bias=True):super(ResidualDenseBlock, self).__init__()self.conv1 = nn.Conv2d(nf + 0 * gc, gc, 3, 1, 1, bias=bias)self.conv2 = nn.Conv2d(nf + 1 * gc, gc, 3, 1, 1, bias=bias)self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)# 初始化权重self._initialize_weights()def _initialize_weights(self):"""权重初始化"""for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)def forward(self, x):x1 = self.lrelu(self.conv1(x))x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))# 残差连接return x5 * 0.2 + xclass RRDB(nn.Module):"""残差在残差密集块"""def __init__(self, nf, gc=32):super(RRDB, self).__init__()self.rdb1 = ResidualDenseBlock(nf, gc)self.rdb2 = ResidualDenseBlock(nf, gc)self.rdb3 = ResidualDenseBlock(nf, gc)def forward(self, x):out = self.rdb1(x)out = self.rdb2(out)out = self.rdb3(out)# 残差连接return out * 0.2 + xclass RRDBNet(nn.Module):"""ESRGAN 生成器的基础模块(RRDB 网络)"""def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4):super(RRDBNet, self).__init__()self.scale = scale# 示例结构:卷积 + RRDB块 + 上采样 + 输出卷积self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)self.body = self._make_rrdb_blocks(num_feat, num_block, num_grow_ch)self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)self.upsampler = self._make_upsampler(num_feat, scale)self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True)def _make_rrdb_blocks(self, num_feat, num_block, num_grow_ch):blocks = []for _ in range(num_block):blocks.append(RRDB(num_feat, num_grow_ch))return nn.Sequential(*blocks)def _make_upsampler(self, num_feat, scale):# 实现上采样模块(如PixelShuffle)upsampler = []for _ in range(int(torch.log2(torch.tensor(scale)))):upsampler.append(nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True))upsampler.append(nn.PixelShuffle(2))return nn.Sequential(*upsampler)def forward(self, x):# 实现前向传播逻辑feat = self.conv_first(x)body_feat = self.conv_body(self.body(feat))feat = feat + body_featout = self.conv_last(self.upsampler(feat))return out# 定义ESRGAN生成器(继承RRDB网络,保持接口一致性)
class ESRGAN(RRDBNet):"""ESRGAN生成器类(与RRDB网络结构一致,用于统一接口)"""def __init__(self, scale_factor=4, **kwargs):super(ESRGAN, self).__init__(scale=scale_factor,** kwargs)
判别器实现
对应src/models/discriminator.py
文件:
import torch
import torch.nn as nnclass Discriminator(nn.Module):"""ESRGAN判别器"""def __init__(self, num_in_ch=3, num_feat=64):super(Discriminator, self).__init__()self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)# 特征提取层self.features = nn.Sequential(nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False),nn.BatchNorm2d(num_feat),nn.LeakyReLU(0.2, True),nn.Conv2d(num_feat, num_feat*2, 4, 2, 1, bias=False),nn.BatchNorm2d(num_feat*2),nn.LeakyReLU(0.2, True),nn.Conv2d(num_feat*2, num_feat*4, 4, 2, 1, bias=False),nn.BatchNorm2d(num_feat*4),nn.LeakyReLU(0.2, True),nn.Conv2d(num_feat*4, num_feat*8, 4, 2, 1, bias=False),nn.BatchNorm2d(num_feat*8),nn.LeakyReLU(0.2, True),nn.Conv2d(num_feat*8, num_feat*8, 4, 2, 1, bias=False),nn.BatchNorm2d(num_feat*8),nn.LeakyReLU(0.2, True))self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv_last = nn.Conv2d(num_feat*8, 1, 1, 1, 0)def forward(self, x):x = self.lrelu(self.conv_first(x))x = self.features(x)x = self.avg_pool(x)x = self.conv_last(x)return x
生成对抗训练策略
损失函数实现
对应src/losses/content_loss.py
:
import torch
import torch.nn as nn
from torchvision import models, transformsclass ContentLoss(nn.Module):"""内容损失函数,使用VGG特征提取器"""def __init__(self, device):super(ContentLoss, self).__init__()# 使用预训练的VGG作为特征提取器vgg = models.vgg19(pretrained=True).features[:35].eval()for param in vgg.parameters():param.requires_grad = Falseself.vgg = vgg.to(device)self.criterion = nn.L1Loss()self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])def forward(self, sr, hr):# 归一化输入以匹配VGG训练条件sr_norm = self.normalize(sr)hr_norm = self.normalize(hr)# 提取特征sr_feat = self.vgg(sr_norm)hr_feat = self.vgg(hr_norm)return self.criterion(sr_feat, hr_feat)
对应src/losses/gan_loss.py
:
import torch
import torch.nn as nnclass GANLoss(nn.Module):"""GAN损失函数"""def __init__(self, gan_type='vanilla', real_label_val=1.0, fake_label_val=0.0):super(GANLoss, self).__init__()self.gan_type = gan_typeself.real_label_val = real_label_valself.fake_label_val = fake_label_valif self.gan_type == 'vanilla':self.loss = nn.BCEWithLogitsLoss()elif self.gan_type == 'lsgan':self.loss = nn.MSELoss()else:raise NotImplementedError(f"GAN type {self.gan_type} is not implemented")def forward(self, pred, target_is_real):if target_is_real:target_val = self.real_label_valelse:target_val = self.fake_label_valtarget = torch.full_like(pred, fill_value=target_val, device=pred.device)return self.loss(pred, target)
GPU训练优化技巧
混合精度训练
对应src/training/trainer.py
中的训练实现:
import torch
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
import time
import numpy as np
from tqdm import tqdm
from src.utils.metrics import psnrclass ESRGANTrainer:def __init__(self, generator, discriminator, content_criterion, gan_criterion,g_optimizer, d_optimizer,device, log_interval=10):self.generator = generatorself.discriminator = discriminatorself.content_criterion = content_criterionself.gan_criterion = gan_criterionself.g_optimizer = g_optimizerself.d_optimizer = d_optimizerself.device = deviceself.log_interval = log_interval# 初始化混合精度训练self.scaler = GradScaler(enabled=True)def train_epoch(self, train_loader, epoch, grad_accum_steps=4):"""训练一个epoch"""self.generator.train()self.discriminator.train()total_gen_loss = 0.0total_dis_loss = 0.0total_psnr = 0.0pbar = tqdm(train_loader, desc=f"Epoch {epoch}")for batch_idx, (lr_imgs, hr_imgs) in enumerate(pbar):lr_imgs = lr_imgs.to(self.device)hr_imgs = hr_imgs.to(self.device)# 训练判别器self.d_optimizer.zero_grad()with autocast():# 生成超分辨率图像sr_imgs = self.generator(lr_imgs)# 判别器对真实图像的预测real_pred = self.discriminator(hr_imgs)# 判别器对生成图像的预测fake_pred = self.discriminator(sr_imgs.detach()) # detach避免更新生成器# 计算判别器损失real_loss = self.gan_criterion(real_pred, True)fake_loss = self.gan_criterion(fake_pred, False)dis_loss = (real_loss + fake_loss) * 0.5# 反向传播self.scaler.scale(dis_loss).backward()# 梯度累积if (batch_idx + 1) % grad_accum_steps == 0:self.scaler.step(self.d_optimizer)self.scaler.update()self.d_optimizer.zero_grad()# 训练生成器self.g_optimizer.zero_grad()with autocast():# 生成器损失 = 内容损失 + GAN损失content_loss = self.content_criterion(sr_imgs, hr_imgs)fake_pred = self.discriminator(sr_imgs)gan_loss = self.gan_criterion(fake_pred, True)# 内容损失权重更高gen_loss = content_loss * 0.01 + gan_loss * 0.005# 计算PSNRbatch_psnr = psnr(hr_imgs, sr_imgs)self.scaler.scale(gen_loss).backward()# 梯度累积if (batch_idx + 1) % grad_accum_steps == 0:self.scaler.step(self.g_optimizer)self.scaler.update()self.g_optimizer.zero_grad()# 累计损失total_gen_loss += gen_loss.item()total_dis_loss += dis_loss.item()total_psnr += batch_psnr# 日志输出if batch_idx % self.log_interval == 0:avg_gen_loss = total_gen_loss / (batch_idx + 1)avg_dis_loss = total_dis_loss / (batch_idx + 1)avg_psnr = total_psnr / (batch_idx + 1)pbar.set_postfix({'gen_loss': f'{avg_gen_loss:.4f}','dis_loss': f'{avg_dis_loss:.4f}','psnr': f'{avg_psnr:.2f}'})return (total_gen_loss / len(train_loader), total_dis_loss / len(train_loader), total_psnr / len(train_loader))
主训练脚本
对应scripts/train_esrgan.py
:
import os
import argparse
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from src.models.esrgan import ESRGAN
from src.models.discriminator import Discriminator
from src.losses.content_loss import ContentLoss
from src.losses.gan_loss import GANLoss
from src.data.datasets import create_optimized_dataloaders
from src.training.trainer import ESRGANTrainer
from src.training.validator import validate
from src.utils.logger import init_tensorboard, log_to_tensorboard
from src.utils.helpers import save_checkpoint, load_checkpointdef parse_args():parser = argparse.ArgumentParser(description='Train ESRGAN model')parser.add_argument('--epochs', type=int, default=2000, help='Number of epochs')parser.add_argument('--batch_size', type=int, default=16, help='Batch size')parser.add_argument('--lr', type=float, default=0.0002, help='Learning rate')parser.add_argument('--lr_decay', type=float, default=0.5, help='Learning rate decay')parser.add_argument('--scale_factor', type=int, default=4, help='Upscaling factor')parser.add_argument('--patch_size', type=int, default=192, help='Patch size for training')parser.add_argument('--start_epoch', type=int, default=0, help='Start epoch')parser.add_argument('--checkpoint_interval', type=int, default=50, help='Checkpoint interval')parser.add_argument('--log_interval', type=int, default=20, help='Log interval')parser.add_argument('--num_workers', type=int, default=8, help='Number of workers for dataloader')parser.add_argument('--pin_memory', action='store_true', default=True, help='Pin memory for dataloader')parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device')parser.add_argument('--dataset_path', type=str, default='./data/raw', help='Dataset path')parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints', help='Checkpoint directory')parser.add_argument('--log_dir', type=str, default='./logs/tensorboard', help='Log directory')parser.add_argument('--resume', type=str, default=None, help='Resume from checkpoint')return parser.parse_args()def main():# 解析配置参数args = parse_args()# 设置设备device = torch.device(args.device)print(f"使用设备: {device}")# 初始化模型generator = ESRGAN(scale_factor=args.scale_factor).to(device)discriminator = Discriminator().to(device)# 定义损失和优化器content_criterion = ContentLoss(device)gan_criterion = GANLoss(gan_type='vanilla')g_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.9, 0.999))d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr * 0.1, betas=(0.9, 0.999))# 学习率调度器 - 阶梯式衰减g_scheduler = lr_scheduler.StepLR(g_optimizer, step_size=10, gamma=args.lr_decay)d_scheduler = lr_scheduler.StepLR(d_optimizer, step_size=10, gamma=args.lr_decay)# 加载检查点(如果需要)if args.resume:args.start_epoch = load_checkpoint(args.resume, generator, discriminator, g_optimizer, d_optimizer)# 准备数据集路径dataset_paths = [os.path.join(args.dataset_path, "DIV2K"),os.path.join(args.dataset_path, "Flickr2K")]# 加载大型数据集train_loader, val_loader = create_optimized_dataloaders(batch_size=args.batch_size,dataset_paths=dataset_paths,scale_factor=args.scale_factor,patch_size=args.patch_size,num_workers=args.num_workers,pin_memory=args.pin_memory)# 初始化训练器trainer = ESRGANTrainer(generator, discriminator,content_criterion, gan_criterion,g_optimizer, d_optimizer,device, args.log_interval)# 初始化TensorBoardwriter = init_tensorboard(args.log_dir)# 训练循环for epoch in range(args.start_epoch, args.epochs):start_time = time.time()# 梯度累积参数grad_accum_steps = 4 # 累积4个batch的梯度# 训练gen_loss, dis_loss, train_psnr = trainer.train_epoch(train_loader, epoch, grad_accum_steps)# 更新学习率g_scheduler.step()d_scheduler.step()# 验证val_psnr, val_images = validate(generator, val_loader, device)# 日志记录print(f'Epoch {epoch}/{args.epochs}, 'f'Gen Loss: {gen_loss:.4f}, Dis Loss: {dis_loss:.4f}, 'f'Train PSNR: {train_psnr:.2f} dB, Val PSNR: {val_psnr:.2f} dB, 'f'Time: {time.time() - start_time:.2f}秒')# 写入TensorBoardlog_to_tensorboard(writer, epoch, {'gen_loss': gen_loss,'dis_loss': dis_loss,'psnr': train_psnr,'gen_lr': g_optimizer.param_groups[0]['lr']}, {'psnr': val_psnr}, val_images)# 保存检查点if (epoch + 1) % args.checkpoint_interval == 0:save_checkpoint(epoch, generator, discriminator, g_optimizer, d_optimizer, args.checkpoint_dir)writer.close()if __name__ == "__main__":main()
扩展运行脚本
对应项目根目录下的run_esrgan.sh
:
#!/bin/bash
# run_esrgan.sh# 设置工作目录
cd /home/vscode/workspace# 记录开始时间
start_time=$(date +%s)
echo "实验开始时间: $(date)"# 检查GPU状态
nvidia-smi# 创建输出目录
mkdir -p logs checkpoints results# 运行训练脚本,增加内存优化参数
nohup python3 -u scripts/train_esrgan.py \--epochs 2000 \--batch_size 16 \--lr 0.0002 \--scale_factor 4 \--patch_size 192 \--checkpoint_interval 50 \--log_interval 20 \--dataset_path ./data/raw \> training_log_esrgan_$(date +%Y%m%d_%H%M%S).txt 2>&1 &# 记录进程ID和日志文件
echo "训练任务已在后台启动,PID: $!"
log_file="training_log_esrgan_$(date +%Y%m%d_%H%M%S).txt"
echo "日志文件: $log_file"# 监控GPU使用情况(每5分钟记录一次)
while true; doecho "GPU监控: $(date)" >> $log_filenvidia-smi >> $log_file 2>&1sleep 300 # 5分钟
done &
实验监控与分析
对应src/utils/logger.py
:
import os
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriterdef init_tensorboard(log_dir):"""初始化TensorBoard"""os.makedirs(log_dir, exist_ok=True)writer = SummaryWriter(log_dir=log_dir)return writerdef log_to_tensorboard(writer, epoch, train_metrics, val_metrics, images):"""将训练指标和图像写入TensorBoard"""# 日志指标writer.add_scalar('Loss/Generator', train_metrics['gen_loss'], epoch)writer.add_scalar('Loss/Discriminator', train_metrics['dis_loss'], epoch)writer.add_scalar('PSNR/Train', train_metrics['psnr'], epoch)writer.add_scalar('PSNR/Validation', val_metrics['psnr'], epoch)writer.add_scalar('LearningRate/Generator', train_metrics['gen_lr'], epoch)# 日志图像(每10个epoch)if epoch % 10 == 0 and images is not None:lr_img, sr_img, hr_img = imageswriter.add_image('Input/LowResolution', lr_img, epoch)writer.add_image('Output/SuperResolution', sr_img, epoch)writer.add_image('Target/HighResolution', hr_img, epoch)# 保存图像到文件save_comparison_plot(lr_img, sr_img, hr_img, epoch)def save_comparison_plot(lr, sr, hr, epoch, save_dir='results/comparisons'):"""保存图像对比结果"""os.makedirs(save_dir, exist_ok=True)# 转换为适合显示的格式lr = lr.permute(1, 2, 0).cpu().detach().numpy()sr = sr.permute(1, 2, 0).cpu().detach().numpy()hr = hr.permute(1, 2, 0).cpu().detach().numpy()# 反归一化lr = (lr * 0.5 + 0.5) * 255sr = (sr * 0.5 + 0.5) * 255hr = (hr * 0.5 + 0.5) * 255# 绘制对比图plt.figure(figsize=(15, 5))plt.subplot(131)plt.title('Low Resolution')plt.imshow(lr.astype('uint8'))plt.axis('off')plt.subplot(132)plt.title('Super Resolution')plt.imshow(sr.astype('uint8'))plt.axis('off')plt.subplot(133)plt.title('High Resolution')plt.imshow(hr.astype('uint8'))plt.axis('off')plt.tight_layout()plt.savefig(f'{save_dir}/comparison_epoch_{epoch}.png', dpi=300, bbox_inches='tight')plt.close()
总结与后续方向
通过本篇实验,我们实现了一个结构完整的超分辨率项目,包括:
- 规范的项目结构:将代码模块化,分离数据处理、模型定义、损失函数和训练逻辑
- 大型数据集管理:自动下载、解压和组合多个大型数据集,优化数据加载流程
- 复杂模型构建:实现了基于残差密集块的ESRGAN模型,相比SRCNN能生成更丰富的细节
- 高级训练策略:引入混合精度训练、梯度累积和阶梯式学习率调度,提升GPU利用率
- 完善监控体系:结合日志文件、GPU监控和TensorBoard可视化,全面跟踪实验过程
后续可探索的方向:
- 尝试更大规模的模型(如RCAN、SwinIR)
- 引入感知损失和GAN的改进变体(如Relativistic GAN)
- 实现模型并行和数据并行,利用多GPU进行训练
- 探索模型压缩和加速技术,实现实时超分辨率
- 尝试视频超分辨率任务,考虑时间维度的一致性
下一篇我们将探索更前沿的视觉Transformer模型在超分辨率任务中的应用,进一步提升重建质量。