圖神經網絡三劍客:GCN、GAT與GraphSAGE

sandag 發佈 2020-02-28T19:24:18+00:00

隨著圖神經網絡研究的熱度不斷上升,我們也看到圖神經網絡的不同變種不斷湧現,此外,因為對於非歐空間數據良好的表達能力,圖神經網絡在交通、金融、社會科學等有大量相應數據積澱的交叉領域也面臨著廣闊的應用前景。

2019 年號稱圖神經網絡元年,在各個領域關於圖神經網絡的研究爆髮式增長。 本文主要介紹一下三種常見圖神經網絡: GCN、GAT 以及 GraphSAGE。 前兩者是目前應用比較廣泛的圖神經網絡,後者則為圖神經網絡的工程應用提供了基礎。

GCN

圖神經網絡基於巴拿赫不動點定理提出,但圖神經網絡領域的大發展是在 2013 年 Bruna 提出圖上的基於頻域和基於空域的卷積神經網絡後。

關於圖卷積神經網絡的理解與介紹,知乎上的回答已經講的非常透徹了。

如何理解 Graph Convolutional Network (GCN)?

https://www.zhihu.com/question/54504471/answer/332657604

這裡主要介紹一下 PyG 和 DGL 兩個主要的圖神經網絡庫實現所基於的文章 Semi-supervised Classification with Graph Convolutional Networks 。它基於對圖上頻域卷積的一階近似提出了一種高效的逐層傳播規則。

論文標題: Semi-supervised Classification with Graph Convolutional Networks

論文連結: https://arxiv.org/abs/1609.02907

在將定義在歐式空間上的拉普拉斯算子和傅立葉變換對應到圖上之後,圖上的頻域卷積操作可以基於卷積定理自然導出:

其中圖上的拉普拉斯矩陣(歸一化後)L 是一個半正定對稱矩陣,它具有一些良好的性質,可以進行譜分解:

,其中 U 是 L 的特徵向向量組成的矩陣,Λ 是 L 的特徵值組成的對角矩陣, 則是定義在圖上的對信號 的傅立葉變換。

而對角矩陣 則是卷積核,也是不同的卷積操作關注的焦點,對 不同的設計會影響卷積操作的效率,其編碼的信息也會影響最終任務的精度。

一開始的圖卷積神經網絡將視作 L 的特徵值的一個函數 。但這種定義存在兩個問題:

1. 對特徵向量矩陣 U 的乘法操作時間複雜度是;

2. 對大規模圖的拉普拉斯矩陣 L 的特徵分解是困難的。

之後的研究發現可以使用切比雪夫多項式來對進行近似:

其中

。是 L 的最大特徵值,是切比雪夫多項式的係數向量。切比雪夫多項式通過如下的遞推公式定義: ,起始值:

。 將其代入之前定義的卷積操作:

其中

,此時的時間複雜度為 。文章在此基礎上對卷積操作進行了進一步的簡化,首先固定 K=1,並且讓近似等於 2(注意之前對 L 的定義),則上式可以簡化為一個包含兩個自由參數 和 的公式:

我們進一步假定

,則可進一步對公式進行變形:

但是此時的

的特徵值取值在 [0, 2],對這一操作的堆疊會導致數值不穩定以及梯度爆炸(或消失)等問題。為了解決這一問題,引入一種稱為重歸一化(renormalization)的技術:

最後將計算進行向量化,得到最終的卷積計算公式為:

這一計算的時間複雜度為

。基於上式實現的 GCN 在三個數據集上取得了當時最好的結果。

GAT

PyG 與 DGL 的 GAT 模塊都是基於 Graph Attention Networks 實現的,它的思想非常簡單,就是將 transform 中大放異彩的注意力機制遷移到了圖神經網絡上。

論文標題: Graph Attention Networks

論文連結: https://arxiv.org/abs/1710.10903

整篇文章的內容可以用下面一張圖來概況。

首先回顧下注意力機制的定義,注意力機制實質上可以理解成一個加權求和的過程:對於一個給定的 query,有一系列的 value 和與之一一對應的 key,怎樣計算 query 的結果呢?

很簡單,對 query 和所有的 key 求相似度,然後根據相似度對所有的 value 加權求和就行了。這個相似度就是 attention coefficients,在文章中計算如下:

其中

是前饋神經網絡的權重係數,|| 代表拼接操作。

利用注意力機制對圖中結點特徵進行更新:

既然得到了上式,那麼多頭注意力的更新就不言而明了,用 k 個權重係數分別得到新的結點特徵之後再拼接就可以了:

最後就是大家喜聞樂見的暴打 benchmarks 的環節,GAT 在三個數據集上達到了當時的 SOTA。

GraphSAGE

GraphSAGE 由 Inductive Representation Learning on Large Graphs 提出,該方法提供了一種通用的歸納式框架,使用結點信息特徵為未出現過的(unseen)結點生成結點向量,這一方法為後來的 PinSage(GCN 在商業推薦系統首次成功應用)提供了基礎。

論文標題: Inductive Representation Learning on Large Graphs

論文連結: https://arxiv.org/abs/1706.02216

但 GraphSAGE 的思想卻非常簡單,也可以用一張圖表示。

算法的詳細過程如下:

1. 對圖上的每個結點 v,設置它的初始 embedding 為它的輸入特徵 ;

2. 之後進行 K次疊代,在每次疊代中,對每個結點 v,聚合它的鄰居結點(採樣後)的在上一輪疊代中生成的結點表示 生成當前結點的鄰居結點表示 ,之後連接 輸入一個前饋神經網絡得到結點的當前表示 ;

3. 最後得到每個結點的表示

這個算法有兩個關鍵點:一是鄰居結點採樣,二是聚合鄰居結點信息的聚合函數。

鄰居結點採樣方面,論文中在 K 輪疊代中,每輪採樣不同的樣本,採樣數量為 。在聚合函數方面,論文提出了三種聚合函數:

Mean aggregator:

LSTM aggregator: 使用 LSTM 對鄰居結點信息進行聚合。值得注意地是,因為 LSTM 的序列性,這個聚合函數不具備對稱性。文章中使用對鄰居結點隨機排列的方法來將其應用於無序集合。

Pooling aggregator:

論文在三個數據集上取得了對於 baseline 的 SOTA。

既然為工程應用提出的方法,對於實驗部分就不能一筆帶過了,這裡給出論文中兩個有意思的結論:

對於鄰居結點的採樣,設置 K=2 和

得到比較好的表現;

對於聚合函數的比較上,LSTM aggregator 和 Pooling aggregator 表現最好,但是前者比後者慢大約兩倍。

總結

本文對圖神經網絡中常用的三種方法進行了介紹。隨著圖神經網絡研究的熱度不斷上升,我們也看到圖神經網絡的不同變種不斷湧現,此外,因為對於非歐空間數據良好的表達能力,圖神經網絡在交通、金融、社會科學等有大量相應數據積澱的交叉領域也面臨著廣闊的應用前景。

關鍵字: