JK-Net:殘差連結增加圖神經網路深度

今天學習的是 MIT 同學 2018 年的論文《Representation Learning on Graphs with Jumping Knowledge Networks》,發表於 ICML,目前共有 140 多次引用。

目前的圖表示學習都遵循著領域聚合的方式,但這種方式的層數無法增加,kipf 的 GCN 使用了兩層模型,隨著深度增加會出現 over-smooth 的問題,導致效能下降。

為了更好的學習鄰居的結合和屬性,作者提出了一種叫跳躍知識的網路(Jumping Knowledge Networks)架構,並在諸多資料集中取得了 SOTA 的成績。

此外,JK 架構可以與現有的卷積網路(如 GraphSAGE、GAT 等)模型相結合,可以用於改善這些模型的效能。

1。Introduction

目前基於聚合方式的 GCN 最好的效能是 2 層,更深的層數會降低模型效能。在計算機視覺中,殘差連線可以解決類似的學習能力退化的問題,並且極大的幫助了深度模型的訓練。但是即使使用了殘差連結,GCN 也沒法增加層數,相關的工作有:citation network。

為此,作者研究了目前基於領域聚合方式的性質和侷限性。

作者針對隨機遊走進行研究,並發現:除節點特徵以外,節點的子圖結構(也可以理解為所處位置)會極大的影響領域聚合的效果。下圖展示了 GooglePlus 的社交網路,從正方形節點開始進行 n-step 的隨機遊走:

JK-Net:殘差連結增加圖神經網路深度

我們可以從(a)中看到,處在中心節點位置的正方形節點經過 4-step 就可以涵蓋整個圖;而(b)中,處在邊緣節點位置的正方形節點,經過 4-step 僅僅擴充套件了一小部分,經過 5-step 達到核心後才迅速蔓延。

這表明:即使在同一張圖中,相同步數也會導致不同的效果。在實際應用中,我們應該透過組合不同形式的 n-step 來控制不同節點的擴散速度。

隨後,作者又證明了 k 層 GCN 的影響和隨機遊走 k 步的影響近似相同。(證明方式見論文)

下圖展示 k 層 GCN 和 k 步隨機遊走的結果,顏色深表示影響機率越高。

JK-Net:殘差連結增加圖神經網路深度

下圖展示了帶有殘差的 GCN 的影響分佈,與惰性隨機遊走更加相似。

JK-Net:殘差連結增加圖神經網路深度

可以看到,帶有殘差的 GCN 導致每一步都有更高的機率停留在當前節點,這與節點多樣化需求相違背。

回顧我們看到的第一場圖,如果 GCN 使用相同的層數,其與施加固定 step 的隨機遊走會有相同的效果。相同的層數可能會導致中心區域的節點表示失去區域性資訊,但卻會讓邊緣節點探索到其周圍的區域性資訊。

也就是說,如果用如果 GCN 具有具有固定層數,並不能讓為所有節點帶來最佳的向量表示。

2。JKnet

透過上面的分析,作者得出結論:目前通用的聚合方法引起的固定但與結構相關的影響半徑大小並不能實現所有節點和人物的最佳向量表示。較大的半徑可能會導致 over-smooth,而較小的半徑可能會導致不穩定和資訊聚集不足的問題。

為此,作者提出了兩個簡單而又強大的架構:跳躍連線(jump connection)和自適應選擇性聚集機制。

下圖闡述了作者的想法:

JK-Net:殘差連結增加圖神經網路深度

和普通的 GCN 一樣,每一層都會聚合來自上一層的領域來增加節點的影響大小。但在最後一層中,每個節點都會從之前的中間表示中篩選一些進行合併。這一步是針對每個節點獨立完成的,所以模型可以根據需要為每個節點調整有效的領域大小,從而完成自適應選擇性聚集。

作者給出了三種融合方法:

Concatenation:直接將各層的表達串聯合,送入到 Linear 層進行分類。這種方法不支援節點的自適應選擇,而是找到最適合整個資料集的方式來組合子圖特徵。這種方法適用於較規則的小型圖,並且一定機率上透過(Linear 層的)權重共享來減少過擬合;

Max-pooling:為每個節點進行基於 element-wise 的 max-pooling 操作。這樣的操作可以讓節點選擇從底層學習區域性領域的資訊,而是去從高層學習全域性領域的資訊。max-pooling 是自適應的,其優點在於不會引入額外的引數;

LSTM-attention:Attention 機制是一個有效的可以學習節點資訊的方式。每一個節點都會算出一個基於層的 attention 分數,其表示對於節點來說,不同層的重要性,然後基於此進行加權求和。LSTM-attention 是將各層的表達送入到雙向的 LSTM 中,這樣每個層都會有一個前向表達 和後向表達 ,然後將這個表達串聯拼接送入到 Linear 層來擬合出一個 score,對每一層的 score 進行歸一化後邊得到 attention score:,最後加權求和得到最終的表達。

這種設計的關鍵思想在於:在檢視所有層的學習特徵之後,不同位置的節點可以確定其子圖特徵的重要性,而不是為所有節點固定相同的權重。

下圖展示了利用 Max-pooling 聚合的 6 層 JK-Net,不同子圖結構的視覺化展示:

JK-Net:殘差連結增加圖神經網路深度

a 和 b 為邊緣節點,其影響的節點停留在小社群中;而和中心節點有關(c,d)的節點或者中心節點(e),其影響分散在一個合理範圍內的相鄰節點上。

3。Experiments

簡單看一下實驗。

下圖為不同模型在不同資料集中的表現,其中 JK-Net 基於 GCN 模型。LSTM 效果不好主要是因為資料集太小了:

JK-Net:殘差連結增加圖神經網路深度

下圖為基於 GraphSAGE 的 JK-Net 在 Reddit 資料集中的表現,層數皆為 2 層,評價指標為 F1:

JK-Net:殘差連結增加圖神經網路深度

下圖為在 PPI 資料集中表現,LSTM 的效果就顯示出來了

JK-Net:殘差連結增加圖神經網路深度

4。Conclusion

總結:本篇論文分析了 GCN 隨著層數增大而導致效能下降的原因,並受分析結果啟發提出了一個網路架構——JK-Net。JK-Net 透過自適應學習處在不同位置的節點聚合不同領域,從而可以改善節點的表示形式。JK-Net 可與現有的模型架構相結合,並在多個數據集中取得了 SOTA 的成績。

5。Reference

《Representation Learning on Graphs with Jumping Knowledge Networks》