Skip to content
Published on

Diffusion Transformer(DiT)アーキテクチャ分析:U-NetからTransformerへの転換

Authors
  • Name
    Twitter

はじめに

2022年末に発表された**「Scalable Diffusion Models with Transformers」**(William Peebles & Saining Xie、2023 ICCV)は、画像生成のバックボーンをU-NetからTransformerへ転換する起点となりました。この論文で提案されたDiT(Diffusion Transformer)は、その後OpenAIのSORA、Stability AIのStable Diffusion 3、そして2025年のICLRで発表されたDynamic DiTへと続く核心アーキテクチャとなりました。

背景:なぜU-Netから離れるのか?

U-Netの限界

従来のDDPM、LDM(Latent Diffusion Model)はU-Netをバックボーンとして使用しています。U-Netは画像生成で良好な性能を示しますが、根本的な限界があります:

  • スケーリング則の不在: モデルを大きくしても性能向上が予測不可能
  • アーキテクチャの複雑さ: skip connection、多様な解像度のfeature map管理
  • 演算の非効率: 高解像度でメモリ・演算コストが急増
  • 他のモダリティとの統合の困難さ: テキスト、動画などとの統合が不自然

Transformerの利点

一方、Transformer(ViT)は:

  • 実証済みのスケーリング則: パラメータ数に比例した性能向上(LLMで実証済み)
  • シンプルなアーキテクチャ: Self-attention + FFNの繰り返し
  • モダリティ不問: テキスト、画像、動画すべてを同一構造で処理可能
  • ハードウェア最適化: GPU/TPUでの効率的な並列処理

DiTアーキテクチャの詳細

全体パイプライン

DiTはLatent Diffusionフレームワーク上で動作します:

  1. 画像をVAEエンコーダで潜在空間(latent space)に変換
  2. 潜在表現をパッチに分割(ViTと同様)
  3. DiTブロック(Transformer)でノイズを予測
  4. VAEデコーダで画像を復元

パッチ埋め込み

import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    """潜在表現をパッチシーケンスに変換"""
    def __init__(self, patch_size=2, in_channels=4, embed_dim=1152):
        super().__init__()
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # x: (B, C, H, W) -> (B, N, D)
        # 例: (B, 4, 32, 32) -> (B, 256, 1152) (patch_size=2)
        x = self.proj(x)  # (B, D, H/p, W/p)
        x = x.flatten(2).transpose(1, 2)  # (B, N, D)
        return x

入力は256x256の画像をVAEで32x32の潜在表現に変換した後、patch_size=2で分割すると16x16=256個のトークンになります。

条件付けメカニズム — 4つのバリエーション

DiTはタイムステップ(t)とクラスラベル(c)を注入する4つの方式を実験しています:

1. In-context Conditioning

class InContextConditioning(nn.Module):
    """tとcを追加トークンとしてシーケンスに連結"""
    def __init__(self, embed_dim):
        super().__init__()
        self.t_embed = TimestepEmbedder(embed_dim)
        self.c_embed = LabelEmbedder(1000, embed_dim)

    def forward(self, x, t, c):
        t_token = self.t_embed(t).unsqueeze(1)  # (B, 1, D)
        c_token = self.c_embed(c).unsqueeze(1)  # (B, 1, D)
        x = torch.cat([t_token, c_token, x], dim=1)  # (B, N+2, D)
        return x

2. Cross-Attention

class CrossAttentionBlock(nn.Module):
    """tとcをcross-attentionで注入"""
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.mlp = MLP(embed_dim)

    def forward(self, x, cond):
        x = x + self.self_attn(x, x, x)[0]
        x = x + self.cross_attn(x, cond, cond)[0]
        x = x + self.mlp(x)
        return x

3. Adaptive Layer Norm(adaLN)

class AdaLNBlock(nn.Module):
    """条件に応じてLayerNormのscale/shiftを調整"""
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(embed_dim, 2 * embed_dim)
        )

    def forward(self, x, cond):
        shift, scale = self.adaLN_modulation(cond).chunk(2, dim=-1)
        x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
        x = x + self.attn(x, x, x)[0]
        return x

4. adaLN-Zero(最終選択)

class DiTBlock(nn.Module):
    """adaLN-Zero: 初期化時に残差接続をゼロから開始"""
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(4 * embed_dim, embed_dim),
        )
        # 6つのパラメータ: gamma1, beta1, alpha1, gamma2, beta2, alpha2
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(embed_dim, 6 * embed_dim)
        )
        # alphaを0に初期化 -> 学習初期にidentity functionとして動作
        nn.init.zeros_(self.adaLN_modulation[-1].weight)
        nn.init.zeros_(self.adaLN_modulation[-1].bias)

    def forward(self, x, cond):
        # cond: (B, D) - タイムステップ + クラス埋め込みの和
        shift_msa, scale_msa, gate_msa, \
        shift_mlp, scale_mlp, gate_mlp = \
            self.adaLN_modulation(cond).chunk(6, dim=-1)

        # Self-Attention with adaLN
        h = self.norm1(x)
        h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
        h = self.attn(h, h, h)[0]
        x = x + gate_msa.unsqueeze(1) * h  # gateは0から開始

        # FFN with adaLN
        h = self.norm2(x)
        h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
        h = self.mlp(h)
        x = x + gate_mlp.unsqueeze(1) * h  # gateは0から開始

        return x

adaLN-Zeroの核心: gateパラメータを0に初期化することで、学習初期に各DiTブロックがidentity functionとして動作します。これにより深いネットワークの学習安定性が大幅に向上します。

DiTモデル全体

class DiT(nn.Module):
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        embed_dim=1152,
        depth=28,
        num_heads=16,
        num_classes=1000,
    ):
        super().__init__()
        self.patch_embed = PatchEmbed(patch_size, in_channels, embed_dim)
        num_patches = (input_size // patch_size) ** 2  # 256

        # 位置埋め込み
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, embed_dim)
        )

        # タイムステップ & クラス埋め込み
        self.t_embedder = TimestepEmbedder(embed_dim)
        self.y_embedder = LabelEmbedder(num_classes, embed_dim)

        # DiTブロック
        self.blocks = nn.ModuleList([
            DiTBlock(embed_dim, num_heads) for _ in range(depth)
        ])

        # 最終レイヤー
        self.final_layer = FinalLayer(embed_dim, patch_size, in_channels * 2)

    def forward(self, x, t, y):
        # パッチ埋め込み + 位置埋め込み
        x = self.patch_embed(x) + self.pos_embed

        # 条件埋め込み
        t_emb = self.t_embedder(t)  # (B, D)
        y_emb = self.y_embedder(y)  # (B, D)
        cond = t_emb + y_emb        # (B, D)

        # DiTブロックを通過
        for block in self.blocks:
            x = block(x, cond)

        # ノイズと分散を予測
        x = self.final_layer(x, cond)  # (B, N, 2*C*p*p)
        return x

モデルバリエーションとスケーリング

モデルdepthembed_dimheadsパラメータFID-50K
DiT-S/212384633M68.4
DiT-B/21276812130M43.5
DiT-L/224102416458M9.62
DiT-XL/228115216675M2.27

核心的発見:DiTはLLMと同一のスケーリング則に従います。 モデルサイズを大きくするほどFIDが一貫して改善されます。

パッチサイズの影響

モデルpatch_sizeトークン数GflopsFID
DiT-XL/881611914.2
DiT-XL/44645255.11
DiT-XL/2225611212.27

パッチサイズが小さいほど(トークン数が多いほど)品質は良くなりますが、演算量が大幅に増加します。

Classifier-Free Guidanceの適用

def sample_with_cfg(model, z, class_labels, cfg_scale=4.0):
    """Classifier-Free Guidanceでサンプリング"""
    # 条件付き + 無条件予測を同時に
    z_combined = torch.cat([z, z], dim=0)
    y_combined = torch.cat([class_labels, torch.full_like(class_labels, 1000)])

    for t in reversed(range(1000)):
        t_batch = torch.full((z_combined.shape[0],), t, device=z.device)
        noise_pred = model(z_combined, t_batch, y_combined)

        # CFG適用
        cond_pred, uncond_pred = noise_pred.chunk(2)
        guided_pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred)

        # DDPM step
        z = ddpm_step(z, guided_pred, t)

    return z

CFG scale=1.5でDiT-XL/2はFID 2.27を達成し、当時のImageNet 256x256でSOTAを記録しました。

DiTの後続への影響

SORA(OpenAI、2024)

  • DiTを動画に拡張したSpatial-Temporal DiT
  • 3Dパッチ埋め込み(空間 + 時間)
  • 可変解像度・長さに対応

Stable Diffusion 3(Stability AI、2024)

  • DiT + Flow Matching
  • MM-DiT:テキストと画像を同一Transformerで処理
  • 別途のtext encoderなしでjoint attention

Dynamic DiT(ICLR 2025)

  • Early Exitメカニズムで推論を効率化
  • ノイズが大きい初期ステップでは浅いレイヤー、ディテール生成時に深いレイヤーを使用
  • 従来DiT比40%の推論時間削減

PyTorch実装チュートリアル

# DiT公式コードをクローン
git clone https://github.com/facebookresearch/DiT.git
cd DiT

# 環境設定
pip install torch torchvision diffusers accelerate

# 事前学習モデルのダウンロード
python download.py DiT-XL/2

# サンプル生成
python sample.py \
  --model DiT-XL/2 \
  --image-size 256 \
  --num-sampling-steps 250 \
  --seed 42 \
  --class-labels 207 360 387 974 88 979 417 279

カスタムトレーニング

from accelerate import Accelerator
from torch.utils.data import DataLoader

accelerator = Accelerator()

# モデル初期化
model = DiT(
    input_size=32,  # VAE latent size
    patch_size=2,
    in_channels=4,
    embed_dim=1152,
    depth=28,
    num_heads=16,
    num_classes=1000,
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=0.0
)

model, optimizer, dataloader = accelerator.prepare(
    model, optimizer, dataloader
)

# 学習ループ
for epoch in range(epochs):
    for batch in dataloader:
        images, labels = batch
        # VAEエンコーディング
        with torch.no_grad():
            latents = vae.encode(images).latent_dist.sample() * 0.18215

        # ランダムタイムステップ
        t = torch.randint(0, 1000, (latents.shape[0],), device=latents.device)

        # ノイズ追加
        noise = torch.randn_like(latents)
        noisy_latents = scheduler.add_noise(latents, noise, t)

        # ノイズ予測
        noise_pred = model(noisy_latents, t, labels)
        loss = F.mse_loss(noise_pred, noise)

        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

まとめ

DiTは画像生成におけるアーキテクチャパラダイムの転換を牽引しました:

  • U-Netの複雑な構造を純粋なTransformerに置き換え
  • LLMと同一のスケーリング則が画像生成にも適用されることを実証
  • adaLN-Zeroにより深いTransformerの安定的な学習を実現
  • SORA、SD3などの後続モデルの基盤アーキテクチャとして定着

クイズ:DiTアーキテクチャ理解度チェック(8問)

Q1. DiTがU-Netの代わりにTransformerを使用する主な理由は?

実証済みのスケーリング則、シンプルなアーキテクチャ、他のモダリティとの統合の容易さのためです。

Q2. DiTにおいて画像はどのようにトークンシーケンスに変換されますか?

VAEで潜在表現(latent)に変換した後、パッチ単位で分割して線形投影します。

Q3. adaLN-Zeroの「Zero」が意味するものは?

gateパラメータを0に初期化することで、学習初期に各ブロックがidentity functionとして動作するようにします。

Q4. DiTで実験された4つの条件付け方式のうち、最終的に選択されたものは?

adaLN-Zeroです。In-context、Cross-Attention、adaLN、adaLN-Zeroの中で最も良いFIDを記録しました。

Q5. パッチサイズが小さいほど、どのようなトレードオフがありますか?

トークン数が増加して画像品質(FID)は向上しますが、演算量(Gflops)が二乗に比例して増加します。

Q6. Classifier-Free GuidanceはDiTでどのように実装されていますか?

クラスラベルのドロップアウトで学習し、推論時に条件付き/無条件予測の重み付き和でサンプリングします。

Q7. DiT-XL/2のImageNet 256x256 FID-50Kスコアは?

2.27で、当時のclass-conditional生成SOTAを達成しました。

Q8. Dynamic DiT(ICLR 2025)の核心アイデアは?

Early Exitメカニズムにより、ノイズレベルに応じて使用するレイヤー数を動的に調整し、推論時間を40%削減します。