Skip to content
Published on

Mamba & 状態空間モデル完全ガイド:Transformerを超える線形時間シーケンスモデリング

Authors

概要

2017年の画期的な「Attention is All You Need」論文以来、Transformerは自然言語処理から画像、コード、音声まで、ほぼすべてのシーケンスモデリングタスクを支配してきました。しかし、Transformerには根本的な弱点があります:自己注意機構のO(n^2)二乗複雑性です。

2023年12月、Albert GuとTri Daoが発表したMambaはこの問題をエレガントに解決し、ディープラーニングコミュニティに衝撃を与えました。Mambaは古典的な制御理論の概念であるState Space Models(SSM)を現代のディープラーニングに適応させ、長いシーケンスを線形時間で処理できるようにしました。

このガイドでは、SSMの数学的基礎からMambaのコアイノベーション、実践的な実装まですべてをカバーします。


1. Transformerの限界とSSMの台頭

Self-AttentionのO(n^2)複雑性問題

TransformerのコアはSelf-Attention機構です。長さnのシーケンスで、アテンション行列の計算はすべてのトークンペア間の関係を計算する必要があり、時間とメモリの両方の複雑性が**O(n^2)**になります。

アテンション行列サイズ: n × n
= 1,000トークン   → 1,000,000要素
= 10,000トークン  → 100,000,000要素
= 100,000トークン → 10,000,000,000要素(実行不可能!)

この問題のため、実際のLLMはコンテキストウィンドウサイズに制約があります。GPT-4は128Kトークンをサポートしていますが、これには膨大な計算コストと最適化技術が必要です。

長いシーケンス処理の難しさ

長い文書の要約、コードベース全体の理解、長期的な会話の維持などの実際のアプリケーションは非常に長いコンテキストを必要とします。しかしTransformerの二乗複雑性により実際の使用が困難になっています。

例えば、本全体(約100,000語 = 130,000トークン)を一度に処理することは、標準的なTransformerを使った現在のハードウェアでは事実上実行不可能なメモリを必要とします。

再帰モデル(RNN、LSTM)の限界

再帰ニューラルネットワーク(RNN)とLSTMは、各タイムステップで固定サイズの隠れ状態を維持するため、理論的にはO(n)複雑性を持ちます。しかしRNN/LSTMには次の問題があります:

  • 勾配消失/爆発:長いシーケンスでの誤差逆伝播中に勾配が消失または爆発する
  • 並列化不可能:各タイムステップが前の状態に依存するため、GPU並列化が難しい
  • 長距離依存関係の捕捉困難:シーケンス内で離れた情報を結びつけることが難しい

SSMが提供する解決策

State Space Models(SSM)は両方のアプローチの利点を組み合わせます:

  • 学習時:畳み込みによる完全な並列化 → Transformerに匹敵する学習効率
  • 推論時:O(1)メモリとO(n)計算での再帰 → RNNのような効率的な推論
  • 長いシーケンス:線形複雑性により非常に長いシーケンスの処理が可能

2. State Space Modelsの数学的基礎

連続時間SSM

SSMの起源は1960年代のRudolf Kalmanの制御理論に遡ります。連続時間線形システムは次のように表されます:

x'(t) = A·x(t) + B·u(t)
y(t)  = C·x(t) + D·u(t)

ここで:

  • u(t):入力信号
  • x(t):状態ベクトル(潜在状態)、サイズN
  • y(t):出力信号
  • A:N×N状態遷移行列
  • B:N×1入力射影行列
  • C:1×N出力射影行列
  • D:直達項(スキップ接続、通常0に設定)

このシステムは入力u(t)を受け取り、内部状態x(t)を更新し、出力y(t)を生成します。状態x(t)は過去の情報の「サマリー」と考えることができます。

離散化

連続時間SSMをデジタルコンピュータで使用するには、離散化が必要です。サンプリング間隔Deltaを使って、2つの主要な離散化方法があります。

ゼロ次ホールド(ZOH)

A_bar = exp(Delta · A)
B_bar = (Delta · A)^(-1) · (exp(Delta · A) - I) · Delta · B

双一次(Tustin)変換

A_bar = (I - Delta/2 · A)^(-1) · (I + Delta/2 · A)
B_bar = (I - Delta/2 · A)^(-1) · Delta · B

離散化後、システムは次のようになります:

x[t] = A_bar · x[t-1] + B_bar · u[t]
y[t] = C · x[t]

この形式は線形再帰であり、各タイムステップで状態を更新します。

畳み込みカーネルとしてのSSM

SSMの強力な特性は、学習時に畳み込みとして計算できるということです。初期状態x[0] = 0から始めると:

y[0] = C · B_bar · u[0]
y[1] = C · A_bar · B_bar · u[0] + C · B_bar · u[1]
y[2] = C · A_bar^2 · B_bar · u[0] + C · A_bar · B_bar · u[1] + C · B_bar · u[2]
...

これを畳み込みカーネルKとして表現すると:

K = (C·B_bar, C·A_bar·B_bar, C·A_bar^2·B_bar, ...)
y = K * u  (畳み込み)

このカーネルKは並列に計算でき、FFTを使えばO(n log n)で計算できます。

これがSSMの根本的な二重性です:

  • 推論:O(1)メモリでの再帰
  • 学習:O(n log n)並列計算での畳み込み

3. S4(構造化状態空間シーケンスモデル)

S4は2021年にAlbert Guらによって発表され、SSMをディープラーニングに実践的に適用した最初の重要な成果です。

HiPPO行列初期化

S4の主要な貢献の1つはHiPPO(High-order Polynomial Projection Operators)行列初期化です。Aをランダムに初期化するだけでは勾配消失問題が発生します。HiPPOは多項式を使って過去の入力を近似する特別に設計されたA行列を提供します。

HiPPO-LegS(Legendre多項式ベース):

A[n,k] = -sqrt((2n+1)(2k+1))  n > k の場合
A[n,k] = -(n+1)               n == k の場合
A[n,k] = 0                    n < k の場合

この初期化により、状態x[t]はすべての過去の入力u[0..t]の最適な多項式近似を維持します。

構造化行列A(DPLR)

A行列の畳み込みカーネルを効率的に計算するため、S4はAを**対角プラス低ランク(DPLR)**形式で表現します:

A = Λ - P·Q^T

ここでΛは対角行列で、P、Qは低ランクベクトルです。この構造を使用することで、カーネル計算をO(N)に削減できます。

効率的な計算

import torch
import torch.nn as nn
import numpy as np

class S4Layer(nn.Module):
    """S4レイヤーの簡略化した実装"""
    def __init__(self, d_model, d_state=64, dropout=0.0):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state

        # HiPPO-LegS初期化
        A = self._make_hippo(d_state)
        # DPLR分解
        self.A = nn.Parameter(torch.tensor(A, dtype=torch.float32))
        self.B = nn.Parameter(torch.randn(d_state, 1) * 0.01)
        self.C = nn.Parameter(torch.randn(1, d_state) * 0.01)
        self.log_delta = nn.Parameter(torch.zeros(d_model))
        self.D = nn.Parameter(torch.ones(d_model))

    def _make_hippo(self, N):
        """HiPPO-LegS行列を生成"""
        P = np.sqrt(1 + 2 * np.arange(N))
        A = P[:, np.newaxis] * P[np.newaxis, :]
        A = np.tril(A) - np.diag(np.arange(N))
        return -A

    def discretize(self, A, B, C, delta):
        """ZOH離散化"""
        dA = torch.matrix_exp(delta.unsqueeze(-1) * A)
        dB = torch.linalg.solve(A, (dA - torch.eye(A.shape[0])) @ B)
        return dA, dB, C

    def forward(self, u):
        # u: (batch, seq_len, d_model)
        delta = torch.exp(self.log_delta)
        # 離散化と畳み込み計算
        # 実際の実装はより複雑で、ここでは概念的に簡略化
        return u

4. H3(Hungry Hungry Hippos)

2022年に発表されたH3は、言語モデリングによりよく適合したS4の改良版です。タイトルはHiPPOの頭字語のユーモラスな拡張です。

S4との違い

S4は長距離依存関係をうまく捉えますが、言語モデリングで重要なトークン間インタラクションが欠けています。Attentionは「この単語とあの単語はどれだけ関連しているか?」を直接計算しますが、S4は純粋な再帰でこれを捉えるのが難しいです。

ゲーティングメカニズム

H3は2つのSSMを使用し、それらの間にゲーティングを追加します:

class H3Layer(nn.Module):
    """H3レイヤー"""
    def __init__(self, d_model, d_state=64):
        super().__init__()
        # 2つのSSM(シフトSSM + 対角SSM)
        self.shift_ssm = S4Layer(d_model, d_state)
        self.diag_ssm = S4Layer(d_model, d_state)

        # 射影
        self.Q_proj = nn.Linear(d_model, d_model)
        self.K_proj = nn.Linear(d_model, d_model)
        self.V_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        # Q, K, V射影
        q = self.Q_proj(x)
        k = self.K_proj(x)
        v = self.V_proj(x)

        # SSM 1: KにシフトSSMを適用
        k_ssm = self.shift_ssm(k)

        # 乗算ゲーティング: Q要素ごとのK_ssm
        gated = q * k_ssm

        # SSM 2: Vに対角SSMを適用し、gatedと乗算
        v_ssm = self.diag_ssm(v)
        output = gated * v_ssm

        return self.out_proj(output)

言語モデリングの改善

H3はGPT-2サイズの言語モデルでTransformerに近いパフォーマンスを達成しながら、大幅に高速な推論速度を示しました。これは言語モデリングにおけるSSMの実用性を示す重要なマイルストーンでした。


5. Mamba(選択的状態空間モデル)

2023年12月、Albert GuとTri Daoが発表したMambaはSSM研究で最も重要な進歩を表しています。Mambaのコアイノベーションは選択的メカニズムです。

コアイノベーション:選択的メカニズム

S4とH3の根本的な限界は、行列A、B、Cが入力非依存であることでした。畳み込みベースの学習のためにカーネルを固定する必要がありました。

Mambaはこの制約を破り、**S6(選択的SSM)**を導入:B、C、Deltaの各行列を入力u(t)の関数にします。

B(t) = Linear_B(u(t))   # 入力によって変わるB
C(t) = Linear_C(u(t))   # 入力によって変わるC
Delta(t) = softplus(Linear_Delta(u(t)))  # 入力によって変わるDelta

これにより、モデルは入力に基づいて、どの情報を状態に保存し、どれを破棄するかを動的に決定できます。

直感的には:

  • 大きなDelta → 現在の入力が状態に強く影響する(情報を記憶)
  • 小さなDelta → 状態がほとんど変化しない(情報を無視)

これはLSTMのゲート(入力ゲート、忘却ゲート)に似た役割を果たしますが、連続時間SSMフレームワーク内でです。

ハードウェア対応アルゴリズム

選択的メカニズムの導入は、B、Cが入力に依存するため、畳み込みを計算に使えなくなることを意味します。これは深刻な計算効率の低下につながる可能性があります。

MambaはGPUメモリ階層(HBMとSRAM)を活用したハードウェア対応アルゴリズム(カーネル融合技術)でこれを解決します:

問題:タイムステップごとの再帰中に中間状態をHBM(低速メモリ)に保存することでメモリ帯域幅のボトルネックが生じる。

解決策

  • すべての中間計算を高速SRAM(オンチップメモリ)で実行
  • 完全な畳み込みの代わりに並列スキャンアルゴリズムを使用
  • 最終出力のみをHBMに保存
# 並列スキャンの概念(実際の実装はCUDA)
def parallel_scan(gates, tokens):
    """
    線形再帰を並列に計算
    x[t] = gates[t] * x[t-1] + tokens[t]

    二分木構造でO(log n)の並列深さを達成
    """
    n = len(tokens)
    # アップスウィープフェーズ
    log_n = int(np.log2(n))
    for d in range(log_n):
        step = 2 ** (d + 1)
        for i in range(step - 1, n, step):
            gates[i] = gates[i] * gates[i - 2**d]
            tokens[i] = gates[i - 2**d] * tokens[i - 2**d] + tokens[i]
    # ダウンスウィープフェーズ(省略)
    return tokens

Mambaブロック構造

入力 x (B, L, D)
    ├──────────────────────┐
    │                      │
  Linear(DED)          Linear(DED)
  + SiLU活性化SSM (S6)
    │                      │
    └─────── ⊙ ────────────┘
    (要素ごとの乗算)
  Linear(EDD)
  出力 y (B, L, D)

ここでEは拡張率(通常2)、Dはモデル次元です。

Mambaの完全な実装

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

class MambaBlock(nn.Module):
    """
    Mambaブロック実装
    論文: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces"
    """
    def __init__(
        self,
        d_model,       # モデル次元D
        d_state=16,    # SSM状態次元N
        d_conv=4,      # 畳み込みカーネルサイズ
        expand=2,      # 拡張率E
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        bias=False,
    ):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)

        if dt_rank == "auto":
            self.dt_rank = max(1, int(d_model / 16))
        else:
            self.dt_rank = dt_rank

        # 入力射影 (D → 2*ED、両方のブランチを一度に計算)
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=bias)

        # ローカル畳み込み(depthwise conv)
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=True,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,  # 因果パディング
        )

        # SSMパラメータ射影
        self.x_proj = nn.Linear(
            self.d_inner, self.dt_rank + self.d_state * 2, bias=False
        )
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

        # A初期化(HiPPOベース)
        A = repeat(
            torch.arange(1, self.d_state + 1),
            "n -> d n",
            d=self.d_inner
        )
        self.A_log = nn.Parameter(torch.log(A))
        self.A_log._no_weight_decay = True

        # D(スキップ接続)
        self.D = nn.Parameter(torch.ones(self.d_inner))
        self.D._no_weight_decay = True

        # 出力射影
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)

    def forward(self, hidden_states):
        """
        hidden_states: (B, L, D)
        戻り値: (B, L, D)
        """
        batch, seqlen, dim = hidden_states.shape

        # 入力射影
        xz = self.in_proj(hidden_states)
        x, z = xz.chunk(2, dim=-1)  # それぞれ (B, L, ED)

        # 畳み込み(因果的1D conv)
        x = rearrange(x, "b l d -> b d l")
        x = self.conv1d(x)[:, :, :seqlen]  # 因果トリミング
        x = rearrange(x, "b d l -> b l d")
        x = F.silu(x)

        # SSM
        y = self.ssm(x)

        # ゲーティング(zブランチと要素ごとに乗算)
        y = y * F.silu(z)

        # 出力射影
        output = self.out_proj(y)
        return output

    def ssm(self, x):
        """選択的SSM(S6)計算"""
        d_in, n = self.d_inner, self.d_state

        # A行列(-exp(-A_log)は常に負を保証 = 安定性)
        A = -torch.exp(self.A_log.float())  # (ED, N)

        # x_proj: Delta、B、Cを計算
        x_dbl = self.x_proj(x)  # (B, L, dt_rank + 2N)
        delta, B, C = x_dbl.split(
            [self.dt_rank, n, n], dim=-1
        )

        # Delta: softplusで正の値を保証
        delta = F.softplus(self.dt_proj(delta))  # (B, L, ED)

        # ZOH離散化と選択的スキャン
        y = self.selective_scan(x, delta, A, B, C, self.D)
        return y

    def selective_scan(self, u, delta, A, B, C, D):
        """
        選択的スキャンアルゴリズム
        実際には、mamba_ssmのCUDAカーネルを使用
        ここでは概念的理解のためにPure PyTorchで示す
        """
        b, l, d_in = u.shape
        n = A.shape[1]

        # 離散化
        # delta: (B, L, ED), A: (ED, N) → dA: (B, L, ED, N)
        deltaA = torch.exp(
            torch.einsum("bld,dn->bldn", delta, A)
        )
        # delta: (B, L, ED), B: (B, L, N), u: (B, L, ED)
        # → deltaB_u: (B, L, ED, N)
        deltaB_u = torch.einsum("bld,bln,bld->bldn", delta, B, u)

        # 再帰スキャン
        x = torch.zeros(b, d_in, n, device=u.device, dtype=u.dtype)
        ys = []
        for i in range(l):
            x = deltaA[:, i] * x + deltaB_u[:, i]
            y = torch.einsum("bdn,bn->bd", x, C[:, i, :])
            ys.append(y)
        y = torch.stack(ys, dim=1)  # (B, L, ED)

        # D(スキップ接続)
        y = y + u * D

        return y


class MambaModel(nn.Module):
    """積み重ねられたMambaブロックで構成されるシーケンスモデル"""
    def __init__(self, d_model, n_layers, d_state=16, expand=2):
        super().__init__()
        self.layers = nn.ModuleList([
            MambaBlock(d_model, d_state=d_state, expand=expand)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        """x: (B, L, D)"""
        for layer in self.layers:
            x = x + layer(self.norm(x))  # 前段正規化残差
        return x

6. Mamba 2

2024年5月、Tri DaoとAlbert GuがMamba 2を発表し、理論的基盤をさらに強化しました。

Mamba 1との違い

Mamba 1の選択的スキャンは各タイムステップで逐次計算が必要でした。Mamba 2は**状態空間双対性(SSD)**を発見して、さらに効率的にしました。

状態空間双対性(SSD)

Mamba 2のコア洞察:特定の構造を持つSSMは半分離可能行列として表現でき、これは特定の形式のアテンションと数学的に等価です。

この数学的双対性により:

  1. SSM計算を行列乗算形式として表現
  2. 高度に最適化されたBLASライブラリを活用
  3. Tensor Core利用でGPU効率を最大化
# SSD演算の概念
# SSM再帰: x[t] = A[t] * x[t-1] + B[t] * u[t]
#           y[t] = C[t]^T * x[t]
#
# チャンクで処理:
# - チャンク内: 行列乗算による並列計算
# - チャンク間: 状態伝播のための再帰

class Mamba2Block(nn.Module):
    """Mamba 2ブロック"""
    def __init__(self, d_model, d_state=64, n_heads=8, chunk_size=64):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.n_heads = n_heads
        self.chunk_size = chunk_size
        self.d_head = d_model // n_heads

        # マルチヘッド構造
        self.norm = nn.LayerNorm(d_model)
        self.in_proj = nn.Linear(d_model, d_model * 2 + d_state * 2 + n_heads)
        self.out_proj = nn.Linear(d_model, d_model)

        # Aパラメータ(ヘッドごと)
        self.A_log = nn.Parameter(torch.randn(n_heads))

    def forward(self, x):
        """x: (B, L, D)"""
        # 実際の実装ではmamba_ssmパッケージのCUDAカーネルを使用
        return x

マルチヘッド構造

Mamba 2はマルチヘッドSSMを導入し、TransformerのMulti-Head Attentionと構造的に似ています。これにより:

  • 表現力の向上
  • アテンションとの理論的接続
  • ハイブリッドアーキテクチャの設計が容易

TransformerとのIntegration可能性

SSD理論は特定のSSMが特定のアテンション機構と数学的に等価であることを示しています。これにより、MambaとTransformerを混合したハイブリッドアーキテクチャの理論的基盤が提供されます。


7. ハイブリッドアーキテクチャ

MambaFormer

MambaFormerはMambaブロックとTransformerブロックを交互に並べます:

レイヤー1: Mambaブロック   (局所パターンを捕捉)
レイヤー2: Attention     (グローバル依存関係を捕捉)
レイヤー3: Mambaブロック
レイヤー4: Attention
...

この構造は各コンポーネントの強みを活かします:

  • Mamba:効率的なシーケンス処理、局所パターン
  • Attention:選択的な情報検索、グローバル依存関係

Jamba(SSM + Transformer + MoE)

AI21 Labsが2024年にリリースしたJambaは、より複雑なハイブリッドです:

Jamba = Mamba + Transformer + Mixture of Experts

アーキテクチャ:
- 52Bパラメータ(アクティブ: 12B)
- レイヤー比率: Attention 1 : Mamba 7
- 一部のレイヤーにMoE適用
- 256Kコンテキストウィンドウをサポート

同サイズのTransformerと比較して、Jambaは:

  • 推論スループットが3倍向上
  • 長いコンテキストでメモリ効率が大幅に改善
# Jambaスタイルのハイブリッドブロック(概念)
class JambaLayer(nn.Module):
    def __init__(self, d_model, layer_idx, attn_every_n=8):
        super().__init__()
        self.use_attention = (layer_idx % attn_every_n == 0)

        if self.use_attention:
            self.mixer = nn.MultiheadAttention(d_model, num_heads=8)
        else:
            self.mixer = MambaBlock(d_model)

        # MoE(一部のレイヤーのみ)
        self.use_moe = (layer_idx % 2 == 0)
        if self.use_moe:
            self.ffn = MixtureOfExperts(d_model, n_experts=16)
        else:
            self.ffn = nn.Sequential(
                nn.Linear(d_model, d_model * 4),
                nn.SiLU(),
                nn.Linear(d_model * 4, d_model)
            )

    def forward(self, x):
        x = x + self.mixer(x)
        x = x + self.ffn(x)
        return x

RWKV(線形Attention)

RWKVはTransformerとRNNのハイブリッドです。「Receptance Weighted Key Value」の略:

- 学習: Transformerのように並列化(行列形式)
- 推論: RNNのように再帰的(O(1)状態)
- Attentionなしのトークンミキシング

最新バージョン(RWKV-6、RWKV-7)はMambaと競合するパフォーマンスを示しています。

RetNet(Retentive Network)

MicrosoftのRetNetは「学習並列性、推論効率、競争力のあるパフォーマンス」を同時に達成することを目指します:

3つの計算パラダイム:
1. 並列: 学習中のO(n^2)並列計算(Transformerより低い定数)
2. 再帰: 推論中のO(1)メモリ
3. チャンクワイズ再帰: バランスの取れた中間アプローチ

8. Mambaのパフォーマンス比較

推論速度(線形スケール)

Mambaの最大の強みはシーケンス長に対する線形スケーリングです。

シーケンス長2Kでのパフォーマンスを1xに正規化:

シーケンス長   Transformer   Mamba
1K             0.5x          0.5x
2K             1.0x          1.0x
4K             3.5x          2.0x    ← Transformerより1.75倍高速
8K             13x           4.0x    ← Transformerより3.25倍高速
16K            50x           8.0x    ← Transformerより6.25倍高速
100K           ~1800x        ~50x    ← Transformerより36倍高速

メモリ効率

推論時の状態サイズ比較:

モデル           推論メモリ(1Kトークンあたり)
Transformer     O(n) KVキャッシュ
Mamba           O(1) SSM状態(固定サイズ!)

: 130Mパラメータモデル、1Mトークンシーケンス
- Transformer: ~16GB KVキャッシュ
- Mamba: ~1MB状態(一定!)

長いシーケンスタスク

Long Range Arena(LRA)ベンチマーク(シーケンス長1K-16K):

モデル        ListOps  Text  Retrieval  Image  Path-X  平均
Transformer  36.4     65.0  57.5       42.4   0.0     40.3
LSTM         35.9     63.7  65.0       43.3   0.0     41.6
S4           59.6     86.8  90.9       88.7   86.1    82.4
Mamba        ~S4レベルのパフォーマンス、高速処理

Mambaは特に選択的コピー誘導ヘッドなどの合成ベンチマークでS4を上回り、これらは実際の言語モデリング能力により関連しています。


9. Mambaのアプリケーション

自然言語処理

Mambaベースの言語モデルが急速に登場しています:

  • MambaChat:会話AIアシスタント
  • Falcon Mamba:TIIがリリースしたオープンソース7B Mambaモデル
  • CodeMamba:コード生成に特化

Mambaはランダムな文書処理、要約、翻訳においてTransformerと比較して効率的です。

バイオインフォマティクス

ゲノム配列分析は非常に長いシーケンス(数百万の塩基対)を扱い、Mambaが特に有利です:

  • Caduceus:長いDNA配列のモデリング
  • Hyena:長配列のDNA/タンパク質分析
  • タンパク質構造予測への潜在的応用
# 生物学的配列モデリングのためのMamba例
from mamba_ssm import MambaLMHeadModel
import torch

# DNA配列モデリング(A, T, G, C, Nトークン)
DNA_VOCAB = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4}
VOCAB_SIZE = len(DNA_VOCAB)

# 非常に長いシーケンスを効率的に処理
model = MambaLMHeadModel.from_pretrained(
    "state-spaces/mamba-130m",
    device="cuda",
    dtype=torch.float16
)

時系列分析

金融、気象、IoTセンサーデータなどの長い時系列データに対して、Mambaはその強みを発揮します:

  • TimeMamba:長い時系列予測
  • MambaMixer:多変量時系列モデリング
  • S4/S5ベースの時系列モデルの進歩

画像処理(VMamba)

VMambaはMambaを2D画像処理に拡張します:

# VMambaのコア: 2D選択的スキャン
# 4方向に画像をスキャンして2D構造を捕捉

# 方向:
# 1. 左→右、上→下(標準ラスタスキャン)
# 2. 右→左、下→上(逆)
# 3. 上→下、左→右(列優先)
# 4. 下→上、右→左(逆)

class VMambaBlock(nn.Module):
    """VMamba: Visual Mambaブロック"""
    def __init__(self, d_model, d_state=16):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        # 4方向SSM
        self.ssms = nn.ModuleList([
            MambaBlock(d_model, d_state) for _ in range(4)
        ])
        self.out_proj = nn.Linear(d_model * 4, d_model)

    def forward(self, x):
        """x: (B, H, W, D) - 画像パッチ埋め込み"""
        b, h, w, d = x.shape
        x_flat = x.view(b, h*w, d)

        outputs = []
        # 4方向スキャン
        for i, ssm in enumerate(self.ssms):
            if i == 0:   # 順方向
                seq = x_flat
            elif i == 1: # 逆方向
                seq = x_flat.flip(1)
            elif i == 2: # 列優先
                seq = x.permute(0, 2, 1, 3).reshape(b, h*w, d)
            else:        # 列優先逆
                seq = x.permute(0, 2, 1, 3).reshape(b, h*w, d).flip(1)

            out = ssm(seq)
            if i % 2 == 1:
                out = out.flip(1)
            outputs.append(out)

        # 4方向の結果を結合
        combined = torch.cat(outputs, dim=-1)
        return self.out_proj(combined).view(b, h, w, d)

10. 実践的な使用方法

mamba-ssmパッケージのインストール

# 必要なパッケージ
pip install torch torchvision torchaudio
pip install causal-conv1d>=1.2.0
pip install mamba-ssm

# またはソースからインストール(最新機能)
git clone https://github.com/state-spaces/mamba
cd mamba
pip install -e ".[dev]"

MambaLMHeadModelの使用

import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer

# 事前学習済みモデルの読み込み
model_name = "state-spaces/mamba-2.8b-slimpj"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

model = MambaLMHeadModel.from_pretrained(
    model_name,
    device="cuda",
    dtype=torch.bfloat16
)
model.eval()

# テキスト生成
def generate_text(prompt, max_new_tokens=200, temperature=0.7):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")

    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=0.9,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id,
        )

    generated = output[0][input_ids.shape[1]:]
    return tokenizer.decode(generated, skip_special_tokens=True)

# 例
prompt = "State Space Modelsが強力な理由は"
result = generate_text(prompt)
print(result)

ファインチューニング例

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer, get_linear_schedule_with_warmup

class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=512):
        self.encodings = tokenizer(
            texts,
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors="pt"
        )

    def __len__(self):
        return len(self.encodings['input_ids'])

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
        }

def finetune_mamba(
    model_name="state-spaces/mamba-130m",
    texts=None,
    num_epochs=3,
    learning_rate=1e-4,
    batch_size=8,
):
    """Mambaモデルをファインチューニング"""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    tokenizer.pad_token = tokenizer.eos_token

    # モデルの読み込み
    model = MambaLMHeadModel.from_pretrained(
        model_name,
        device=device,
        dtype=torch.bfloat16
    )

    # データセット
    dataset = TextDataset(texts, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # オプティマイザー(MambaにはAdamWを推奨)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=0.1
    )

    # スケジューラー
    total_steps = len(dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=total_steps // 10,
        num_training_steps=total_steps
    )

    # 学習ループ
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            input_ids = batch['input_ids'].to(device)

            # 言語モデリング: 次トークン予測
            outputs = model(input_ids)
            logits = outputs.logits

            # 次トークン予測のためのシフト
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()

            loss = nn.CrossEntropyLoss()(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f"エポック {epoch+1}, ステップ {batch_idx}: 損失 = {loss.item():.4f}")

        avg_loss = total_loss / len(dataloader)
        print(f"エポック {epoch+1} 完了: 平均損失 = {avg_loss:.4f}")

    return model

カスタムMambaモデルの設定

from mamba_ssm import Mamba
from mamba_ssm.models.config_mamba import MambaConfig

# カスタム設定
config = MambaConfig(
    d_model=1024,
    n_layer=48,
    vocab_size=50280,
    d_state=16,
    d_conv=4,
    expand=2,
    dt_rank="auto",
    dt_min=0.001,
    dt_max=0.1,
    dt_init="random",
    dt_scale=1.0,
    dt_init_floor=1e-4,
    rms_norm=True,
    residual_in_fp32=True,
    fused_add_norm=True,
    pad_vocab_size_multiple=8,
)

# モデルの作成(約14億パラメータ)
model = MambaLMHeadModel(config, device="cuda", dtype=torch.bfloat16)
print(f"パラメータ数: {sum(p.numel() for p in model.parameters()):,}")

まとめ:Mambaの未来

Mambaはディープラーニングのシーケンスモデリング分野に革命的な変化をもたらしました。主な貢献をまとめると:

  1. 選択的メカニズム:入力に基づいて動的に変化するSSMパラメータ
  2. ハードウェア対応設計:GPUメモリ階層を活用した効率的な計算
  3. 二重表現:学習時に並列化、推論時に再帰的
  4. 線形複雑性:シーケンス長に対して線形な計算とメモリ複雑性

まだ対処すべき課題があります:

  • Transformerと比較してやや弱いコンテキスト内学習
  • 非常に大きなスケールでの検証が必要
  • アテンションベースモデルに対する明確な優位性の実証

しかし、MambaとSSMファミリーのモデルは、Transformerが苦手とする分野、つまり長いシーケンス処理、リアルタイム推論、エッジデバイス展開において重要な役割を果たしていくでしょう。

参考文献