挑戰 Transformer!Mamba 的架構及實現(Pytorch)

**Mamba 一經出現就在人工智能界掀起波瀾,被吹捧爲 Transformer 的競爭對手。**到底是什麼讓 Mamba 在擁擠的序列建模中脫穎而出?  今天我們來詳細研究這篇論文《Mamba: 具有選擇性狀態空間的線性時間序列建模》

在介紹之前先簡要回顧一下現有的模型

Transformer: 以其注意力機制而聞名,其中序列的任何部分都可以動態地與任何其他部分相互作用,特別是具有因果注意力機制的的 Transformer,擅長處理序列中的單個元素。但是它們帶來了顯著的計算和內存成本,與序列長度的平方 (L²) 成比例。

循環神經網絡 (rnn): rnn 只考慮當前輸入和最後一個隱藏狀態,按順序更新隱藏狀態。這種方法允許它們潛在地處理無限序列長度和恆定的內存需求。但是 rnn 的簡單性是一個缺點,限制了它們記住長期依賴關係的能力。此外,rnn 中的時間反向傳播 (BPTT) 是內存密集型的,並且可能遭受梯度消失或爆炸的影響,儘管有 LSTM 等創新部分結解決了這個問題。

State Space Models(S4): 這些模型已經顯示出很好的特性。它們提供了一種平衡,比 rnn 更有效地捕獲遠程依賴關係,同時比 transformer 更高效地使用內存。

接下來 Manba 登場!

Mamba

選擇性狀態空間: Mamba 建立在狀態空間模型的概念之上,但引入了一個新的變化。它利用選擇性狀態空間,支持跨長序列更高效和有效地捕獲相關信息。

線性時間複雜度: 與 Transformer 不同,Mamba 在序列長度方面以線性時間運行。這個屬性使得它特別適合涉及非常長的序列的任務,而傳統模型在這方面會遇到困難。

Mamba 以其選擇性狀態空間的概念引入了傳統狀態空間模型的一個有趣的改進。這種方法稍微放鬆了標準狀態空間模型的嚴格狀態轉換,使其更具適應性和靈活性(有點類似於 lstm)。並且 Mamba 保留了狀態空間模型的高效計算特性,使其能夠在一次掃描中執行整個序列的前向傳遞 - 這一特性更讓人想起 Transformer。

在訓練期間,Mamba 的行爲類似於 Transformer,同時處理整個序列。而 lstm 必須一步一步地計算前向傳遞,即使所有輸入都是已知的。在推理中,Mamba 的行爲更符合傳統的循環模型,提供有效的序列處理。

先驗狀態空間模型 (ssm) 的一個關鍵限制是其剛性的、輸入不變的結構。這些模型爲整個序列使用一組固定參數(我們稱它們爲 a 和 B)。這種結構甚至比 lstm 等模型更具限制性,在 lstm 中,信號的轉換可能依賴於先前的隱藏狀態和輸入。

Mamba 則一種範式轉換,即如何計算向下一個隱藏狀態的過渡?在 Mamba 的體系結構中,轉換依賴於當前輸入,這種方法在傳統 ssm 的固定計算和循環神經網絡的輸入依賴動態性之間取得了平衡。

主要組成如下:

固定主幹: 從一個隱藏狀態到下一個隱藏狀態的轉換仍然是一個固定的計算 (由 a 矩陣定義),允許跨序列的預計算。

輸入相關轉換: 輸入影響下一個隱藏狀態 (由 B 矩陣定義) 的方式取決於當前輸入,而不是之前的隱藏狀態。與傳統 ssm 相比,這種輸入依賴性提供了更大的靈活性。

爲了滿足這種方法的計算需求,Mamba 使用了一種硬件感知算法。該算法使用掃描操作而不是卷積來循環執行計算,這樣在 gpu 上非常高效的。儘管輸入依賴轉換帶來了算法複雜性,但這種效率對於保持高性能至關重要。

Mamba 和選擇性狀態空間模型不是同義詞。Mamba 是一個使用選擇性狀態空間概念的實現。這種區別是至關重要的,因爲它突出了 Mamba 的獨特貢獻: 在保持計算效率的同時,使 SSM 框架更加靈活和響應輸入。

SRAM 和 HBM

gpu 包含兩種主要類型的內存: HBM (High Bandwidth memory) 和 SRAM (Static Random-Access memory)。HBM 雖然帶寬很高,但與更快但更小的 SRAM 相比,它的訪問時間相對較慢。Mamba 則使用 SRAM 在矩陣乘法期間進行快速訪問,這是其計算的關鍵。

計算中的主要瓶頸通常不是計算本身,而是數據在內存類型之間的移動。Mamba 通過顯著減少傳輸大量數據的需求來解決這個問題。它通過直接在 SRAM 中執行算法的關鍵部分 (如離散化和遞歸計算) 來實現,從而減少延遲。

還引入了一個融合選擇掃描層,使其內存需求與使用 flash attention 的優化 Transformer 實現相當。這一層對於保持效率至關重要,尤其是在處理模型中依賴於輸入的元素時。

結果

Mamba 代表了序列建模的重大進步,特別是在其高效使用 GPU 內存和計算策略方面。它具有高效率處理長序列的能力,使其成爲各種應用的有前途的模型,我們下面來使用 Pytorch 代碼來對其進復現。

Pytorch 復現

導入基本庫

 import torch
 import torch.nn as nn
 import torch.optim as optim
 from torch.utils.data import DataLoader, Dataset
 from torch.nn import functional as F
 from einops import rearrange
 from tqdm import tqdm
 
 import math
 import os
 import urllib.request
 from zipfile import ZipFile
 
 from transformers import AutoTokenizer
 
 torch.autograd.set_detect_anomaly(True)

設置標誌和超參數

 # Configuration flags and hyperparameters
 USE_MAMBA = 1
 DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM = 0
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

定義超參數和初始化

 d_model = 8
 state_size = 128 # Example state size
 seq_len = 100 # Example sequence length
 batch_size = 256 # Example batch size
 last_batch_size = 81 # only for the very last batch of the dataset
 current_batch_size = batch_size
 different_batch_size = False
 h_new = None
 temp_buffer = None

這裏的超參數,如模型維度 (d_model)、狀態大小、序列長度和批大小。

S6 模塊是 Mamba 架構中的一個複雜組件,負責通過一系列線性變換和離散化過程處理輸入序列。它在捕獲序列的時間動態方面起着關鍵作用,這是序列建模任務 (如語言建模) 的一個關鍵方面。這裏包括張量運算和自定義離散化方法來處理序列數據的複雜需求。

 class S6(nn.Module):
     def __init__(self, seq_len, d_model, state_size, device):
         super(S6, self).__init__()
 
         self.fc1 = nn.Linear(d_model, d_model, device=device)
         self.fc2 = nn.Linear(d_model, state_size, device=device)
         self.fc3 = nn.Linear(d_model, state_size, device=device)
 
         self.seq_len = seq_len
         self.d_model = d_model
         self.state_size = state_size
 
 
         self.A = nn.Parameter(F.normalize(torch.ones(d_model, state_size, device=device), p=2, dim=-1))
         nn.init.xavier_uniform_(self.A)
 
         self.B = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)
         self.C = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)
 
         self.delta = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)
         self.dA = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
         self.dB = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
 
         # h [batch_size, seq_len, d_model, state_size]
         self.h = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
         self.y = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)
 
 
     def discretization(self):
 
         self.dB = torch.einsum("bld,bln->bldn", self.delta, self.B)
 
         self.dA = torch.exp(torch.einsum("bld,dn->bldn", self.delta, self.A))
 
 
         return self.dA, self.dB
 
     def forward(self, x):
         # Algorithm 2 MAMBA paper
         self.B = self.fc2(x)
         self.C = self.fc3(x)
         self.delta = F.softplus(self.fc1(x))
 
         self.discretization()
 
         if DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:  
           
             global current_batch_size
             current_batch_size = x.shape[0]
 
             if self.h.shape[0] != current_batch_size:
                 different_batch_size = True
 
                 h_new =  torch.einsum('bldn,bldn->bldn', self.dA, self.h[:current_batch_size, ...]) + rearrange(x, "b l d -> b l d 1") * self.dB
 
             else:
                 different_batch_size = False
                 h_new =  torch.einsum('bldn,bldn->bldn', self.dA, self.h) + rearrange(x, "b l d -> b l d 1") * self.dB
 
             # y [batch_size, seq_len, d_model]
             self.y = torch.einsum('bln,bldn->bld', self.C, h_new)
 
             global temp_buffer
             temp_buffer = h_new.detach().clone() if not self.h.requires_grad else h_new.clone()
   
             return self.y
 
         else:  
             # h [batch_size, seq_len, d_model, state_size]
             h = torch.zeros(x.size(0), self.seq_len, self.d_model, self.state_size, device=x.device)
             y = torch.zeros_like(x)
 
             h =  torch.einsum('bldn,bldn->bldn', self.dA, h) + rearrange(x, "b l d -> b l d 1") * self.dB
 
             # y [batch_size, seq_len, d_model]
             y = torch.einsum('bln,bldn->bld', self.C, h)
 
             return y

這個 S6 的模塊,可以處理離散化過程和正向傳播。

MambaBlock 類是一個定製的神經網絡模塊,被設計爲 Mamba 模型的關鍵構建塊。它封裝了幾個層和操作來處理輸入數據。

包括線性投影、卷積、激活函數、自定義 S6 模塊和殘差連接。該塊是 Mamba 模型的基本組件,負責通過一系列轉換處理輸入序列,以捕獲數據中的相關模式和特徵。這些不同層和操作的組合允許 MambaBlock 有效地處理複雜的序列建模任務。MambaBlock 是 Mamba 核心功能。

 class MambaBlock(nn.Module):
     def __init__(self, seq_len, d_model, state_size, device):
         super(MambaBlock, self).__init__()
 
         self.inp_proj = nn.Linear(d_model, 2*d_model, device=device)
         self.out_proj = nn.Linear(2*d_model, d_model, device=device)
 
         # For residual skip connection
         self.D = nn.Linear(d_model, 2*d_model, device=device)
 
         # Set _no_weight_decay attribute on bias
         self.out_proj.bias._no_weight_decay = True
 
         # Initialize bias to a small constant value
         nn.init.constant_(self.out_proj.bias, 1.0)
 
         self.S6 = S6(seq_len, 2*d_model, state_size, device)
 
         # Add 1D convolution with kernel size 3
         self.conv = nn.Conv1d(seq_len, seq_len, kernel_size=3, padding=1, device=device)
 
         # Add linear layer for conv output
         self.conv_linear = nn.Linear(2*d_model, 2*d_model, device=device)
 
         # rmsnorm
         self.norm = RMSNorm(d_model, device=device)
 
     def forward(self, x):
         """
        x_proj.shape = torch.Size([batch_size, seq_len, 2*d_model])
        x_conv.shape = torch.Size([batch_size, seq_len, 2*d_model])
        x_conv_act.shape = torch.Size([batch_size, seq_len, 2*d_model])
        """
         # Refer to Figure 3 in the MAMBA paper
 
         x = self.norm(x)
 
         x_proj = self.inp_proj(x)
 
         # Add 1D convolution with kernel size 3
         x_conv = self.conv(x_proj)
 
         x_conv_act = F.silu(x_conv)
 
         # Add linear layer for conv output
         x_conv_out = self.conv_linear(x_conv_act)
 
         x_ssm = self.S6(x_conv_out)
         x_act = F.silu(x_ssm)  # Swish activation can be implemented as x * sigmoid(x)
 
         # residual skip connection with nonlinearity introduced by multiplication
         x_residual = F.silu(self.D(x))
 
         x_combined = x_act * x_residual
 
         x_out = self.out_proj(x_combined)
 
         return x_out

Mamba 模型

包括一系列 MambaBlock 模塊。每個塊都順序處理輸入數據,一個塊的輸出作爲下一個塊的輸入。這種順序處理允許模型捕獲輸入數據中的複雜模式和關係,使其對涉及順序建模的任務有效。多個塊的堆疊是深度學習架構中的常見設計,因爲它使模型能夠學習數據的分層表示。

 class Mamba(nn.Module):
     def __init__(self, seq_len, d_model, state_size, device):
         super(Mamba, self).__init__()
         self.mamba_block1 = MambaBlock(seq_len, d_model, state_size, device)
         self.mamba_block2 = MambaBlock(seq_len, d_model, state_size, device)
         self.mamba_block3 = MambaBlock(seq_len, d_model, state_size, device)
 
     def forward(self, x):
         x = self.mamba_block1(x)
         x = self.mamba_block2(x)
         x = self.mamba_block3(x)
         return x

RMSNorm 是一個自定義規範化層,這一層用於規範神經網絡的激活,這可以幫助穩定和加快訓練。

 class RMSNorm(nn.Module):
     def __init__(self,
                  d_model: int,
                  eps: float = 1e-5,
                  device: str ='cuda'):
         super().__init__()
         self.eps = eps
         self.weight = nn.Parameter(torch.ones(d_model, device=device))
 
 
     def forward(self, x):
         output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
 
         return output

這一層的用法:

 x = torch.rand(batch_size, seq_len, d_model, device=device)
 # Create the Mamba model
 mamba = Mamba(seq_len, d_model, state_size, device)
 
 # rmsnorm
 norm = RMSNorm(d_model)
 x = norm(x)
 
 # Forward pass
 test_output = mamba(x)
 print(f"test_output.shape = {test_output.shape}")  # Should be [batch_size, seq_len, d_model]

上面就是模型的全部基本代碼,下面就可以進行數據準備和訓練

我們自定義一個 Enwiki8Dataset

 class Enwiki8Dataset(Dataset):
     def __init__(self, data):
         self.data = data
 
     def __len__(self):
         return len(self.data['input_ids'])
 
     def __getitem__(self, idx):
         item = {key: val[idx].clone().detach() for key, val in self.data.items()}
         return item

pad_sequences_3d 用於將一批序列填充到統一的長度,確保批中的每個序列具有相同數量的元素 (或時間步長)。這在許多機器學習任務中尤其重要,因爲輸入數據必須具有一致的形狀。

 # Define a function for padding
 def pad_sequences_3d(sequences, max_len=None, pad_value=0):
     # Assuming sequences is a tensor of shape (batch_size, seq_len, feature_size)
     batch_size, seq_len, feature_size = sequences.shape
 
     if max_len is None:
         max_len = seq_len + 1
 
 
     # Initialize padded_sequences with the pad_value
     padded_sequences = torch.full((batch_size, max_len, feature_size), fill_value=pad_value, dtype=sequences.dtype, device=sequences.device)
     # Pad each sequence to the max_len
     padded_sequences[:, :seq_len, :] = sequences
 
     return padded_sequences

訓練過程:

 def train(model, tokenizer, data_loader, optimizer, criterion, device, max_grad_norm=1.0, DEBUGGING_IS_ON=False):
     model.train()
     total_loss = 0
     for batch in data_loader:
         optimizer.zero_grad()
 
         input_data = batch['input_ids'].clone().to(device)
         attention_mask = batch['attention_mask'].clone().to(device)
 
         target = input_data[:, 1:]
         input_data = input_data[:, :-1]
 
         # Pad all the sequences in the batch:
         input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)
         target = pad_sequences_3d(target, max_len=input_data.size(1), pad_value=tokenizer.pad_token_id)
 
         if USE_MAMBA:
             output = model(input_data)
             loss = criterion(output, target)
 
         loss.backward(retain_graph=True)
 
         for name, param in model.named_parameters():
            if 'out_proj.bias' not in name:
                # clip weights but not bias for out_proj
                torch.nn.utils.clip_grad_norm_(param, max_norm=max_grad_norm)
 
         if DEBUGGING_IS_ON:
             for name, parameter in model.named_parameters():
                 if parameter.grad is not None:
                     print(f"{name} gradient: {parameter.grad.data.norm(2)}")
                 else:
                     print(f"{name} has no gradient")
 
         if USE_MAMBA and DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:
             model.S6.h[:current_batch_size, ...].copy_(temp_buffer)
 
         optimizer.step()
 
         total_loss += loss.item()
     return total_loss / len(data_loader)

評估函數:

 def evaluate(model, data_loader, criterion, device):
     model.eval()
     total_loss = 0
     with torch.no_grad():
         for batch in data_loader:
             input_data = batch['input_ids'].clone().detach().to(device)
             attention_mask = batch['attention_mask'].clone().detach().to(device)
 
             target = input_data[:, 1:]
             input_data = input_data[:, :-1]
 
             # Pad all the sequences in the batch:
             input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)
             target = pad_sequences_3d(target, max_len=input_data.size(1), pad_value=tokenizer.pad_token_id)
 
             if USE_MAMBA:
                 output = model(input_data)
                 loss = criterion(output, target)
             total_loss += loss.item()
     return total_loss / len(data_loader)

最後,calculate_perplexity 用於評估語言模型 (如 Mamba) 的性能。

 def calculate_perplexity(loss):
     return math.exp(loss)

load_enwiki8_dataset 函數用於下載和提取 enwiki8 數據集,該數據集通常用於對語言模型進行基準測試。

 def load_enwiki8_dataset():
     print(f"Download and extract enwiki8 data")
     url = "http://mattmahoney.net/dc/enwik8.zip"
     urllib.request.urlretrieve(url, "enwik8.zip")
 
     with ZipFile("enwik8.zip") as f:
         data = f.read("enwik8").decode("utf-8")
 
     return data

encode_dataset 函數設計用於標記和編碼數據集,爲神經網絡模型 (如 Mamba) 處理數據集做準備。

 # Tokenize and encode the dataset
 def encode_dataset(tokenizer, text_data):
     def batch_encode(tokenizer, text_data, batch_size=1000):
         # Tokenize in batches
         batched_input_ids = []
         for i in range(0, len(text_data), batch_size):
             batch = text_data[i:i+batch_size]
             inputs = tokenizer(batch, add_special_tokens=True, truncation=True,
                                padding='max_length', max_length=seq_len,
                                return_tensors='pt')
             batched_input_ids.append(inputs['input_ids'])
         return torch.cat(batched_input_ids)
 
     # Assuming enwiki8_data is a list of sentences
     input_ids = batch_encode(tokenizer, enwiki8_data)
 
     # vocab_size is the number of unique tokens in the tokenizer's vocabulary
     global vocab_size
     vocab_size = len(tokenizer.vocab)  # Note that for some tokenizers, we might access the vocab directly
     print(f"vocab_size = {vocab_size}")
 
     # Create an embedding layer
     # embedding_dim is the size of the embedding vectors (MAMBA model's D)
     embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
 
     # Pass `input_ids` through the embedding layer
     # This will change `input_ids` from shape [B, L] to [B, L, D]
     def batch_embedding_calls(input_ids, embedding_layer, batch_size=256):
         # Check if input_ids is already a tensor, if not convert it
         if not isinstance(input_ids, torch.Tensor):
             input_ids = torch.tensor(input_ids, dtype=torch.long)
 
         # Calculate the number of batches needed
         num_batches = math.ceil(input_ids.size(0) / batch_size)
 
         # List to hold the output embeddings
         output_embeddings = []
 
         # Process each batch
         for i in range(num_batches):
             # Calculate start and end indices for the current batch
             start_idx = i * batch_size
             end_idx = start_idx + batch_size
 
             # Get the batch
             input_id_batch = input_ids[start_idx:end_idx]
 
             # Call the embedding layer
             with torch.no_grad():  # No need gradients for this operation
                 batch_embeddings = embedding_layer(input_id_batch)
 
             # Append the result to the list
             output_embeddings.append(batch_embeddings)
 
         # Concatenate the embeddings from each batch into a single tensor
         all_embeddings = torch.cat(output_embeddings, dim=0)
 
         return all_embeddings
 
     # `input_ids` is a list or tensor of the input IDs and `embedding_layer` is model's embedding layer
     if USE_MAMBA:
         # Set `batch_size` to a value that works for memory constraints
         encoded_inputs = batch_embedding_calls(input_ids, embedding_layer, batch_size=1).float()
 
     attention_mask = (input_ids != tokenizer.pad_token_id).type(input_ids.dtype)
 
     return encoded_inputs, attention_mask

下面就可以進行訓練了

 # Load a pretrained tokenizer
 tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
 
 # Assuming encoded_inputs is a preprocessed tensor of shape [num_samples, seq_len, d_model]
 encoded_inputs_file = 'encoded_inputs_mamba.pt'
 
 
 if os.path.exists(encoded_inputs_file):
     print("Loading pre-tokenized data...")
     encoded_inputs = torch.load(encoded_inputs_file)
 else:
     print("Tokenizing raw data...")
     enwiki8_data = load_enwiki8_dataset()
     encoded_inputs, attention_mask = encode_dataset(tokenizer, enwiki8_data)
     torch.save(encoded_inputs, encoded_inputs_file)
     print(f"finished tokenizing data")
 
 
 # Combine into a single dictionary
 data = {
     'input_ids': encoded_inputs,
     'attention_mask': attention_mask
 }
 
 # Split the data into train and validation sets
 total_size = len(data['input_ids'])
 train_size = int(total_size * 0.8)
 
 train_data = {key: val[:train_size] for key, val in data.items()}
 val_data = {key: val[train_size:] for key, val in data.items()}
 
 train_dataset = Enwiki8Dataset(train_data)
 val_dataset = Enwiki8Dataset(val_data)
 
 
 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
 val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
 
 
 # Initialize the model
 
 model = Mamba(seq_len, d_model, state_size, device).to(device)
 
 # Define the loss function and optimizer
 criterion = nn.CrossEntropyLoss()
 optimizer = optim.AdamW(model.parameters(), lr=5e-6)
 
 # Training loop
 num_epochs = 25  # Number of epochs to train for
 
 for epoch in tqdm(range(num_epochs)):  # loop over the dataset multiple times
     train_loss = train(model, tokenizer, train_loader, optimizer, criterion, device, max_grad_norm=10.0, DEBUGGING_IS_ON=False)
     val_loss = evaluate(model, val_loader, criterion, device)
     val_perplexity = calculate_perplexity(val_loss)
     print(f'Epoch: {epoch+1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation Perplexity: {val_perplexity:.4f}')

以上就是訓練的完整代碼

總結

我們介紹了 Mamba 的概念和架構,並且從頭開始構建 Mamba 復現,這樣可以將理論轉化爲實踐。通過這種動手的方法,可以看到 Mamba 序列建模方法和效率。如果你想直接使用,可以看論文提供的代碼。

論文地址:

https://arxiv.org/abs/2312.00752

論文提供的源代碼:

https://github.com/state-spaces/mamba

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/fjyrcu7sxtyCgSYyT1pldg