0
点赞
收藏
分享

微信扫一扫

图神经网络 torch_geometric 库的 MessagePassing 运行机制学习

小_北_爸 2022-03-17 阅读 185

torch_geometric.nn.conv. MessagePassing( )
继承这个类,可以自定义节点信息传播机制

例子

import torch
from torch_geometric.nn.conv import MessagePassing

class GCNConv(MessagePassing):
    def __init__(self):
    	#选择相加的方式进行邻居节点信息聚合
        super().__init__(aggr='add')

    def forward(self, x, edge_index):
    	#给图添加自环
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        print(edge_index)
        out = self.propagate(edge_index, x=x)
        print(out)

    def message(self, x_j):
        print(x_j)
        return x_j




edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [6], [1]], dtype=torch.float)
edge_index = edge_index.permute(1, 0)
model = GCNConv(1, 3)
out = model(x, edge_index)

运行截图
在这里插入图片描述
MessagePassing的运行机制就是用行坐标[0,1,1,2,0,1,2]计算每个节点要汇聚的feature,然后用列坐标[1,0,2,1,0,1,2]进行 add 聚合信息,x_j其实就是根据行坐标得来的,行坐标里面的每一个元素其实就是一个节点标号,它告诉我们当前聚合信息时,每一个节点的信息应该是怎么样,在这里我没有转换节点feature,直接就是初始feature进行聚合,然后列坐标的元素进行聚合,如:列坐标中0节点与行节点对应的元素为1,0,所以在x_j对应位置找到元素6,-1然后相加得5,同理,1节点为-1+1+6 =6,2节点为6+1=7;
需要注意的是def message(self, x_j)中x_j的参数名字不能随便改变,不然会出错;其实x_j可以变为x_i,x_i代表以列坐标[0,1,1,2,0,1,2]计算每个节点要汇聚的feature,但仍以列坐标[0,1,1,2,0,1,2]汇聚信息;最后得到的结果如下:
在这里插入图片描述
也可以def message(self, x_j,x_i) ,其中x_j,x_i同时返回,可以根据具体应用进行灵活操作;

举报

相关推荐

图神经网络

0 条评论