SRGAN实战:用Python+PyTorch实现照片级超分辨率重建(附代码)

张开发
2026/4/16 1:32:42 15 分钟阅读

分享文章

SRGAN实战:用Python+PyTorch实现照片级超分辨率重建(附代码)
SRGAN实战用PythonPyTorch实现照片级超分辨率重建当你翻出十年前的老照片是否曾被模糊的像素和失真的细节所困扰超分辨率重建技术正悄然改变这一现状。在众多解决方案中SRGAN凭借其生成对抗网络的独特架构能够从低分辨率图像中还原出令人惊艳的高频细节。本文将带你从零实现一个完整的SRGAN模型不仅涵盖核心代码实现更会分享实际训练中的调参技巧和避坑指南。1. 环境配置与数据准备工欲善其事必先利其器。我们需要搭建一个支持GPU加速的PyTorch开发环境。推荐使用Python 3.8和PyTorch 1.10版本它们对GAN训练提供了更好的支持。conda create -n srgan python3.8 conda activate srgan pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 pip install opencv-python pillow matplotlib tqdm数据集的选择直接影响模型效果。DIV2K是超分辨率任务的标准数据集包含800张训练图像和100张验证图像涵盖丰富场景。实际应用中你可能还需要加入自己的业务数据from torchvision import transforms train_transform transforms.Compose([ transforms.RandomCrop(96), # 随机裁剪96x96 patches transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), transforms.Normalize(mean[0.5, 0.5, 0.5], std[0.5, 0.5, 0.5]) ]) # 低分辨率图像通过双三次下采样获得 def get_lr_image(hr_img, scale4): lr_img hr_img.resize((hr_img.width//scale, hr_img.height//scale), Image.BICUBIC) return lr_img.resize((hr_img.width, hr_img.height), Image.BICUBIC)数据加载器的实现需要考虑内存效率。对于大型数据集建议使用Dataset类按需加载class SRDataset(torch.utils.data.Dataset): def __init__(self, image_paths, transformNone): self.image_paths image_paths self.transform transform def __getitem__(self, idx): hr_img Image.open(self.image_paths[idx]).convert(RGB) lr_img get_lr_image(hr_img) if self.transform: hr_img self.transform(hr_img) lr_img self.transform(lr_img) return lr_img, hr_img2. 模型架构设计SRGAN的核心在于生成器与判别器的对抗设计。生成器采用深度残差结构而判别器则借鉴VGG网络的判别能力。2.1 生成器网络SRResNet生成器基于ResNet构建包含多个残差块和亚像素卷积层import torch.nn as nn class ResidualBlock(nn.Module): def __init__(self, channels): super().__init__() self.conv1 nn.Conv2d(channels, channels, kernel_size3, padding1) self.bn1 nn.BatchNorm2d(channels) self.prelu nn.PReLU() self.conv2 nn.Conv2d(channels, channels, kernel_size3, padding1) self.bn2 nn.BatchNorm2d(channels) def forward(self, x): residual x out self.conv1(x) out self.bn1(out) out self.prelu(out) out self.conv2(out) out self.bn2(out) return out residual class Generator(nn.Module): def __init__(self, scale_factor4, num_residual16): super().__init__() self.conv1 nn.Conv2d(3, 64, kernel_size9, padding4) self.prelu nn.PReLU() # 残差块堆叠 self.res_blocks nn.Sequential(*[ResidualBlock(64) for _ in range(num_residual)]) # 上采样部分 upsampling [] for _ in range(scale_factor//2): upsampling [ nn.Conv2d(64, 256, kernel_size3, padding1), nn.PixelShuffle(2), nn.PReLU() ] self.upsampling nn.Sequential(*upsampling) self.conv2 nn.Conv2d(64, 3, kernel_size9, padding4) def forward(self, x): x self.prelu(self.conv1(x)) residual x x self.res_blocks(x) x x residual x self.upsampling(x) x self.conv2(x) return torch.tanh(x)2.2 判别器网络判别器采用PatchGAN结构对图像的局部区域进行真伪判断class Discriminator(nn.Module): def __init__(self): super().__init__() self.net nn.Sequential( nn.Conv2d(3, 64, kernel_size3, padding1), nn.LeakyReLU(0.2), nn.Conv2d(64, 64, kernel_size3, stride2, padding1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2), # 重复堆叠卷积层 nn.Conv2d(64, 128, kernel_size3, padding1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Conv2d(128, 128, kernel_size3, stride2, padding1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Conv2d(128, 256, kernel_size3, padding1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Conv2d(256, 256, kernel_size3, stride2, padding1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Conv2d(256, 512, kernel_size3, padding1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2), nn.Conv2d(512, 512, kernel_size3, stride2, padding1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2), nn.AdaptiveAvgPool2d(1), nn.Conv2d(512, 1024, kernel_size1), nn.LeakyReLU(0.2), nn.Conv2d(1024, 1, kernel_size1) ) def forward(self, x): return self.net(x)3. 损失函数与训练策略SRGAN的成功很大程度上归功于其精心设计的感知损失函数。它结合了内容损失和对抗损失在像素级准确性和感知质量之间取得平衡。3.1 感知损失实现VGG特征提取器用于计算内容损失class VGGFeatureExtractor(nn.Module): def __init__(self): super().__init__() vgg torchvision.models.vgg19(pretrainedTrue) self.features nn.Sequential(*list(vgg.features.children())[:35]) # 截取到conv5_4 def forward(self, x): # 输入图像需要归一化到VGG的训练范围 x (x 1) / 2 # [-1,1] - [0,1] x x.sub(torch.Tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(x.device)) x x.div(torch.Tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(x.device)) return self.features(x) def perceptual_loss(hr, sr, feature_extractor): mse_loss nn.MSELoss() hr_features feature_extractor(hr) sr_features feature_extractor(sr) return mse_loss(hr_features, sr_features)3.2 对抗损失与优化器配置GAN训练需要平衡生成器和判别器的学习进度# 初始化模型 generator Generator().to(device) discriminator Discriminator().to(device) feature_extractor VGGFeatureExtractor().to(device).eval() # 优化器设置 g_optimizer torch.optim.Adam(generator.parameters(), lr1e-4, betas(0.9, 0.999)) d_optimizer torch.optim.Adam(discriminator.parameters(), lr1e-4, betas(0.9, 0.999)) # 损失函数 adversarial_criterion nn.BCEWithLogitsLoss() pixel_criterion nn.L1Loss() def train_step(lr, hr): # 生成器训练 sr generator(lr) real_label torch.ones(hr.size(0), 1, 1, 1).to(device) # 内容损失 content_loss pixel_criterion(sr, hr) 0.006 * perceptual_loss(hr, sr, feature_extractor) # 对抗损失 g_loss adversarial_criterion(discriminator(sr), real_label) total_loss content_loss 1e-3 * g_loss g_optimizer.zero_grad() total_loss.backward() g_optimizer.step() # 判别器训练 d_loss_real adversarial_criterion(discriminator(hr), real_label) d_loss_fake adversarial_criterion(discriminator(sr.detach()), torch.zeros_like(real_label)) d_loss (d_loss_real d_loss_fake) / 2 d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() return total_loss.item(), d_loss.item()4. 训练技巧与效果优化GAN训练 notoriously unstable以下技巧可显著提升SRGAN的训练稳定性4.1 两阶段训练策略预训练生成器仅使用MSE损失训练生成器20-30个epoch联合训练加入判别器进行对抗训练# 生成器预训练 def pretrain_generator(generator, dataloader, epochs20): optimizer torch.optim.Adam(generator.parameters(), lr1e-4) criterion nn.MSELoss() for epoch in range(epochs): for lr, hr in dataloader: lr, hr lr.to(device), hr.to(device) sr generator(lr) loss criterion(sr, hr) optimizer.zero_grad() loss.backward() optimizer.step()4.2 学习率调度与梯度裁剪# 学习率调度器 g_scheduler torch.optim.lr_scheduler.StepLR(g_optimizer, step_size1000, gamma0.1) d_scheduler torch.optim.lr_scheduler.StepLR(d_optimizer, step_size1000, gamma0.1) # 梯度裁剪 torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm1.0) torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm1.0)4.3 训练监控与可视化实时监控训练过程有助于及时调整策略def save_sample(lr, sr, hr, epoch, pathsamples): os.makedirs(path, exist_okTrue) lr lr[0].cpu().detach().numpy().transpose(1,2,0) sr sr[0].cpu().detach().numpy().transpose(1,2,0) hr hr[0].cpu().detach().numpy().transpose(1,2,0) fig, axes plt.subplots(1, 3, figsize(15,5)) axes[0].imshow((lr1)/2) axes[0].set_title(Low Resolution) axes[1].imshow((sr1)/2) axes[1].set_title(Super Resolution) axes[2].imshow((hr1)/2) axes[2].set_title(High Resolution) plt.savefig(f{path}/epoch_{epoch}.png) plt.close()5. 模型评估与应用训练完成后我们需要全面评估模型性能5.1 定量指标评估def calculate_psnr(sr, hr, max_val1.0): mse torch.mean((sr - hr) ** 2) return 10 * torch.log10(max_val**2 / mse) def calculate_ssim(sr, hr, window_size11): # 实现SSIM计算 pass5.2 实际应用示例将训练好的模型应用于真实场景def enhance_image(image_path, generator, device): lr_img Image.open(image_path).convert(RGB) lr_tensor transforms.ToTensor()(lr_img).unsqueeze(0).to(device) with torch.no_grad(): sr_tensor generator(lr_tensor) sr_img transforms.ToPILImage()(sr_tensor.squeeze().cpu()) return sr_img5.3 模型导出与部署# 导出为TorchScript traced_generator torch.jit.trace(generator, torch.rand(1,3,96,96).to(device)) traced_generator.save(srgan_generator.pt) # ONNX导出 torch.onnx.export(generator, torch.rand(1,3,96,96).to(device), srgan.onnx, input_names[input], output_names[output], dynamic_axes{input: {0: batch}, output: {0: batch}})在实际项目中我发现生成器的残差块数量并非越多越好。当超过20个残差块时模型容易出现训练不稳定的情况。此外使用Adam优化器时beta2参数设置为0.999比默认的0.99能带来更稳定的训练过程。对于4K图像的超分辨率处理建议先对图像分块处理再合并可以有效降低显存消耗。

更多文章