這篇論文非常火!差分Transformer能消除注意力噪聲,如降噪耳機

機器之心報道

編輯:Panda

Transformer 的強大實力已經在諸多大型語言模型(LLM)上得到了證明,但該架構遠非完美,也有很多研究者致力於改進這一架構,比如機器之心曾報道過的 Reformer 和 Infini-Transformer。

今天我們又將介紹另一種新型 Transformer 架構:Differential Transformer(差分 Transformer,簡稱 Diff Transformer)。該架構來自微軟研究院和清華大學,有四位共一作者:Tianzhu Ye、Li Dong、Yuqing Xia、Yutao Sun。

在 Hacker News 及 Twitter 等社交網絡上,該論文都反響熱烈,有網友表示差分 Transformer 提出的改進簡單又美麗,而帶來的提升又非常顯著。

甚至已有開發者做出了差分 Transformer 的輕量實現!

差分 Transformer 的輕量實現,https://github.com/Jaykef/ai-algorithms/blob/main/DIFF_Transformer.ipynb

那麼差分 Transformer 彌補了原生 Transformer 的哪些問題呢?如下圖所示,Transformer 往往會過度關注不相關的上下文,該團隊將此稱爲注意力噪聲(attention noise)。而差分 Transformer 則能放大對答案範圍的注意力並消除噪音,從而增強上下文建模的能力。這就要用到該團隊新提出的差分注意力機制(differential attention mechanism)了。

差分注意力機制可以消除注意力噪聲,鼓勵模型重點關注關鍵信息。該方法有些類似於電氣工程中的降噪耳機和差分放大器。

下面我們就來詳細瞭解一下差分 Transformer 的設計思路。

差分 Transformer

差分 Transformer 是一種用於序列建模的基礎模型架構。爲了方便說明,他們使用了僅解碼器(decoder-only)模型作爲示例來描述該架構。

該模型堆疊了 L 個 Diff Transformer 層。給定一個輸入序列 x,將輸入嵌入打包成 X^0。輸入會被進一步上下文化來獲得輸出 X^L。每一層都由兩個模塊組成:一個差分注意力模塊和之後的前向網絡模塊。

相比於 Transformer,差分 Transformer 的主要差別在於使用差分注意力替換了傳統的 softmax 注意力,同時保持整體宏觀佈局不變。此外,他們也參考 LLaMA 採用了 pre-RMSNorm 和 SwiGLU 這兩項改進措施。

差分注意力

差分注意力機制的作用是將查詢、鍵和值向量映射成輸出。這裡使用查詢和鍵向量來計算注意力分數,然後計算值向量的加權和。

此處的關鍵設計是使用一對 softmax 函數來消除注意力分數的噪聲。具體來說,給定輸入 X,首先將它們投射成查詢、鍵和值 Q_1、Q_2、K_1、K_2、V。然後差分注意力算子 DiffAttn (・) 通過以下方式計算輸出:

其中 W^Q、W^K 、W^V 是參數,λ 是可學習的標量。爲了同步學習動態,將標量 λ 重新參數化爲:

其中 λ_q1、λ_k1、λ_q2、λ_k2 是可學習的向量,λ_init ∈ (0, 1) 是用於初始化 λ 的常數。該團隊通過經驗發現,設置 λ_init = 0.8 − 0.6 × exp (−0.3・(l − 1)) 在實踐中效果很好,其中 l ∈ [1, L] 表示層索引。它在實驗中被用作默認策略。

他們也探索了另一種初始化策略:對所有層使用相同的 λ_init(例如 0.8)。如後面消融研究所示,使用不同的初始化策略時,性能相對穩健。

差分注意力利用兩個 softmax 注意力函數之間的差來消除注意力噪聲。這個想法類似於電氣工程中提出的差分放大器,其中兩個信號之間的差用作輸出,這樣就可以消除輸入的共模噪聲。此外,降噪耳機的設計也基於類似的想法。

該團隊也爲差分注意力使用了多頭機制。令 h 表示注意力頭的數量。他們對各個頭使用不同的投影矩陣 W^Q_i 、W^K_i 、W^V_i ,i ∈ [1, h]。標量 λ 在同一層內的頭之間共享。然後對頭輸出執行歸一化,並投射成最終結果,如下所示:

其中 λ_init 是 (2) 式中的常數標量,W^O 是可學習的投影矩陣,LN (・) 是對每個頭使用 RMSNorm,Concat (・) 的作用是沿通道維度將頭連接在一起。這裡使用一個固定乘數(1 − λ_init)作爲 LN (・) 的縮放尺度,以使梯度與 Transformer 對齊。

圖 2 使用了 GroupNorm (・) 來強調 LN (・) 獨立應用於每個 head。由於差分注意力往往具有更稀疏的模式,因此頭之間的統計信息更加多樣化。爲了改進梯度的統計情況,LN (・) 算子會在連接操作之前對每個頭進行歸一化。

整體架構

其整體架構會堆疊 L 層,其中每層包含一個多頭差分注意力模塊和一個前向網絡模塊。如此,便可將差分 Transformer 層描述爲:

其中 LN (・) 是 RMSNorm,SwiGLU (X) = (swish (XW^G) ⊙ XW_1) W_2,且 W^G、W_1、W_2 是可學習的矩陣。

實驗

該團隊從以下角度評估了差分 Transformer 在 LLM 中的應用,包括對比評估、應用評估和消融研究。這裡我們僅關注實驗結果,更多實驗過程請訪問原論文。

語言建模評估

該團隊評估了差分 Transformer 的語言建模能力。爲此,他們使用 1T token 訓練了一個 3B 大小的差分 Transformer 語言模型,並與之前的 Transformer 語言模型做了比較。

結果見表 1,其中報告的是在 LM Eval Harness 基準上的零樣本結果。

可以看到,3B 規模下,差分 Transformer 語言模型的表現優於之前的 Transformer 語言模型。此外,實驗也表明差分 Transformer 在多種任務上都勝過 Transformer,詳見原論文附錄。

與 Transformer 的可擴展性比較

該團隊也比較了新舊 Transformer 的可擴展性。結果見圖 3,其中 a 比較了模型規模方面的可擴展性,而 b 則是訓練 token 數量方面的可擴展性。

可以看到,在這兩個方面,差分 Transformer 的可擴展性均優於常規 Transformer:僅需後者 65% 左右的模型大小或訓練 token 數量就能達到相媲美的性能。

長上下文評估

當 3B 模型上下文長度增長至 64K,模型的表現又如何呢?又使用另外 1.5B token 訓練了 3B 版本的檢查點模型之後,該團隊發現隨着上下文長度的增加,累積平均負對數似然(NLL)持續下降。差分 Transformer 得到的 NLL 值低於常規 Transformer。見圖 4,這樣的結果表明,差分 Transformer 可以有效地利用不斷增加的上下文。

關鍵信息檢索

爲了檢驗差分 Transformer 檢索關鍵信息的能力,該團隊執行了 Needle-In-A-Haystack(草堆找針)測試。

表 2 給出了 4K 上下文長度的情況,其中 N 是針的數量,R 是查詢引用的數量。可以看到,差分 Transformer 的多針檢索準確度高於常規 Transformer,尤其是當針數量較多時,差分 Transformer 的優勢會更加明顯。

那麼當上下文長度提升至 64K 時,又會如何呢?結果見圖 5,這裡使用的上下文長度在 8K 到 64K 之間,使用了 N = 8 和 R = 1 的設置。

可以看到,在不同的上下文長度下,差分 Transformer 能夠保持相對穩定的性能。而當上下文長度越來越大時,常規 Transformer 的性能會逐漸下降。

另外,表 3 展示了分配給關鍵信息檢索任務的答案範圍和噪聲上下文的注意力分數。該分數可代表模型保留有用信息、抵抗注意力噪聲的能力。

可以看到,相比於常規 Transformer,差分 Transformer 能爲答案範圍分配更高的注意力分數,同時爲注意力噪聲分配更低的注意力分數。

上下文學習能力評估

該團隊從兩個角度評估模型的上下文學習能力,包括多樣本分類和上下文學習的穩健性。

圖 6 展示了新舊 Transformer 模型的多樣本分類結果。結果表明,在不同的數據集和不同的演示樣本數量上,差分 Transformer 均穩定地優於 Transformer。此外,差分 Transformer 的平均準確度優勢也很明顯,從 5.2% 到 21.6% 不等。

圖 7 則展示了兩種模型的上下文學習穩健性結果。該分析基於 TREC 數據集,並且採用了兩種提示詞格式:示例隨機排列(圖 7a)和按類別交替排列(圖 7b)。

在這兩種設置下,差分 Transformer 的性能方差要小得多。結果表明,新方法在上下文學習任務中更爲穩健。相比之下,Transformer 容易受到順序排列的影響,導致最佳結果與最差結果之間差距巨大。

上下文幻覺評估

該團隊基於文本摘要和問答任務評估了模型的上下文幻覺現象。結果見表 4。

可以看到,相比於常規 Transformer,差分 Transformer 在摘要和問答任務上的上下文幻覺更低。該團隊表示,原因可能是差分 Transformer 能更好地關注任務所需的基本信息,而不是無關上下文。

激活異常值分析

在 LLM 中,一部分激活值明顯大於大多數激活值的現象被稱爲激活異常值(activation outliers)。異常值導致訓練和推理過程中模型量化困難。實驗表明差分 Transformer 可以降低激活異常值的幅度,從而可能實現更低的量化位寬。

表 5 展示了兩個訓練得到 Transformer 和差分 Transformer 模型的激活值統計情況。這裡分析了兩種類型的激活,包括注意力 logit(即 pre-softmax 激活)和隱藏狀態(即層輸出)。可以看到,儘管中位數相似,但與 Transformer 相比,差分 Transformer 的較大激活值要低得多。這表明新方法產生的激活異常值較少。

圖 8 則展示了將注意力 logit 量化到更低位的情況。這裡使用的方案是:使用 absmax 量化的動態後訓練量化。其中,16 位配置表示未經量化的原始結果。模型逐步量化爲 8 位、6 位和 4 位。這裡報告的是在 HellaSwag 上的零樣本準確度,但該團隊也指出在其它數據集上也有類似表現。

從圖中可知,即使降低位寬,差分 Transformer 也能保持較高性能。相較之下,常規 Transformer 的準確度在 6 位和 4 位量化時會顯著下降。這一結果表明,差分 Transformer 本身就能緩解注意力分數中的激活異常值問題,從而可爲低位 FlashAttention 的實現提供新機會。

最後,該團隊也進行了消融實驗,證明了各個新設計的有效性。