現代 LLM 基本技術整理

0 開始之前

本文從 Llama 3 報告出發,基本整理一些現代 LLM 的技術。'基本',是說對一些具體細節不會過於詳盡,而是希望得到一篇相對全面,包括預訓練,後訓練,推理,又能介紹清楚一些具體技術,例如 RM,DPO,KV Cache,GQA,PagedAttention,Data Parallelism 等等的索引向文章。由於東西比較多,且無法詳盡細節,所以推薦大家二次整理爲自己的筆記。

本文的主要參考是 Llama Team 的 The Llama 3 Herd of Models 報告原文,以及沐神迴歸 B 站新出的論文精讀系列。同時也包括一些知乎的優秀文章。

1 Intro

Illustration of the overall architecture and training of Llama 3Overview of the Llama 3 Herd of models.

1.1 現代基礎模型訓練的主要階段

(a)預訓練階段(pre-training stage):算法相對直接,一般是用大量的數據去做下一個詞的預測(next-word prediction)。

(b)後訓練階段(post-training stage):算法比較豐富,包括 SFT,RLHF,DPO 等等。任務上看,包括讓模型做一些指令跟隨的任務(instruction following),將模型偏好對齊到人類喜好上(align with human preferences),或者提高模型在特定任務的能力,例如 code,math,roleplay 等等。

從過去的模型看,基本上可以認爲 GPT1,2,3 都是在做 pre-training,而 InstructGPT 和 RLHF 則是在做 post-training。以上是較爲籠統的介紹。

1.2 現代基礎模型訓練的關鍵

Meta:We believe there are three key levers in the development of high-quality foundation models: data, scale, and managing complexity.

meta 認爲現代基礎模型訓練的關鍵是:data, scale, and managing complexity。

(a)關於 data ,Llama 系列有堆數據的傳統:相較於 Llama 2 1.8T 的預訓練語料,Llama 3 的預訓練語料堆到了 15T 的 multilingual tokens。

沐神:15 個 T 可能是目前在公有的網絡上面,能夠抓到的文本數據的一個大概的上限,這個'上限'的意思是指,與其再找一些增量的數據,不如去調整現有的數據的質量。

(b)關於 scale,Llama 3.1 提供了 8B,70B,405B 三個規模。每個規模的性能差異可參考下面的 benchmark。

(c)關於 managing complexity,複雜度管理,說白了即 Llama 3 的算法相對簡單。Llama 3 選擇了一個標準的稠密 Transformer 模型架構,只進行了少量調整,而沒有選擇 MOE。後訓練方面,Llama 3 採用了 SFT、RS 和 DPO,即一套'相對簡單'的過程,而不是更復雜的 RLHF 算法,因爲後者往往穩定性較差且更難以擴展。這些都屬於 design choice。2,3 章會詳細介紹相關技術。

1.3 benchmark 表現

Llama 3 各規格模型的 benchmark 表現如下。簡要介紹其中的 MMLU 和 IFEval。

Performance of finetuned Llama 3 models on key benchmark evaluations.

(a)MMLU 系列 :類似於各種考試裏面的選擇題,只是主要考察模型的知識面(背答案)。

Question: Glucose is transported into the muscle cell:

Choices:
A. via protein transporters called GLUT4.
B. only in the presence of insulin.
C. via hexokinase.
D. via monocarbylic acid transporters.

Correct answer: A

原版 MMLU 是比較老的 benchmark,存在大家 overfit 的可能性。MMLU-Pro 相對更新一些,可以看到在 MMLU-Pro 上,8B,70B,405B 的差距相當大,說明參數規模和內化到權重中的知識量還是非常相關的。

(b)IFEval :IF 即 Instruction Following,考察模型對指令的理解和遵循能力。原文見:IFEval Dataset | Papers With Code[1]。

IFEval 示例

在 IFEVAL 上,8B 和 70B 的差距還是很明顯的(80.4/87.5),而 70B 和 405B 的差距已經不明顯了(87.5/88.6)。說明參數規模到達一定程度後,再想通過擴大規模來提 IF 能力,可能會逐漸不顯著。

(c)剩下的 benchmark 則偏垂直一些,分別包含了 Code,Math,Reasoning,Tool use,Long context,Multilingual,可參見報告原文。

補充:上述評估集既然都有 overfit 和 leaking 的風險,那還有沒有其他的 benchmark 呢?當然,比如 LiveBench 這種 monthly 更新的 benchmark,LiveBench[2]。不過,天底下是沒有完美的 benchmark 的,尤其是對於具體業務而言。

總體上看,8B 和 70B 在各方面差距都還是比較明顯,但 70B 和 405B 在以上的評估集中,則差異相對小一些。405B 的推理和訓練都比較慢,一般情況下,70B 算是複雜應用的首選。如果特別複雜,再考慮 405B,畢竟性價比還是會差一些。值得一提的是,Llama 3.1 70B 在 IFEval 上接近 Claude3.5 sonnet 的水準。

2 Pre-Training

Meta:Language model pre-training involves: (1) the curation and filtering of a large-scale training corpus, (2) the development of a model architecture and corresponding scaling laws for determining model size, (3) the development of techniques for efficient pre-training at large scale, and (4) the development of a pre-training recipe. We present each of these components separately below.

上文比較籠統地說明了 Pre-Training 的要點。

2.1 Pre-Training Data

預訓練數據處理的要點包括 de-duplication methods and data cleaning mechanisms,即去重和清洗,如果做得不好,質量會很差。具體報告中的 Web Data Curation 章節提到了以下內容:

a)PII and safety filtering:報告提到預訓練數據中移除了包含 PII(personally identifiable information,關於人的身份信息,隱私信息)和成人內容的域名。但具體是什麼一個標準來錨定該數據是否屬於 PII 和成人內容,未給出示例一類的說明,所以大概率是混了一些進去的。

b)Text extraction and cleaning:由於 web data 是 raw HTML content,所以 Llama 構建了一個 parser 來解析各類文檔。有趣的觀點是,報告認爲 Markdown 對模型的性能有害,因此刪除了所有 Markdown marker。但挪掉之後具體怎麼做的,未加說明。

(c)De-duplication:Llama 使用了三個級別的去重,URL,document, and line level。具體來說,URL 去重即保留每個 URL 對應頁面的最新版本。document 級別則在整個數據集上採用了 global MinHash 來去除近似重複的文檔。line level 的具體做法則是按照每 30M 的 documents 進行搜索,去除其中出現超過 6 次的文本行。

(d)Heuristic filtering:啓發式的過濾。包括 n-gram 的過濾,如果 n 比較長,重複較多,則把該行去掉,典型的例子是 logging 文本。也包括危險詞的過濾,如果一個網頁的 dirty word 太多,則去掉。報告還提到使用了基於 token-distribution Kullback-Leibler divergence(KL 散度)的方法來過濾過於奇葩的數據。即如果一個文檔和其他文檔算 KL 的距離差太遠的話,就把該文檔標記爲奇怪的文檔去掉。

KL 散度的概念比較常用,是用於衡量兩個概率分佈之間的差異程度。定義爲:

(e)Model-based quality filtering:基於模型的分類。比如 fasttext 和基於 Llama 2 訓練的 Roberta-based classifiers,分類包括分高質量 or 低質量,也可以是打領域 tag 等等。

(f)Code and reasoning data and Multilingual data:也是一些特定數據的抽取 pipeline,花錢花人力做的一些工作。

數據配比確實相當重要,且是實驗性較強的工作(煉丹),燒錢燒時間出成果。報告中提到了 Knowledge classification 和 scaling law 的一些實驗。

(a)Knowledge classification. 即使用一個分類器劃分數據的類別,例如客觀知識類,娛樂八卦類,成人內容類...... 娛樂八卦類的數據對模型就不太好,分類後就可以讓這類數據少來一些。

**(b)Scaling laws for data mix. ** 即多做不同配比的實驗,看指標變化。稍詳細一點說,是在不同的小模型上做不同的配比實驗,然後用來預測更大 scale 的最優配比。

總結,最後的預訓練數據大概是 50% 的 general knowledge,25% 的 mathematical and reasoning 數據,17% 的 code 數據,8% 的多語言數據。

報告發現,在少量高質量的 code 和 math 的數據上做一下學習率的退火,能夠提升預訓練模型的 benchmark performance。這很符合直覺,即'考前多背一下題目考的會更好一些'。(?)

具體來說,是在大量通用數據的訓練完成後,用一小撮高質量的特定領域數據繼續訓練,同時將學習率慢慢降低。Llama 3 在預訓練的最後 40M token 採取了將 LR 線性退火到 0 的方法,同時配合數據配比調整。最後 8B 模型在 GSM8k 和 MATH 驗證集上提升不錯,但對 405B 的模型提升卻可以忽略不計,說明該參數規模的模型也許不需要 specific in-domain 的訓練樣本來提升性能。

同時,報告提到可以使用退火來評估 domain-specific 的小數據集的質量,比做 Scaling Law 的相關實驗效率更高。

2.2 Model Architecture

總體上看,Llama 3 相較於 2 做了以下改動:GQA,面向一個 sequence 內部的不同文檔的 attention mask,128K tokens 的詞表,RoPE 的調整。

Llama 3 使用標準的 Dense Transformer 架構,性能的提高主要來自於數據質量和多樣性的改進,以及訓練規模的增加(很喜歡說一些實話)。當然,和 Llama 2 相比還算有一些改變:

例如上述提到的 Grouped Query Attention:GQA 用於加速推理,節省解碼的內存。對於 70B 及以上的模型,幾乎是必須用的技術。GQA 涉及到 KV Cache,KV Cache 涉及到基本的推理過程,因此從推理開始寫。

(a)基本推理過程

LLM 推理過程

1、輸入的 Text,根據詞表被切分成 n 個 token/token ids,n 個 token ids 被映射爲 n 個 embedding 向量,即 1 個 embedding 矩陣;

2、embedding 矩陣通過 L 個 transformer block(內部有各種注意力計算和 FFN 層),在最後一層輸出一個與輸入形狀相同的 embedding 矩陣;

3、輸出的 n 個 embedding 再過一個線性層 lm_head,該線性層的輸出形狀和詞表大小一致。線性層輸出再接一個 softmax,就得到了 next token 的概率分;

4、隨後再根據解碼策略採樣即可。Next token 被算出來後,加入輸入的 token 序列(長度爲 n+1),繼續計算第 n+2 個 token,這就是自迴歸。

(b)KV Cache

由於在計算第 n+1 個 token 時,L 個 Transformer block 的中間結果是可以被保存下來的,所以也許可以複用它們。我們把第  層,第  個 token 的輸出記爲  。不難發現,需要計算第 n+2 個 token 時,有很大一部分中間結果和計算 n+1 時相同。可表示爲:

輸入 token 序列:  與輸入 token 序列爲 的中間結果 一致,所以我們利用緩存來可以減少大量的計算。

因此,LLM 推理過程分爲 Prefill 和 Decode 兩個階段,Prefill 階段會對 Prompt 中所有的 token 做並行計算,得到 Prompt 中所有 Tokens 的 KV Cache 以及計算得到首 Token。Prompt Tokens 計算得到的 KV Cache 會保存下來,留給 Decode 階段複用;

Decode 階段是一個自迴歸過程,每解碼一個新的 Token,都需要用到所有之前計算得到的 KV Cache 來計算當前 query token 的 Attention。因此,當輸出長度越來越大或者 context 很長時,KV Cache 將會佔用大量的顯存。

本段內容以及下圖引用自:[KV Cache 優化] MQA/GQA/YOCO/CLA/MLKV 筆記: 層內和層間 KV Cache 共享 [3]。

所以現在也存在 prefix caching 的概念,簡單地說,就是把特定前綴的 KV Cache 緩存起來保留備用。對於指令複雜,prompt 較長的任務,或者多輪對話場景非常有效。vllm 已經可以很方便地開啓 prefix caching,對長輸入短輸出的固定任務優化較好。KV Cache 有大量的方向可以做,是 LLM 推理優化的核心之一。

(c)GQA,Grouped Query Attention

GQA 是從模型層面降低 KV Cache 大小的手段之一。聊 GQA 之前的慣例是聊 MHA 和 MQA。

MHA,即 Multi Head Attention,多頭注意力,Transformer 原文的 attention 形式。如下圖所示,MHA 中每個 Query 向量都會對應一個 Key,Value,其輸出會把每個注意力頭的輸出拼接起來。因此也會存較多的 KV Cache。

MQA,即 Multi Query Attention。如下圖所示,MQA 的思路比較直接,就是讓每個注意力頭共用一個 KV,很顯然,相較於 MHA,KV Cache 的佔用直接減少到了 1/head_num。不過,由於結構的修改和 Attention 部分的參數量降低,模型效果也必然受到影響。MQA 似乎還是有些暴力。

因此出現了平衡的版本,即 GQA,Grouped Query Attention。和圖中一致,即將 Queries 進行分組,每組對應一個 KV,用一種折中的方法實現了減少計算量和 KV Cache 大小。

首先應該聊聊經典的正弦編碼。上文在 LM 的一次推理過程中提到,token 會映射爲 embedding 向量,在經典 transformer 的結構中,這個 embedding 向量是詞嵌入向量(實體的'孤立'語義)和位置編碼(實體間的'關聯'語義)的疊加。如何表徵 token 的位置,則是位置編碼研究的問題。

《動手學深度學習 PyTorch 版》:全要點筆記 [4],經典 transformer 架構的位置編碼是正弦編碼。

正弦編碼存在一些可能的問題,比如對相對位置的表示較弱。RoPE 則嘗試在解決這些問題。

2.3 Scaling Laws

最初的形式

簡單來說,就是可以用小模型的一些實驗結果來預測更大模型的結果。Scaling Law 由 OpenAI 提出,有兩個大家熟知的結論:

1、對於 Decoder-only 的 LM,計算量  ,模型參數量  ,數據大小  ,三者滿足  。其中  的單位是 Flops,  是 token 數;

2、模型的最終性能主要與 ,, 相關,與模型的具體結構(高矮胖瘦)相關性不高。

-** Llama 報告的內容?**

之前的 Scaling Law 的預測方法主要是從 next-token prediction loss(訓練時的 validation loss)出發的,但這個 loss 和具體的任務表現不一定是絕對相關的。因爲 next-token prediction loss 並不和具體任務表現(例如數學)絕對掛鉤。所以 Llama 3 在做 Scaling Law 的實驗時,做了一個 two-stage 的方法:

step1:預測模型在具體下游任務上的 NLL loss,這個 NLL loss 還是和 compute(FLOPs)掛鉤,成函數關係;

step2:利用 Scaling Law 將 step1 中的 loss 和具體的 task accuracy 關聯起來。例如 1.4 的 NLL loss 對應 0.25 的 accuracy,1.2 的誤差對應 0.95 的 accuracy,所以這個規律和具體也可以解耦,得到對於一個具體 benchmark 的 Scaling Law 曲線,x,y 軸分別爲 loss 和 accuracy。

具體可見下圖。ARC Challenge benchmark 是一個做推理的多選題任務集。發現 Scaling Law 的預測還是挺準的。不過要注意,不同任務的 benchmark 曲線可能也長得不一樣。

2.4 Training Recipe

Llama 3 的預訓練策略主要由三步構成,分別爲:(1) initial pre-training, (2) long-context pre-training, and (3) annealing.

Initial Pre-Training

主要是一些細節。簡單翻譯下。我們使用 AdamW 對 Llama 3 405B 進行預訓練,peak learning rate 爲  ,linear warm up 爲 8000 步,以及 cosine learning rate(預計在 1,200,000 步中衰減到  )。爲了提高訓練穩定性,我們在訓練初期使用了較小的批次大小,並隨後增加了批次大小以提高效率。具體來說,我們使用的 initial batch size 爲 4M 的 tokens,長度爲 4096 的序列,在訓練了 252M tokens 後後將這些值加倍,8M sequences of 8,192 tokens。在訓練了 2.87 T token 後,再次將加倍到 16M。我們發現這種訓練配方非常穩定:我們觀察到的損失峯值(loss spikes)很少,並且不需要進行干預來糾正模型訓練的偏差。

同時也做了一些 data mix 的調整。比如多拿非英語數據,數學數據,更多的最新網絡數據等等。

Long Context Pre-Training

簡單翻譯下。在預訓練的最後階段,我們對 long sequences 進行訓練,以支持最多 128K tokens 的 context 窗口。我們之前沒有對 long sequences 進行訓練,因爲在 self-attention layers 中的計算量隨 sequence length 呈平方增長。我們逐步增加支持的 context length,進行 pre-training,直到模型成功適應了增加的 context length。

我們通過以下兩點評估成功的適應性:(1) 模型在 short-context evaluations 中的表現是否完全恢復,具體來說可能就是 MMLU 這些評測集;(2) 模型是否能完美解決長度達到該值的'needle in a haystack' 任務(大海撈針任務)。

在 Llama 3 405B 的 pre-training 中,我們逐步在六個階段增加了 context length,從最初的 8K context 窗口開始,最終達到 128K context 窗口。這個 long-context pre-training 階段使用了大約 0.8T tokens。

Annealing

見 2.1 Pre-Training Data,同退火數據(Annealing Data)一節的內容。

3 Post-Training

下圖很清晰地概括了 Llama 3 的後訓練思路,要素包括 RM,SFT,RS,DPO。本章會一一介紹。後訓練是業內絕大多數 NLPer 做的事情。

Illustration of the overall post-training approach for Llama 3.

Llama 3 後訓練策略的 backbone 是一個 Reward Model 和一個 Language Model。首先利用人類標註的偏好數據,在 pre-trained checkpoint 之上訓練一個 RM。然後,對 pre-trained checkpoint 做 SFT,之後用 DPO 做對齊,作爲本輪的最佳模型,進入下輪迭代,參與 Rejection Sampling 過程。

注意到,訓練是迭代式的,即有多輪方法相同的訓練。具體來說,Llama 3 進行了 6 輪的循環。在每個週期中,收集新的偏好標註和 SFT 數據,並從最新的模型中採樣合成數據。

3.1 Reward Model

紅框部分是 RM 的訓練路徑

首先應該簡介一下 Reward Model(RM)。Reward Model 是一種通過” 偏好排序數據 “(A>> B > C = D)訓練得到的模型,能夠給一段文本一個偏好性(例如安全性,擬人性,或者某種綜合性的偏好)的分數。這個分數是一個標量,體現了人類的某種偏好。

而且,A > B 可能不僅是 A > B,也可能是遠好於,稍好於,這個其實也能在損失函數里體現出來(margin loss),即 Llama 2 論文中  的部分:

Preference Data 構建

Llama 詳細講解了 Preference Data 的構建過程。大概是這樣幾個 step:

step 1. 使用不同的數據配比和訓練策略訓練出多個 for annotation 的模型。部署多個不同的模型,針對一個具體的 user prompt 採樣出兩個來自不同模型的 response。

step 2. 標註同學會按照 “好多少” 的標準,對 response 對進行打分,包括四個等級:significantly better, better, slightly better, or marginally better。

step 3. 偏好標註好後,鼓勵標註同學去 “edit”chosen response,即他們上一步已經選擇了更好的那個答案,改的更好。既可以直接修改 chosen response 本身,也可以修改 prompt 來 refine 這些數據。

所以,最後有一部分偏好數據是有三個 ranked response 的,即 edited > chosen > rejected。最後,得到了這樣的數據構成。

訓練

訓練和 Llama 2 類似。但是 Llama 3 反而在損失函數中去掉了 margin loss,即上文的  ,因爲觀察到在數據規模擴大後,margin 的改進效果逐漸減弱,不如簡化。

3.2 SFT

SFT 大概是大多數同學接觸 LLM 訓練的首選。SFT,即使用標準的交叉熵損失(standard cross entropy loss),同時 mask prompt 部分的 loss,訓練 target tokens 的過程。

SFT Data 構建

SFT 數據有很多個來源:Rejection Sampling 的數據,針對特定能力的合成數據,少量的人工標註數據。

Rejection Sampling

Rejection Sampling 的過程,就是固定模型和 prompt,讓 LM 採樣出 K 個不同的答案,根據 RM 的 K 個不同的分數,選出最優答案。然後將該最優答案作爲 SFT 數據,做迭代式的訓練。其中,模型一般是前一輪訓練中表現最好的 checkpoint,K 則可以調整,一般是 10-30。採樣也有很多細節,涉及到 preference pair 構造,比如 rejected 可能不能無腦選最差的,這些需要實驗。

爲了提高拒絕採樣的效率,Llama 3 採用了 PagedAttention。在 PagedAttention 中,內存浪費只會發生在序列的最後一個塊中,可以很好地提升吞吐量。PagedAttention 的內存共享也是很好的優化,在 Rejection Sampling 中,多個 response 是由同一個 prompt 生成的。在這種情況下,prompt 的計算和內存可以在輸出序列中共享。這裏做一些簡單介紹。

PagedAttention

think of blocks as pages, tokens as bytes and requests as processes。

PagedAttention 也是主流推理加速框架 vLLM 之選。大家應該都學過 OS 課,瞭解虛擬內存,內存分頁管理,內存碎片的概念。PagedAttention 也是受到 OS 的啓發,認爲 KV Cache 沒有必要存儲在連續的內存中,而是像操作系統一樣,把塊的概念引入爲 “page”,byte 的概念引入爲 “token”,進程的概念引入爲 “request”。

2.2 節中我們提到,由於在計算第 n+1 個 token 時,L 個 Transformer block 的中間結果是可以被保存下來的,所以也許可以複用它們。這被稱作 KV Cache。

但是 KV Cache 非常大,需要一塊連續內存來存儲。並且,我們在接收到 sequence 之前,並不知道需要預留多少連續內存,所以只能預先分配一個最大可能長度的 cache,導致了很多浪費,這被稱爲 “內部碎片”。而由於我們給多個 sequence 分配了內存,所以剩下的內存不足以分配給新的 sequence,這一部分內存實際上也沒用了,所以也造成了浪費,這被稱爲 “外部碎片”。

PagedAttention 允許在非連續的內存空間中存儲連續的 key 和 value 。具體來說, 它將每個序列的 KV cache 劃分爲塊,每個塊包含固定數量 token 的鍵和值。因此,對於 1 個 sequence,最多會有 1 個 page 是有內存碎片的。由於按塊分配,外部碎片則徹底沒有了。這和 OS 中的分頁存儲解決的問題一致。

回到 SFT Data,最後,得到了這樣的數據構成。

訓練細節上,Llama 3 對 405B 進行微調時,學習率爲 10⁻⁵,訓練步數在 8.5K 到 9K 之間。

3.3 Rejection Sampling

見 3.2 SFT 中的 Rejection Sampling。

3.4 Direct Preference Optimization

DPO 在 SFT 之後進行,目的是對齊人類的偏好。DPO 是 RLHF 的簡化,目的是跳過複雜的 RM 訓練等過程,RLHF 是先用標註的偏好數據去訓練 RM,然後再指導 RL 的過程,而 DPO 則這把上述兩個步驟的 loss 融合到一起。

因此,DPO 的訓練數據也是人類偏好數據,格式類似於 chosen-rejected 對。DPO 的損失如下

# DPO的數據格式
{    
    'prompt''',
    'chosen''',
    'rejected'''
}

DPO 訓練細節

在訓練過程中,Llama 3 主要使用最新一批的偏好數據,這些數據是通過前幾輪對齊中表現最好的模型收集的,需要用到 RM。好處是,這些數據更好地符合每輪正在優化的 Policy Model 的分佈。所以這種 DPO 也是 Iterative 的,屬於 on-policy。

(a)第一個細節是,由於 DPO 損失函數的特點,chosen response 和 rejected response 中如果出現了一些共同的 token,則會導致相互衝突的學習目標,因爲模型需要同時增加和減少這些 token 的生成概率。所以 Llama 3 Mask 了 formatting tokens 的 loss,實驗發現這些 token 如果算 loss,可能會導致 tail repetition 和突然生成終止的 token。

(b)第二個細節是,Llama 3 給 chosen sequence 加上了一個 negative log-likelihood(NLL) loss,從 NLL loss 和標準交叉熵損失的差別上看,可以簡單把 NLL loss 理解爲 SFT loss:

加上 NLL loss 的好處是,防止 chosen response 的 log probability 下降。壞處是,chosen response 如果本身不夠好,加這個 SFT loss 可能也不太好,需要具體問題具體分析。

3.5 Data Processing and Quality Control

數據質量始終是最關鍵的。由於 Llama 3 的大部分訓練數據是模型生成的,因此需要仔細進行清洗和質量控制。這和絕大多數垂直業務模型也一致。

數據清洗(Data cleaning)

首先,數據中往往存在一些不理想的模式,Llama 3 就有過度使用表情符號或感嘆號的問題。一些非常經典的 AI 味語風也需要注意,例如 “過於喜歡滑跪” 的語氣問題,遇事不決就 “對不起” 或“我道歉”,這種樣本應該不能在數據集中太多。

數據修剪(Data pruning)

Llama 3 還應用了一些基於模型的技術來去除低質量的訓練樣本,來提升模型整體性能:

1、主題分類(Topic classification):首先,對一個小模型(如 Llama 3 8B)進行微調,使其成爲 topic classifier,例如專門用一大堆分類文本的任務數據去 SFT 一下。然後對所有訓練數據進行分類,將其分類爲粗粒度類別(如 “數學推理”)和細粒度類別(如 “幾何和三角學”)。

2、質量評分(Quality scoring):使用 Reward model 和基於 Llama 的信號爲每個樣本的質量打分。對於基於 RM 的評分,我們將得分處於 RM 評分前四分之一的數據視爲高質量數據。對於基於 Llama 的評分,就是在 Llama 3 設計了一些打分的 prompt,一般英語數據使用三個維度的評分(準確性、指令遵循性和語氣 / 表達),coding 數據則使用兩個維度的評分(錯誤識別和用戶意圖),並將獲得最高分的樣本視爲高質量數據。

最後發現 RM 評分和 Llama 評分的分歧率較高,但發現結合這兩種機制能在 meta 內部測試集中取得最佳的召回率。最終,選擇被 RM OR Llama 3 分類模型標記爲高質量的樣本。

3、難度評分(Difficulty scoring):由於還希望優先處理對模型來說更復雜的樣本,因此報告提到兩種難度評估方法對數據進行評分:Instag 和基於 Llama 的評分。對於 Instag,我們提示 Llama 3 70B 對 SFT 提示進行意圖標註,意圖越多,複雜性越高。基於 Llama 的思路和 Quality scoring 相似,給了 Llama 3 一些 prompt,基於三個維度去打分。

4、語義去重(Semantic deduplication):最後,進行語義去重。Llama 3 首先使用 RoBERTa 對完整對話進行聚類,然後在每個聚類內按質量分數 × 難度分數對其進行排序。接着,遍歷所有排序的樣本進行貪婪選擇,僅保留與當前聚類中已見樣本的餘弦相似度小於閾值的樣本。

4 Inference

首先請參考 2.2 Model Architecture 中,關於基本推理過程,KV Cache,GQA 部分的內容,同時請參考 3.2 SFT 中關於 PagedAttention 的介紹。

4.1 Parallelism

Parallelism,LLM 分佈式訓練推理的一部分,包括 Data Parallelism 和 Model Parallelism,本節做一些介紹。同樣涉及到 OS 的一些概念。

Data Parallelism

Data Parallelism,數據並行,在每個設備上,獨立接收到不同的輸入數據批次(可稱 mini-batch)並執行前向傳播,以計算該批次上的損失。在反向傳播過程中,每個設備會計算梯度,並與所有其他設備交換這些梯度。然後,使用這些梯度的平均值來更新每個設備上的模型權重,確保在下一次訓練步驟開始時,所有設備都具有相同的模型權重。

好處是加快了 batch 的訓練速度,並且能夠放下更大 batch size 的數據。壞處是,每張卡也都使用了完整的模型權重,得保證單卡能裝得下。

Data Parallelism

Model Parallelism

Model Parallelism。模型並行,包括 Tensor Parallelism 和 Pipeline Parallelism。Model Parallelism 解決的是單張卡放不下一個完整模型權重的問題,每張顯卡只放部分參數。一般來說,會按照層進行劃分參數,按層劃分一般叫 Pipeline Parallelism。如果模型的一層如果都裝不下了,同一個模型層內拆分開訓練,是 Tensor Parallelism。

好處是能放下更大的權重了,壞處是後面層的卡需要等待前面層的計算結果,所以 GPU 會有空閒狀態。反向傳播時也一樣,前面層的卡要等後面層的卡。

Llama 3 中的 Pipeline Parallelism

使用 BF16 數值表示模型參數時,Llama 3 405B 模型無法在一臺配備 8 個 Nvidia H100 GPU 的單機內完全加載到 GPU 內存中。爲了解決這一問題,Llama 3 team 使用兩臺機器(node)上的 16 個 GPU 並行進行 BF16 精度的模型推理。

在每個 node 內部,利用 NVLink 的 high bandwidth 來啓用 tensor parallelism。而在 node 之間,連接的帶寬較低,延遲較高,因此採用 pipeline parallelism(Gpipe)。

在使用 pipeline parallelism 進行訓練時,bubble 是一個主要的效率問題(詳見論文 Gpipe)。然而,在推理過程中,這並不是一個問題,因爲推理不涉及反向傳遞。因此,Llama 3 使用 micro-batching 來提高推理的吞吐量(throughput)。

Gpipe

在前向傳播過程中,GPipe 首先將每個大小爲 N 的 mini-batch 劃分爲 M 個相等的 micro-batch,並將它們通過 K 個 GPU 進行流水線處理。在反向傳播過程中,每個 micro-batch 的梯度是基於前向傳播時使用的相同模型參數計算的。在每個 mini-batch 結束時,所有 M 個 micro-batch 的梯度會被累積,並應用於所有 GPU 以更新模型參數。

micro-batching 效果

報告在 key-value cache pre-fill stage 和 decoding stage 兩個階段(見 2.2 Model Architecture 的講解)都評估了 micro-batches 的效果。在 4096 個輸入 tokens 和 256 個輸出 tokens 的情況下,報告發現,在相同的 local batch size 下,micro-batches 提高了推理的吞吐量,如下圖所示。

這些改進歸因於 micro-batches 在這兩個階段中實現了併發執行。由於 micro-batches 帶來了額外的同步點(synchronization points),導致延遲增加,但總體而言,micro-batches 仍然帶來了更好的吞吐量 - 延遲平衡(throughput-latency trade-off)。

4.2 Quantization

Quantization,量化,也是當前熱門的話題,核心手段是通過降低模型參數的精度來減少 GPU 佔用,並減少計算量。和 PagedAttention 類似,同樣可以從 OS 中找到很多相關的東西。一些常見的精度表示如下:

INT8 量化

INT 8 量化相對簡單。如圖所示的是 absmax 的 INT 8 量化,輸入是一個 FP16 的向量。假設用 absmax 對向量 [1.2, -0.5, -4.3, 1.2, -3.1, 0.8, 2.4, 5.4] 進行量化。首先需要計算該向量的最大絕對值,在本例中爲 5.4。Int8 的範圍爲[-127, 127],因此我們將 127 除以 5.4,得到縮放因子(scaling factor)23.5。最後,將原始向量乘以縮放因子得到最終的量化向量[28, -12, -101, 28, -73, 19, 56, 127]。

要恢復原向量,可以將 int8 量化值除以縮放因子,但由於上面的過程是 “四捨五入” 的,我們將丟失一些精度。

FP8 量化

Llama 3 利用 H100 GPU 的原生 FP8 支持來執行低精度推理。爲了啓用低精度推理,Llama 3 對模型內部的大多數矩陣乘法應用 FP8 量化。實現細節見下面的兩篇參考文章。特別是,對模型中前饋網絡層的大多數參數和激活值進行量化,這些部分約佔推理計算時間的 50%。其中還有一些細節:

Llama 3 沒有對模型的自注意力層中的參數進行量化。也沒有在第一個和最後一個 Transformer 層中執行量化。並且,採用了按行量化的方式,對參數和激活矩陣的每一行計算縮放因子(Scaling Factor)。如下圖所示。

量化結果

量化結果主要是兩個方面,一個是好處,即 efficiency 的提升;一個是壞處,即 accuracy 的下降。

對於 efficiency,Llama 3 針對於 4,096 input tokens and 256 output tokens 做了定量實驗,在 prefill 階段(2.2 Model Architecture 中有詳細介紹),使用 FP8 推理可將吞吐量提高多達 50%(4k->9k);在 decode 階段,也能更好地 trade off throughput-latency。

對於 accuracy,在標準 benchmark 上,即使不做上文所說的細節,FP8 推理的表現也與 BF16 推理相當。但是當 Scaling Factor 沒有上限時,模型有時會生成錯誤的響應,所以 benchmark 無法正確和充分地反映 FP8 量化的影響。於是 Llama 3 使用 FP8 和 BF16 生成了 100,000 個響應,選擇用獎勵模型的分佈來分析。從下圖可以看到,FP8 的得分幾乎沒有影響 RM 的得分分佈。

Throughput-latency trade-off in FP8 inference with Llama 3 405BReward score distribution for Llama 3 405B using BF16 and FP8 inference.

5 寫在最後

最近平時工作可以說是把腦子想 “幹” 了,所以花大概三個週末完成了這篇接近 2w 字的文章。寫完感覺有很多不足,但還是隨便找個時間發了吧。其一是,本來是打算從 Llama 3 這種優質開源模型和報告出發,進行一些知識上的梳理,結果行文時幾乎保留了論文原來的結構,導致前一個知識點到下一個知識點不夠絲滑;

其二是,由於水平不夠和 “綜合性” 考量的限制,所以對很多需要深入的知識沒有詳盡。後面幾個週末也許還會持續迭代一下本文,主要是繼續細化技術點。所以也懇請諸位指出錯誤或不足,盡情提出需要補充內容的部分。

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/eCNk7HKjZw7u7-B3moTk7Q