告别K-Means!用PyTorch手把手实现GMVAE,搞定MNIST无监督聚类(附完整代码)
用PyTorch实战GMVAE从数学公式到MNIST聚类可视化全流程解析当我们在处理MNIST这样的手写数字数据集时传统K-Means算法往往力不从心。每个数字的书写风格千变万化单一聚类中心难以捕捉这种多样性。这就是为什么我们需要高斯混合变分自编码器(GMVAE)——它将VAE的表达能力与GMM的灵活性相结合通过多个高斯分布来建模数据的内在结构。1. 环境准备与核心概念梳理在开始编码前我们需要明确几个关键概念标准VAE的局限性传统变分自编码器假设隐变量服从单一高斯分布这在处理多模态数据时显得过于简化GMM的优势高斯混合模型可以用多个高斯分布的加权和来建模复杂分布GMVAE的创新点将VAE的隐空间划分为多个高斯成分每个成分对应数据的一个潜在类别安装所需环境pip install torch torchvision matplotlib numpy核心依赖库版本要求PyTorch ≥ 1.8.0Torchvision ≥ 0.9.0提示建议使用conda创建虚拟环境以避免依赖冲突。GPU加速可显著提升训练速度但本文代码也兼容CPU运行。2. 模型架构设计GMVAE包含几个关键组件我们需要在PyTorch中逐一实现2.1 编码器网络设计编码器需要输出两组参数隐变量x的分布参数(μ, σ)辅助变量w的分布参数class Encoder(nn.Module): def __init__(self, input_dim784, hidden_dim400, x_dim20, w_dim10): super().__init__() self.fc1 nn.Linear(input_dim, hidden_dim) self.fc2 nn.Linear(hidden_dim, hidden_dim) # x的分布参数 self.fc_mu_x nn.Linear(hidden_dim, x_dim) self.fc_var_x nn.Linear(hidden_dim, x_dim) # w的分布参数 self.fc_mu_w nn.Linear(hidden_dim, w_dim) self.fc_var_w nn.Linear(hidden_dim, w_dim) def forward(self, x): h F.relu(self.fc1(x)) h F.relu(self.fc2(h)) mu_x self.fc_mu_x(h) var_x F.softplus(self.fc_var_x(h)) mu_w self.fc_mu_w(h) var_w F.softplus(self.fc_var_w(h)) return mu_x, var_x, mu_w, var_w2.2 解码器网络设计解码器相对简单只需从隐变量重构输入class Decoder(nn.Module): def __init__(self, output_dim784, hidden_dim400, x_dim20): super().__init__() self.fc1 nn.Linear(x_dim, hidden_dim) self.fc2 nn.Linear(hidden_dim, hidden_dim) self.fc3 nn.Linear(hidden_dim, output_dim) def forward(self, z): h F.relu(self.fc1(z)) h F.relu(self.fc2(h)) return torch.sigmoid(self.fc3(h))3. 核心算法实现GMVAE的核心在于其特殊的损失函数计算。我们需要实现四个关键部分3.1 重构损失计算这部分与传统VAE相同使用二元交叉熵def reconstruction_loss(recon_x, x): BCE F.binary_cross_entropy(recon_x, x.view(-1, 784), reductionsum) return BCE3.2 条件先验项实现这是GMVAE最复杂的部分涉及混合高斯分布的计算def conditional_prior_loss(qx_mu, qx_var, w, components): qx_mu, qx_var: 编码器输出的x分布参数 w: 从q(w|y)采样的辅助变量 components: 高斯混合成分数 # 计算各个高斯成分的参数 comp_mus, comp_vars compute_gmm_parameters(w, components) # 计算log q(x|y) log_qx torch.distributions.Normal(qx_mu, qx_var.sqrt()).log_prob(qx_mu) # 计算log p(x|w,z) log_pxz [] for k in range(components): dist torch.distributions.Normal(comp_mus[:,k], comp_vars[:,k].sqrt()) log_pxz.append(dist.log_prob(qx_mu)) log_pxz torch.stack(log_pxz, dim1) # 计算z的后验概率 z_probs compute_z_posterior(qx_mu, w, components) loss (log_qx - (z_probs * log_pxz).sum(1)).mean() return loss3.3 W先验与Z先验项这两个正则项确保辅助变量的分布接近先验def w_prior_loss(qw_mu, qw_var): # KL(q(w|y) || p(w)), p(w)为标准正态 kl -0.5 * torch.sum(1 torch.log(qw_var) - qw_mu.pow(2) - qw_var) return kl def z_prior_loss(z_probs, components): # KL(p(z|x,w) || p(z)), p(z)为均匀分布 uniform torch.ones_like(z_probs) / components kl F.kl_div(z_probs.log(), uniform, reductionbatchmean) return kl4. 训练流程与技巧完整的训练循环需要注意以下几个关键点4.1 重参数化技巧的扩展GMVAE需要在两个变量上应用重参数化def reparameterize(mu, var): std torch.sqrt(var) eps torch.randn_like(std) return mu eps * std # 在训练循环中 x_mu, x_var, w_mu, w_var encoder(input) x_sample reparameterize(x_mu, x_var) w_sample reparameterize(w_mu, w_var)4.2 学习率调度策略由于损失函数复杂建议使用学习率预热optimizer torch.optim.Adam(model.parameters(), lr1e-3) scheduler torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambdalambda epoch: min(epoch/10, 1.0))4.3 训练监控指标除了损失值建议监控以下指标重构质量(PSNR)聚类纯度(NMI)隐空间可视化5. MNIST聚类可视化实战训练完成后我们可以进行聚类可视化分析5.1 隐空间可视化def plot_latent_space(encoder, test_loader, device): with torch.no_grad(): for data, _ in test_loader: data data.to(device) mu, _, _, _ encoder(data) mu mu.cpu().numpy() plt.scatter(mu[:,0], mu[:,1], alpha0.5) plt.show()5.2 聚类效果评估计算归一化互信息(NMI)分数from sklearn.metrics import normalized_mutual_info_score def evaluate_clustering(encoder, test_loader, device): true_labels [] pred_labels [] with torch.no_grad(): for data, label in test_loader: data data.to(device) mu, _, _, _ encoder(data) # 获取预测聚类标签 pred get_cluster_assignment(mu) true_labels.extend(label.numpy()) pred_labels.extend(pred.cpu().numpy()) return normalized_mutual_info_score(true_labels, pred_labels)5.3 生成样本展示通过不同高斯成分生成样本def generate_samples(decoder, components, device): with torch.no_grad(): fig, axes plt.subplots(1, components, figsize(15,3)) for k in range(components): # 从第k个高斯成分采样 z sample_from_component(k, components) sample decoder(z).cpu().numpy() axes[k].imshow(sample.reshape(28,28), cmapgray) plt.show()在实际项目中我发现调整高斯成分数量对结果影响显著。对于MNIST数据集10个成分通常能得到最佳效果这与数字类别数一致。另一个关键点是w变量的维度——太小会导致信息瓶颈太大则会增加训练难度。经过多次实验10-20维的w空间在大多数情况下表现良好。