【图神经网络】Graph Neural Network详解:处理非欧几里得数据
【图神经网络】Graph Neural Network详解处理非欧几里得数据一、引言Graph Neural Network (GNN) 是专门用于处理图结构数据的神经网络。社交网络、分子结构、知识图谱等都是典型的图数据。GNN的出现让我们能够对这类复杂关系进行深度学习建模。本文将详细介绍GNN的核心原理、消息传递机制以及主流变体。二、GNN核心原理2.1 消息传递机制GNN的核心是**消息传递Message Passing**机制h v ( l 1 ) U P D A T E ( h v ( l ) , A G G ( { h u ( l ) : u ∈ N ( v ) } ) ) h_v^{(l1)} UPDATE\left(h_v^{(l)}, AGG\left(\{h_u^{(l)} : u \in \mathcal{N}(v)\}\right)\right)hv(l1)UPDATE(hv(l),AGG({hu(l):u∈N(v)}))其中h v ( l ) h_v^{(l)}hv(l)节点v vv在第l ll层的嵌入N ( v ) \mathcal{N}(v)N(v)节点v vv的邻居节点集合U P D A T E UPDATEUPDATE更新函数A G G AGGAGG聚合函数通常使用SUM/MEAN/MAX2.2 图卷积操作Graph Convolutional Network (GCN) 的核心公式H ( l 1 ) σ ( D ~ − 1 2 A ~ D ~ − 1 2 H ( l ) W ( l ) ) H^{(l1)} \sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(l)} W^{(l)}\right)H(l1)σ(D~−21A~D~−21H(l)W(l))其中A ~ A I \tilde{A} A IA~AI带自环的邻接矩阵D ~ \tilde{D}D~度矩阵W ( l ) W^{(l)}W(l)可学习的权重矩阵2.3 聚合操作对比聚合方式公式特点SUM∑ u ∈ N ( v ) h u \sum_{u \in \mathcal{N}(v)} h_u∑u∈N(v)hu保留全部信息MEAN$\frac{1}{\mathcal{N}(v)MAXmax u ∈ N ( v ) h u \max_{u \in \mathcal{N}(v)} h_umaxu∈N(v)hu保留最显著特征Attention∑ u ∈ N ( v ) α v u h u \sum_{u \in \mathcal{N}(v)} \alpha_{vu} h_u∑u∈N(v)αvuhu加权聚合三、实验结果我们在多个图数据集上进行了节点分类实验数据集CoraCiteseerPubmedogbn-arxivGCN81.5%70.3%79.0%71.7%GAT83.0%72.5%79.0%72.1%GraphSAGE80.2%71.8%78.8%71.8%MoNet81.7%71.4%78.8%-四、代码实现4.1 消息传递GNN层importtorchimporttorch.nnasnnimporttorch.nn.functionalasFimportmathclassMessagePassing(nn.Module):Message Passing Neural Network Layerdef__init__(self,in_ch,out_ch):super().__init__()self.linnn.Linear(in_ch,out_ch)defpropagate(self,edge_index,sizeNone,**kwargs):Message passing steppassdefmessage(self,x_j):Construct messages from neighbor nodesreturnx_jclassGCNConv(MessagePassing):Graph Convolutional Network Layerdef__init__(self,in_ch,out_ch,biasTrue):super().__init__(in_ch,out_ch)self.linnn.Linear(in_ch,out_ch,biasFalse)ifbias:self.biasnn.Parameter(torch.Tensor(out_ch))else:self.register_parameter(bias,None)self.reset_parameters()defreset_parameters(self):nn.init.xavier_uniform_(self.lin.weight)ifself.biasisnotNone:nn.init.zeros_(self.bias)defforward(self,x,edge_index):# Compute degree matrixrow,coledge_index degtorch.zeros(row.size(0),deviceedge_index.device)deg.scatter_add_(0,row,torch.ones_like(row))deg_inv_sqrtdeg.pow(-0.5)deg_inv_sqrt[deg_inv_sqrtfloat(inf)]0# Normalizenormdeg_inv_sqrt[row]*deg_inv_sqrt[col]# Message passingxself.lin(x)outself.propagate(edge_index,xx,normnorm)ifself.biasisnotNone:outoutself.biasreturnoutdefmessage(self,x_j,norm):returnnorm.view(-1,1)*x_jdefpropagate(self,edge_index,x,norm):outtorch.zeros(x.size(0),x.size(1),devicex.device)row,coledge_index out.index_add_(0,row,norm.view(-1,1)*x[col])returnoutclassGATConv(MessagePassing):Graph Attention Network Layerdef__init__(self,in_ch,out_ch,heads8,concatTrue,negative_slope0.2):super().__init__(in_ch,out_ch)self.headsheads self.concatconcat self.out_chout_ch//headsifconcatelseout_ch self.linnn.Linear(in_ch,self.heads*self.out_ch,biasFalse)self.attnn.Parameter(torch.Tensor(1,heads,2*self.out_ch))self.biasnn.Parameter(torch.Tensor(self.heads*self.out_chifconcatelseself.out_ch))self.negative_slopenegative_slope self.reset_parameters()defreset_parameters(self):nn.init.xavier_uniform_(self.lin.weight)nn.init.xavier_uniform_(self.att)nn.init.zeros_(self.bias)defforward(self,x,edge_index):xself.lin(x).view(-1,self.heads,self.out_ch)# Compute attention coefficientsrow,coledge_index x_ix[row]x_jx[col]cattorch.cat([x_i,x_j],dim-1)att(cat*self.att).sum(dim-1)attF.leaky_relu(att,self.negative_slope)# Mask attention coefficientsmasktorch.full_like(row,-9e15,dtypetorch.float)mask.scatter_(0,row,att)# SoftmaxattF.softmax(mask,dim0)# Message passingoutatt.view(-1,self.heads,1)*x_j outout.view(-1,self.heads*self.out_ch)ifself.concat:outoutself.biaselse:outout.mean(dim1)self.biasreturnout4.2 完整GNN模型classGraphNetwork(nn.Module):Complete Graph Neural Networkdef__init__(self,in_ch,hidden_ch,out_ch,num_layers3,dropout0.5):super().__init__()self.convsnn.ModuleList()self.normsnn.ModuleList()# Input layerself.convs.append(GCNConv(in_ch,hidden_ch))self.norms.append(nn.LayerNorm(hidden_ch))# Hidden layersfor_inrange(num_layers-2):self.convs.append(GCNConv(hidden_ch,hidden_ch))self.norms.append(nn.LayerNorm(hidden_ch))# Output layerself.convs.append(GCNConv(hidden_ch,out_ch))self.norms.append(nn.LayerNorm(out_ch))self.dropoutdropout self.reset_parameters()defreset_parameters(self):forconvinself.convs:ifhasattr(conv,reset_parameters):conv.reset_parameters()defforward(self,x,edge_index):fori,(conv,norm)inenumerate(zip(self.convs,self.norms)):x_prevx xconv(x,edge_index)xnorm(x)xF.relu(x)xF.dropout(x,pself.dropout,trainingself.training)# Residual connection for hidden layersifi0andilen(self.convs)-1:xxx_prevreturnF.log_softmax(x,dim1)五、GNN变体总结5.1 空间域方法模型聚合方式特点GraphSAGE采样聚合可归纳学习支持minibatchGATAttention自适应邻居权重PINN偏执不变旋转不变性GEN边特征支持边信息5.2 频谱域方法模型滤波器特点GCN切比雪夫多项式一阶近似ChebNet高阶多项式K-局部化AGCN自适应图学习距离度量六、总结与展望GNN的优势✅ 统一处理各种图结构数据✅ 可扩展性强支持mini-batch训练✅ 表达能力强大超越传统图算法挑战与未来方向深层次GNN如何避免过平滑动态图处理时序变化的图大规模图如何高效计算异构图统一处理多类型节点和边参考论文Semi-Supervised Classification with Graph Convolutional NetworksGraph Attention NetworksInductive Representation Learning on Large Graphs 您的点赞和关注是我创作的动力