謝賽寧新作:表徵學習有多重要?一個操作刷新SOTA,DiT訓練速度暴漲18倍
新智元報道
編輯:喬楊
【新智元導讀】在NLP領域,研究者們已經充分認識並認可了表徵學習的重要性,那麼視覺領域的生成模型呢?最近,謝賽寧團隊發表的一篇研究就拿出了非常有力的證據:Representation matters!
擴散模型如何突破瓶頸? 成本高又難訓練的DiT/SiT模型如何提升效率?
對於這個問題,紐約大學謝賽寧團隊最近發表的一篇論文找到了一個全新的切入點:提升表徵(representation)的質量。
論文的核心或許就可以用一句話概括:「表徵很重要!」
用謝賽寧的話來說,即使只是想讓生成模型重建出好看的圖像,仍然需要先學習強大的表徵,然後再去渲染高頻的、使圖像看起來更美觀的細節。
這個觀點,Yann LeCun之前也多次強調過。
有網友還在線幫謝賽寧想標題:你這篇論文不如就叫「Representation is all you need」(手動狗頭)
由於觀點一致,這篇研究也獲得了同在紐約大學的Yann LeCun的轉發。
REPA的核心思想非常簡單,就是讓擴散模型中的表徵與外部更強大的視覺表徵進行對齊,但提升效果非常顯著,頗有「他山之石,可以攻玉」的意味。
僅僅是在損失函數添加一項相似度最大化,就能將SiT/DiT的訓練速度提升將近18倍,還刷新了模型的SOTA性能,在ImageNet 256x256上實現了最先進的FID=1.42。
謝賽寧表示,剛看到實驗結果時,他自己也被震驚到了,因爲感覺並沒有發明什麼全新的東西,而只是意識到了,我們幾乎完全不理解擴散模型和SSL方法學習到的表示。
論文簡介
論文地址:https://arxiv.org/abs/2410.06940
項目地址:https://sihyun.me/REPA/
在生成高維的視覺數據方面,基於去噪方法(如擴散模型)或基於流的生成模型,已經成爲了一種可擴展的途徑,並在有挑戰性的的零樣本文生圖/文生視頻任務上取得了非常成功的結果。
最近的研究表明,生成擴散模型中的去噪過程可以在模型內部的隱藏狀態中引入有意義的表示,但這些表示的質量目前仍落後於自監督學習方法,例如DINOv2。
作者認爲,訓練大規模擴散模型的一個主要瓶頸,就在於無法有效學習到高質量的內部表示。
如果能夠結合高質量的外部視覺表示,而不是僅僅依靠擴散模型來獨立學習,就可以使訓練過程變得更容易。
爲了實現這一點,論文基於經典的擴散Transformer架構,引入了一種簡單的正則化方法REPA(REPresentation Alignment)。
簡單來說,就是將去噪網絡中從噪聲輸入 得到的隱藏狀態的投影,與外部自監督預訓練的視覺編碼器從乾淨圖像獲得的視覺表示*進行對齊。
這樣一個非常直給的策略,卻獲得了驚人的結果:應用於流行的SiT或DiT時,模型的訓練效率和生成質量都得到了顯著提高。
具體來說,REPA可以將SiT的訓練速度加快17.5×以上,以不到40萬步的訓練量匹配有700萬步訓練的SiT-XL模型的性能,同時實現了FID=1.42的SOTA結果。
REPA:使用表徵對齊的正則化
統一視角的擴散模型+流模型
由於論文希望同時優化基於流的模型SiT和基於去噪的擴散模型DiT,因此首先從統一的隨機插值視角,對這兩種模型進行簡要的回顧。
考慮在t∈[0,T]的連續時間步中,對數據*~p()使用高斯分佈ε~(0,)添加隨機噪音:
其中,αt和σt分別表示t的遞減和遞增函數。在公式(1)給定的過程中,存在一個帶有速度場(velocity field)的概率流常微分方程:
其中t步時的分佈就等於邊際概率pt()。
速度(,t)可以表示爲如下兩個條件期望之和:
這個值可以通過最小化如下訓練目標得到近似值θ(,t):
同時,還存在一個反向的隨機微分方程(SDE),帶有擴散係數wt,其中的邊際概率pt()與公式(2)相符:
其中,(t,t)是一個條件期望值,定義爲:
對任意t>0,都可以通過速度(,t)計算出(,t)的值:
這表明,數據t也可以通過求解公式(5)的SDE來以另一種方式生成。
以上定義對類似的擴散模型變體,例如DDPM,同樣適用,只是需要將連續的時間步離散化。
方法概述
令p()爲數據∈的未知目標分佈,我們的訓練目標就是通過模型對數據的學習得到p()的近似。
爲了降低計算成本,最近流行的「潛在擴散」方法(latent diffusion)提出學習潛在變量=E()的分佈p(),其中E表示來自預訓練自編碼器(例如KL-VAE)中的編碼部分。
要學習到分佈p(),就需要訓練擴散模型θ(t,t),訓練目標是進行速度預測,具體方法如上一節所述。
放在自監督表示學習的背景中,可以將擴散模型看成編碼器fθ:⭢和解碼器gθ:⭢的組合,其中編碼器負責隱式地學習到表示t以重建目標t。
然而,作者提出,用於生成的大型擴散模型並不擅長表徵學習,因此REPA引入了外部的語義豐富的表示,從而顯著提升生成性能。
REPA方法概述
模型觀察
擴散模型是否真的不擅長表徵學習?這需要更進一步地觀察模型才能確定,爲此,研究人員測量並比對了diffusion transformer和當前的SOTA自監督模型DINOv2之間的表徵差距,包括語義差距和特徵對齊兩種角度。
語義差距
從圖2a可知,預訓練SiT的隱藏層表示在第20層達到最佳狀態,這與之前的研究結果相符,但仍遠遠落後於DINOv2。
特徵對齊
如圖2b和2c所示,使用CKNNA值測量SiT和DINOv2之間的表徵對齊程度後發現,SiT的對齊效果會隨着模型增大和訓練迭代步數增加而逐漸改善,但即使增加到7M次迭代,和DINOv2之間的對齊程度仍然不足。
事實上,這種差距不僅在SiT中存在,根據附錄C.2的實驗結果,DiT等其他基於去噪的生成式Transformer模型也存在類似的問題。
縮小表徵差距
那麼,REPA方法究竟如何縮小這種表徵差距,讓diffusion transformer在噪聲輸入中也能學到有用的語義特徵?
定義N,D分別表示patch數量預訓練編碼器f的嵌入維度,編碼器輸入爲無噪聲的圖像*,輸出爲*=f(*)∈ℝN×D。
Diffusion transformer將編碼器輸出t=fθ(t)通過一個可訓練的投影頭hφ(MLP)投影爲hφ(t)∈ℝN×D。
之後,REPA負責將hφ(t)與*進行對齊,通過最大化兩者間的patch間相似度:
在實際實現中,將這一項添加到公式(4)定義的基於擴散的訓練目標中,就得到總體的訓練目標:
其中超參數λ>0用於控制模型在去噪目標和表徵對齊間的權衡。
從圖3結果可知,REPA減少了表示中的語義差距。
有趣的是,使用REPA後,僅對齊前幾個Transformer塊就能實現足夠程度的表示對齊,從而讓diffusion transformer的靠後層專注於捕獲高頻細節,從而進一步提高生成性能。
實驗結果
爲了驗證REPA方法的有效性,實驗在兩種流行的擴散模型訓練目標(即velocity)上進行了實驗,包括DiT中改進後的DDPM和SiT中的線性隨機插值,但實際中也同樣可以考慮其他的訓練目標。
所用模型默認嚴格遵循SiT和DiT的原始結構(除非有特別說明),包括B/2、L/2、XL/2三種參數設置,如表1所示。
以下實驗旨在回答3個問題:
- REPA能否顯著提升diffusion transformer的訓練?
- REPA在模型規模和表徵質量方面是否具有可擴展性?
- 擴散模型的表徵能否和多種視覺表徵進行對齊?
REPA提升視覺縮放
首先比較兩個SiT-XL/2模型在前400K次迭代期間生成的圖像,它們共享相同的噪聲、採樣器和採樣步數,但其中使用REPA訓練的模型顯示出更好的進展。
REPA在各個方面都展現出了強大的可擴展性
研究人員還改變了預訓練編碼器和Diffusion Transformer的模型大小來檢驗REPA的可擴展性。
圖5a結果表明,與更好的視覺表示相結合可以改善生成效果和線性探測的結果。
此外,如圖5b和c所示,增加模型大小可以在生成和線性評估方面帶來更快的收益,也就是說,模型規模越大,REPA的加速效果越明顯,表現出了強大的可擴展性。
REPA顯著提高訓練效率和生成質量
最後,論文比較了普通DiT或SiT模型在訓練中使用REPA前後的FID值。
在沒有指導的情況下,REPA在400K次迭代時實現了FID=7.9,優於普通模型在7M次迭代後的性能。
此外,使用無分類器引導時,帶有REPA的SiT-XL/2的性能優於SOTA性能(FID=1.42),同時迭代次數減少了7倍。
作者介紹
Sihyun Yu
本文一作Sihyun Yu是KAIST(韓國科學技術院)人工智能專業最後一年的博士生,此前他同樣在KAIST獲得了數學和計算機科學的雙專業學士學位。
他的研究主要集中在減少大型生成模型訓練(和採樣)的內存和計算負擔,其中,對大規模且高效的視頻生成特別感興趣;博士期間,他還曾在英偉達和谷歌研究院擔任實習生。
參考資料:
https://x.com/sainingxie/statdus/1845510163152687242