擴散模型版CS: GO!世界模型+強化學習:2小時訓練登頂Atari 100K

新智元報道

編輯:LRS

【新智元導讀】DIAMOND是一種新型的強化學習智能體,在一個由擴散模型構建的虛擬世界中進行訓練,能夠以更高效率學習和掌握各種任務。在Atari 100k基準測試中,DIAMOND的平均得分超越了人類玩家,證明了其在模擬複雜環境中處理細節和進行決策的能力。

環境生成模型(generative models of environments),也可以叫世界模型(world model),在「通用智能體規劃」和「推理環境」中的關鍵組成部分,相比傳統強化學習採樣效率更高。

但世界模型主要操作一系列離散潛在變量(discrete latent variables)以模擬環境動態,但這種壓縮緊湊的離散表徵有可能會忽略那些在強化學習中很重要的視覺細節。

日內瓦大學、愛丁堡大學的研究人員提出了一個在擴散世界模型中訓練的強化學習智能體DIAMOND(DIffusion As a Model Of eNvironment Dreams),文中分析了使擴散模型適應於世界建模(world modeling)所需的設計要素,並展示瞭如何通過改善視覺細節來提高智能體的性能。

論文鏈接:https://arxiv.org/pdf/2405.12399

代碼鏈接:https://github.com/eloialonso/diamond

項目鏈接:https://diamond-wm.github.io

DIAMOND在Atari 100k基準測試中達到了1.46的平均人類標準化分數(mean human

normalized score),也是完全在世界模型內訓練智能體的最佳成績。

此外,在圖像空間中操作還有一個好處是,擴散世界模型能夠成爲環境的即插即用替代品,更方便地深入分析世界模型和智能體行爲。

在項目主頁,研究人員還展示了智能體玩CS: GO的畫面,先收集了87小時人類玩家的視頻;然後用兩階段管道(two-stage pipeline:)以低分辨率執行動態預測,降低訓練成本;將擴散模型從Atari的4.4M參數擴展(scaling)到 CS: GO 的381M;最後對上採樣器使用隨機採樣(stochastic sampling)來提高視覺生成質量。

模型在RTX 4090上訓練了12天,並且可以在RTX 3090上以約10 FPS的速度運行。

不過該方法在模擬世界模型時,在部分場景下仍然會失效。

強化學習和世界模型

我們可以把環境看作是一個複雜的系統,智能體在這個系統中通過執行動作來探索並接收反饋(獎勵)。

智能體不能直接知道環境的具體狀態,只能通過圖像觀測來理解環境,最終的目標是教會智能體一個策略,使其能夠根據所看到的圖像來決定最佳的行動方式,以獲得最大的長期獎勵。

爲此,研究人員構建了一個世界模型來模擬環境的行爲,讓智能體在模擬環境中進行訓練,這樣可以更高效地利用數據,提高學習速度。

整個訓練過程包括收集真實世界中的數據,用這些數據來訓練世界模型,然後讓智能體在世界模型中進行訓練,類似於在一個虛擬的環境中進行練習一樣,也可以稱之爲「想象中的訓練」(imagination)。

基於評分的擴散模型

擴散模型是一類受非平衡熱力學啓發的生成模型,通過逆轉加噪過程來生成樣本。

假設有一個由連續時間變量τ索引的擴散過程,其中τ的取值範圍是0到T,然後有一系列的分佈,以及邊界條件:在τ=0時,分佈是數據的真實分佈,而在τ=T時,分佈是一個易於處理的無結構先驗分佈,比如高斯分佈。

爲了逆轉正向的加噪過程,需要定義漂移係數和擴散係數的函數,以及估計與過程相關的未知得分函數;在實踐中,可以使用一個單一的時間依賴得分模型來估計這些得分函數。

不過在任意時間點估計得分函數並不簡單,現有的方法使用得分匹配作爲目標,可以在不知道潛在得分函數的情況下,從數據樣本中訓練得分模型。

爲了獲得邊際分佈的樣本,需要模擬從時間0到時間τ的正向過程,然後通過一個高斯擾動核到清潔數據樣本,在一步之內解析地到達正向過程的任何時間τ;由於核是可微的,得分匹配簡化爲一個去噪得分匹配目標(denoising score matching),這時目標變成了一個簡單的L2重建損失,其中包含了一個時間依賴的重參數化項。

用於世界建模的擴散模型

世界模型需要一個條件生成模型來模擬環境的動態,即給定過去的狀態和動作,預測下一個狀態的概率分佈,可以看作是部分可觀察馬爾可夫決策過程(POMDP),通過在歷史數據上訓練一個條件生成模型,來預測環境的下一個狀態,雖然理論上可以採用任意常微分方程(ODE)或隨機微分方程(SDE)求解器,但在生成新的觀察結果時,需要在採樣質量和計算成本之間做出權衡。

DIAMOND

DIAMOND模型有兩個重要的參數,一個是漂移係數,決定了系統隨時間變化的趨勢;另一個是擴散係數,決定了噪聲的強度,兩個係數共同調節可以使模型更好地模擬真實世界的變化。

模型的核心是預測環境的下一個狀態,爲了訓練該網絡,需要提供一系列的數據,包括過去的觀察結果和動作,網絡的目標是從當前的狀態和動作中預測出下一個狀態。

在訓練過程中,會逐漸向數據中加入噪聲,模擬環境的不確定性;然後,網絡需要學會從這些帶有噪聲的數據中恢復出原始的、清晰的下一個狀態,整個過程就像是在一堆雜亂無章的信息中找到規律,預測出接下來可能發生的事情。

爲了幫助網絡更好地學習和預測,DIAMOND使用了一種叫做U-Net的神經網絡結構。這種結構特別適合處理圖像數據,因爲它可以捕捉到圖像中的複雜模式。我們還使用了一種特殊的技術,叫做自適應組歸一化,這有助於網絡在處理不同噪聲水平的數據時保持穩定。

最後使用歐拉方法來生成預測結果,不需要複雜的計算,在大多數情況下都可以提供足夠準確的預測。

在想象中強化學習

比如說,我們正在訓練一個智能體如何在一個虛擬世界中行動:智能體需要「獎勵模型」告訴它做得好不好,需要「終止模型」告訴他什麼時候遊戲結束。

智能體有兩個部分:一個部分告訴它該怎麼做(actor),用REINFORCE方法來訓練;另一個部分告訴它做得怎麼樣(critic ),用λ-回報的貝爾曼誤差的方法來訓練。

讓智能體在一個完全由計算機生成的世界中進行訓練,這樣就可以在不真實接觸環境的情況下學習和成長。

只需要在真實環境中收集一些數據;每次收集完數據後,都會更新智能體的虛擬世界,然後讓模型在這個更新後的世界中繼續訓練;整個過程不斷重複,直到智能體學會如何在虛擬世界中更好地行動。

Atari 100k基準結果

Atari 100k包括了26個不同的電子遊戲,每個遊戲都要求模型具有不同的能力。

在測試中,智能體在開始真正玩遊戲之前,只能在遊戲中嘗試100,000次動作,大概相當於人類玩2個小時的遊戲時間,而其他無限嘗試的遊戲智能體通常會嘗試5億次動作,多了500倍。

爲了更容易與人類玩家的表現進行比較,使用人類歸一化得分(HNS)指標,結果顯示,DIAMOND的表現非常出色,在11個遊戲中超過了人類玩家的表現,基本實現了超越人類的水平,平均得分爲1.46,在所有世界模型訓練的智能體中是最高的。

DIAMOND在某些遊戲中的表現尤其好,要求智能體能夠捕捉到細節,比如《阿斯特里克斯》、《打磚塊》和《公路賽跑者》。

參考資料:

https://diamond-wm.github.io/

https://x.com/op7418/status/1845152731901853970

https://the-decoder.com/ai-model-simulates-counter-strike-with-10-fps-on-a-single-rtx-3090/