从IMDB电影推荐到DBLP学者分类:手把手教你用PyTorch和DGL复现异构图注意力网络HAN
从IMDB电影推荐到DBLP学者分类手把手教你用PyTorch和DGL复现异构图注意力网络HAN在现实世界中数据往往以复杂的网络结构存在——社交平台的好友关系、学术论文的引用网络、电商平台的用户-商品交互记录这些场景都蕴含着丰富的图结构信息。传统机器学习方法在处理这类非欧几里得数据时往往力不从心而图神经网络(GNN)的出现为这类问题提供了全新的解决方案。异构图注意力网络(HAN)作为GNN家族中的重要成员通过双层注意力机制巧妙捕捉了异构图中不同类型节点和边的语义关系在推荐系统、学术网络分析等领域展现出强大潜力。本文将带您深入HAN模型的核心实现细节使用PyTorch和DGL框架从零构建完整的模型架构。不同于单纯的理论讲解我们会以IMDB电影数据集和DBLP学术网络为具体案例详细剖析如何处理异构数据、设计元路径、实现注意力机制最终完成电影类型预测和学者领域分类的实战任务。无论您是希望将HAN应用于推荐系统的工程师还是研究图表示学习的研究者都能从中获得可直接复用的代码范例和工程实践技巧。1. 环境准备与数据加载1.1 基础环境配置在开始模型实现前需要确保开发环境已安装必要的软件包。推荐使用Python 3.8环境和conda进行依赖管理conda create -n han python3.8 conda activate han pip install torch1.12.0cu113 torchvision0.13.0cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install dgl-cu1130.9.0 pip install scikit-learn pandas tqdm注意根据您的CUDA版本选择对应的PyTorch和DGL安装包CPU版本可省略cu113后缀1.2 数据集解析与预处理HAN论文中使用了三个经典异构图数据集IMDB电影-演员-导演、DBLP论文-作者-会议和ACM论文-作者-主题。我们重点分析IMDB和DBLP的结构特点IMDB数据集特征节点类型电影(M)、演员(A)、导演(D)边类型演员-电影(出演)、导演-电影(执导)目标任务电影类型分类(动作/喜剧/话剧)原始特征电影剧情词袋表示(维度3066)DBLP数据集特征节点类型论文(P)、作者(A)、会议(C)、术语(T)边类型作者-论文(撰写)、论文-会议(发表)、论文-术语(包含)目标任务作者领域分类(DB/DM/ML/IR)原始特征作者关键词词袋表示(维度334)使用DGL加载预处理好的数据集import dgl from dgl.data import HANDataset # 加载IMDB数据集 imdb_data HANDataset(nameimdb) imdb_graph imdb_data[0] # 获取异构图对象 imdb_labels imdb_data.labels[movie] # 电影节点标签 imdb_train_mask imdb_data.train_mask[movie] # 训练集掩码 # 加载DBLP数据集 dblp_data HANDataset(namedblp) dblp_graph dblp_data[0] dblp_labels dblp_data.labels[author] dblp_train_mask dblp_data.train_mask[author]2. 元路径设计与邻居子图构建2.1 理解元路径语义元路径是HAN模型处理异构信息的关键设计它定义了节点间的复合关系路径。对于IMDB数据集我们设计以下典型元路径MAM电影-演员-电影表示同一演员出演的不同电影MDM电影-导演-电影表示同一导演执导的不同电影MYM电影-年份-电影表示同一年份上映的不同电影需年份节点这些元路径捕捉了电影之间不同的语义关系。例如MAM路径更适合发现电影类型相似性同一演员常出演同类电影而MDM路径可能反映电影风格的一致性。2.2 基于DGL的元路径子图生成DGL提供了便捷的metapath_reachable_graph函数来提取基于元路径的同构子图# 为IMDB构建MAM和MDM元路径子图 mam_graph dgl.metapath_reachable_graph(imdb_graph, [movie, actor, movie]) mdm_graph dgl.metapath_reachable_graph(imdb_graph, [movie, director, movie]) # DBLP的元路径示例APA(作者-论文-作者) apa_graph dgl.metapath_reachable_graph(dblp_graph, [author, paper, author])提示实际应用中应根据业务场景设计有意义的元路径。例如在电商场景用户-商品-用户(UBU)路径可发现相似用户偏好2.3 异构图特征工程异构节点通常具有不同类型的特征需要进行统一化处理import torch.nn as nn # 特征投影层将不同类型节点特征映射到相同维度 class TypeSpecificLinear(nn.Module): def __init__(self, in_features, out_features, ntypes): super().__init__() self.linears nn.ModuleDict({ ntype: nn.Linear(in_features[ntype], out_features) for ntype in ntypes }) def forward(self, graph, feat_dict): return { ntype: self.linears[ntype](feat_dict[ntype]) for ntype in feat_dict } # IMDB特征投影示例 imdb_ntypes [movie, actor, director] imdb_in_feats {movie: 3066, actor: 256, director: 256} projector TypeSpecificLinear(imdb_in_feats, 64, imdb_ntypes)3. HAN模型架构实现3.1 节点级注意力层节点级注意力学习同一元路径下邻居节点的重要性权重其实现借鉴了GAT的思路但扩展到了异构场景import torch.nn.functional as F class NodeLevelAttention(nn.Module): def __init__(self, in_feats, out_feats, num_heads): super().__init__() self.num_heads num_heads self.fc nn.Linear(in_feats, out_feats * num_heads, biasFalse) self.attn_fc nn.Linear(2 * out_feats, 1, biasFalse) self.reset_parameters() def reset_parameters(self): gain nn.init.calculate_gain(relu) nn.init.xavier_normal_(self.fc.weight, gaingain) nn.init.xavier_normal_(self.attn_fc.weight, gaingain) def edge_attention(self, edges): z2 torch.cat([edges.src[z], edges.dst[z]], dim1) a self.attn_fc(z2) return {e: F.leaky_relu(a)} def forward(self, g, h): z self.fc(h).view(-1, self.num_heads, self.out_feats) g.ndata[z] z g.apply_edges(self.edge_attention) g.edata[alpha] F.softmax(g.edata[e], dim1) h_out [] for head in range(self.num_heads): g.edata[alpha_h] g.edata[alpha][:, head:head1] g.update_all( fn.u_mul_e(z, alpha_h, m), fn.sum(m, h_out) ) h_out.append(g.ndata[h_out]) return torch.cat(h_out, dim1)3.2 语义级注意力层语义级注意力学习不同元路径的重要性权重聚合各元路径的节点表示class SemanticLevelAttention(nn.Module): def __init__(self, in_feats, hidden_size): super().__init__() self.project nn.Sequential( nn.Linear(in_feats, hidden_size), nn.Tanh(), nn.Linear(hidden_size, 1, biasFalse) ) def forward(self, z): w self.project(z).mean(0) # (M, 1) beta torch.softmax(w, dim0) # (M, 1) beta beta.expand((z.shape[0],) beta.shape) # (N, M, 1) return (beta * z).sum(1) # (N, D*K)3.3 完整HAN模型集成将节点级和语义级注意力组合成完整HAN模型class HAN(nn.Module): def __init__(self, meta_paths, in_feats, hidden_size, out_feats, num_heads): super().__init__() self.layers nn.ModuleList() self.layers.append(NodeLevelAttention(in_feats, hidden_size, num_heads)) self.semantic_attention SemanticLevelAttention(hidden_size * num_heads, hidden_size) self.predict nn.Linear(hidden_size * num_heads, out_feats) self.meta_paths meta_paths def forward(self, g, h): semantic_embeddings [] for i, meta_path in enumerate(self.meta_paths): new_g dgl.metapath_reachable_graph(g, meta_path) semantic_embeddings.append(self.layers[0](new_g, h).flatten(1)) semantic_embeddings torch.stack(semantic_embeddings, dim1) # (N, M, D*K) h self.semantic_attention(semantic_embeddings) return self.predict(h)4. 模型训练与评估4.1 训练流程实现def train(model, g, features, labels, train_mask, val_mask, epochs100): optimizer torch.optim.Adam(model.parameters(), lr0.005, weight_decay0.001) loss_fcn nn.CrossEntropyLoss() for epoch in range(epochs): model.train() logits model(g, features) loss loss_fcn(logits[train_mask], labels[train_mask]) optimizer.zero_grad() loss.backward() optimizer.step() acc evaluate(model, g, features, labels, val_mask) print(fEpoch {epoch:02d} | Loss: {loss.item():.4f} | Val Acc: {acc:.4f}) def evaluate(model, g, features, labels, mask): model.eval() with torch.no_grad(): logits model(g, features) logits logits[mask] labels labels[mask] _, indices torch.max(logits, dim1) correct torch.sum(indices labels) return correct.item() * 1.0 / len(labels)4.2 IMDB电影分类实战# 准备IMDB数据 imdb_meta_paths [[movie, actor, movie], [movie, director, movie]] imdb_model HAN(imdb_meta_paths, 64, 64, 3, 8) # 3个电影类别 # 训练模型 train(imdb_model, imdb_graph, imdb_features, imdb_labels, imdb_train_mask, imdb_val_mask, epochs50)4.3 DBLP学者分类实战# 准备DBLP数据 dblp_meta_paths [[author, paper, author], [author, paper, conference, paper, author]] dblp_model HAN(dblp_meta_paths, 64, 64, 4, 8) # 4个学者领域 # 训练模型 train(dblp_model, dblp_graph, dblp_features, dblp_labels, dblp_train_mask, dblp_val_mask, epochs50)5. 模型优化与生产部署5.1 注意力权重可视化理解模型决策过程对实际应用至关重要我们可以可视化注意力权重def plot_attention_weights(model, g, node_idx): model.eval() with torch.no_grad(): # 获取节点级注意力 node_attentions [] for meta_path in model.meta_paths: subgraph dgl.metapath_reachable_graph(g, meta_path) edge_weights model.layers[0](subgraph, g.ndata[h]).detach() node_attentions.append(edge_weights[node_idx]) # 获取语义级注意力 semantic_weights model.semantic_attention.beta.detach() # 绘制热力图...5.2 模型压缩与加速为满足生产环境需求可以考虑以下优化策略知识蒸馏训练小型学生模型模仿HAN的行为量化训练使用8位整数降低模型存储和计算开销元路径剪枝根据注意力权重剔除不重要的元路径# 量化示例 quantized_model torch.quantization.quantize_dynamic( model, {nn.Linear}, dtypetorch.qint8 )5.3 实际业务迁移建议将HAN应用于新场景时需注意元路径设计原则起点和终点应为目标预测节点类型路径长度通常不超过3跳结合业务语义设计有意义的关系路径特征工程技巧对稀疏特征使用Embedding层对数值特征进行标准化考虑加入节点度等图结构特征模型调优方向注意力头数通常4-8个效果最佳隐藏层维度64-256之间使用早停法防止过拟合在电商推荐场景中可以设计如下元路径用户-商品-用户(UBU)发现相似用户用户-商品-类别-商品(UBCB)基于类别的协同过滤用户-商品-品牌-商品(UBrB)基于品牌的推荐