SRCNN超分辨率实战在Colab上用PyTorch训练自己的图像修复模型当你在社交媒体上看到一张模糊的老照片或是从监控视频中截取的关键帧分辨率太低时是否想过用AI技术让它们重获新生超分辨率技术正是为解决这类问题而生。本文将带你从零开始在Google Colab的免费GPU环境下用PyTorch实现经典的SRCNN模型并处理你自己的图片数据集。1. 环境准备与数据管理在Colab中运行深度学习项目首先要解决的是数据存储问题。与本地开发不同Colab的临时存储空间会在会话结束后清空因此我们需要合理利用Google Drive进行持久化存储。from google.colab import drive drive.mount(/content/drive)挂载成功后建议在Drive中创建如下目录结构SRCNN_Project/ ├── data/ │ ├── raw/ # 存放原始图像 │ ├── processed/ # 存放处理后的h5文件 ├── outputs/ # 训练输出 ├── logs/ # TensorBoard日志对于数据集选择除了论文中提到的91-image和Set5/Set14我们还可以使用以下更适合初学者的替代方案DIV2K包含800张训练图像和100张验证图像BSD500伯克利分割数据集含500张自然图像Flickr2K2650张高分辨率图像提示使用小规模数据集时建议将图像裁剪为256x256或128x128的patch这样可以增加样本数量并减少显存消耗。2. 高效数据预处理技巧原始论文要求将图像转换为h5格式这对Colab环境尤为重要——频繁读取小文件会显著降低IO性能。我们改进的prepare.py脚本增加了以下功能def create_h5_file(image_paths, output_path, patch_size33, stride14, scale3): h5_file h5py.File(output_path, w) lr_patches [] hr_patches [] for image_path in tqdm(image_paths): hr cv2.imread(image_path, cv2.IMREAD_COLOR) hr cv2.cvtColor(hr, cv2.COLOR_BGR2RGB) lr cv2.resize(hr, (hr.shape[1]//scale, hr.shape[0]//scale), interpolationcv2.INTER_CUBIC) # 生成patch for i in range(0, hr.shape[0]-patch_size1, stride): for j in range(0, hr.shape[1]-patch_size1, stride): hr_patch hr[i:ipatch_size, j:jpatch_size] lr_patch lr[i//scale:(ipatch_size)//scale, j//scale:(jpatch_size)//scale] lr_patch cv2.resize(lr_patch, (patch_size, patch_size), interpolationcv2.INTER_CUBIC) lr_patches.append(lr_patch.transpose(2,0,1)) hr_patches.append(hr_patch.transpose(2,0,1)) # 转换为numpy数组并保存 h5_file.create_dataset(lr, datanp.array(lr_patches, dtypenp.float32)/255.) h5_file.create_dataset(hr, datanp.array(hr_patches, dtypenp.float32)/255.) h5_file.close()关键参数说明参数推荐值作用patch_size33训练patch的大小stride14滑动窗口步长scale3超分辨率放大倍数3. 模型训练与优化SRCNN的PyTorch实现虽然简单但在Colab环境中训练时仍有多个优化点需要注意class SRCNN(nn.Module): def __init__(self): super(SRCNN, self).__init__() self.conv1 nn.Conv2d(3, 64, kernel_size9, padding4) self.conv2 nn.Conv2d(64, 32, kernel_size1, padding0) self.conv3 nn.Conv2d(32, 3, kernel_size5, padding2) self.relu nn.ReLU(inplaceTrue) def forward(self, x): x self.relu(self.conv1(x)) x self.relu(self.conv2(x)) x self.conv3(x) return x训练时的实用技巧学习率策略scheduler torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, modemax, factor0.5, patience5, verboseTrue)混合精度训练减少显存占用scaler torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): preds model(inputs) loss criterion(preds, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()早停机制if epoch - best_epoch 20: # 连续20轮未提升 print(Early stopping triggered) break4. 自定义图像处理实战训练完成后我们需要一个灵活的测试脚本处理各种来源的图像def process_custom_image(model, image_path, scale3, devicecuda): # 支持多种图像格式 img Image.open(image_path).convert(RGB) original_size img.size # 调整尺寸为scale的整数倍 new_width (img.width // scale) * scale new_height (img.height // scale) * scale if new_width ! img.width or new_height ! img.height: img img.resize((new_width, new_height), Image.BICUBIC) # 生成低分辨率版本 lr img.resize((new_width//scale, new_height//scale), Image.BICUBIC) lr lr.resize((new_width, new_height), Image.BICUBIC) # 上采样 # 转换到YCbCr色彩空间 ycbcr lr.convert(YCbCr) y, cb, cr ycbcr.split() # 处理Y通道 y_tensor torch.from_numpy(np.array(y, dtypenp.float32)/255.) y_tensor y_tensor.unsqueeze(0).unsqueeze(0).to(device) with torch.no_grad(): pred_y model(y_tensor).clamp(0, 1) # 合并通道 pred_y pred_y[0,0].cpu().numpy() * 255. pred_y Image.fromarray(pred_y.astype(np.uint8), modeL) cb cb.resize(pred_y.size, Image.BICUBIC) cr cr.resize(pred_y.size, Image.BICUBIC) result Image.merge(YCbCr, [pred_y, cb, cr]).convert(RGB) return result.resize(original_size, Image.BICUBIC)常见问题解决方案边缘伪影在测试时对图像进行镜像padding色彩失真确保在YCbCr空间只增强Y通道大图像内存不足使用滑动窗口分块处理5. 模型部署与性能提升虽然SRCNN结构简单但我们仍可以通过以下方式提升其实用性模型量化减小模型体积quantized_model torch.quantization.quantize_dynamic( model, {torch.nn.Conv2d}, dtypetorch.qint8) torch.jit.save(torch.jit.script(quantized_model), srcnn_quantized.pt)ONNX导出跨平台部署dummy_input torch.randn(1, 3, 256, 256).to(device) torch.onnx.export(model, dummy_input, srcnn.onnx, input_names[input], output_names[output], dynamic_axes{input: {0: batch}, output: {0: batch}})对于希望进一步提升效果的用户可以考虑这些改进方向使用ESRGAN的感知损失替代MSE添加通道注意力机制采用渐进式超分辨率策略