从NLP跨界CV手把手教你用PyTorch复现Vision Transformer (ViT) 图像分类当Transformer在自然语言处理领域大放异彩时计算机视觉研究者们开始思考这种基于自注意力机制的架构能否同样颠覆图像识别领域2020年Vision Transformer (ViT) 的出现给出了肯定答案。本文将带你从零开始用PyTorch实现这一开创性模型体验如何将图像转化为视觉词汇的奇妙过程。1. ViT核心原理与设计思路传统卷积神经网络(CNN)通过局部感受野逐层提取特征而ViT则采用全局视角处理图像——它将输入图片分割为16x16的视觉词汇块(patches)每个块经过线性投影后成为Transformer可处理的序列元素。这种设计带来了三大关键创新图像序列化将2D图像转换为1D令牌序列位置编码通过可学习的位置嵌入保留空间信息纯Transformer架构完全摒弃卷积操作注意ViT在中小型数据集上可能不如CNN表现优异但当训练数据超过1亿张图片时其性能开始显著超越传统方法。下表对比了ViT与典型CNN的核心差异特性ViTCNN特征提取方式全局自注意力局部卷积核空间信息处理显式位置编码隐式感受野累积数据依赖性需要大量训练数据中等规模数据即可计算复杂度O(n²)O(n)2. 环境准备与数据预处理2.1 安装必要依赖确保你的Python环境包含以下核心库pip install torch torchvision pytorch-lightning einops2.2 CIFAR-10数据集处理我们将使用CIFAR-10作为演示数据集。虽然原始ViT论文使用更大规模的ImageNet但CIFAR-10更适合快速验证from torchvision import datasets, transforms # 定义数据增强策略 train_transform transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # 加载数据集 train_data datasets.CIFAR10(data, trainTrue, downloadTrue, transformtrain_transform) test_data datasets.CIFAR10(data, trainFalse, transformtrain_transform)3. ViT模型实现详解3.1 图像分块与线性嵌入ViT的第一步是将图像分割为固定大小的块并线性投影到特征空间import torch import torch.nn as nn from einops import rearrange class PatchEmbedding(nn.Module): def __init__(self, img_size32, patch_size4, in_channels3, embed_dim64): super().__init__() self.proj nn.Conv2d(in_channels, embed_dim, kernel_sizepatch_size, stridepatch_size) def forward(self, x): x self.proj(x) # [B, C, H, W] - [B, D, H/P, W/P] x rearrange(x, b d h w - b (h w) d) return x3.2 位置编码与分类令牌Transformer需要位置信息来理解图像的空间结构class ViTEncoder(nn.Module): def __init__(self, num_patches, embed_dim, num_heads, num_layers): super().__init__() self.cls_token nn.Parameter(torch.randn(1, 1, embed_dim)) self.pos_embed nn.Parameter(torch.randn(1, num_patches 1, embed_dim)) self.transformer nn.TransformerEncoder( nn.TransformerEncoderLayer(embed_dim, num_heads), num_layers ) def forward(self, x): cls_tokens self.cls_token.expand(x.shape[0], -1, -1) x torch.cat((cls_tokens, x), dim1) x self.pos_embed return self.transformer(x)4. 完整模型组装与训练4.1 构建端到端ViT模型整合所有组件形成完整架构class VisionTransformer(nn.Module): def __init__(self, img_size32, patch_size4, in_channels3, embed_dim64, num_heads4, num_layers4, num_classes10): super().__init__() self.patch_embed PatchEmbedding(img_size, patch_size, in_channels, embed_dim) num_patches (img_size // patch_size) ** 2 self.encoder ViTEncoder(num_patches, embed_dim, num_heads, num_layers) self.head nn.Linear(embed_dim, num_classes) def forward(self, x): x self.patch_embed(x) x self.encoder(x) return self.head(x[:, 0]) # 使用分类令牌输出4.2 训练策略与超参数设置使用PyTorch Lightning简化训练流程import pytorch_lightning as pl from torch.utils.data import DataLoader class ViTLightning(pl.LightningModule): def __init__(self, lr1e-3): super().__init__() self.model VisionTransformer() self.lr lr self.criterion nn.CrossEntropyLoss() def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): x, y batch preds self(x) loss self.criterion(preds, y) self.log(train_loss, loss) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lrself.lr) # 初始化训练器 trainer pl.Trainer(max_epochs50, gpus1 if torch.cuda.is_available() else 0) model ViTLightning() # 数据加载器 train_loader DataLoader(train_data, batch_size64, shuffleTrue) test_loader DataLoader(test_data, batch_size64) # 开始训练 trainer.fit(model, train_loader)5. 模型优化与调参技巧5.1 学习率调度策略ViT训练对学习率非常敏感推荐使用warmup策略def configure_optimizers(self): optimizer torch.optim.AdamW(self.parameters(), lrself.lr) scheduler torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lrself.lr, total_stepsself.trainer.estimated_stepping_batches ) return [optimizer], [scheduler]5.2 混合精度训练加速利用NVIDIA GPU的Tensor Core加速训练trainer pl.Trainer( max_epochs50, precision16, acceleratorgpu if torch.cuda.is_available() else cpu )5.3 关键超参数经验值基于CIFAR-10的实验验证以下配置表现良好参数推荐值说明patch_size4平衡计算量与局部信息保留embed_dim64-128特征维度num_heads4-8注意力头数num_layers6-12Transformer层数batch_size64-128根据GPU内存调整6. 模型评估与结果分析6.1 测试集性能评估def test_step(self, batch, batch_idx): x, y batch preds self(x) loss self.criterion(preds, y) acc (preds.argmax(1) y).float().mean() self.log(test_loss, loss) self.log(test_acc, acc) return {loss: loss, acc: acc}6.2 可视化注意力机制理解模型如何关注图像不同区域import matplotlib.pyplot as plt def visualize_attention(model, img): model.eval() with torch.no_grad(): patches model.patch_embed(img.unsqueeze(0)) attns model.encoder.transformer.layers[0].self_attn( patches, patches, patches )[1] plt.imshow(attns[0, 0, 1:].reshape(8, 8).cpu()) plt.colorbar() plt.show()在CIFAR-10上训练约50个epoch后预期可以达到75-80%的测试准确率。虽然这低于原始论文在更大数据集上的结果但足以验证ViT的基本原理。