distalliation蒸馏
fine-tuning属于迁移学习但在工程语境里“迁移学习” ≈ “拿预训练模型来 fine-tune”fine-tuning时指模型的所有/部分参数参与反向传播更新。冻结backbone只训练分类头也属于迁移学习但不属于fine-turning它是一个大概念指把在一类数据集上训练的某个任务的模型用在其他任务中因此需要用小的学习率微调一下适应新的任务。剪枝 “硬调整结构”剪枝前先遍历查看每个卷积层的重要性根据重要性来剪掉不重要得层。蒸馏 “软传递知识”一般剪枝后对剪枝后的模型使用蒸馏。蒸馏的损失函数KLDivLoss最推荐MSELoss勉强可以用CrossEntropyLoss不可用蒸馏就是在teacher的教导下student从0开始开始拿一部分数据进行训练。T代表着蒸馏的温度当它约大概率之间的差异约小称为越软越平滑。举例如下logits: [10, 2]T1softmax → [0.999, 0.001]T5softmax → [0.73, 0.27]这个代码比较简单称为基本的蒸馏方法现在诞生出很多蒸馏框架目前没有数据证明他们比基本的蒸馏方法好用参数调整结果蒸馏的原理其中loss_clc是student的预测与ground truth的损失loss_kd是s与t预测的损失。loss_clc保证s不会偏离真实值进行训练。而loss_kd则教给s输出结构让他学会像t一样输出。比如tiger:0.8 cat:0.15 dog:0.05 这个结构其实是在教s猫和老虎有点相似之处。狗则完全不像让s更能像t一样思考。此是T(温度)太大导致越软概率之间区分不明显各个概率像差不大学不到有效信息。T太小导致太硬跟ground truth差不多只直到谁是正确的其他都是错的也学不到结构信息。一般蒸馏中T取大于1的值2–5 常用T 4 工业界最稳定。T最好的值就是告诉s哪个大概率是对的也有一个跟他比较相似另外一个完全不像他。loss(1-args.lamda)*loss_clcargs.lamda*args.T*args.T*loss_kd目前来看蒸馏就是在teacher教学下训练一部分数据怎么证明它比没有teacher的教学下训练的好#coding:utf8from __future__importprint_functionimportosimportargparseimportshutilimportnumpyas npimporttorchimporttorch.nnas nnimporttorch.nn.functionalas Fimporttorch.optimas optim from torchvisionimportdatasets,transforms from torch.autogradimportVariableimportmodelsfrom modelsimport*## 训练配置参数 parserargparse.ArgumentParser(descriptionPyTorch KD simpleconv3 training)parser.add_argument(--batch-size,typeint,default64,metavarN,helpinput batch sizefortraining(default:64))parser.add_argument(--epochs,typeint,default100,metavarN,helpnumber of epochs totrain(default:100))parser.add_argument(--start_epoch,default0,typeint,metavarN,helpmanual epochnumber(useful on restarts))parser.add_argument(--lr,typefloat,default0.1,metavarLR,helplearning rate (default: 0.1))parser.add_argument(--lamda,typefloat,default0.5,helpKL loss weight (default: 0.5))##损失权重系数 parser.add_argument(--T,typefloat,default5.0,helpknowledge distillationtemperature(default:5))##蒸馏温度系数 parser.add_argument(--momentum,typefloat,default0.9,metavarM,helpSGD momentum (default: 0.9))parser.add_argument(--weight-decay,--wd,default1e-4,typefloat,metavarW,helpweight decay (default: 1e-4))parser.add_argument(--no-cuda,actionstore_true,defaultFalse,helpdisables CUDA training)parser.add_argument(--seed,typeint,default1,metavarS,helprandom seed (default: 1))parser.add_argument(--tmodelpath,typestr,defaultmodels/tmodel.pth.tar,helpteacher model path)parser.add_argument(--save,default./checkpoint,typestr,metavarPATH,helppath to save studentmodel(default:current directory))argsparser.parse_args()args.cudanotargs.no_cudaandtorch.cuda.is_available()torch.manual_seed(args.seed)ifargs.cuda:torch.cuda.manual_seed(args.seed)ifnotos.path.exists(args.save):os.makedirs(args.save)kwargs{num_workers:1,pin_memory:True}ifargs.cudaelse{}## 训练配置 image_size60##图像缩放大小 crop_size48##图像裁剪大小 nclass4##类别 tmodelsimpleconv3(nclass)tmodel.eval()smodelsimpleconv3small(nclass)ifargs.cuda:tmodel.cuda()smodel.cuda()## 数据读取与预处理方法 data_dir./datadata_transforms{train:transforms.Compose([transforms.RandomSizedCrop(crop_size,scale(0.8,1.0)),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ColorJitter(brightness0.1,contrast0.1,saturation0.1,hue0.1),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]),val:transforms.Compose([transforms.Scale(image_size),transforms.CenterCrop(crop_size),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]),}image_datasets{x:datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x])forx in[train,val]}dataloaders{x:torch.utils.data.DataLoader(image_datasets[x],batch_size64,shuffleTrue,num_workers4)forx in[train,val]}train_loaderdataloaders[train]test_loaderdataloaders[val]## 优化方法 optimizeroptim.SGD(smodel.parameters(),lrargs.lr,momentumargs.momentum,weight_decayargs.weight_decay)kd_funnn.KLDivLoss(reduceTrue)## 载入训练好的teacher模型print( loading teacher model checkpoint .format(args.tmodelpath))tcheckpointtorch.load(args.tmodelpath,map_locationlambda storage,loc:storage)tmodel.load_state_dict(tcheckpoint[state_dict])print( loaded checkpoint {} (epoch {}) Prec1: {:f}.format(args.tmodelpath,tcheckpoint[epoch],tcheckpoint[best_prec1]))print(optistr(tcheckpoint[optimizer][param_groups]))## 训练函数 from tensorboardXimportSummaryWriterwriterSummaryWriter(args.save)defprocess(epoch,data_loader,istrainTrue):ifistrain:smodel.train()else:smodel.eval()running_loss_clc0.0running_loss_kd0.0running_loss0.0running_acc0.0num_batch0fordata,target in data_loader:ifargs.cuda:data,targetdata.cuda(),target.cuda()data,targetVariable(data),Variable(target)ifistrain:optimizer.zero_grad()output_ssmodel(data)output_ttmodel(data)_,predstorch.max(output_s.data,1)s_maxF.log_softmax(output_s/args.T,dim1)t_maxF.softmax(output_t/args.T,dim1)batch_sizetarget.shape[0]loss_kdkd_fun(s_max,t_max)##KL散度,实现为logy-x输入第一项必须是对数形式#loss_kdkd_fun(s_max,t_max)/batch_size##KL散度,实现位logy-x输入第一项必须是对数形式loss_clcF.cross_entropy(output_s,target)##分类loss loss(1-args.lamda)*loss_clcargs.lamda*args.T*args.T*loss_kd running_acctorch.sum(predstarget).item()running_lossloss.data.item()running_loss_clcloss_clc.data.item()running_loss_kdloss_kd.data.item()num_batch1ifistrain:loss.backward()optimizer.step()epoch_lossrunning_loss/num_batch epoch_loss_clcrunning_loss_clc/num_batch epoch_loss_kdrunning_loss_kd/num_batch epoch_accrunning_acc/len(data_loader.dataset)ifistrainTrue:writer.add_scalar(data/trainloss,epoch_loss,epoch)writer.add_scalar(data/trainloss_clc,epoch_loss_clc,epoch)writer.add_scalar(data/trainloss_kd,epoch_loss_kd,epoch)writer.add_scalar(data/trainacc,epoch_acc,epoch)print(\nTrain set:Epoch:{},Average loss:{:.4f}.format(epoch,epoch_loss))else:writer.add_scalar(data/testloss,epoch_loss,epoch)writer.add_scalar(data/testloss_clc,epoch_loss_clc,epoch)writer.add_scalar(data/testloss_kd,epoch_loss_kd,epoch)writer.add_scalar(data/testacc,epoch_acc,epoch)print(\nTest set:Average loss:{:.4f},Accuracy:{}/{}({:.3f})\n.format(epoch_loss,running_acc,len(data_loader.dataset),epoch_acc))returnepoch_acc defsave_checkpoint(state,is_best,filepath):torch.save(state,os.path.join(filepath,checkpoint.pth.tar))ifis_best:shutil.copyfile(os.path.join(filepath,checkpoint.pth.tar),os.path.join(filepath,model_best.pth.tar))## 训练 best_prec10.print(args.start_epochstr(args.start_epoch))forepoch inrange(args.start_epoch,args.epochs):ifepoch in[args.epochs*0.5,args.epochs*0.75]:##学习率变更50%75%的2次epoch学习率乘以0.1param_groups包括[{params,lr,momentum,dampening,weight_decay,nesterov},{……}]forparam_group in optimizer.param_groups:param_group[lr]*0.1prec_trainprocess(epoch,train_loader,istrainTrue)prec1process(epoch,test_loader,istrainFalse)is_bestprec1best_prec1 best_prec1max(prec1,best_prec1)save_checkpoint({epoch:epoch1,state_dict:smodel.state_dict(),best_prec1:best_prec1,optimizer:optimizer.state_dict(),},is_best,filepathargs.save)print(Best accuracy: str(best_prec1))