- Authors

- Name
- Youngju Kim
- @fjvbn20031
Overview
Since the landmark "Attention is All You Need" paper in 2017, Transformers have dominated virtually every sequence modeling task — from natural language processing to images, code, and audio. However, Transformers have a fundamental weakness: the quadratic complexity O(n^2) of self-attention.
In December 2023, Mamba, published by Albert Gu and Tri Dao, elegantly solved this problem and sent shockwaves through the deep learning community. Mamba adapts State Space Models (SSMs) — a concept from classical control theory — to modern deep learning, enabling linear-time processing of long sequences.
This guide covers everything from the mathematical foundations of SSMs to Mamba's core innovations and practical implementations.
1. The Limitations of Transformers and the Rise of SSMs
The O(n^2) Complexity Problem of Self-Attention
The core of Transformers is the Self-Attention mechanism. For a sequence of length n, computing the attention matrix requires computing relationships between all pairs of tokens, resulting in both time and memory complexity of O(n^2).
Attention matrix size: n × n
= 1,000 tokens → 1,000,000 elements
= 10,000 tokens → 100,000,000 elements
= 100,000 tokens → 10,000,000,000 elements (infeasible!)
Due to this problem, real-world LLMs face constraints on context window size. GPT-4 supports 128K tokens, but this requires enormous computational cost and optimization techniques.
The Difficulty of Processing Long Sequences
Real applications — such as long document summarization, understanding entire codebases, and maintaining long-term conversations — require very long context. But Transformer's quadratic complexity makes practical use challenging.
For example, processing an entire book (approximately 100,000 words = 130,000 tokens) at once would require memory that is virtually infeasible on current hardware with standard Transformers.
The Limitations of Recurrent Models (RNN, LSTM)
Recurrent neural networks (RNNs) and LSTMs theoretically have O(n) complexity because they maintain a fixed-size hidden state at each timestep. However, RNN/LSTM have the following problems:
- Vanishing/Exploding Gradients: Gradients disappear or explode during backpropagation on long sequences
- Non-parallelizable: Each timestep depends on the previous state, making GPU parallelization difficult
- Difficulty capturing long-range dependencies: Hard to connect information far apart in the sequence
The Solution SSMs Offer
State Space Models (SSMs) combine the advantages of both approaches:
- During training: Full parallelization via convolution → training efficiency comparable to Transformers
- During inference: Recurrence with O(1) memory and O(n) computation → efficient inference like RNNs
- Long sequences: Linear complexity enables processing of very long sequences
2. Mathematical Foundations of State Space Models
Continuous-Time SSM
The origins of SSMs trace back to Rudolf Kalman's control theory in the 1960s. A continuous-time linear system is expressed as:
x'(t) = A·x(t) + B·u(t)
y(t) = C·x(t) + D·u(t)
Where:
u(t): input signalx(t): state vector (latent state), size Ny(t): output signalA: N×N state transition matrixB: N×1 input projection matrixC: 1×N output projection matrixD: direct feedthrough term (skip connection, usually set to 0)
This system receives input u(t), updates the internal state x(t), and produces output y(t). The state x(t) can be thought of as a "summary" of past information.
Discretization
To use continuous-time SSMs on digital computers, discretization is required. With a sampling interval Delta, there are two main discretization methods.
Zero-Order Hold (ZOH):
A_bar = exp(Delta · A)
B_bar = (Delta · A)^(-1) · (exp(Delta · A) - I) · Delta · B
Bilinear (Tustin) Transform:
A_bar = (I - Delta/2 · A)^(-1) · (I + Delta/2 · A)
B_bar = (I - Delta/2 · A)^(-1) · Delta · B
After discretization, the system becomes:
x[t] = A_bar · x[t-1] + B_bar · u[t]
y[t] = C · x[t]
This form is a Linear Recurrence, updating the state at each timestep.
SSM as a Convolution Kernel
The powerful property of SSMs is that they can be computed as convolutions during training. Starting with initial state x[0] = 0:
y[0] = C · B_bar · u[0]
y[1] = C · A_bar · B_bar · u[0] + C · B_bar · u[1]
y[2] = C · A_bar^2 · B_bar · u[0] + C · A_bar · B_bar · u[1] + C · B_bar · u[2]
...
Expressing this as convolution kernel K:
K = (C·B_bar, C·A_bar·B_bar, C·A_bar^2·B_bar, ...)
y = K * u (convolution)
This kernel K can be computed in parallel, and using FFT, it can be computed in O(n log n).
This is the fundamental duality of SSMs:
- Inference: Recurrence with O(1) memory
- Training: Convolution with O(n log n) parallel computation
3. S4 (Structured State Space Sequence Model)
S4 was published by Albert Gu et al. in 2021 and is the first important work that practically applied SSMs to deep learning.
HiPPO Matrix Initialization
One of S4's key contributions is the HiPPO (High-order Polynomial Projection Operators) matrix initialization. Simply initializing A randomly leads to gradient vanishing problems. HiPPO provides a specially designed A matrix that approximates past inputs using polynomials.
HiPPO-LegS (Legendre polynomial-based):
A[n,k] = -sqrt((2n+1)(2k+1)) if n > k
A[n,k] = -(n+1) if n == k
A[n,k] = 0 if n < k
This initialization ensures that state x[t] maintains an optimal polynomial approximation of all past inputs u[0..t].
Structured Matrix A (DPLR)
To efficiently compute the convolution kernel of the A matrix, S4 expresses A in Diagonal Plus Low-Rank (DPLR) form:
A = Λ - P·Q^T
Where Λ is a diagonal matrix and P, Q are low-rank vectors. Using this structure, kernel computation can be reduced to O(N).
Efficient Computation
import torch
import torch.nn as nn
import numpy as np
class S4Layer(nn.Module):
"""Simplified implementation of S4 layer"""
def __init__(self, d_model, d_state=64, dropout=0.0):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# HiPPO-LegS initialization
A = self._make_hippo(d_state)
# DPLR decomposition
self.A = nn.Parameter(torch.tensor(A, dtype=torch.float32))
self.B = nn.Parameter(torch.randn(d_state, 1) * 0.01)
self.C = nn.Parameter(torch.randn(1, d_state) * 0.01)
self.log_delta = nn.Parameter(torch.zeros(d_model))
self.D = nn.Parameter(torch.ones(d_model))
def _make_hippo(self, N):
"""Generate HiPPO-LegS matrix"""
P = np.sqrt(1 + 2 * np.arange(N))
A = P[:, np.newaxis] * P[np.newaxis, :]
A = np.tril(A) - np.diag(np.arange(N))
return -A
def discretize(self, A, B, C, delta):
"""ZOH discretization"""
dA = torch.matrix_exp(delta.unsqueeze(-1) * A)
dB = torch.linalg.solve(A, (dA - torch.eye(A.shape[0])) @ B)
return dA, dB, C
def forward(self, u):
# u: (batch, seq_len, d_model)
delta = torch.exp(self.log_delta)
# Discretization and convolution computation
# Actual implementation is more complex, simplified here for concept
return u
4. H3 (Hungry Hungry Hippos)
Published in 2022, H3 is an improvement of S4 better suited for language modeling. The title is a humorous expansion of the HiPPO acronym.
Differences from S4
S4 captures long-range dependencies well, but lacks the token-to-token interactions important in language modeling. Attention directly computes "how related is this word to that word?", but S4 finds this difficult to capture with pure recurrence.
Gating Mechanism
H3 uses two SSMs and adds gating between them:
class H3Layer(nn.Module):
"""H3 Layer"""
def __init__(self, d_model, d_state=64):
super().__init__()
# Two SSMs (shift SSM + diagonal SSM)
self.shift_ssm = S4Layer(d_model, d_state)
self.diag_ssm = S4Layer(d_model, d_state)
# Projections
self.Q_proj = nn.Linear(d_model, d_model)
self.K_proj = nn.Linear(d_model, d_model)
self.V_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x):
# Q, K, V projections
q = self.Q_proj(x)
k = self.K_proj(x)
v = self.V_proj(x)
# SSM 1: Apply Shift SSM to K
k_ssm = self.shift_ssm(k)
# Multiplicative gating: Q element-wise K_ssm
gated = q * k_ssm
# SSM 2: Apply Diagonal SSM to V, then multiply with gated
v_ssm = self.diag_ssm(v)
output = gated * v_ssm
return self.out_proj(output)
Language Modeling Improvements
H3 achieved performance close to Transformers in GPT-2 sized language models while showing significantly faster inference speed. This was an important milestone demonstrating the practicality of SSMs for language modeling.
5. Mamba (Selective State Space Model)
In December 2023, Mamba published by Albert Gu and Tri Dao represents the most important advance in SSM research. Mamba's core innovation is the Selective Mechanism.
Core Innovation: The Selective Mechanism
The fundamental limitation of S4 and H3 was that matrices A, B, C were input-independent. The kernel had to be fixed for convolution-based training.
Mamba breaks this constraint and introduces S6 (Selective SSM): making B, C, and Delta matrices functions of input u(t).
B(t) = Linear_B(u(t)) # B that varies with input
C(t) = Linear_C(u(t)) # C that varies with input
Delta(t) = softplus(Linear_Delta(u(t))) # Delta that varies with input
This allows the model to dynamically decide, based on input, which information to store in the state and which to discard.
Intuitively:
- Large Delta → current input strongly influences state (remember information)
- Small Delta → state changes little (ignore information)
This plays a role similar to LSTM's gates (input gate, forget gate), but within the continuous-time SSM framework.
Hardware-Aware Algorithm
Introducing the selective mechanism means B, C depend on input, so convolution can no longer be used for computation. This can lead to severe computational inefficiency.
Mamba solves this with a Hardware-Aware Algorithm — a kernel fusion technique leveraging GPU memory hierarchy (HBM vs SRAM):
Problem: Storing intermediate states in HBM (slow memory) during per-timestep recurrence creates memory bandwidth bottlenecks.
Solution:
- Perform all intermediate computations in fast SRAM (on-chip memory)
- Use Parallel Scan algorithm instead of full convolution
- Store only the final output in HBM
# Parallel Scan concept (actual implementation in CUDA)
def parallel_scan(gates, tokens):
"""
Compute linear recurrence in parallel
x[t] = gates[t] * x[t-1] + tokens[t]
Achieves O(log n) parallel depth with binary tree structure
"""
n = len(tokens)
# Up-sweep phase
log_n = int(np.log2(n))
for d in range(log_n):
step = 2 ** (d + 1)
for i in range(step - 1, n, step):
gates[i] = gates[i] * gates[i - 2**d]
tokens[i] = gates[i - 2**d] * tokens[i - 2**d] + tokens[i]
# Down-sweep phase (omitted)
return tokens
Mamba Block Structure
Input x (B, L, D)
│
├──────────────────────┐
│ │
Linear(D→ED) Linear(D→ED)
+ SiLU activation │
│ SSM (S6)
│ │
└─────── ⊙ ────────────┘
(element-wise multiply)
│
Linear(ED→D)
│
Output y (B, L, D)
Where E is the expansion ratio (usually 2) and D is the model dimension.
Complete Mamba Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
class MambaBlock(nn.Module):
"""
Mamba Block implementation
Paper: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces"
"""
def __init__(
self,
d_model, # Model dimension D
d_state=16, # SSM state dimension N
d_conv=4, # Convolution kernel size
expand=2, # Expansion ratio E
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
bias=False,
):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
if dt_rank == "auto":
self.dt_rank = max(1, int(d_model / 16))
else:
self.dt_rank = dt_rank
# Input projection (D → 2*ED, compute both branches at once)
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=bias)
# Local convolution (depthwise conv)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=True,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1, # causal padding
)
# SSM parameter projections
self.x_proj = nn.Linear(
self.d_inner, self.dt_rank + self.d_state * 2, bias=False
)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
# A initialization (HiPPO-based)
A = repeat(
torch.arange(1, self.d_state + 1),
"n -> d n",
d=self.d_inner
)
self.A_log = nn.Parameter(torch.log(A))
self.A_log._no_weight_decay = True
# D (skip connection)
self.D = nn.Parameter(torch.ones(self.d_inner))
self.D._no_weight_decay = True
# Output projection
self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)
def forward(self, hidden_states):
"""
hidden_states: (B, L, D)
Returns: (B, L, D)
"""
batch, seqlen, dim = hidden_states.shape
# Input projection
xz = self.in_proj(hidden_states)
x, z = xz.chunk(2, dim=-1) # each (B, L, ED)
# Convolution (causal 1D conv)
x = rearrange(x, "b l d -> b d l")
x = self.conv1d(x)[:, :, :seqlen] # causal trimming
x = rearrange(x, "b d l -> b l d")
x = F.silu(x)
# SSM
y = self.ssm(x)
# Gating (element-wise multiply with z branch)
y = y * F.silu(z)
# Output projection
output = self.out_proj(y)
return output
def ssm(self, x):
"""Selective SSM (S6) computation"""
d_in, n = self.d_inner, self.d_state
# A matrix (-exp(-A_log) ensures always negative = stability)
A = -torch.exp(self.A_log.float()) # (ED, N)
# x_proj: compute Delta, B, C
x_dbl = self.x_proj(x) # (B, L, dt_rank + 2N)
delta, B, C = x_dbl.split(
[self.dt_rank, n, n], dim=-1
)
# Delta: softplus ensures positive values
delta = F.softplus(self.dt_proj(delta)) # (B, L, ED)
# ZOH discretization and selective scan
y = self.selective_scan(x, delta, A, B, C, self.D)
return y
def selective_scan(self, u, delta, A, B, C, D):
"""
Selective scan algorithm
In practice, uses CUDA kernels from mamba_ssm
Here shown in pure PyTorch for conceptual understanding
"""
b, l, d_in = u.shape
n = A.shape[1]
# Discretization
# delta: (B, L, ED), A: (ED, N) → dA: (B, L, ED, N)
deltaA = torch.exp(
torch.einsum("bld,dn->bldn", delta, A)
)
# delta: (B, L, ED), B: (B, L, N), u: (B, L, ED)
# → deltaB_u: (B, L, ED, N)
deltaB_u = torch.einsum("bld,bln,bld->bldn", delta, B, u)
# Recurrent scan
x = torch.zeros(b, d_in, n, device=u.device, dtype=u.dtype)
ys = []
for i in range(l):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = torch.einsum("bdn,bn->bd", x, C[:, i, :])
ys.append(y)
y = torch.stack(ys, dim=1) # (B, L, ED)
# D (skip connection)
y = y + u * D
return y
class MambaModel(nn.Module):
"""Sequence model composed of stacked Mamba blocks"""
def __init__(self, d_model, n_layers, d_state=16, expand=2):
super().__init__()
self.layers = nn.ModuleList([
MambaBlock(d_model, d_state=d_state, expand=expand)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
"""x: (B, L, D)"""
for layer in self.layers:
x = x + layer(self.norm(x)) # Pre-norm residual
return x
6. Mamba 2
In May 2024, Tri Dao and Albert Gu published Mamba 2, further strengthening the theoretical foundations.
Differences from Mamba 1
Mamba 1's selective scan required sequential computation at each timestep. Mamba 2 discovered State Space Duality (SSD) to make this even more efficient.
State Space Duality (SSD)
The core insight of Mamba 2: SSMs with a specific structure can be expressed as Semi-Separable Matrices, which are mathematically equivalent to a specific form of attention.
This mathematical duality enables:
- Expressing SSM computation as matrix multiplication form
- Leveraging highly optimized BLAS libraries
- Maximizing GPU efficiency with Tensor Core utilization
# SSD operation concept
# SSM recurrence: x[t] = A[t] * x[t-1] + B[t] * u[t]
# y[t] = C[t]^T * x[t]
#
# Process in chunks:
# - Within chunk: parallel computation via matrix multiplication
# - Between chunks: recurrence for state propagation
class Mamba2Block(nn.Module):
"""Mamba 2 Block"""
def __init__(self, d_model, d_state=64, n_heads=8, chunk_size=64):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.n_heads = n_heads
self.chunk_size = chunk_size
self.d_head = d_model // n_heads
# Multi-head structure
self.norm = nn.LayerNorm(d_model)
self.in_proj = nn.Linear(d_model, d_model * 2 + d_state * 2 + n_heads)
self.out_proj = nn.Linear(d_model, d_model)
# A parameters (per head)
self.A_log = nn.Parameter(torch.randn(n_heads))
def forward(self, x):
"""x: (B, L, D)"""
# Actual implementation uses CUDA kernels from mamba_ssm package
return x
Multi-Head Structure
Mamba 2 introduces multi-head SSM, making it structurally similar to Transformer's Multi-Head Attention. This provides:
- Better expressiveness
- Theoretical connection to attention
- Easier hybrid architecture design
Integration Potential with Transformers
SSD theory shows that certain SSMs are mathematically equivalent to certain attention mechanisms. This provides the theoretical basis for hybrid architectures mixing Mamba and Transformers.
7. Hybrid Architectures
MambaFormer
MambaFormer interleaves Mamba blocks and Transformer blocks:
Layer 1: Mamba Block (capture local patterns)
Layer 2: Attention (capture global dependencies)
Layer 3: Mamba Block
Layer 4: Attention
...
This structure leverages the strengths of each component:
- Mamba: efficient sequence processing, local patterns
- Attention: selective information retrieval, global dependencies
Jamba (SSM + Transformer + MoE)
Jamba, released by AI21 Labs in 2024, is a more complex hybrid:
Jamba = Mamba + Transformer + Mixture of Experts
Architecture:
- 52B parameters (active: 12B)
- Layer ratio: Attention 1 : Mamba 7
- MoE applied to some layers
- Supports 256K context window
Compared to same-size Transformers, Jamba achieves:
- 3x improvement in inference throughput
- Dramatically improved memory efficiency for long contexts
# Jamba-style hybrid block (concept)
class JambaLayer(nn.Module):
def __init__(self, d_model, layer_idx, attn_every_n=8):
super().__init__()
self.use_attention = (layer_idx % attn_every_n == 0)
if self.use_attention:
self.mixer = nn.MultiheadAttention(d_model, num_heads=8)
else:
self.mixer = MambaBlock(d_model)
# MoE (only for some layers)
self.use_moe = (layer_idx % 2 == 0)
if self.use_moe:
self.ffn = MixtureOfExperts(d_model, n_experts=16)
else:
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.SiLU(),
nn.Linear(d_model * 4, d_model)
)
def forward(self, x):
x = x + self.mixer(x)
x = x + self.ffn(x)
return x
RWKV (Linear Attention)
RWKV is a hybrid of Transformer and RNN. Standing for "Receptance Weighted Key Value":
- Training: Parallelized like Transformer (matrix form)
- Inference: Recurrent like RNN (O(1) state)
- Token mixing without attention
Latest versions (RWKV-6, RWKV-7) show performance competitive with Mamba.
RetNet (Retentive Network)
Microsoft's RetNet aims to simultaneously achieve "Training Parallelism, Inference Efficiency, Competitive Performance":
Three computational paradigms:
1. Parallel: O(n^2) parallel computation during training (lower constant than Transformer)
2. Recurrent: O(1) memory during inference
3. Chunkwise Recurrent: balanced intermediate approach
8. Mamba Performance Comparison
Inference Speed (Linear Scale)
Mamba's greatest strength is its linear scaling with sequence length.
Normalizing performance at sequence length 2K to 1x:
Sequence Length Transformer Mamba
1K 0.5x 0.5x
2K 1.0x 1.0x
4K 3.5x 2.0x ← 1.75x faster than Transformer
8K 13x 4.0x ← 3.25x faster than Transformer
16K 50x 8.0x ← 6.25x faster than Transformer
100K ~1800x ~50x ← 36x faster than Transformer
Memory Efficiency
State size comparison during inference:
Model Inference Memory (per 1K tokens)
Transformer O(n) KV Cache
Mamba O(1) SSM state (fixed size!)
Example: 130M parameter model, 1M token sequence
- Transformer: ~16GB KV Cache
- Mamba: ~1MB state (constant!)
Long Sequence Tasks
Long Range Arena (LRA) benchmark (sequence lengths 1K-16K):
Model ListOps Text Retrieval Image Path-X Avg
Transformer 36.4 65.0 57.5 42.4 0.0 40.3
LSTM 35.9 63.7 65.0 43.3 0.0 41.6
S4 59.6 86.8 90.9 88.7 86.1 82.4
Mamba ~S4-level performance, faster processing
Mamba particularly excels over S4 on synthetic benchmarks like Selective Copying and Induction Heads, which are more relevant to real language modeling ability.
9. Applications of Mamba
Natural Language Processing
Mamba-based language models are emerging rapidly:
- MambaChat: Conversational AI assistant
- Falcon Mamba: Open-source 7B Mamba model released by TII
- CodeMamba: Specialized for code generation
Mamba is efficient for long document processing, summarization, and translation compared to Transformers.
Bioinformatics
Genomic sequence analysis deals with very long sequences (millions of base pairs), making Mamba particularly advantageous:
- Caduceus: Long DNA sequence modeling
- Hyena: Long-sequence DNA/protein analysis
- Potential applications in protein structure prediction
# Example of Mamba for biological sequence modeling
from mamba_ssm import MambaLMHeadModel
import torch
# DNA sequence modeling (A, T, G, C, N tokens)
DNA_VOCAB = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4}
VOCAB_SIZE = len(DNA_VOCAB)
# Process very long sequences efficiently
model = MambaLMHeadModel.from_pretrained(
"state-spaces/mamba-130m",
device="cuda",
dtype=torch.float16
)
Time Series Analysis
For financial, meteorological, IoT sensor data, and other long time series, Mamba shows its strengths:
- TimeMamba: Long time series forecasting
- MambaMixer: Multivariate time series modeling
- Advancements in S4/S5-based time series models
Image Processing (VMamba)
VMamba extends Mamba to 2D image processing:
# VMamba core: 2D selective scan
# Scan image in 4 directions to capture 2D structure
# Directions:
# 1. Left→Right, Top→Bottom (standard raster scan)
# 2. Right→Left, Bottom→Top (reverse)
# 3. Top→Bottom, Left→Right (column-first)
# 4. Bottom→Top, Right→Left (reverse)
class VMambaBlock(nn.Module):
"""VMamba: Visual Mamba Block"""
def __init__(self, d_model, d_state=16):
super().__init__()
self.norm = nn.LayerNorm(d_model)
# 4-directional SSMs
self.ssms = nn.ModuleList([
MambaBlock(d_model, d_state) for _ in range(4)
])
self.out_proj = nn.Linear(d_model * 4, d_model)
def forward(self, x):
"""x: (B, H, W, D) - image patch embeddings"""
b, h, w, d = x.shape
x_flat = x.view(b, h*w, d)
outputs = []
# 4-directional scan
for i, ssm in enumerate(self.ssms):
if i == 0: # forward
seq = x_flat
elif i == 1: # reverse
seq = x_flat.flip(1)
elif i == 2: # column-first
seq = x.permute(0, 2, 1, 3).reshape(b, h*w, d)
else: # column-first reverse
seq = x.permute(0, 2, 1, 3).reshape(b, h*w, d).flip(1)
out = ssm(seq)
if i % 2 == 1:
out = out.flip(1)
outputs.append(out)
# Combine 4 directional results
combined = torch.cat(outputs, dim=-1)
return self.out_proj(combined).view(b, h, w, d)
10. Practical Usage
Installing mamba-ssm Package
# Required packages
pip install torch torchvision torchaudio
pip install causal-conv1d>=1.2.0
pip install mamba-ssm
# Or install from source (latest features)
git clone https://github.com/state-spaces/mamba
cd mamba
pip install -e ".[dev]"
Using MambaLMHeadModel
import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer
# Load pre-trained model
model_name = "state-spaces/mamba-2.8b-slimpj"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained(
model_name,
device="cuda",
dtype=torch.bfloat16
)
model.eval()
# Text generation
def generate_text(prompt, max_new_tokens=200, temperature=0.7):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
with torch.no_grad():
output = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=0.9,
do_sample=True,
eos_token_id=tokenizer.eos_token_id,
)
generated = output[0][input_ids.shape[1]:]
return tokenizer.decode(generated, skip_special_tokens=True)
# Example
prompt = "State Space Models are powerful because"
result = generate_text(prompt)
print(result)
Fine-tuning Example
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
class TextDataset(Dataset):
def __init__(self, texts, tokenizer, max_length=512):
self.encodings = tokenizer(
texts,
truncation=True,
max_length=max_length,
padding="max_length",
return_tensors="pt"
)
def __len__(self):
return len(self.encodings['input_ids'])
def __getitem__(self, idx):
return {
'input_ids': self.encodings['input_ids'][idx],
'attention_mask': self.encodings['attention_mask'][idx],
}
def finetune_mamba(
model_name="state-spaces/mamba-130m",
texts=None,
num_epochs=3,
learning_rate=1e-4,
batch_size=8,
):
"""Fine-tune Mamba model"""
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.pad_token = tokenizer.eos_token
# Load model
model = MambaLMHeadModel.from_pretrained(
model_name,
device=device,
dtype=torch.bfloat16
)
# Dataset
dataset = TextDataset(texts, tokenizer)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Optimizer (AdamW recommended for Mamba)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=0.1
)
# Scheduler
total_steps = len(dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=total_steps // 10,
num_training_steps=total_steps
)
# Training loop
model.train()
for epoch in range(num_epochs):
total_loss = 0
for batch_idx, batch in enumerate(dataloader):
input_ids = batch['input_ids'].to(device)
# Language modeling: next token prediction
outputs = model(input_ids)
logits = outputs.logits
# Shift for next token prediction
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
loss = nn.CrossEntropyLoss()(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
if batch_idx % 100 == 0:
print(f"Epoch {epoch+1}, Step {batch_idx}: Loss = {loss.item():.4f}")
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1} complete: Avg Loss = {avg_loss:.4f}")
return model
Custom Mamba Model Configuration
from mamba_ssm import Mamba
from mamba_ssm.models.config_mamba import MambaConfig
# Custom configuration
config = MambaConfig(
d_model=1024,
n_layer=48,
vocab_size=50280,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
pad_vocab_size_multiple=8,
)
# Create model (~1.4B parameters)
model = MambaLMHeadModel(config, device="cuda", dtype=torch.bfloat16)
print(f"Parameter count: {sum(p.numel() for p in model.parameters()):,}")
Conclusion: The Future of Mamba
Mamba has brought revolutionary changes to the field of deep learning sequence modeling. Summarizing the key contributions:
- Selective Mechanism: SSM parameters that dynamically change based on input
- Hardware-Aware Design: Efficient computation leveraging GPU memory hierarchy
- Dual Representation: Parallelized during training, recurrent during inference
- Linear Complexity: Linear computation and memory complexity with sequence length
There are still challenges to address:
- Somewhat weaker in-context learning compared to Transformers
- Needs validation at very large scales
- Clear demonstration of superiority over attention-based models
However, Mamba and SSM-family models will increasingly play an important role in areas where Transformers struggle — long sequence processing, real-time inference, and edge device deployment.
References
- Mamba paper: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" (Gu & Dao, 2023) — https://arxiv.org/abs/2312.00752
- Mamba 2 paper: "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality" (Dao & Gu, 2024) — https://arxiv.org/abs/2405.21060
- S4 paper: "Efficiently Modeling Long Sequences with Structured State Spaces" (Gu et al., 2021) — https://arxiv.org/abs/2111.00396
- H3 paper: "Hungry Hungry Hippos: Towards Language Modeling with State Space Models" (Fu et al., 2022) — https://arxiv.org/abs/2208.04933
- Mamba GitHub: https://github.com/state-spaces/mamba
- Jamba paper: "Jamba: A Hybrid Transformer-Mamba Language Model" (AI21 Labs, 2024)