用Focal Loss破解图像分类中的样本不平衡难题在工业质检和医疗影像分析中我们常遇到正负样本比例悬殊的场景——比如生产线上的缺陷检测正常产品占99%缺陷仅占1%。传统交叉熵损失(CE Loss)在这种极端不平衡的数据集上往往表现不佳模型会倾向于预测多数类来降低整体损失。本文将带你用PyTorch实现Focal Loss通过一个真实的PCB板缺陷检测项目演示如何通过调整alpha和gamma参数显著提升少数类的识别效果。1. 为什么CE Loss在样本不平衡时失效假设我们有个1万张图片的数据集其中正常PCB板占9900张缺陷板仅100张。使用普通CE Loss训练时即使模型将所有样本都预测为正常也能达到99%的准确率——这个数字看起来很漂亮但完全漏检了所有缺陷。CE Loss的数学表达式def cross_entropy_loss(output, target): # output: 模型原始输出 (未经过softmax) # target: 真实标签 (类别索引) return -torch.log(torch.softmax(output, dim1)[:, target])这种多数类偏见源于两个根本问题数量失衡损失函数被多数类样本主导难度差异简单样本(高置信度预测)的梯度贡献远大于困难样本下表展示了CE Loss在不同场景下的表现对比场景正负样本比例验证准确率缺陷召回率平衡数据1:192%89%轻度不平衡(10:1)10:195%76%重度不平衡(100:1)100:199%9%2. Focal Loss的核心机制与实现Focal Loss通过两个关键改进解决上述问题2.1 类别平衡因子(alpha)为少数类分配更高权重缓解数量不平衡。在PCB缺陷检测中我们可以给缺陷类设置alpha0.75正常类alpha0.25。2.2 困难样本聚焦因子(gamma)降低高置信度样本的损失贡献让模型更关注难以分类的样本。gamma通常取2。PyTorch实现代码class FocalLoss(nn.Module): def __init__(self, alpha0.25, gamma2, num_classes2): super().__init__() self.alpha torch.tensor([alpha, 1-alpha]) # 假设第0类是少数类 self.gamma gamma self.num_classes num_classes def forward(self, inputs, targets): # 计算标准CE Loss ce_loss F.cross_entropy(inputs, targets, reductionnone) # 计算概率pt pt torch.exp(-ce_loss) # p_t p if y1, else 1-p # 组合alpha和gamma因子 alpha self.alpha[targets] # 按类别选择alpha focal_loss alpha * (1-pt)**self.gamma * ce_loss return focal_loss.mean()参数选择经验alpha少数类样本比例越高alpha应越小。建议初始值为1/样本比例gamma通常在0.5-5之间2是最常用起始点3. 实战PCB缺陷检测项目我们使用ResNet18在DeepPCB数据集上进行实验该数据集包含1500张图像缺陷与正常比例为1:30。3.1 基础训练配置# 数据加载 train_loader DataLoader( ImbalancedDatasetSampler(train_dataset), # 使用采样器缓解不平衡 batch_size32, num_workers4 ) # 模型与优化器 model resnet18(pretrainedTrue) model.fc nn.Linear(512, 2) # 二分类 optimizer torch.optim.Adam(model.parameters(), lr1e-4) # 损失函数对比 ce_criterion nn.CrossEntropyLoss() focal_criterion FocalLoss(alpha0.75, gamma2)3.2 训练过程关键指标训练曲线对比Focal Loss vs CE Loss指标CE LossFocal Loss训练损失0.120.35验证准确率98.7%96.2%缺陷召回率15%83%精确率60%78%虽然Focal Loss的总体准确率略低但关键的缺陷召回率提升了5倍多3.3 参数调优技巧通过网格搜索寻找最佳参数组合alpha_range [0.1, 0.25, 0.5, 0.75] gamma_range [0.5, 1, 2, 3] results [] for alpha in alpha_range: for gamma in gamma_range: criterion FocalLoss(alphaalpha, gammagamma) trainer Trainer(model, criterion, optimizer) metrics trainer.evaluate(val_loader) results.append((alpha, gamma, metrics[recall]))最佳参数组合通常出现在alpha ≈ 1/少数类比例gamma在1-3之间4. 进阶技巧与问题排查4.1 结合其他不平衡处理方法Focal Loss可以与以下技术配合使用过采样复制少数类样本欠采样减少多数类样本数据增强特别针对少数类的增强# 示例结合过采样 from torchsampler import ImbalancedDatasetSampler train_loader DataLoader( train_dataset, samplerImbalancedDatasetSampler(train_dataset), batch_size32 )4.2 常见问题解决方案问题1训练初期损失震荡剧烈解决降低初始学习率使用学习率热身(warmup)问题2验证集指标波动大解决增加batch size或使用梯度累积问题3模型对gamma过于敏感解决从gamma1开始逐步增加并观察验证集召回率4.3 多分类场景扩展对于多分类问题Focal Loss需要为每个类别设置不同的alphaclass MultiClassFocalLoss(nn.Module): def __init__(self, class_weights, gamma2): super().__init__() self.alpha class_weights # 各类别权重张量 self.gamma gamma def forward(self, inputs, targets): ce_loss F.cross_entropy(inputs, targets, reductionnone) pt torch.exp(-ce_loss) alpha self.alpha[targets] return (alpha * (1-pt)**self.gamma * ce_loss).mean()在医疗影像分类中(如肺炎、肿瘤、正常三类)可以按样本比例的反比设置class_weights。5. 其他不平衡损失函数对比除了Focal Loss还有几种处理样本不平衡的损失函数值得了解损失函数优点缺点适用场景CE Loss简单稳定忽视样本不平衡平衡数据集Focal Loss关注困难样本需调参极端不平衡GHM Loss避免离群点干扰实现复杂噪声较多数据Class-Balanced Loss自动调整权重计算开销大类别分布已知在医疗影像分割任务中我们发现当缺陷区域非常小(如仅占图像的1%)时Focal Loss配合Dice Loss能取得更好效果def hybrid_loss(pred, target): focal FocalLoss(alpha0.8, gamma2)(pred, target) dice 1 - dice_coeff(pred, target) # Dice系数 return focal 0.5*dice最终在PCB缺陷检测项目中经过2周调优我们的模型将缺陷检出率从15%提升至88%同时将误报率控制在5%以下。关键收获是gamma值并非越大越好当gamma3时模型开始过度关注极端困难样本导致性能下降。最佳参数组合是alpha0.7gamma1.5配合适度的数据增强。