从简单CNN到ResNet18:我是如何一步步把MNIST手写数字识别准确率提到99.5%以上的
从简单CNN到ResNet18我是如何一步步把MNIST手写数字识别准确率提到99.5%以上的当第一次接触MNIST数据集时我天真地以为用几层卷积神经网络就能轻松达到99%以上的准确率。现实很快给了我一记耳光——我的第一个简单CNN模型在测试集上只能达到97%左右的准确率。这促使我开启了一段持续优化的旅程最终将准确率提升到99.5%以上。在这个过程中我深刻体会到模型优化不是简单的堆叠层数而是需要系统性地思考数据、架构和训练策略的协同作用。1. 基础CNN模型搭建与初步优化我的起点是一个典型的LeNet风格架构包含两个卷积层和两个全连接层。这个基础版本在10个epoch后达到了97.11%的测试准确率但存在几个明显问题class BasicCNN(nn.Module): def __init__(self): super().__init__() self.conv1 nn.Conv2d(1, 10, kernel_size5) self.conv2 nn.Conv2d(10, 20, kernel_size5) self.fc1 nn.Linear(320, 50) self.fc2 nn.Linear(50, 10) def forward(self, x): x F.relu(F.max_pool2d(self.conv1(x), 2)) x F.relu(F.max_pool2d(self.conv2(x), 2)) x x.view(-1, 320) x F.relu(self.fc1(x)) return self.fc2(x)第一轮优化主要关注代码结构和训练效率使用nn.Sequential重构网络模块提升可读性和复用性添加批归一化层(BatchNorm)加速收敛采用nn.Flatten()替代手动展平操作设置ReLU的inplace参数为True减少内存占用优化后的模型结构如下class ImprovedCNN(nn.Module): def __init__(self): super().__init__() self.features nn.Sequential( nn.Conv2d(1, 10, 5), nn.MaxPool2d(2), nn.ReLU(True), nn.BatchNorm2d(10), nn.Conv2d(10, 20, 5), nn.MaxPool2d(2), nn.ReLU(True), nn.BatchNorm2d(20), nn.Flatten() ) self.classifier nn.Linear(320, 10)这些改动看似简单却带来了显著提升优化项准确率提升训练时间变化BatchNorm0.8%-15%结构化代码-代码可维护性↑inplace ReLU无内存占用↓20%2. 训练策略的精细调整当模型架构达到一个平台期后我开始关注训练过程的优化。这一阶段的关键发现是好的模型需要匹配好的训练策略。2.1 学习率动态调整固定学习率就像用恒定的速度爬山——开始可能合适但随着地形变化就会变得低效。我实现了学习率动态调整scheduler torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, modemax, factor0.5, patience3, threshold0.0001 )配合验证集准确率监控当指标停滞时自动降低学习率。这种策略在第85个epoch帮助模型突破了99.5%的关键瓶颈。2.2 数据增强的艺术MNIST虽然是干净的数据集但适度的数据增强能显著提升模型鲁棒性。我采用了以下增强组合transform transforms.Compose([ transforms.RandomAffine(degrees0, translate(0.1, 0.1)), transforms.RandomRotation((-10, 10)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])增强策略对比实验增强方式测试准确率过拟合程度无增强98.9%中等仅平移99.2%低平移旋转99.5%很低过度增强98.1%极低(欠拟合)2.3 正则化技术组合Dropout与权重衰减的协同使用产生了意想不到的效果self.classifier nn.Sequential( nn.Linear(64*3*3, 256), nn.ReLU(), nn.Dropout(0.5), # 关键位置的高dropout率 nn.Linear(256, 10) )配合权重初始化策略def weights_init(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, modefan_out, nonlinearityrelu) model.apply(weights_init)3. 深度架构探索从CNN到ResNet当传统CNN的优化空间逐渐缩小我开始尝试更先进的架构。ResNet的残差连接设计特别适合解决深度网络中的梯度消失问题。3.1 残差块实现要点class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, stride1): super().__init__() self.conv1 nn.Conv2d(in_channels, out_channels, kernel_size3, stridestride, padding1, biasFalse) self.bn1 nn.BatchNorm2d(out_channels) self.conv2 nn.Conv2d(out_channels, out_channels, kernel_size3, stride1, padding1, biasFalse) self.bn2 nn.BatchNorm2d(out_channels) self.shortcut nn.Sequential() if stride ! 1 or in_channels ! out_channels: self.shortcut nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size1, stridestride, biasFalse), nn.BatchNorm2d(out_channels) ) def forward(self, x): out F.relu(self.bn1(self.conv1(x))) out self.bn2(self.conv2(out)) out self.shortcut(x) return F.relu(out)3.2 自定义ResNet18架构针对MNIST的28x28小尺寸特点我对标准ResNet18做了适配调整class ResNetMNIST(nn.Module): def __init__(self, block, layers, num_classes10): super().__init__() self.in_channels 16 self.conv1 nn.Conv2d(1, 16, kernel_size3, stride1, padding1, biasFalse) self.bn1 nn.BatchNorm2d(16) self.layer1 self._make_layer(block, 16, layers[0], stride1) self.layer2 self._make_layer(block, 32, layers[1], stride2) self.layer3 self._make_layer(block, 64, layers[2], stride2) self.avgpool nn.AdaptiveAvgPool2d((1,1)) self.fc nn.Linear(64, num_classes)3.3 预训练模型适配直接使用torchvision的ResNet需要处理通道数不匹配问题model torchvision.models.resnet18(pretrainedFalse) model.conv1 nn.Conv2d(1, 64, kernel_size7, stride2, padding3, biasFalse)架构对比实验结果模型类型参数量测试准确率训练时间(每epoch)基础CNN50K97.1%12s优化CNN55K99.1%15s自定义ResNet181.1M99.3%45storchvision ResNet1811M98.4%60s4. 工程实践与性能优化在实际部署中我发现几个影响模型效用的关键因素4.1 GPU加速技巧# 数据加载优化 train_loader DataLoader( dataset, batch_size512, shuffleTrue, num_workers4, pin_memoryTrue # 减少CPU-GPU传输延迟 ) # 混合精度训练 scaler torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): outputs model(inputs) loss criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()4.2 训练监控与分析使用TensorBoard记录关键指标writer SummaryWriter() writer.add_scalar(Loss/train, loss.item(), global_step) writer.add_scalar(Accuracy/test, accuracy, global_step) writer.add_histogram(conv1/weights, model.conv1.weight, global_step)4.3 模型压缩与部署达到目标准确率后我尝试了模型量化quantized_model torch.quantization.quantize_dynamic( model, {nn.Linear}, dtypetorch.qint8 )量化前后对比指标原始模型量化模型模型大小4.7MB1.2MB推理延迟8.2ms3.1ms准确率99.5%99.4%这段优化之旅让我明白在深度学习中没有银弹式的解决方案。每个百分点的提升都需要数据、模型和训练策略的精心配合。当我在第85个epoch看到99.51%的测试准确率时所有的调试和等待都变得值得。