【圖計算】DGL - 消息傳遞機制
一、消息傳遞
1.1 消息傳遞範式
如下圖所示,綠框爲對應節點的屬性;藍框代表對應邊的屬性。節點的狀態更新需聚合各入邊傳遞來的消息。
消息傳遞範式的形式化表達如下所示:
1.2 內置消息傳遞 API
詳情可訪問:dgl.function - DGL 0.6.1 documentation
消息函數
消息函數:接受一個參數 edges(EdgeBatch 實例),在消息傳遞時,它在 DGL 內部表示一批邊。edges 屬性 src、dst、data 分別表示源節點特徵、目標節點特徵、邊特徵。
比如:要對源節點的 hu 特徵和目標節點的 hv 特徵求和,然後將結果保存在邊的 he 特徵上。解決此問題有兩個方案。
# 方法一:使用DGL的內置消息函數解決
dgl.function.u_add_v('hu', 'hv', 'he')
# 方支二:用戶自定義消息函數
def message_func(edges):
return {'he': edges.src['hu'] + edges.dst['hv']}
聚合函數
聚合函數接受一個 nodes 參數(NodeBatch 實例),在消息傳遞時,在 DGL 內部表示一批節點。nodes 的 mailbox 屬性代表節點收到的消息。
比如:要把節點收到的消息 m 進行 sum 聚合,再把結果賦值給節點的 h 特徵。
# 方法一:應用DGL內置函數解決
dgl.function.sum('m', 'h')
# 方法二:用戶自定義函數
import torch
def reduce_func(nodes):
return {'h': torch.sum(nodes.mailbox['m'], dim=1)}
apply_edges 與 update_all
如果不涉及消息傳遞,可通過 apply_edges() 單獨調用逐邊計算。apply_edges() 參數是一個消息函數,默認情況這個接口將更新所有的邊。比如下述調用:
import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
對於消息傳遞,可考慮使用 update_all()。它的參數爲一個消息函數、一個聚合函數、一個更新函數。更新函數是可選項,DGL 不推薦指定,而在外圍基於張量操作實現。
例如:將源節點特徵 ft 與邊特徵 a 相乘生成消息 m,然後對所有消息求和來更新節點特徵 ft,再將 ft 乘以 2 得到最終結果 final_ft。其形式化表達和實現代碼如下:
def updata_all_example(graph):
# 在graph.ndata['ft']中存儲結果
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
# 在update_all外調用更新函數
final_ft = graph.ndata['ft'] * 2
return final_ft
1.3 邊權重使用方法
有時需要在消息聚合前使用邊的權重,比如 GAT 和一些 GCN 的變種模型。對應的解決方案爲:
-
將權重存爲邊的特徵。
-
在消息函數中用邊的特徵與源節點特徵相乘。如下述代碼中 eweight 被用作邊的權重:
import dgl.function as fn
# 假定eweight是一個形狀爲(E, *)的張量,E是邊的數量。
graph.edata['a'] = eweight
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
1.4 異構圖中消息傳遞
異構圖上的消息傳遞可以分爲兩個部分:1)對每個關係計算和聚合消息 2)對每個結點聚合來自不同關係的消息。在 DGL 中,對異構圖進行消息傳遞的接口爲 multi_update_all(),其有兩個參數:
-
參數 1:字典型,其中鍵代表關係,值是這種關係對應 update_all() 的參數。
-
參數 2:字符串型,用來表示整合不同關係聚合結果的方式。可結合下例進行理解:
import dgl.function as fn
for c_etype in G.canonical_etypes:
srctype, etype, dsttype = c_etype
Wh = self.weight[etype](feat_dict[srctype])
# 把它存在圖中用來做消息傳遞
G.nodes[srctype].data['Wh_%s' % etype] = Wh
# 指定每個關係的消息傳遞函數:(message_func, reduce_func).
# 注意結果保存在同一個目標特徵“h”,說明聚合是逐類進行的。
funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
# 將每個類型消息聚合的結果相加。
G.multi_update_all(funcs, 'sum')
# 返回更新過的節點特徵字典
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
二、SAGEConv 中的消息傳遞
基於上述對消息傳遞機制的解讀,直接來看 SAGEConv 的 forward 方法如下:
"""Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch
from torch import nn
from torch.nn import functional as F
from .... import function as fn
from ....utils import expand_as_pair, check_eq_shape
class SAGEConv(nn.Module):
def __init__(self,
in_feats,
out_feats,
aggregator_type,
feat_drop=0.,
bias=True,
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation
# aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type != 'gcn':
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()
def reset_parameters(self):
r"""
Description
-----------
Reinitialize learnable parameters.
Note
----
The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
The LSTM module is using xavier initialization method for its weights.
"""
gain = nn.init.calculate_gain('relu')
if self._aggre_type == 'pool':
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm':
self.lstm.reset_parameters()
if self._aggre_type != 'gcn':
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
def _lstm_reducer(self, nodes):
"""LSTM reducer
NOTE(zihao): lstm reducer with default schedule (degree bucketing)
is slow, we could accelerate this with degree padding in the future.
"""
m = nodes.mailbox['m'] # (B, L, D)
batch_size = m.shape[0]
h = (m.new_zeros((1, batch_size, self._in_src_feats)),
m.new_zeros((1, batch_size, self._in_src_feats)))
_, (rst, _) = self.lstm(m, h)
return {'neigh': rst.squeeze(0)}
def forward(self, graph, feat, edge_weight=None):
r"""
Description
-----------
Compute GraphSAGE layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, it represents the input feature of shape
:math:`(N, D_{in})`
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
edge_weight : torch.Tensor, optional
Optional tensor on the edge. If given, the convolution will weight
with regard to the message.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
with graph.local_scope():
if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
aggregate_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
h_self = feat_dst
# Handle the case of graphs without edges
if graph.number_of_edges() == 0:
graph.dstdata['neigh'] = torch.zeros(
feat_dst.shape[0], self._in_src_feats).to(feat_dst) # @@將張量的dtype,device轉換成與參數feat_dst一致
if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(aggregate_fn, fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().to(feat_dst) # @@返回一維張量
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1) # @@擴展爲2維張量,維度廣播加1
elif self._aggre_type == 'pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(aggregate_fn, fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'lstm':
graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, self._lstm_reducer)
h_neigh = graph.dstdata['neigh']
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh)
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
# activation
if self.activation is not None:
rst = self.activation(rst)
# normalization
if self.norm is not None:
rst = self.norm(rst)
return rst
參考資料
DGL 官方文檔:User Guide - DGL 0.6.1 documentation
https://docs.dgl.ai/guide/index.html
本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源:https://mp.weixin.qq.com/s/Szzot9MeZSGgayK8YTmv7w