Skip to content
Published on

Mamba and State Space Models Complete Guide: Beyond Transformers

Authors

Overview

Since the landmark "Attention is All You Need" paper in 2017, Transformers have dominated virtually every sequence modeling task — from natural language processing to images, code, and audio. However, Transformers have a fundamental weakness: the quadratic complexity O(n^2) of self-attention.

In December 2023, Mamba, published by Albert Gu and Tri Dao, elegantly solved this problem and sent shockwaves through the deep learning community. Mamba adapts State Space Models (SSMs) — a concept from classical control theory — to modern deep learning, enabling linear-time processing of long sequences.

This guide covers everything from the mathematical foundations of SSMs to Mamba's core innovations and practical implementations.


1. The Limitations of Transformers and the Rise of SSMs

The O(n^2) Complexity Problem of Self-Attention

The core of Transformers is the Self-Attention mechanism. For a sequence of length n, computing the attention matrix requires computing relationships between all pairs of tokens, resulting in both time and memory complexity of O(n^2).

Attention matrix size: n × n
= 1,000 tokens   → 1,000,000 elements
= 10,000 tokens  → 100,000,000 elements
= 100,000 tokens → 10,000,000,000 elements (infeasible!)

Due to this problem, real-world LLMs face constraints on context window size. GPT-4 supports 128K tokens, but this requires enormous computational cost and optimization techniques.

The Difficulty of Processing Long Sequences

Real applications — such as long document summarization, understanding entire codebases, and maintaining long-term conversations — require very long context. But Transformer's quadratic complexity makes practical use challenging.

For example, processing an entire book (approximately 100,000 words = 130,000 tokens) at once would require memory that is virtually infeasible on current hardware with standard Transformers.

The Limitations of Recurrent Models (RNN, LSTM)

Recurrent neural networks (RNNs) and LSTMs theoretically have O(n) complexity because they maintain a fixed-size hidden state at each timestep. However, RNN/LSTM have the following problems:

  • Vanishing/Exploding Gradients: Gradients disappear or explode during backpropagation on long sequences
  • Non-parallelizable: Each timestep depends on the previous state, making GPU parallelization difficult
  • Difficulty capturing long-range dependencies: Hard to connect information far apart in the sequence

The Solution SSMs Offer

State Space Models (SSMs) combine the advantages of both approaches:

  • During training: Full parallelization via convolution → training efficiency comparable to Transformers
  • During inference: Recurrence with O(1) memory and O(n) computation → efficient inference like RNNs
  • Long sequences: Linear complexity enables processing of very long sequences

2. Mathematical Foundations of State Space Models

Continuous-Time SSM

The origins of SSMs trace back to Rudolf Kalman's control theory in the 1960s. A continuous-time linear system is expressed as:

x'(t) = A·x(t) + B·u(t)
y(t)  = C·x(t) + D·u(t)

Where:

  • u(t): input signal
  • x(t): state vector (latent state), size N
  • y(t): output signal
  • A: N×N state transition matrix
  • B: N×1 input projection matrix
  • C: 1×N output projection matrix
  • D: direct feedthrough term (skip connection, usually set to 0)

This system receives input u(t), updates the internal state x(t), and produces output y(t). The state x(t) can be thought of as a "summary" of past information.

Discretization

To use continuous-time SSMs on digital computers, discretization is required. With a sampling interval Delta, there are two main discretization methods.

Zero-Order Hold (ZOH):

A_bar = exp(Delta · A)
B_bar = (Delta · A)^(-1) · (exp(Delta · A) - I) · Delta · B

Bilinear (Tustin) Transform:

A_bar = (I - Delta/2 · A)^(-1) · (I + Delta/2 · A)
B_bar = (I - Delta/2 · A)^(-1) · Delta · B

After discretization, the system becomes:

x[t] = A_bar · x[t-1] + B_bar · u[t]
y[t] = C · x[t]

This form is a Linear Recurrence, updating the state at each timestep.

SSM as a Convolution Kernel

The powerful property of SSMs is that they can be computed as convolutions during training. Starting with initial state x[0] = 0:

y[0] = C · B_bar · u[0]
y[1] = C · A_bar · B_bar · u[0] + C · B_bar · u[1]
y[2] = C · A_bar^2 · B_bar · u[0] + C · A_bar · B_bar · u[1] + C · B_bar · u[2]
...

Expressing this as convolution kernel K:

K = (C·B_bar, C·A_bar·B_bar, C·A_bar^2·B_bar, ...)
y = K * u  (convolution)

This kernel K can be computed in parallel, and using FFT, it can be computed in O(n log n).

This is the fundamental duality of SSMs:

  • Inference: Recurrence with O(1) memory
  • Training: Convolution with O(n log n) parallel computation

3. S4 (Structured State Space Sequence Model)

S4 was published by Albert Gu et al. in 2021 and is the first important work that practically applied SSMs to deep learning.

HiPPO Matrix Initialization

One of S4's key contributions is the HiPPO (High-order Polynomial Projection Operators) matrix initialization. Simply initializing A randomly leads to gradient vanishing problems. HiPPO provides a specially designed A matrix that approximates past inputs using polynomials.

HiPPO-LegS (Legendre polynomial-based):

A[n,k] = -sqrt((2n+1)(2k+1))  if n > k
A[n,k] = -(n+1)               if n == k
A[n,k] = 0                    if n < k

This initialization ensures that state x[t] maintains an optimal polynomial approximation of all past inputs u[0..t].

Structured Matrix A (DPLR)

To efficiently compute the convolution kernel of the A matrix, S4 expresses A in Diagonal Plus Low-Rank (DPLR) form:

A = Λ - P·Q^T

Where Λ is a diagonal matrix and P, Q are low-rank vectors. Using this structure, kernel computation can be reduced to O(N).

Efficient Computation

import torch
import torch.nn as nn
import numpy as np

class S4Layer(nn.Module):
    """Simplified implementation of S4 layer"""
    def __init__(self, d_model, d_state=64, dropout=0.0):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state

        # HiPPO-LegS initialization
        A = self._make_hippo(d_state)
        # DPLR decomposition
        self.A = nn.Parameter(torch.tensor(A, dtype=torch.float32))
        self.B = nn.Parameter(torch.randn(d_state, 1) * 0.01)
        self.C = nn.Parameter(torch.randn(1, d_state) * 0.01)
        self.log_delta = nn.Parameter(torch.zeros(d_model))
        self.D = nn.Parameter(torch.ones(d_model))

    def _make_hippo(self, N):
        """Generate HiPPO-LegS matrix"""
        P = np.sqrt(1 + 2 * np.arange(N))
        A = P[:, np.newaxis] * P[np.newaxis, :]
        A = np.tril(A) - np.diag(np.arange(N))
        return -A

    def discretize(self, A, B, C, delta):
        """ZOH discretization"""
        dA = torch.matrix_exp(delta.unsqueeze(-1) * A)
        dB = torch.linalg.solve(A, (dA - torch.eye(A.shape[0])) @ B)
        return dA, dB, C

    def forward(self, u):
        # u: (batch, seq_len, d_model)
        delta = torch.exp(self.log_delta)
        # Discretization and convolution computation
        # Actual implementation is more complex, simplified here for concept
        return u

4. H3 (Hungry Hungry Hippos)

Published in 2022, H3 is an improvement of S4 better suited for language modeling. The title is a humorous expansion of the HiPPO acronym.

Differences from S4

S4 captures long-range dependencies well, but lacks the token-to-token interactions important in language modeling. Attention directly computes "how related is this word to that word?", but S4 finds this difficult to capture with pure recurrence.

Gating Mechanism

H3 uses two SSMs and adds gating between them:

class H3Layer(nn.Module):
    """H3 Layer"""
    def __init__(self, d_model, d_state=64):
        super().__init__()
        # Two SSMs (shift SSM + diagonal SSM)
        self.shift_ssm = S4Layer(d_model, d_state)
        self.diag_ssm = S4Layer(d_model, d_state)

        # Projections
        self.Q_proj = nn.Linear(d_model, d_model)
        self.K_proj = nn.Linear(d_model, d_model)
        self.V_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        # Q, K, V projections
        q = self.Q_proj(x)
        k = self.K_proj(x)
        v = self.V_proj(x)

        # SSM 1: Apply Shift SSM to K
        k_ssm = self.shift_ssm(k)

        # Multiplicative gating: Q element-wise K_ssm
        gated = q * k_ssm

        # SSM 2: Apply Diagonal SSM to V, then multiply with gated
        v_ssm = self.diag_ssm(v)
        output = gated * v_ssm

        return self.out_proj(output)

Language Modeling Improvements

H3 achieved performance close to Transformers in GPT-2 sized language models while showing significantly faster inference speed. This was an important milestone demonstrating the practicality of SSMs for language modeling.


5. Mamba (Selective State Space Model)

In December 2023, Mamba published by Albert Gu and Tri Dao represents the most important advance in SSM research. Mamba's core innovation is the Selective Mechanism.

Core Innovation: The Selective Mechanism

The fundamental limitation of S4 and H3 was that matrices A, B, C were input-independent. The kernel had to be fixed for convolution-based training.

Mamba breaks this constraint and introduces S6 (Selective SSM): making B, C, and Delta matrices functions of input u(t).

B(t) = Linear_B(u(t))   # B that varies with input
C(t) = Linear_C(u(t))   # C that varies with input
Delta(t) = softplus(Linear_Delta(u(t)))  # Delta that varies with input

This allows the model to dynamically decide, based on input, which information to store in the state and which to discard.

Intuitively:

  • Large Delta → current input strongly influences state (remember information)
  • Small Delta → state changes little (ignore information)

This plays a role similar to LSTM's gates (input gate, forget gate), but within the continuous-time SSM framework.

Hardware-Aware Algorithm

Introducing the selective mechanism means B, C depend on input, so convolution can no longer be used for computation. This can lead to severe computational inefficiency.

Mamba solves this with a Hardware-Aware Algorithm — a kernel fusion technique leveraging GPU memory hierarchy (HBM vs SRAM):

Problem: Storing intermediate states in HBM (slow memory) during per-timestep recurrence creates memory bandwidth bottlenecks.

Solution:

  • Perform all intermediate computations in fast SRAM (on-chip memory)
  • Use Parallel Scan algorithm instead of full convolution
  • Store only the final output in HBM
# Parallel Scan concept (actual implementation in CUDA)
def parallel_scan(gates, tokens):
    """
    Compute linear recurrence in parallel
    x[t] = gates[t] * x[t-1] + tokens[t]

    Achieves O(log n) parallel depth with binary tree structure
    """
    n = len(tokens)
    # Up-sweep phase
    log_n = int(np.log2(n))
    for d in range(log_n):
        step = 2 ** (d + 1)
        for i in range(step - 1, n, step):
            gates[i] = gates[i] * gates[i - 2**d]
            tokens[i] = gates[i - 2**d] * tokens[i - 2**d] + tokens[i]
    # Down-sweep phase (omitted)
    return tokens

Mamba Block Structure

Input x (B, L, D)
    ├──────────────────────┐
    │                      │
  Linear(DED)          Linear(DED)
  + SiLU activation        │
SSM (S6)
    │                      │
    └─────── ⊙ ────────────┘
    (element-wise multiply)
  Linear(EDD)
  Output y (B, L, D)

Where E is the expansion ratio (usually 2) and D is the model dimension.

Complete Mamba Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

class MambaBlock(nn.Module):
    """
    Mamba Block implementation
    Paper: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces"
    """
    def __init__(
        self,
        d_model,       # Model dimension D
        d_state=16,    # SSM state dimension N
        d_conv=4,      # Convolution kernel size
        expand=2,      # Expansion ratio E
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        bias=False,
    ):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)

        if dt_rank == "auto":
            self.dt_rank = max(1, int(d_model / 16))
        else:
            self.dt_rank = dt_rank

        # Input projection (D → 2*ED, compute both branches at once)
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=bias)

        # Local convolution (depthwise conv)
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=True,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,  # causal padding
        )

        # SSM parameter projections
        self.x_proj = nn.Linear(
            self.d_inner, self.dt_rank + self.d_state * 2, bias=False
        )
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

        # A initialization (HiPPO-based)
        A = repeat(
            torch.arange(1, self.d_state + 1),
            "n -> d n",
            d=self.d_inner
        )
        self.A_log = nn.Parameter(torch.log(A))
        self.A_log._no_weight_decay = True

        # D (skip connection)
        self.D = nn.Parameter(torch.ones(self.d_inner))
        self.D._no_weight_decay = True

        # Output projection
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)

    def forward(self, hidden_states):
        """
        hidden_states: (B, L, D)
        Returns: (B, L, D)
        """
        batch, seqlen, dim = hidden_states.shape

        # Input projection
        xz = self.in_proj(hidden_states)
        x, z = xz.chunk(2, dim=-1)  # each (B, L, ED)

        # Convolution (causal 1D conv)
        x = rearrange(x, "b l d -> b d l")
        x = self.conv1d(x)[:, :, :seqlen]  # causal trimming
        x = rearrange(x, "b d l -> b l d")
        x = F.silu(x)

        # SSM
        y = self.ssm(x)

        # Gating (element-wise multiply with z branch)
        y = y * F.silu(z)

        # Output projection
        output = self.out_proj(y)
        return output

    def ssm(self, x):
        """Selective SSM (S6) computation"""
        d_in, n = self.d_inner, self.d_state

        # A matrix (-exp(-A_log) ensures always negative = stability)
        A = -torch.exp(self.A_log.float())  # (ED, N)

        # x_proj: compute Delta, B, C
        x_dbl = self.x_proj(x)  # (B, L, dt_rank + 2N)
        delta, B, C = x_dbl.split(
            [self.dt_rank, n, n], dim=-1
        )

        # Delta: softplus ensures positive values
        delta = F.softplus(self.dt_proj(delta))  # (B, L, ED)

        # ZOH discretization and selective scan
        y = self.selective_scan(x, delta, A, B, C, self.D)
        return y

    def selective_scan(self, u, delta, A, B, C, D):
        """
        Selective scan algorithm
        In practice, uses CUDA kernels from mamba_ssm
        Here shown in pure PyTorch for conceptual understanding
        """
        b, l, d_in = u.shape
        n = A.shape[1]

        # Discretization
        # delta: (B, L, ED), A: (ED, N) → dA: (B, L, ED, N)
        deltaA = torch.exp(
            torch.einsum("bld,dn->bldn", delta, A)
        )
        # delta: (B, L, ED), B: (B, L, N), u: (B, L, ED)
        # → deltaB_u: (B, L, ED, N)
        deltaB_u = torch.einsum("bld,bln,bld->bldn", delta, B, u)

        # Recurrent scan
        x = torch.zeros(b, d_in, n, device=u.device, dtype=u.dtype)
        ys = []
        for i in range(l):
            x = deltaA[:, i] * x + deltaB_u[:, i]
            y = torch.einsum("bdn,bn->bd", x, C[:, i, :])
            ys.append(y)
        y = torch.stack(ys, dim=1)  # (B, L, ED)

        # D (skip connection)
        y = y + u * D

        return y


class MambaModel(nn.Module):
    """Sequence model composed of stacked Mamba blocks"""
    def __init__(self, d_model, n_layers, d_state=16, expand=2):
        super().__init__()
        self.layers = nn.ModuleList([
            MambaBlock(d_model, d_state=d_state, expand=expand)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        """x: (B, L, D)"""
        for layer in self.layers:
            x = x + layer(self.norm(x))  # Pre-norm residual
        return x

6. Mamba 2

In May 2024, Tri Dao and Albert Gu published Mamba 2, further strengthening the theoretical foundations.

Differences from Mamba 1

Mamba 1's selective scan required sequential computation at each timestep. Mamba 2 discovered State Space Duality (SSD) to make this even more efficient.

State Space Duality (SSD)

The core insight of Mamba 2: SSMs with a specific structure can be expressed as Semi-Separable Matrices, which are mathematically equivalent to a specific form of attention.

This mathematical duality enables:

  1. Expressing SSM computation as matrix multiplication form
  2. Leveraging highly optimized BLAS libraries
  3. Maximizing GPU efficiency with Tensor Core utilization
# SSD operation concept
# SSM recurrence: x[t] = A[t] * x[t-1] + B[t] * u[t]
#                 y[t] = C[t]^T * x[t]
#
# Process in chunks:
# - Within chunk: parallel computation via matrix multiplication
# - Between chunks: recurrence for state propagation

class Mamba2Block(nn.Module):
    """Mamba 2 Block"""
    def __init__(self, d_model, d_state=64, n_heads=8, chunk_size=64):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.n_heads = n_heads
        self.chunk_size = chunk_size
        self.d_head = d_model // n_heads

        # Multi-head structure
        self.norm = nn.LayerNorm(d_model)
        self.in_proj = nn.Linear(d_model, d_model * 2 + d_state * 2 + n_heads)
        self.out_proj = nn.Linear(d_model, d_model)

        # A parameters (per head)
        self.A_log = nn.Parameter(torch.randn(n_heads))

    def forward(self, x):
        """x: (B, L, D)"""
        # Actual implementation uses CUDA kernels from mamba_ssm package
        return x

Multi-Head Structure

Mamba 2 introduces multi-head SSM, making it structurally similar to Transformer's Multi-Head Attention. This provides:

  • Better expressiveness
  • Theoretical connection to attention
  • Easier hybrid architecture design

Integration Potential with Transformers

SSD theory shows that certain SSMs are mathematically equivalent to certain attention mechanisms. This provides the theoretical basis for hybrid architectures mixing Mamba and Transformers.


7. Hybrid Architectures

MambaFormer

MambaFormer interleaves Mamba blocks and Transformer blocks:

Layer 1: Mamba Block   (capture local patterns)
Layer 2: Attention     (capture global dependencies)
Layer 3: Mamba Block
Layer 4: Attention
...

This structure leverages the strengths of each component:

  • Mamba: efficient sequence processing, local patterns
  • Attention: selective information retrieval, global dependencies

Jamba (SSM + Transformer + MoE)

Jamba, released by AI21 Labs in 2024, is a more complex hybrid:

Jamba = Mamba + Transformer + Mixture of Experts

Architecture:
- 52B parameters (active: 12B)
- Layer ratio: Attention 1 : Mamba 7
- MoE applied to some layers
- Supports 256K context window

Compared to same-size Transformers, Jamba achieves:

  • 3x improvement in inference throughput
  • Dramatically improved memory efficiency for long contexts
# Jamba-style hybrid block (concept)
class JambaLayer(nn.Module):
    def __init__(self, d_model, layer_idx, attn_every_n=8):
        super().__init__()
        self.use_attention = (layer_idx % attn_every_n == 0)

        if self.use_attention:
            self.mixer = nn.MultiheadAttention(d_model, num_heads=8)
        else:
            self.mixer = MambaBlock(d_model)

        # MoE (only for some layers)
        self.use_moe = (layer_idx % 2 == 0)
        if self.use_moe:
            self.ffn = MixtureOfExperts(d_model, n_experts=16)
        else:
            self.ffn = nn.Sequential(
                nn.Linear(d_model, d_model * 4),
                nn.SiLU(),
                nn.Linear(d_model * 4, d_model)
            )

    def forward(self, x):
        x = x + self.mixer(x)
        x = x + self.ffn(x)
        return x

RWKV (Linear Attention)

RWKV is a hybrid of Transformer and RNN. Standing for "Receptance Weighted Key Value":

- Training: Parallelized like Transformer (matrix form)
- Inference: Recurrent like RNN (O(1) state)
- Token mixing without attention

Latest versions (RWKV-6, RWKV-7) show performance competitive with Mamba.

RetNet (Retentive Network)

Microsoft's RetNet aims to simultaneously achieve "Training Parallelism, Inference Efficiency, Competitive Performance":

Three computational paradigms:
1. Parallel: O(n^2) parallel computation during training (lower constant than Transformer)
2. Recurrent: O(1) memory during inference
3. Chunkwise Recurrent: balanced intermediate approach

8. Mamba Performance Comparison

Inference Speed (Linear Scale)

Mamba's greatest strength is its linear scaling with sequence length.

Normalizing performance at sequence length 2K to 1x:

Sequence Length   Transformer   Mamba
1K                0.5x          0.5x
2K                1.0x          1.0x
4K                3.5x          2.0x    ← 1.75x faster than Transformer
8K                13x           4.0x    ← 3.25x faster than Transformer
16K               50x           8.0x    ← 6.25x faster than Transformer
100K              ~1800x        ~50x    ← 36x faster than Transformer

Memory Efficiency

State size comparison during inference:

Model           Inference Memory (per 1K tokens)
Transformer     O(n) KV Cache
Mamba           O(1) SSM state (fixed size!)

Example: 130M parameter model, 1M token sequence
- Transformer: ~16GB KV Cache
- Mamba: ~1MB state (constant!)

Long Sequence Tasks

Long Range Arena (LRA) benchmark (sequence lengths 1K-16K):

Model        ListOps  Text  Retrieval  Image  Path-X  Avg
Transformer  36.4     65.0  57.5       42.4   0.0     40.3
LSTM         35.9     63.7  65.0       43.3   0.0     41.6
S4           59.6     86.8  90.9       88.7   86.1    82.4
Mamba        ~S4-level performance, faster processing

Mamba particularly excels over S4 on synthetic benchmarks like Selective Copying and Induction Heads, which are more relevant to real language modeling ability.


9. Applications of Mamba

Natural Language Processing

Mamba-based language models are emerging rapidly:

  • MambaChat: Conversational AI assistant
  • Falcon Mamba: Open-source 7B Mamba model released by TII
  • CodeMamba: Specialized for code generation

Mamba is efficient for long document processing, summarization, and translation compared to Transformers.

Bioinformatics

Genomic sequence analysis deals with very long sequences (millions of base pairs), making Mamba particularly advantageous:

  • Caduceus: Long DNA sequence modeling
  • Hyena: Long-sequence DNA/protein analysis
  • Potential applications in protein structure prediction
# Example of Mamba for biological sequence modeling
from mamba_ssm import MambaLMHeadModel
import torch

# DNA sequence modeling (A, T, G, C, N tokens)
DNA_VOCAB = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4}
VOCAB_SIZE = len(DNA_VOCAB)

# Process very long sequences efficiently
model = MambaLMHeadModel.from_pretrained(
    "state-spaces/mamba-130m",
    device="cuda",
    dtype=torch.float16
)

Time Series Analysis

For financial, meteorological, IoT sensor data, and other long time series, Mamba shows its strengths:

  • TimeMamba: Long time series forecasting
  • MambaMixer: Multivariate time series modeling
  • Advancements in S4/S5-based time series models

Image Processing (VMamba)

VMamba extends Mamba to 2D image processing:

# VMamba core: 2D selective scan
# Scan image in 4 directions to capture 2D structure

# Directions:
# 1. Left→Right, Top→Bottom (standard raster scan)
# 2. Right→Left, Bottom→Top (reverse)
# 3. Top→Bottom, Left→Right (column-first)
# 4. Bottom→Top, Right→Left (reverse)

class VMambaBlock(nn.Module):
    """VMamba: Visual Mamba Block"""
    def __init__(self, d_model, d_state=16):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        # 4-directional SSMs
        self.ssms = nn.ModuleList([
            MambaBlock(d_model, d_state) for _ in range(4)
        ])
        self.out_proj = nn.Linear(d_model * 4, d_model)

    def forward(self, x):
        """x: (B, H, W, D) - image patch embeddings"""
        b, h, w, d = x.shape
        x_flat = x.view(b, h*w, d)

        outputs = []
        # 4-directional scan
        for i, ssm in enumerate(self.ssms):
            if i == 0:   # forward
                seq = x_flat
            elif i == 1: # reverse
                seq = x_flat.flip(1)
            elif i == 2: # column-first
                seq = x.permute(0, 2, 1, 3).reshape(b, h*w, d)
            else:        # column-first reverse
                seq = x.permute(0, 2, 1, 3).reshape(b, h*w, d).flip(1)

            out = ssm(seq)
            if i % 2 == 1:
                out = out.flip(1)
            outputs.append(out)

        # Combine 4 directional results
        combined = torch.cat(outputs, dim=-1)
        return self.out_proj(combined).view(b, h, w, d)

10. Practical Usage

Installing mamba-ssm Package

# Required packages
pip install torch torchvision torchaudio
pip install causal-conv1d>=1.2.0
pip install mamba-ssm

# Or install from source (latest features)
git clone https://github.com/state-spaces/mamba
cd mamba
pip install -e ".[dev]"

Using MambaLMHeadModel

import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer

# Load pre-trained model
model_name = "state-spaces/mamba-2.8b-slimpj"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

model = MambaLMHeadModel.from_pretrained(
    model_name,
    device="cuda",
    dtype=torch.bfloat16
)
model.eval()

# Text generation
def generate_text(prompt, max_new_tokens=200, temperature=0.7):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")

    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=0.9,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id,
        )

    generated = output[0][input_ids.shape[1]:]
    return tokenizer.decode(generated, skip_special_tokens=True)

# Example
prompt = "State Space Models are powerful because"
result = generate_text(prompt)
print(result)

Fine-tuning Example

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer, get_linear_schedule_with_warmup

class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=512):
        self.encodings = tokenizer(
            texts,
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors="pt"
        )

    def __len__(self):
        return len(self.encodings['input_ids'])

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
        }

def finetune_mamba(
    model_name="state-spaces/mamba-130m",
    texts=None,
    num_epochs=3,
    learning_rate=1e-4,
    batch_size=8,
):
    """Fine-tune Mamba model"""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    tokenizer.pad_token = tokenizer.eos_token

    # Load model
    model = MambaLMHeadModel.from_pretrained(
        model_name,
        device=device,
        dtype=torch.bfloat16
    )

    # Dataset
    dataset = TextDataset(texts, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Optimizer (AdamW recommended for Mamba)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=0.1
    )

    # Scheduler
    total_steps = len(dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=total_steps // 10,
        num_training_steps=total_steps
    )

    # Training loop
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            input_ids = batch['input_ids'].to(device)

            # Language modeling: next token prediction
            outputs = model(input_ids)
            logits = outputs.logits

            # Shift for next token prediction
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()

            loss = nn.CrossEntropyLoss()(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f"Epoch {epoch+1}, Step {batch_idx}: Loss = {loss.item():.4f}")

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1} complete: Avg Loss = {avg_loss:.4f}")

    return model

Custom Mamba Model Configuration

from mamba_ssm import Mamba
from mamba_ssm.models.config_mamba import MambaConfig

# Custom configuration
config = MambaConfig(
    d_model=1024,
    n_layer=48,
    vocab_size=50280,
    d_state=16,
    d_conv=4,
    expand=2,
    dt_rank="auto",
    dt_min=0.001,
    dt_max=0.1,
    dt_init="random",
    dt_scale=1.0,
    dt_init_floor=1e-4,
    rms_norm=True,
    residual_in_fp32=True,
    fused_add_norm=True,
    pad_vocab_size_multiple=8,
)

# Create model (~1.4B parameters)
model = MambaLMHeadModel(config, device="cuda", dtype=torch.bfloat16)
print(f"Parameter count: {sum(p.numel() for p in model.parameters()):,}")

Conclusion: The Future of Mamba

Mamba has brought revolutionary changes to the field of deep learning sequence modeling. Summarizing the key contributions:

  1. Selective Mechanism: SSM parameters that dynamically change based on input
  2. Hardware-Aware Design: Efficient computation leveraging GPU memory hierarchy
  3. Dual Representation: Parallelized during training, recurrent during inference
  4. Linear Complexity: Linear computation and memory complexity with sequence length

There are still challenges to address:

  • Somewhat weaker in-context learning compared to Transformers
  • Needs validation at very large scales
  • Clear demonstration of superiority over attention-based models

However, Mamba and SSM-family models will increasingly play an important role in areas where Transformers struggle — long sequence processing, real-time inference, and edge device deployment.

References