【圖計算】DGL - 消息傳遞機制

一、消息傳遞

1.1 消息傳遞範式

如下圖所示,綠框爲對應節點的屬性;藍框代表對應邊的屬性。節點的狀態更新需聚合各入邊傳遞來的消息。

消息傳遞範式的形式化表達如下所示:

1.2 內置消息傳遞 API

詳情可訪問:dgl.function - DGL 0.6.1 documentation

消息函數

消息函數:接受一個參數 edges(EdgeBatch 實例),在消息傳遞時,它在 DGL 內部表示一批邊。edges 屬性 src、dst、data 分別表示源節點特徵、目標節點特徵、邊特徵。

比如:要對源節點的 hu 特徵和目標節點的 hv 特徵求和,然後將結果保存在邊的 he 特徵上。解決此問題有兩個方案。

# 方法一:使用DGL的內置消息函數解決
dgl.function.u_add_v('hu', 'hv', 'he')
 
# 方支二:用戶自定義消息函數
def message_func(edges):
     return {'he': edges.src['hu'] + edges.dst['hv']}

聚合函數

聚合函數接受一個 nodes 參數(NodeBatch 實例),在消息傳遞時,在 DGL 內部表示一批節點。nodes 的 mailbox 屬性代表節點收到的消息。

比如:要把節點收到的消息 m 進行 sum 聚合,再把結果賦值給節點的 h 特徵。

# 方法一:應用DGL內置函數解決
dgl.function.sum('m', 'h')

# 方法二:用戶自定義函數
import torch
def reduce_func(nodes):
     return {'h': torch.sum(nodes.mailbox['m'], dim=1)}

apply_edges 與 update_all

如果不涉及消息傳遞,可通過 apply_edges() 單獨調用逐邊計算。apply_edges() 參數是一個消息函數,默認情況這個接口將更新所有的邊。比如下述調用:

import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))

對於消息傳遞,可考慮使用 update_all()。它的參數爲一個消息函數、一個聚合函數、一個更新函數。更新函數是可選項,DGL 不推薦指定,而在外圍基於張量操作實現。

例如:將源節點特徵 ft 與邊特徵 a 相乘生成消息 m,然後對所有消息求和來更新節點特徵 ft,再將 ft 乘以 2 得到最終結果 final_ft。其形式化表達和實現代碼如下:

def updata_all_example(graph):
    # 在graph.ndata['ft']中存儲結果
    graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                     fn.sum('m', 'ft'))
    # 在update_all外調用更新函數
    final_ft = graph.ndata['ft'] * 2
    return final_ft

1.3 邊權重使用方法

有時需要在消息聚合前使用邊的權重,比如 GAT 和一些 GCN 的變種模型。對應的解決方案爲:

  1. 將權重存爲邊的特徵。

  2. 在消息函數中用邊的特徵與源節點特徵相乘。如下述代碼中 eweight 被用作邊的權重:

import dgl.function as fn
 
# 假定eweight是一個形狀爲(E, *)的張量,E是邊的數量。
graph.edata['a'] = eweight
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))

1.4 異構圖中消息傳遞

異構圖上的消息傳遞可以分爲兩個部分:1)對每個關係計算和聚合消息 2)對每個結點聚合來自不同關係的消息。在 DGL 中,對異構圖進行消息傳遞的接口爲 multi_update_all(),其有兩個參數:

import dgl.function as fn
 
for c_etype in G.canonical_etypes:
    srctype, etype, dsttype = c_etype
    Wh = self.weight[etype](feat_dict[srctype])
    # 把它存在圖中用來做消息傳遞
    G.nodes[srctype].data['Wh_%s' % etype] = Wh
    # 指定每個關係的消息傳遞函數:(message_func, reduce_func).
    # 注意結果保存在同一個目標特徵“h”,說明聚合是逐類進行的。
    funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
# 將每個類型消息聚合的結果相加。
G.multi_update_all(funcs, 'sum')
# 返回更新過的節點特徵字典
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}

二、SAGEConv 中的消息傳遞

基於上述對消息傳遞機制的解讀,直接來看 SAGEConv 的 forward 方法如下:

"""Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch
from torch import nn
from torch.nn import functional as F
 
from .... import function as fn
from ....utils import expand_as_pair, check_eq_shape
 
 
class SAGEConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 feat_drop=0.,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()
 
        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.feat_drop = nn.Dropout(feat_drop)
        self.activation = activation
        # aggregator type: mean/pool/lstm/gcn
        if aggregator_type == 'pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type != 'gcn':
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.reset_parameters()
 
    def reset_parameters(self):
        r"""
 
        Description
        -----------
        Reinitialize learnable parameters.
 
        Note
        ----
        The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
        The LSTM module is using xavier initialization method for its weights.
        """
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
 
    def _lstm_reducer(self, nodes):
        """LSTM reducer
        NOTE(zihao): lstm reducer with default schedule (degree bucketing)
        is slow, we could accelerate this with degree padding in the future.
        """
        m = nodes.mailbox['m'] # (B, L, D)
        batch_size = m.shape[0]
        h = (m.new_zeros((1, batch_size, self._in_src_feats)),
             m.new_zeros((1, batch_size, self._in_src_feats)))
        _, (rst, _) = self.lstm(m, h)
        return {'neigh': rst.squeeze(0)}
 
    def forward(self, graph, feat, edge_weight=None):
        r"""
 
        Description
        -----------
        Compute GraphSAGE layer.
 
        Parameters
        ----------
        graph : DGLGraph
            The graph.
        feat : torch.Tensor or pair of torch.Tensor
            If a torch.Tensor is given, it represents the input feature of shape
            :math:`(N, D_{in})`
            where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
            If a pair of torch.Tensor is given, the pair must contain two tensors of shape
            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
        edge_weight : torch.Tensor, optional
            Optional tensor on the edge. If given, the convolution will weight
            with regard to the message.
 
        Returns
        -------
        torch.Tensor
            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
            is size of output feature.
        """
        with graph.local_scope():
            if isinstance(feat, tuple):
                feat_src = self.feat_drop(feat[0])
                feat_dst = self.feat_drop(feat[1])
            else:
                feat_src = feat_dst = self.feat_drop(feat)
                if graph.is_block:
                    feat_dst = feat_src[:graph.number_of_dst_nodes()]
            aggregate_fn = fn.copy_src('h', 'm')
            if edge_weight is not None:
                assert edge_weight.shape[0] == graph.number_of_edges()
                graph.edata['_edge_weight'] = edge_weight
                aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
 
            h_self = feat_dst
 
            # Handle the case of graphs without edges
            if graph.number_of_edges() == 0:
                graph.dstdata['neigh'] = torch.zeros(
                    feat_dst.shape[0], self._in_src_feats).to(feat_dst)    # @@將張量的dtype,device轉換成與參數feat_dst一致
 
            if self._aggre_type == 'mean':
                graph.srcdata['h'] = feat_src
                graph.update_all(aggregate_fn, fn.mean('m', 'neigh'))
                h_neigh = graph.dstdata['neigh']
            elif self._aggre_type == 'gcn':
                check_eq_shape(feat)
                graph.srcdata['h'] = feat_src
                graph.dstdata['h'] = feat_dst     # same as above if homogeneous
                graph.update_all(aggregate_fn, fn.sum('m', 'neigh'))
                # divide in_degrees
                degs = graph.in_degrees().to(feat_dst)                   # @@返回一維張量
                h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)  # @@擴展爲2維張量,維度廣播加1
            elif self._aggre_type == 'pool':
                graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
                graph.update_all(aggregate_fn, fn.max('m', 'neigh'))
                h_neigh = graph.dstdata['neigh']
            elif self._aggre_type == 'lstm':
                graph.srcdata['h'] = feat_src
                graph.update_all(aggregate_fn, self._lstm_reducer)
                h_neigh = graph.dstdata['neigh']
            else:
                raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
 
            # GraphSAGE GCN does not require fc_self.
            if self._aggre_type == 'gcn':
                rst = self.fc_neigh(h_neigh)
            else:
                rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
            # activation
            if self.activation is not None:
                rst = self.activation(rst)
            # normalization
            if self.norm is not None:
                rst = self.norm(rst)
            return rst

參考資料

DGL 官方文檔:User Guide - DGL 0.6.1 documentation

https://docs.dgl.ai/guide/index.html

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/Szzot9MeZSGgayK8YTmv7w