一文搞懂 LLM 推理加速的關鍵,從零實現 KV 緩存!

KV 緩存(KV cache)是讓大模型在生產環境中實現高效推理的關鍵技術之一。本文將通過通俗易懂的方式,從概念到代碼,手把手教你從零實現 KV 緩存。

Sebastian Raschka 此前已推出多篇關於大模型構建的深度教程,廣受讀者歡迎。本篇內容原計劃收錄於其著作《從零構建大模型》,因篇幅所限未能納入,此次借作者養傷期間整理推出,以迴應衆多讀者的來信請求,也作爲其下一篇研究型文章發佈前的精彩預熱。快來一起了解一下吧!

什麼是 KV 緩存?

想象一下,一個大模型(LLM)正在生成文本。比如說,模型接收到的提示詞是 “Time”。你可能已經知道,LLM 是一次生成一個詞(或 token)的,如下圖所示,它可能經歷如下兩個生成步驟:

圖示展示了 LLM 是如何逐步生成文本的,每次僅生成一個 token。從 “Time” 開始,生成 “flies”;接着模型會重新處理整個序列 “Time flies”,再生成 “fast”。

但你也許注意到了,模型每次都要重新處理完整的上下文信息(如 “Time flies”),這就帶來了重複計算的問題。如下圖所示:

在這張圖中可以看到,每次生成新 token(比如 “fast”)時,模型都重新對上下文 “Time flies” 進行編碼。由於沒有緩存中間的鍵和值向量的狀態,模型每次都必須重新處理整個序列。

在我們實現文本生成函數時,我們通常只使用每個步驟中最後生成的 token。但上述可視化揭示了一個概念層面上的主要低效之處:重複計算。這個問題在深入關注注意力機制本身時會更明顯。

如果你對注意力機制感興趣,可以參考我寫的《從零構建大模型》一書中的第三章。

接下來這張圖展示了注意力機制中的一部分計算過程,這是大模型的核心之一。圖中,輸入的 token(比如 “Time” 和 “flies”)被編碼爲三維向量(真實情況中維度會更高,這裏爲了圖示簡潔而簡化了)。矩陣 W 是注意力機制的權重矩陣,它們將這些輸入轉換爲鍵、值和查詢向量。

下圖展示了帶有突出顯示的鍵和值向量的基本注意力分數計算的一個摘錄:

這張圖展示了模型是如何通過學習到的 W_k 和 W_v 矩陣,將每個 token(例如 “Time” 和 “flies”)的嵌入映射爲對應的鍵和值向量的。

如前所述,LLM 每次生成一個 token。比如在生成了 “fast” 之後,下一個提示詞就變成了 “Time flies fast”。如下圖所示:

這張圖展示了每次生成新 token(比如 “fast”)時,模型會重新計算先前 token(“Time” 和 “flies”)的鍵和值向量,而不是複用它們。這種重複計算清晰地揭示了在自迴歸解碼過程中不使用 KV 緩存的低效。

通過比較前兩張圖可以發現,對於前兩個 token,其鍵和值向量在每一輪生成中都是完全相同的。每次都重新計算這些內容顯然是沒有必要的,純屬浪費計算資源。

因此,KV 緩存的理念是實現一個緩存機制,把前面已經算好的鍵和值向量存儲下來,供之後的生成步驟重複使用,從而避免這些無意義的重複計算。

LLM 如何生成文本(有無 KV 緩存的區別)

在前一節介紹了 KV 緩存的基本概念後,我們來稍微深入一點,在講具體代碼實現前,先看看實際生成過程中出現的差異。

假設我們要生成 “Time flies fast” 這段文本,如果沒有 KV 緩存,大致流程是這樣的:

每生成一個新詞,模型都會重新處理前面的所有詞,比如每次都要重新計算 “Time” 和 “flies” 的信息。這就造成了明顯的重複計算

KV 緩存的作用就是解決這個問題——把之前已經計算過的鍵和值向量存下來,以後就不用再算了:

下表總結了不同階段的計算與緩存過程:

這裏的好處是,“Time” 只計算了一次,但複用了兩次;“flies” 也只計算了一次,複用了一次。(這個例子用的是很短的文本,爲了方便說明。但直觀來看,文本越長,能複用的鍵和值向量就越多,生成速度也會提升得越明顯。)

下圖展示了在第 3 步生成時,使用和不使用 KV 緩存兩種情況下的對比效果。

比較有和沒有 KV 緩存的文本生成。在上圖(沒有緩存):每次生成都重新計算所有 token 的鍵和值向量,效率低;下圖(有緩存):只計算當前新 token 的信息,其他的都直接從緩存中取出來,速度快了不少。

所以,如果你想在代碼中實現 KV 緩存,核心思路其實很簡單:正常計算值和 鍵向量後,把它們存儲起來,下一次生成時直接拿來用就行。接下來的部分就會用代碼例子具體演示這個過程。

從零開始實現 KV 緩存

實現 KV 緩存的方法有很多,主要思想在文本生成的每一步中,我們只對**新生成的 token **計算鍵和值,而不是把所有的 token 都重新計算一遍。

在這裏,我選擇了一種簡單的方法,強調代碼的可讀性。我認爲直接瀏覽代碼更改以瞭解其實現方式是最簡單的。

我在 GitHub 上分享了兩個文件,它們都是獨立的 Python 腳本,從零實現了一個 LLM 的簡化版——一個帶 KV 緩存,一個不帶:

如果你想查看跟 KV 緩存相關的代碼修改,有兩種方式你可以選擇:

a. 打開 gpt_with_kv_cache.py 文件,查找標註爲 # NEW 的部分,那裏標記了新增或改動的代碼段;

b. 你也可以用任意一款文件對比工具,對這兩個代碼文件進行差異比較,直觀查看具體修改了哪些地方。

另外,下面幾個小節會對實現細節做一個簡要梳理和說明。

1. Registering the Cache Buffers

在 MultiHeadAttention 的構造函數中,我們添加了兩個非持久性的緩存變量:cache_k 和 cache_v,用於在多步生成中保存連接起來的鍵和值。

self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)

2. 前向傳遞中使用 use_cache 標誌

接下來,我們擴展 MultiHeadAttention 類的 forward 方法,讓它接受一個名爲 use_cache 的參數:

def forward(self, x, use_cache=False):
    b, num_tokens, d_in = x.shape
    keys_new = self.W_key(x)  # Shape: (b, num_tokens, d_out)
    values_new = self.W_value(x)
    queries = self.W_query(x)
    #...
    if use_cache:
        if self.cache_k is None:
            self.cache_k, self.cache_v = keys_new, values_new
        else:
            self.cache_k = torch.cat([self.cache_k, keys_new], dim=1)
            self.cache_v = torch.cat([self.cache_v, values_new], dim=1)
        keys, values = self.cache_k, self.cache_v
    else:
        keys, values = keys_new, values_new

這段代碼存儲和檢索鍵和值實現了 KV 緩存的核心思想。

存儲

具體來說,在通過 self.cache_k is None: ``..., 初始化緩存之後,我們分別通過 self.cache_k = torch.cat(...) 和 self.cache_v = torch.cat(...) 將新生成的鍵和值添加到緩存中。

檢索

當緩存中已經存好了前面幾步的鍵和值,就可以直接通過 keys, values = self.cache_k, self.cache_v 取出使用。

這就是 KV 緩存最核心的存儲和檢索機制。接下來的第 3 和第 4 節會補充一些實現上的細節。

3. 清空緩存

在生成文本時,我們必須記得在兩次獨立的文本生成調用之間,重置鍵和值的緩存。否則,新輸入的查詢會關注到上一次序列遺留的過時緩存,導致模型依賴無關的上下文,輸出混亂無意義的內容。爲避免這種情況,我們在 MultiHeadAttention 類中添加了一個 reset_kv_cache 方法,以便在稍後的文本生成調用之間使用:

def reset_cache(self):
    self.cache_k, self.cache_v = None, None

4. 在完整模型中傳播 use_cache

在前面爲 MultiHeadAttention 添加完緩存功能後,接下來我們要修改整個  GPTModel 類,確保緩存機制貫穿整個模型。

首先,我們在模型中添加一個用於記錄標記索引位置的計數器:

self.current_pos = 0

這是一個簡單的計數器,用來記錄當前生成過程中,已經緩存了多少個 token。

然後,我們將一行代碼的塊調用替換爲一個顯式的循環,並在每個 TransformerBlock 中傳遞 use_cache:

def forward(self, in_idx, use_cache=False):
    # ...
    if use_cache:
        pos_ids = torch.arange(
            self.current_pos, self.current_pos + seq_len,            
            device=in_idx.device, dtype=torch.long
        )
        self.current_pos += seq_len
    else:
        pos_ids = torch.arange(
            0, seq_len, device=in_idx.device, dtype=torch.long
        )
    pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
    x = tok_embeds + pos_embeds
    # ...
    for blk in self.trf_blocks:
        x = blk(x, use_cache=use_cache)

如果我們將 use_cache=True,上面會發生什麼?我們從  self.current_pos 開始並計數 seq_len 步。然後,增加計數器,以便下次生成時繼續接着上次的位置。

self.current_pos 跟蹤的原因是新查詢必須直接跟在已經存儲的鍵和值之後。如果不使用計數器,每個新步驟都會再次從位置 0 開始,因此模型會將新 token 視爲與之前的 token 重疊。(或者,我們也可以通過 offset = block.att.cache_k.shape[1] 來跟蹤。)

爲了讓 TransformerBlock 支持這個邏輯,我們還要對它稍作修改,以接收 use_cache 參數:

def forward(self, x, use_cache=False):
    # ...
    self.att(x, use_cache=use_cache)

最後,爲了方便,我們還給 GPTModel 添加了一個模型級別的重置,以便一次性清除所有塊緩存,方便我們使用:

def reset_kv_cache(self):
    for blk in self.trf_blocks:
        blk.att.reset_cache()
    self.current_pos = 0

5. 在生成中使用 KV 緩存

在完成了對 GPTModel、TransformerBlock 和 ``MultiHeadAttention 的修改之後,下面是在文本生成函數中實際使用 KV 緩存的方法:

def generate_text_simple_cached(
        model, idx, max_new_tokens, use_cache=True
    ):
    model.eval()
    ctx_len = model.pos_emb.num_embeddings  # max sup. len., e.g. 1024
    if use_cache:
        # Init cache with full prompt
        model.reset_kv_cache()
        with torch.no_grad():
            logits = model(idx[:, -ctx_len:], use_cache=True)
        for _ in range(max_new_tokens):
            # a) pick the token with the highest log-probability 
            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
            # b) append it to the running sequence
            idx = torch.cat([idx, next_idx], dim=1)
            # c) feed model only the new token
            with torch.no_grad():
                logits = model(next_idx, use_cache=True)
    else:
        for _ in range(max_new_tokens):
            with torch.no_grad():
                logits = model(idx[:, -ctx_len:], use_cache=False)
            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
            idx = torch.cat([idx, next_idx], dim=1)
    return idx

需要特別注意的是:在帶緩存的情況下,我們通過:logits = model(next_idx, use_cache=True) 將最新生成的 token 傳入模型。

而如果沒有緩存,就需要在每輪都重新輸入整個序列 logits = model(idx[:, -ctx_len:], use_cache=False) 因爲模型此時沒有任何中間狀態需要複用。這個區別正是 KV 緩存帶來的核心性能優勢。

簡單的性能對比

在瞭解了 KV 緩存的原理後,接下來你自然要問:它在實際中到底有多大用?

爲了驗證,我們可以運行前面提到的兩個 Python 腳本,分別測試不帶緩存和帶緩存的實現。這兩個腳本會使用一個參數量爲 124M 的小型 LLM 以生成 200 個新 token(給定一個 4 個 token 的提示  "Hello, I am" 以開始)。

運行步驟如下:

pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/requirements.txt
python gpt_ch04.py
python gpt_with_kv_cache.py

在一臺搭載 M4 芯片的 Mac Mini(CPU) 上,結果如下:

所以我們可以看到,即使是一個小型的 124 M 參數模型和一個簡短的 200 token 序列長度,我們也已經獲得了大約 5 倍的速度提升。(注意,這個實現優先考慮了代碼的可讀性,並沒有針對 CUDA 或 MPS 等運行時速度環境進行優化——如果要進一步提速,需要預分配張量,而不是在每一步都重新創建和連接它們)

注意:無論是否使用緩存,模型目前生成的文本都是 “胡言亂語”,輸出文本示例:

Hello, I am Featureiman Byeswickattribute argue logger Normandy Compton analogous bore ITVEGIN ministriesysics Kle functional recountrictionchangingVirgin embarrassedgl ...

這段輸出是模型生成的 “胡言亂語”(gibberish),也就是說,看起來像英文,但並沒有真實的語義或邏輯。

這是因爲我們還沒有對模型進行訓練。下一章會講訓練模型,訓練好後你可以在推理階段使用 KV 緩存來生成連貫的文本(不過 KV 緩存只適合用於推理階段)。這裏我們用的是未經訓練的模型,目的是讓代碼更簡單。

更重要的是,gpt_ch04.py 和 gpt_with_kv_cache.py 的實現產生了完全相同的文本。這說明 KV 緩存的實現是正確的 —— 要做到這一點並不容易,因爲索引處理稍有差錯,就會導致生成結果出現偏差。

KV 緩存的優缺點

隨着序列長度的增加,KV 緩存的優勢和劣勢也會變得更加明顯:

優勢:計算效率大幅提升。 如果沒有緩存,步驟 t 中的注意力必須將新查詢與 t 個之前的鍵進行比較,因此累積工作量呈二次方增長,O(n²)。有了緩存,每個鍵和值只計算一次,然後重複使用,將每步的總複雜度降低到線性,O(n)。

劣勢:內存使用呈線性增長。 每個新標記都會附加到 KV 緩存中。對於長序列和更大的 LLM,累積的 KV 緩存會變得更大,這可能會消耗大量的(GPU)內存,甚至達到不可接受的程度。作爲一種解決方法,我們可以截斷 KV 緩存,但這會增加更多的複雜性(但 again, it may well be worth it when deploying LLMs.)

一種常見的做法是截斷緩存,丟棄最早的部分,但這又會增加額外的實現複雜度。(不過在生產環境中,這種取捨通常是值得的。)

優化 KV 緩存的實現

上文中介紹的 KV 緩存實現方式,主要側重概念清晰和代碼可讀性,非常適合教學用途。

但如果你想在實際項目中部署(尤其是模型更大、文本更長的情況下),就需要針對運行效率、顯存使用等方面進行更加細緻的優化。

提示 1:預分配內存

與其在每一步都反覆連接張量,不如根據預期的最大序列長度,提前分配好足夠大的張量空間。這樣可以穩定內存使用,減少開銷。僞代碼如下:

# Example pre-allocation for keys and values
max_seq_len = 1024  # maximum expected sequence length
cache_k = torch.zeros(
    (batch_size, num_heads, max_seq_len, head_dim), device=device
)
cache_v = torch.zeros(
    (batch_size, num_heads, max_seq_len, head_dim), device=device
)

在推理過程中,我們隨後可以直接寫入這些預先分配的張量的切片。

提示 2:通過滑動窗口截斷緩存

爲了防止 GPU 內存爆炸,我們可以實現一個帶有動態截斷的滑動窗口方法。通過滑動窗口,我們只在緩存中保留最後的 window_size 個標記:

# Sliding window cache implementation
window_size = 512
cache_k = cache_k[:, :, -window_size:, :]
cache_v = cache_v[:, :, -window_size:, :]

實際優化效果,可以在 GitHub 中 gpt_with_kv_cache_optimized.py 文件中看到。

在配備 M4 芯片(CPU)的 Mac Mini 上,對於 200 個 token 的生成和等於 LLM 上下文長度的窗口大小(以保證相同的結果,從而進行公平比較),代碼運行時間如下:

不太幸運的是,在 CUDA 設備上這些提速優勢會消失。由於這個模型體積很小,設備之間的數據傳輸和通信開銷反而抵消了 KV 緩存帶來的性能提升。

總結

儘管緩存引入了額外的複雜性和內存考慮因素,但在生產環境中,效率的顯著提升通常值得這些權衡。

需要注意的是,本文的重點在於講清楚原理,因此優先考慮了代碼的清晰度和可讀性,而非運行效率。而在真實項目中,爲了更高效地利用資源,往往還需要進行一些實用的優化,比如預分配內存、應用滑動窗口緩存來有效控制內存增長等。

希望這篇文章對你有所啓發。

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/pRnmc4lB7PW54HPvQXWk7w