終於有人總結了圖神經網絡!

作者:yyHaker,編輯:極市平臺

來源丨 https://zhuanlan.zhihu.com/p/136521625

導讀

 

本文從一個更直觀的角度對當前經典流行的 GNN 網絡,包括 GCN、GraphSAGE、GAT、GAE 以及 graph pooling 策略 DiffPool 等等做一個簡單的小結。 

近年來,深度學習領域關於圖神經網絡(Graph Neural Networks,GNN)的研究熱情日益高漲,圖神經網絡已經成爲各大深度學習頂會的研究熱點。GNN 處理非結構化數據時的出色能力使其在網絡數據分析、推薦系統、物理建模、自然語言處理和圖上的組合優化問題方面都取得了新的突破。

圖神經網絡有很多比較好的綜述 [1][2][3] 可以參考,更多的論文可以參考清華大學整理的 GNN paper list[4] 。

本篇文章將從一個更直觀的角度對當前經典流行的 GNN 網絡,包括 GCN、GraphSAGE、GAT、GAE 以及 graph pooling 策略 DiffPool 等等做一個簡單的小結。

筆者注:行文如有錯誤或者表述不當之處,還望批評指正!

一、爲什麼需要圖神經網絡?

隨着機器學習、深度學習的發展,語音、圖像、自然語言處理逐漸取得了很大的突破,然而語音、圖像、文本都是很簡單的序列或者網格數據,是很結構化的數據,深度學習很善於處理該種類型的數據(圖 1)。

圖 1

然而現實世界中並不是所有的事物都可以表示成一個序列或者一個網格,例如社交網絡、知識圖譜、複雜的文件系統等(圖 2),也就是說很多事物都是非結構化的。

圖 2

相比於簡單的文本和圖像,這種網絡類型的非結構化的數據非常複雜,處理它的難點包括:

  1. 圖的大小是任意的,圖的拓撲結構複雜,沒有像圖像一樣的空間局部性

  2. 圖沒有固定的節點順序,或者說沒有一個參考節點

  3. 圖經常是動態圖,而且包含多模態的特徵

那麼對於這類數據我們該如何建模呢?能否將深度學習進行擴展使得能夠建模該類數據呢?這些問題促使了圖神經網絡的出現與發展。

二. 圖神經網絡是什麼樣子的?

相比較於神經網絡最基本的網絡結構全連接層(MLP),特徵矩陣乘以權重矩陣,圖神經網絡多了一個鄰接矩陣。計算形式很簡單,三個矩陣相乘再加上一個非線性變換(圖 3)。

圖 3

因此一個比較常見的圖神經網絡的應用模式如下圖(圖 4),輸入是一個圖,經過多層圖卷積等各種操作以及激活函數,最終得到各個節點的表示,以便於進行節點分類、鏈接預測、圖與子圖的生成等等任務。

圖 4

上面是一個對圖神經網絡比較簡單直觀的感受與理解,實際其背後的原理邏輯還是比較複雜的,這個後面再慢慢細說,接下來將以幾個經典的 GNN models 爲線來介紹圖神經網絡的發展歷程。

三、圖神經網絡的幾個經典模型與發展

1 . Graph Convolution Networks(GCN)[5]

GCN 可謂是圖神經網絡的 “開山之作”,它首次將圖像處理中的卷積操作簡單的用到圖結構數據處理中來,並且給出了具體的推導,這裏面涉及到複雜的譜圖理論,具體推到可以參考 [6][7]。推導過程還是比較複雜的,然而最後的結果卻非常簡單 ( 圖 5)。

圖 5

我們來看一下這個式子,天吶,這不就是聚合鄰居節點的特徵然後做一個線性變換嗎?沒錯,確實是這樣,同時爲了使得 GCN 能夠捕捉到 K-hop 的鄰居節點的信息,作者還堆疊多層 GCN layers,如堆疊 K 層有:

上述式子還可以使用矩陣形式表示如下,

其中 是歸一化之後的鄰接矩陣, 相當於給 層的所有節點的 embedding 做了一次線性變換,左乘以鄰接矩陣表示對每個節點來說,該節點的特徵表示爲鄰居節點特徵相加之後的結果。(注意將 換成矩陣 就是圖 3 所說的三矩陣相乘)

那麼 GCN 的效果如何呢?作者將 GCN 放到節點分類任務上,分別在 Citeseer、Cora、Pubmed、NELL 等數據集上進行實驗,相比於傳統方法提升還是很顯著的,這很有可能是得益於 GCN 善於編碼圖的結構信息,能夠學習到更好的節點表示。

圖 6

當然,其實 GCN 的缺點也是很顯然易見的,第一,GCN 需要將整個圖放到內存和顯存,這將非常耗內存和顯存,處理不了大圖;第二,GCN 在訓練時需要知道整個圖的結構信息 (包括待預測的節點), 這在現實某些任務中也不能實現 (比如用今天訓練的圖模型預測明天的數據,那麼明天的節點是拿不到的)。

2. Graph Sample and Aggregate(GraphSAGE)[8]

爲了解決 GCN 的兩個缺點問題,GraphSAGE 被提了出來。在介紹 GraphSAGE 之前,先介紹一下 Inductive learning 和 Transductive learning。注意到圖數據和其他類型數據的不同,圖數據中的每一個節點可以通過邊的關係利用其他節點的信息。這就導致一個問題,GCN 輸入了整個圖,訓練節點收集鄰居節點信息的時候,用到了測試和驗證集的樣本,我們把這個稱爲 Transductive learning。然而,我們所處理的大多數的機器學習問題都是 Inductive learning,因爲我們刻意的將樣本集分爲訓練 / 驗證 / 測試,並且訓練的時候只用訓練樣本。這樣對圖來說有個好處,可以處理圖中新來的節點,可以利用已知節點的信息爲未知節點生成 embedding,GraphSAGE 就是這麼幹的。

GraphSAGE 是一個 Inductive Learning 框架,具體實現中,訓練時它僅僅保留訓練樣本到訓練樣本的邊,然後包含 Sample 和 Aggregate 兩大步驟,Sample 是指如何對鄰居的個數進行採樣,Aggregate 是指拿到鄰居節點的 embedding 之後如何匯聚這些 embedding 以更新自己的 embedding 信息。下圖展示了 GraphSAGE 學習的一個過程,

圖 7

第一步,對鄰居採樣

第二步,採樣後的鄰居 embedding 傳到節點上來,並使用一個聚合函數聚合這些鄰居信息以更新節點的 embedding

第三步,根據更新後的 embedding 預測節點的標籤

接下來,我們詳細的說明一個訓練好的 GrpahSAGE 是如何給一個新的節點生成 embedding 的(即一個前向傳播的過程),如下算法圖:

首先,(line1) 算法首先初始化輸入的圖中所有節點的特徵向量,(line3) 對於每個節點 ,拿到它採樣後的鄰居節點 後,(line4) 利用聚合函數聚合鄰居節點的信息,(line5) 並結合自身 embedding 通過一個非線性變換更新自身的 embedding 表示。

注意到算法裏面的 ,它是指聚合器的數量,也是指權重矩陣的數量,還是網絡的層數,這是因爲每一層網絡中聚合器和權重矩陣是共享的。網絡的層數可以理解爲需要最大訪問的鄰居的跳數 (hops),比如在圖 7 中,紅色節點的更新拿到了它一、二跳鄰居的信息,那麼網絡層數就是 2。爲了更新紅色節點,首先在第一層(k=1),我們會將藍色節點的信息聚合到紅色解節點上,將綠色節點的信息聚合到藍色節點上。在第二層(k=2) 紅色節點的 embedding 被再次更新,不過這次用到的是更新後的藍色節點 embedding,這樣就保證了紅色節點更新後的 embedding 包括藍色和綠色節點的信息,也就是兩跳信息。

爲了看的更清晰,我們將更新某個節點的過程展開來看,如圖 8 分別爲更新節點 A 和更新節點 B 的過程,可以看到更新不同的節點過程每一層網絡中聚合器和權重矩陣都是共享的。

圖 8

那麼 GraphSAGE Sample 是怎麼做的呢?GraphSAGE 是採用定長抽樣的方法,具體來說,定義需要的鄰居個數 ,然後採用有放回的重採樣 / 負採樣方法達到 。保證每個節點(採樣後的)鄰居個數一致,這樣是爲了把多個節點以及它們的鄰居拼接成 Tensor 送到 GPU 中進行批訓練。

那麼 GraphSAGE 有哪些聚合器呢?主要有三個,

這裏說明的一點是 Mean Aggregator 和 GCN 的做法基本是一致的(GCN 實際上是求和)。

到此爲止,整個模型的架構就講完了,那麼 GraphSAGE 是如何學習聚合器的參數以及權重矩陣 呢?如果是有監督的情況下,可以使用每個節點的預測 lable 和真實 lable 的交叉熵作爲損失函數。如果是在無監督的情況下,可以假設相鄰的節點的 embedding 表示儘可能相近,因此可以設計出如下的損失函數,

那麼 GrpahSAGE 的實際實驗效果如何呢?作者在 Citation、Reddit、PPI 數據集上分別給出了無監督和完全有監督的結果,相比於傳統方法提升還是很明顯。

至此,GraphSAGE 介紹完畢。我們來總結一下,GraphSAGE 的一些優點,

(1)利用採樣機制,很好的解決了 GCN 必須要知道全部圖的信息問題,克服了 GCN 訓練時內存和顯存的限制,即使對於未知的新節點,也能得到其表示

(2)聚合器和權重矩陣的參數對於所有的節點是共享的

(3)模型的參數的數量與圖的節點個數無關,這使得 GraphSAGE 能夠處理更大的圖

(4)既能處理有監督任務也能處理無監督任務

(就喜歡這樣解決了問題,方法又簡潔,效果還好的 idea!!!)

當然,GraphSAGE 也有一些缺點,每個節點那麼多鄰居,GraphSAGE 的採樣沒有考慮到不同鄰居節點的重要性不同,而且聚合計算的時候鄰居節點的重要性和當前節點也是不同的。

3. Graph Attention Networks(GAT)[9]

爲了解決 GNN 聚合鄰居節點的時候沒有考慮到不同的鄰居節點重要性不同的問題,GAT 借鑑了 Transformer 的 idea,引入 masked self-attention 機制,在計算圖中的每個節點的表示的時候,會根據鄰居節點特徵的不同來爲其分配不同的權值。

具體的,對於輸入的圖,一個 graph attention layer 如圖 9 所示,

圖 9

其中 採用了單層的前饋神經網絡實現,計算過程如下(注意權重矩陣 對於所有的節點是共享的):

計算完 attention 之後,就可以得到某個節點聚合其鄰居節點信息的新的表示,計算過程如下:

爲了提高模型的擬合能力,還引入了多頭的 self-attention 機制,即同時使用多個 計算 self-attention,然後將計算的結果合併(連接或者求和):

此外,由於 GAT 結構的特性,GAT 無需使用預先構建好的圖,因此 GAT 既適用於 Transductive Learning,又適用於 Inductive Learning。

那麼 GAT 的具體效果如何呢?作者分別在三個 Transductive Learning 和一個 Inductive Learning 任務上進行實驗,實驗結果如下:

無論是在 Transductive Learning 還是在 Inductive Learning 的任務上,GAT 的效果都要優於傳統方法的結果。

至此,GAT 的介紹完畢,我們來總結一下,GAT 的一些優點,

(1)訓練 GCN 無需瞭解整個圖結構,只需知道每個節點的鄰居節點即可

(2)計算速度快,可以在不同的節點上進行並行計算

(3)既可以用於 Transductive Learning,又可以用於 Inductive Learning,可以對未見過的圖結構進行處理

(仍然是簡單的 idea,解決了問題,效果還好!!!)

到此,我們就介紹完了 GNN 中最經典的幾個模型 GCN、GraphSAGE、GAT,接下來我們將針對具體的任務類別來介紹一些流行的 GNN 模型與方法。

四、無監督的節點表示學習(Unsupervised Node Representation)

由於標註數據的成本非常高,如果能夠利用無監督的方法很好的學習到節點的表示,將會有巨大的價值和意義,例如找到相同興趣的社區、發現大規模的圖中有趣的結構等等。

圖 10

這其中比較經典的模型有 GraphSAGE、Graph Auto-Encoder(GAE)等,GraphSAGE 就是一種很好的無監督表示學習的方法,前面已經介紹了,這裏就不贅述,接下來將詳細講解後面兩個。

  1. Graph Auto-Encoder(GAE)[10]

在介紹 Graph Auto-Encoder 之前,需要先了解**自編碼器 (Auto-Encoder)、變分自編碼器 (Variational Auto-Encoder),**具體可以參考 [11],這裏就不贅述。

理解了自編碼器之後,再來理解變分圖的自編碼器就容易多了。如圖 11 輸入圖的鄰接矩陣和節點的特徵矩陣,通過編碼器(圖卷積網絡)學習節點低維向量表示的均值μ和方差σ,然後用解碼器(鏈路預測)生成圖。

圖 11

編碼器(Encoder)採用簡單的兩層 GCN 網絡,解碼器(Encoder)計算兩點之間存在邊的概率來重構圖,損失函數包括生成圖和原始圖之間的距離度量,以及節點表示向量分佈和正態分佈的 KL - 散度兩部分。具體公式如圖 12 所示:

圖 12

另外爲了做比較,作者還提出了圖自編碼器 (Graph Auto-Encoder),相比於變分圖的自編碼器,圖自編碼器就簡單多了,Encoder 是兩層 GCN,Loss 只包含 Reconstruction Loss。

那麼兩種圖自編碼器的效果如何呢?作者分別在 Cora、Citeseer、Pubmed 數據集上做 Link prediction 任務,實驗結果如下表,圖自編碼器(GAE)和變分圖自編碼器(VGAE)效果普遍優於傳統方法,而且變分圖自編碼器的效果更好;當然,Pumed 上 GAE 得到了最佳結果。可能是因爲 Pumed 網絡較大,在 VGAE 比 GAE 模型複雜,所以更難調參。

五、Graph Pooling

Graph pooling 是 GNN 中很流行的一種操作,目的是爲了獲取一整個圖的表示,主要用於處理圖級別的分類任務,例如在有監督的圖分類、文檔分類等等。

圖 13

Graph pooling 的方法有很多,如簡單的 max pooling 和 mean pooling,然而這兩種 pooling 不高效而且忽視了節點的順序信息;這裏介紹一種方法:Differentiable Pooling (DiffPool)。

1.DiffPool[12]

在圖級別的任務當中,當前的很多方法是將所有的節點嵌入進行全局池化,忽略了圖中可能存在的任何層級結構,這對於圖的分類任務來說尤其成問題,因爲其目標是預測整個圖的標籤。針對這個問題,斯坦福大學團隊提出了一個用於圖分類的可微池化操作模塊——DiffPool,可以生成圖的層級表示,並且可以以端到端的方式被各種圖神經網絡整合。

DiffPool 的核心思想是通過一個可微池化操作模塊去分層的聚合圖節點,具體的,這個可微池化操作模塊基於 GNN 上一層生成的節點嵌入 以及分配矩陣 ,以端到端的方式分配給下一層的簇,然後將這些簇輸入到 GNN 下一層,進而實現用分層的方式堆疊多個 GNN 層的想法。(圖 14)

圖 14

那麼這個節點嵌入和分配矩陣是怎麼算的?計算完之後又是怎麼分配給下一層的?這裏就涉及到兩部分內容,一個是分配矩陣的學習,一個是池化分配矩陣

這裏使用兩個分開的 GNN 來生成分配矩陣 和每一個簇節點新的嵌入 ,這兩個 GNN 都是用簇節點特徵矩陣 和粗化鄰接矩陣 作爲輸入,

計算得到分配矩陣 和每一個簇節點新的嵌入 之後,DiffPool 層根據分配矩陣,對於圖中的每個節點 / 簇生成一個新的粗化的鄰接矩陣 與新的嵌入矩陣 ,

總的來看,每層的 DiffPool 其實就是更新每一個簇節點的嵌入和簇節點的特徵矩陣,如下公式:

至此,DiffPool 的基本思想就講完了。那麼效果如何呢?作者在多種圖分類的基準數據集上進行實驗,如蛋白質數據集(ENZYMES,PROTEINS,D&D),社交網絡數據集(REDDIT-MULTI-12K),科研合作數據集(COLLAB),實驗結果如下:

其中,GraphSAGE 是採用全局平均池化;DiffPool-DET 是一種 DiffPool 變體,使用確定性圖聚類算法生成分配矩陣;DiffPool-NOLPDiffPool 的變體,取消了鏈接預測目標部分。總的來說,DiffPool 方法在 GNN 的所有池化方法中獲得最高的平均性能。

爲了更好的證明 DiffPool 對於圖分類十分有效,論文還使用了其他 GNN 體系結構(Structure2Vec(s2v)),並且構造兩個變體,進行對比實驗,如下表:

可以看到 DiffPool 的顯著改善了 S2V 在 ENZYMES 和 D&D 數據集上的性能。

而且 DiffPool 可以自動的學習到恰當的簇的數量。

至此,我們來總結一下 DiffPool 的優點,

(1)可以學習層次化的 pooling 策略

(2)可以學習到圖的層次化表示

(3)可以以端到端的方式被各種圖神經網絡整合

然而,注意到,DiffPool 也有其侷限性,分配矩陣需要很大的空間去存儲,空間複雜度爲 , 爲池化層的層數,所以無法處理很大的圖。

參考

重點總結

1.GCN 的缺點也是很顯然易見的,第一,GCN 需要將整個圖放到內存和顯存,這將非常耗內存和顯存,處理不了大圖;第二,GCN 在訓練時需要知道整個圖的結構信息 (包括待預測的節點)

2.GraphSAGE 的優點:

(1)利用採樣機制,很好的解決了 GCN 必須要知道全部圖的信息問題,克服了 GCN 訓練時內存和顯存的限制,即使對於未知的新節點,也能得到其表示

(2)聚合器和權重矩陣的參數對於所有的節點是共享的

(3)模型的參數的數量與圖的節點個數無關,這使得 GraphSAGE 能夠處理更大的圖

(4)既能處理有監督任務也能處理無監督任務

3.GAT 的優點:

(1)訓練 GCN 無需瞭解整個圖結構,只需知道每個節點的鄰居節點即可

(2)計算速度快,可以在不同的節點上進行並行計算

(3)既可以用於 Transductive Learning,又可以用於 Inductive Learning,可以對未見過的圖結構進行處理


本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/hK1duUWomCt6bkXasn9bZg