面向超網絡的連續學習:新算法讓人工智慧不再「災難性遺忘」

ai科技評論 發佈 2020-01-14T15:53:19+00:00

作者| 十、年編輯 | Camel人腦顯然是人工智慧追求的最高標準。想讓傳統深度學習系統獲得連續學習能力,最重要的是克服人工神經網絡會出現的「災難性遺忘」問題,即一旦使用新的數據集去訓練已有的模型,該模型將會失去對原數據集識別的能力。

作者 | 十、年

編輯 | Camel

人腦顯然是人工智慧追求的最高標準。

畢竟人腦使得人類擁有了連續學習的能力以及情境依賴學習的能力。

這種可以在新的環境中不斷吸收新的知識和根據不同的環境靈活調整自己的行為的能力,也正是深度學習系統與人腦相差甚遠的重要原因。

想讓傳統深度學習系統獲得連續學習能力,最重要的是克服人工神經網絡會出現的「災難性遺忘」問題,即一旦使用新的數據集去訓練已有的模型,該模型將會失去對原數據集識別的能力。

換句話說就是:讓神經網絡在學習新知識的同時保留舊知識。

前段時間,來自蘇黎世聯邦理工學院以及蘇黎世大學的研究團隊發表了一篇名為《超網絡的連續學習》(Continual learning with hypernetworks)的研究。提出了任務條件化的超網絡(基於任務屬性生成目標模型權重的網絡)。該方法能夠有效克服災難性的遺忘問題。

具體來說,該方法能夠幫助在針對多個任務訓練網絡時,有效處理災難性的遺忘問題。除了在標準持續學習基準測試中獲得最先進的性能外,長期的附加實驗任務序列顯示,任務條件超網絡(task-conditioned hypernetworks )表現出非常大的保留先前記憶的能力。

hypernetworks

在蘇黎世聯邦理工學院以及蘇黎世大學的這項工作中,最重要的是對超網絡(hypernetworks)的應用,在介紹超網絡的連續學習之前,我們先對超網絡做一下介紹。

hyperNetwork是一個非常有名的網絡,簡單說就是用一個網絡來生成另外一個網絡的參數工作原理是:用一個hypernetwork輸入訓練集數據,然後輸出對應模型的參數,最好的輸出是這些參數能夠使得在測試數據集上取得好的效果。簡單來說hypernetwork其實就是一個meta network。

傳統的做法是用訓練集直接訓練這個模型,但是如果使用hypernetwork則不用訓練,拋棄反向傳播與梯度下降,直接輸出參數,這等價於hypernetwork學會了如何學習圖像識別。

論文下載見文末

在《hypernetwork》這篇論文中,作者使用 hyperNetwork 生成 RNN 的權重,發現能為 LSTM 生成非共享權重,並在字符級語言建模、手寫字符生成和神經機器翻譯等序列建模任務上實現最先進的結果。超網絡採用一組包含有關權重結構的信息的輸入,並生成該層的權重,如下圖所示。

超網絡生成前饋網絡的權重:黑色連接和參數與主網絡相關聯,而橙色連接和參數與超網絡相關聯。

超網絡的連續學習模型

在整個工作中,首先假設輸入的數據{X(1),......X(T)}是可以被儲存的,並能夠使用輸入的數據計算Θ(T −1)。另外,可以將未使用的數據和已經使用過數據進行混合來避免遺忘。假設F(X,Θ)是模型,那麼混合後的數據集為{(X(1),Yˆ(1)),。。。,(X(T−1),Yˆ(T−1)),(X(T),Yˆ(T))},其中其中Yˆ(T)是由模型f(.,Θ(t−1))生成的一組合成目標。

然而存儲數據顯然違背了連續學習的原則,所以在在論文中,作者提出了一種新的元模型fh(e(t)h)做為解決方案,新的解決方案能夠將關注點從單個的數據輸入輸出轉向參數集{Θ(T)},並實現非儲存的要求。這個元模型稱為任務條件超網絡,主要思想是建立任務e(t)和權重Θ的映射關係,能夠降維處理數據集的存儲,大大節省內存。

在《超網絡的連續學習》這篇論文中,模型部分主要有3個部分,第一部分是任務條件超網絡。首先,超網絡會將目標模型參數化,即不是直接學習特定模型的參數,而是學習元模型的參數,從而元模型會輸出超網絡的權重,也就是說超網絡只是權重生成器。

圖a:正則化後的超網絡生成目標網絡權重參數;圖b:疊代地使用較小的組塊超網絡產生目標網絡權重。

然後利用帶有超網絡的連續學習輸出正則化。在論文中,作者使用兩步優化過程來引入記憶保持型超網絡輸出約束。首先,計算∆Θh(∆Θh的計算原則基於優化器的選擇,本文中作者使用Adam),即找到能夠最小化損失函數的參數。損失函數表達式如下圖所示:

註:Θh是模型學習之前的超網絡的參數;∆Θh為外生變量;βoutput是用來控制正則化強度的參數。

然後考慮模型的e(t),它就像Θh一樣。在算法的每一個學習步驟中,需要及時更新,並使損失函數最小化。在學習任務之後,保存最終e(t)並將其添加到集合{e(T)}。

模型的第二部分是用分塊的超網絡進行模型壓縮。超網絡產生目標神經網絡的整個權重集。然而,超網絡可以疊代調用,在每一步只需分塊填充目標模型中的一部分。這表明允許應用較小的可重複使用的超網絡。有趣的是,利用分塊超網絡可以在壓縮狀態下解決任務,其中學習參數(超網絡的那些)的數量實際上小於目標網絡參數的數量。

為了避免在目標網絡的各個分區之間引入權重共享,作者引入塊嵌入的集合{C} 作為超網絡的附加輸入。因此,目標網絡參數的全集Θ_trgt=[fh(e,c1),,,fh(e,CNc)]是通過在C上疊代而產生的,在這過程中保持e不變。這樣,超網絡可以每個塊上產生截然不同的權重。另外,為了簡化訓練過程,作者對所有任務使用一組共享的塊嵌入。

模型的第三部分:上下文無關推理:未知任務標識(context-free inference: unknown task identity)。從輸入數據的角度確定要解決的任務。超網絡需要任務嵌入輸入來生成目標模型權重。在某些連續學習的應用中,由於任務標識是明確的,或者可以容易地從上下文線索中推斷,因此可以立即選擇合適的嵌入。在其他情況下,選擇合適的嵌入則不是那麼容易。

作者在論文中討論了連續學習中利用任務條件超網絡的兩種不同策略。

策略一:依賴於任務的預測不確定性。神經網絡模型在處理分布外的數據方面越來越可靠。對於分類目標分布,理想情況下為不可見數據產生平坦的高熵輸出,反之,為分布內數據產生峰值的低熵響應。這提出了第一種簡單的任務推理方法(HNET+ENT),即給定任務標識未知的輸入模式,選擇預測不確定性最小的任務嵌入,並用輸出分布熵量化。

策略二:當生成模型可用時,可以通過將當前任務數據與過去合成的數據混合來規避災難性遺忘。除了保護生成模型本身,合成數據還可以保護另一模型。這種策略實際上往往是連續學習中最優的解決方案。受這些成功經驗的啟發,作者探索用回放網絡(replay network)來增強深度學習系統。

合成回放(Synthetic replay)是一種強大但並不完美的連續學習機制,因為生成模式容易漂移,錯誤往往會隨著時間的推移而積累和放大。作者在一系列關鍵觀察的基礎上決定:就像目標網絡一樣,重放模型可以由超網絡指定,並允許使用輸出正則化公式。而不是使用模型自己的回放數據。因此,在這種結合的方法中,合成重放和任務條件元建模同時起作用,避免災難性遺忘。

基準測試

作者使用MNIST、CIFAR10和CIFAR-100公共數據集對論文中的方法進行了評估。評估主要在兩個方面:(1)研究任務條件超網絡在三種連續學習環境下的記憶保持能力,(2)研究順序學習任務之間的信息傳遞。

具體的在評估實驗中,作者根據任務標識是否明確出了三種連續學習場景:CL1,任務標識明確;CL2,任務標識不明確,並不需明確推斷;CL3,任務標識可以明確推斷出來。另外作者在MNIST數據集上構建了一個全連通的網絡,其中超參的設定參考了van de Ven & Tolias (2019)論文中的方法。在CIFAR實驗中選擇了ResNet-32作為目標神經網絡。

van de Ven & Tolias (2019):

Gido M. van de Ven and Andreas S. Tolias. Three scenarios for continual learning. arXiv preprint arXiv:1904.07734, 2019.

為了進一步說明論文中的方法,作者考慮了四個連續學習分類問題中的基準測試:非線性回歸,PermutedMNIST,Split-MNIST,Split CIFAR-10/100。

非線性回歸的結果如下:

註:圖a:有輸出正則化的任務條件超網絡可以很容易地對遞增次數的多項式序列建模,同時能夠達到連續學習的效果。圖b:和多任務直接訓練的目標網絡找到的解決方案類似。圖c:循序漸進地學習會導致遺忘。

在PermutedMNIST中,作者並對輸入的圖像數據的像素進行隨機排列。發現在CL1中,任務條件超網絡在長度為T=10的任務序列中表現最佳。在PermutedMNIST上任務條件超網絡的表現非常好,對比來看突觸智能(Synaptic Intelligence) ,online EWC,以及深度生成回放( deep generative replay)方法有差別,具體來說突觸智能和DGR+distill會發生退化,online EWC不會達到非常高的精度,如下圖a所示。綜合考慮壓縮比率與任務平均測試集準確性,超網絡允許的壓縮模型,即使目標網絡的參數數量超過超網絡模型的參數數量,精度依然保持恆定,如下圖b所示。

Split-MNIST作為另一個比較流行的連續學習的基準測試,在Split-MNIST中將各個數字有序配對,並形成五個二進位分類任務,結果發現任務條件超網絡整體性能表現最好。另外在split MNIST問題上任務重疊,能夠跨任務傳遞信息,並發現該算法收斂到可以產生同時解決舊任務和新任務的目標模型參數的超網絡配置。如下圖所示

圖a:即使在低維度空間下仍然有著高分類性能,同時沒有發生遺忘。圖b:即使最後一個任務占據著高性能區域,並在遠離嵌入向量的情況下退化情況仍然可接受,其性能仍然較高。

在CIFAR實驗中,作者選擇了ResNet-32作為目標神經網絡,在實驗過程中,作者發現運用任務條件超網絡基本完全消除了遺忘,另外還會發生前向信息反饋,這也就是說與從初始條件單獨學習每個任務相比,來自以前任務的知識可以讓網絡表現更好。

綜上,在論文中作者提出了一種新的連續學習的神經網絡應用模型--任務條件超網絡,該方法具有可靈活性和通用性,作為獨立的連續學習方法可以和生成式回放結合使用。該方法能夠實現較長的記憶壽命,並能將信息傳輸到未來的任務,能夠滿足連續學習的兩個基本特性。

參考文獻:

HYPERNETWORKS:

https://arxiv.org/pdf/1609.09106.pdf

CONTINUAL LEARNING WITH HYPERNETWORKS

https://arxiv.org/pdf/1906.00695.pdf

https://mp.weixin.qq.com/s/hZcVRraZUe9xA63CaV54Yg

關鍵字: