3天把Llama訓成Mamba,性能不降,推理更快!

新智元報道

編輯:alan

【新智元導讀】近日,Mamba方面又搞出了有意思的研究:來自康奈爾、普林斯頓等機構的研究人員成功將Llama提煉成了Mamba模型,並且設計了新的推測解碼算法,加速了模型的推理。

先來看一張其樂融融的圖片(一眼AI):

右邊的小羊駝代表Llama,而左邊的蛇(Mamba)也是我們的老熟人了。

至於到底能不能其樂融融,咱就不管了,之所以有此場景,是因爲Mamba方面又搞出了有意思的研究:

——如何把Llama變成Mamba?

論文地址:https://arxiv.org/pdf/2408.15237

代碼地址:https://github.com/jxiw/MambaInLlama

近日,來自康奈爾、普林斯頓等機構的研究人員推出了上面這篇工作,將Llama這樣的大型Transformer提煉成了Mamba模型,

並且成功在Mamba架構上應用了帶有硬件感知的推測解碼算法,提高了整個模型的推理速度。

爲什麼要把Llama變成Mamba?

因爲從頭開始訓練一個大模型太貴了。

Mamba也火了這麼長時間了,相關的研究每天都有,但自己訓練大尺寸Mamba模型的卻很少。

目前比較有名的是AI21的Jamba(進化到了1.5版本,最大398B,MoE),以及NVIDIA的Hybrid Mamba2模型(8B)。

不過世界上有那麼多成功的Transformer大模型,而知識就包含在這些模型參數裡。

如果能夠鎖住知識,同時把Transformer微調成Mamba,不就解決問題了?

在本文中,研究人員結合漸進式蒸餾、監督微調(SFT)和定向偏好優化(DPO)等方法達成了這一目標。

光是變大還不夠,

在性能匹配Transformer的前提下,速度也要夠快才行。

Mamba憑藉固定的推理開銷,在長序列中的優勢明顯,但Transformer這邊也是有推理加速方案的,比如推測解碼。

而由於Mamba本身的結構特性,不能直接應用這種方案,所以作者設計了全新的算法,並結合硬件的性質來實現基於Mamba的推測解碼。

最終,研究人員將Zephyr-7B、Llama-3 8B提煉爲了線性RNN模型(混合Mamba和Mamba2),且性能與蒸餾之前的標準模型相當。

整個訓練過程只使用了20B的token,效果卻能夠與使用1.2T個token從頭開始訓練的Mamba 7B模型,以及使用3.5T個token訓練的NVIDIA Hybrid Mamba2模型相媲美。

從 Transformer 到 Mamba

在介紹Mamba 2的時候我們講過,線性RNN(或SSM)跟線性注意力是一回事。

所以可以根據x,B,C與V,K,Q的對應關係直接複用注意力中的投影矩陣。

額外的參數包括SSM需要的A矩陣和Δt(由x投影得到),這就完成了基本的參數初始化。

之後就是SSM的運算過程,再通過投影和累加得到輸出。

模型架構和訓練

下圖給出了模型的架構,因爲Transformer的知識存在於MLP層,所以凍結這部分參數。

除了用線性RNN層(Mamba)替換掉注意力頭,還有一些組件需要處理,比如跨頭共享鍵和值的分組查詢注意力(GQA)。

知識蒸餾(Knowledge distillation,KD)是一種常用的壓縮技術,用來訓練模仿較大模型(teacher)行爲的較小網絡(student)。

根據經驗,這裡採用逐步替換Attention層的策略,先是每2層進行蒸餾,然後每4層繼續蒸餾......

監督微調

有兩種常見的蒸餾方法。一種方法是使用word-level的KL散度,此時訓練student模型去匹配teacher模型輸出的完整概率分佈。

第二種方法是序列級知識蒸餾(SeqKD),直接使用teacher模型的輸出作爲ground truth來訓練student模型(也稱爲僞標籤)。

這裡θ是student模型的可訓練參數,α和β分別控制序列和詞的loss項的權重。

偏好優化

LLM指令調優的第二階段是使其符合用戶偏好。這個階段,使用一組期望的偏好對來改進模型的輸出。

優化的目標是使獎勵模型最大化,同時保持產生的輸出接近參考模型。

通常,參考模型使用上一步監督微調後的模型。這裡因爲是蒸餾,直接可以用teacher模型:

偏好模型的獎勵函數定義取決於所使用的方法,本文采用直接偏好優化(DPO),通過直接梯度更新有效地到達優化目標。

DPO表明,對於給定的提示x ,如果我們能夠獲得preferred和dispreferred兩種輸出,就可以將這個優化問題重新表述爲:

這種優化可以在序列級別上執行,讓teacher模型和student模型一起對preferred和dispreferred輸出進行評分,然後反向傳播給student模型。

推測解碼

經過上面的一套小連招,模型轉換就搞定了,下面開始想辦法應用Transformer那邊的推測解碼。

推測解碼(Speculative Decoding)可以簡單理解爲下面這張圖。

Transformer做推理的時候,除了要處理不斷變長的KV cache之外,計算效率也是個問題。

因爲顯卡的設計是計算高於訪存的,具體到計算單元就是做矩陣乘法。

而推理的時候每次只能進入一個詞向量,顯卡的很多計算就被浪費了。

推測解碼給出的解決方案是,使用一個小模型做生成,然後拿顯卡多餘的計算做驗證。

小模型跑得快,可以一口氣生成很多輸出向量,但是可能效果差一點。這時候用大模型作爲驗證,一次計算之前生成的很多個向量。

所以小模型串行跑得快,大模型可以並行計算跑得也快,遇到驗證不通過的就直接回滾,整體上提高了推理的速度。

Transformer可以方便地回滾,因爲KV cache跟時間是一一對應的,但Mamba這邊只有一個當前的中間狀態ht,你總不能把所有中間狀態都存起來吧。

爲了解決這個問題,研究人員設計了下面的算法:

簡單來說就是每次使用小模型(draft model)生成一組輸出,然後大模型(verification model)驗證這一組輸出,根據驗證匹配的位置來更新需要保存的中間狀態。

我們可以從下面的僞代碼瞭解詳細的過程:

每次生成K個草稿輸出,驗證模型通過MultiStep函數返回K個真正的輸出,以及上一次校驗成功位置的cache(中間狀態hj)和本次最後位置的cache(hk)。

Multi-Step內核的性能特徵

通過FirstConflict函數找到最後匹配(校驗成功)的位置,如果所有都匹配,則cache可以更新到最後的hk,否則就只更新到上一次的hj。

兵馬後動,糧草先行,不耽誤輸出和校驗,同時只需要多存儲一箇中間狀態。

當然,如果草稿模型也用Mamba的話,算法的推測部分會變得複雜一些,因爲草稿模型需要重新計算上一次迭代中驗證成功位置的狀態。

硬件特定優化

下面使用Mamba 7B和 Mamba 2.8B作爲目標模型進行推測實驗。

最初,作者搞了一版簡單的算法實現,結果在Ampere架構的GPU(3090)上面效果顯著,Mamba 2.8B獲得了1.5倍的推理加速, 同時有60%的接受率。

但是這種實現方式在H100 GPU上不太好使,主要是因爲GEMM操作的速度更快了,使得緩存和重新計算產生的開銷更加明顯。

所以,作者通過融合內核以及調整實現方式來優化算法。

對於驗證模型,首先從緩存中重新計算之前的步驟,然後對新的草稿token序列進行多步解碼,最後在單個內核中進行緩存。

對於草稿模型,重新計算、解碼和緩存也融合在單個內核中。最終實現了上表中的加速效果。

實驗

研究人員使用兩個LLM聊天模型進行實驗:Zephyr-7B和Llama-3 Instruct 8B。

採用三階段蒸餾。在第一階段,使用UltraChat和UltraFeedback作爲種子提示,並使用teacher模型生成僞標籤。

使用AdamW優化器訓練模型,β=(0.9,0.98) ,批量大小64。先使用線性學習率預熱,然後進行餘弦退火。

第二階段,在一個epoch中使用SFT在GenQA、InfinityInstruct和OpenHermes 2.5數據集上對模型進行監督微調,採用與Zephyr相同的超參數。

最後一個階段,對於從Zephyr中提取的模型,在UltraFeedback數據集上使用DPO與標準模型進行蒸餾對齊。

過程中只在第一階段凍結MLP層,後兩個階段所有參數都進行訓練。

作者表示,通常只需要在8卡80G A100上運行3到4天,即可重現本文的結果。

參考資料:

https://arxiv.org/abs/2408.15237