- Authors
- Name
- はじめに
- 背景:なぜU-Netから離れるのか?
- DiTアーキテクチャの詳細
- モデルバリエーションとスケーリング
- Classifier-Free Guidanceの適用
- DiTの後続への影響
- PyTorch実装チュートリアル
- まとめ
はじめに
2022年末に発表された**「Scalable Diffusion Models with Transformers」**(William Peebles & Saining Xie、2023 ICCV)は、画像生成のバックボーンをU-NetからTransformerへ転換する起点となりました。この論文で提案されたDiT(Diffusion Transformer)は、その後OpenAIのSORA、Stability AIのStable Diffusion 3、そして2025年のICLRで発表されたDynamic DiTへと続く核心アーキテクチャとなりました。
背景:なぜU-Netから離れるのか?
U-Netの限界
従来のDDPM、LDM(Latent Diffusion Model)はU-Netをバックボーンとして使用しています。U-Netは画像生成で良好な性能を示しますが、根本的な限界があります:
- スケーリング則の不在: モデルを大きくしても性能向上が予測不可能
- アーキテクチャの複雑さ: skip connection、多様な解像度のfeature map管理
- 演算の非効率: 高解像度でメモリ・演算コストが急増
- 他のモダリティとの統合の困難さ: テキスト、動画などとの統合が不自然
Transformerの利点
一方、Transformer(ViT)は:
- 実証済みのスケーリング則: パラメータ数に比例した性能向上(LLMで実証済み)
- シンプルなアーキテクチャ: Self-attention + FFNの繰り返し
- モダリティ不問: テキスト、画像、動画すべてを同一構造で処理可能
- ハードウェア最適化: GPU/TPUでの効率的な並列処理
DiTアーキテクチャの詳細
全体パイプライン
DiTはLatent Diffusionフレームワーク上で動作します:
- 画像をVAEエンコーダで潜在空間(latent space)に変換
- 潜在表現をパッチに分割(ViTと同様)
- DiTブロック(Transformer)でノイズを予測
- VAEデコーダで画像を復元
パッチ埋め込み
import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
"""潜在表現をパッチシーケンスに変換"""
def __init__(self, patch_size=2, in_channels=4, embed_dim=1152):
super().__init__()
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: (B, C, H, W) -> (B, N, D)
# 例: (B, 4, 32, 32) -> (B, 256, 1152) (patch_size=2)
x = self.proj(x) # (B, D, H/p, W/p)
x = x.flatten(2).transpose(1, 2) # (B, N, D)
return x
入力は256x256の画像をVAEで32x32の潜在表現に変換した後、patch_size=2で分割すると16x16=256個のトークンになります。
条件付けメカニズム — 4つのバリエーション
DiTはタイムステップ(t)とクラスラベル(c)を注入する4つの方式を実験しています:
1. In-context Conditioning
class InContextConditioning(nn.Module):
"""tとcを追加トークンとしてシーケンスに連結"""
def __init__(self, embed_dim):
super().__init__()
self.t_embed = TimestepEmbedder(embed_dim)
self.c_embed = LabelEmbedder(1000, embed_dim)
def forward(self, x, t, c):
t_token = self.t_embed(t).unsqueeze(1) # (B, 1, D)
c_token = self.c_embed(c).unsqueeze(1) # (B, 1, D)
x = torch.cat([t_token, c_token, x], dim=1) # (B, N+2, D)
return x
2. Cross-Attention
class CrossAttentionBlock(nn.Module):
"""tとcをcross-attentionで注入"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.mlp = MLP(embed_dim)
def forward(self, x, cond):
x = x + self.self_attn(x, x, x)[0]
x = x + self.cross_attn(x, cond, cond)[0]
x = x + self.mlp(x)
return x
3. Adaptive Layer Norm(adaLN)
class AdaLNBlock(nn.Module):
"""条件に応じてLayerNormのscale/shiftを調整"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim, 2 * embed_dim)
)
def forward(self, x, cond):
shift, scale = self.adaLN_modulation(cond).chunk(2, dim=-1)
x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = x + self.attn(x, x, x)[0]
return x
4. adaLN-Zero(最終選択)
class DiTBlock(nn.Module):
"""adaLN-Zero: 初期化時に残差接続をゼロから開始"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=False)
self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=False)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, 4 * embed_dim),
nn.GELU(approximate="tanh"),
nn.Linear(4 * embed_dim, embed_dim),
)
# 6つのパラメータ: gamma1, beta1, alpha1, gamma2, beta2, alpha2
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim, 6 * embed_dim)
)
# alphaを0に初期化 -> 学習初期にidentity functionとして動作
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)
def forward(self, x, cond):
# cond: (B, D) - タイムステップ + クラス埋め込みの和
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(cond).chunk(6, dim=-1)
# Self-Attention with adaLN
h = self.norm1(x)
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
h = self.attn(h, h, h)[0]
x = x + gate_msa.unsqueeze(1) * h # gateは0から開始
# FFN with adaLN
h = self.norm2(x)
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
h = self.mlp(h)
x = x + gate_mlp.unsqueeze(1) * h # gateは0から開始
return x
adaLN-Zeroの核心: gateパラメータを0に初期化することで、学習初期に各DiTブロックがidentity functionとして動作します。これにより深いネットワークの学習安定性が大幅に向上します。
DiTモデル全体
class DiT(nn.Module):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
embed_dim=1152,
depth=28,
num_heads=16,
num_classes=1000,
):
super().__init__()
self.patch_embed = PatchEmbed(patch_size, in_channels, embed_dim)
num_patches = (input_size // patch_size) ** 2 # 256
# 位置埋め込み
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches, embed_dim)
)
# タイムステップ & クラス埋め込み
self.t_embedder = TimestepEmbedder(embed_dim)
self.y_embedder = LabelEmbedder(num_classes, embed_dim)
# DiTブロック
self.blocks = nn.ModuleList([
DiTBlock(embed_dim, num_heads) for _ in range(depth)
])
# 最終レイヤー
self.final_layer = FinalLayer(embed_dim, patch_size, in_channels * 2)
def forward(self, x, t, y):
# パッチ埋め込み + 位置埋め込み
x = self.patch_embed(x) + self.pos_embed
# 条件埋め込み
t_emb = self.t_embedder(t) # (B, D)
y_emb = self.y_embedder(y) # (B, D)
cond = t_emb + y_emb # (B, D)
# DiTブロックを通過
for block in self.blocks:
x = block(x, cond)
# ノイズと分散を予測
x = self.final_layer(x, cond) # (B, N, 2*C*p*p)
return x
モデルバリエーションとスケーリング
| モデル | depth | embed_dim | heads | パラメータ | FID-50K |
|---|---|---|---|---|---|
| DiT-S/2 | 12 | 384 | 6 | 33M | 68.4 |
| DiT-B/2 | 12 | 768 | 12 | 130M | 43.5 |
| DiT-L/2 | 24 | 1024 | 16 | 458M | 9.62 |
| DiT-XL/2 | 28 | 1152 | 16 | 675M | 2.27 |
核心的発見:DiTはLLMと同一のスケーリング則に従います。 モデルサイズを大きくするほどFIDが一貫して改善されます。
パッチサイズの影響
| モデル | patch_size | トークン数 | Gflops | FID |
|---|---|---|---|---|
| DiT-XL/8 | 8 | 16 | 119 | 14.2 |
| DiT-XL/4 | 4 | 64 | 525 | 5.11 |
| DiT-XL/2 | 2 | 256 | 1121 | 2.27 |
パッチサイズが小さいほど(トークン数が多いほど)品質は良くなりますが、演算量が大幅に増加します。
Classifier-Free Guidanceの適用
def sample_with_cfg(model, z, class_labels, cfg_scale=4.0):
"""Classifier-Free Guidanceでサンプリング"""
# 条件付き + 無条件予測を同時に
z_combined = torch.cat([z, z], dim=0)
y_combined = torch.cat([class_labels, torch.full_like(class_labels, 1000)])
for t in reversed(range(1000)):
t_batch = torch.full((z_combined.shape[0],), t, device=z.device)
noise_pred = model(z_combined, t_batch, y_combined)
# CFG適用
cond_pred, uncond_pred = noise_pred.chunk(2)
guided_pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred)
# DDPM step
z = ddpm_step(z, guided_pred, t)
return z
CFG scale=1.5でDiT-XL/2はFID 2.27を達成し、当時のImageNet 256x256でSOTAを記録しました。
DiTの後続への影響
SORA(OpenAI、2024)
- DiTを動画に拡張したSpatial-Temporal DiT
- 3Dパッチ埋め込み(空間 + 時間)
- 可変解像度・長さに対応
Stable Diffusion 3(Stability AI、2024)
- DiT + Flow Matching
- MM-DiT:テキストと画像を同一Transformerで処理
- 別途のtext encoderなしでjoint attention
Dynamic DiT(ICLR 2025)
- Early Exitメカニズムで推論を効率化
- ノイズが大きい初期ステップでは浅いレイヤー、ディテール生成時に深いレイヤーを使用
- 従来DiT比40%の推論時間削減
PyTorch実装チュートリアル
# DiT公式コードをクローン
git clone https://github.com/facebookresearch/DiT.git
cd DiT
# 環境設定
pip install torch torchvision diffusers accelerate
# 事前学習モデルのダウンロード
python download.py DiT-XL/2
# サンプル生成
python sample.py \
--model DiT-XL/2 \
--image-size 256 \
--num-sampling-steps 250 \
--seed 42 \
--class-labels 207 360 387 974 88 979 417 279
カスタムトレーニング
from accelerate import Accelerator
from torch.utils.data import DataLoader
accelerator = Accelerator()
# モデル初期化
model = DiT(
input_size=32, # VAE latent size
patch_size=2,
in_channels=4,
embed_dim=1152,
depth=28,
num_heads=16,
num_classes=1000,
)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-4,
weight_decay=0.0
)
model, optimizer, dataloader = accelerator.prepare(
model, optimizer, dataloader
)
# 学習ループ
for epoch in range(epochs):
for batch in dataloader:
images, labels = batch
# VAEエンコーディング
with torch.no_grad():
latents = vae.encode(images).latent_dist.sample() * 0.18215
# ランダムタイムステップ
t = torch.randint(0, 1000, (latents.shape[0],), device=latents.device)
# ノイズ追加
noise = torch.randn_like(latents)
noisy_latents = scheduler.add_noise(latents, noise, t)
# ノイズ予測
noise_pred = model(noisy_latents, t, labels)
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
まとめ
DiTは画像生成におけるアーキテクチャパラダイムの転換を牽引しました:
- U-Netの複雑な構造を純粋なTransformerに置き換え
- LLMと同一のスケーリング則が画像生成にも適用されることを実証
- adaLN-Zeroにより深いTransformerの安定的な学習を実現
- SORA、SD3などの後続モデルの基盤アーキテクチャとして定着
クイズ:DiTアーキテクチャ理解度チェック(8問)
Q1. DiTがU-Netの代わりにTransformerを使用する主な理由は?
実証済みのスケーリング則、シンプルなアーキテクチャ、他のモダリティとの統合の容易さのためです。
Q2. DiTにおいて画像はどのようにトークンシーケンスに変換されますか?
VAEで潜在表現(latent)に変換した後、パッチ単位で分割して線形投影します。
Q3. adaLN-Zeroの「Zero」が意味するものは?
gateパラメータを0に初期化することで、学習初期に各ブロックがidentity functionとして動作するようにします。
Q4. DiTで実験された4つの条件付け方式のうち、最終的に選択されたものは?
adaLN-Zeroです。In-context、Cross-Attention、adaLN、adaLN-Zeroの中で最も良いFIDを記録しました。
Q5. パッチサイズが小さいほど、どのようなトレードオフがありますか?
トークン数が増加して画像品質(FID)は向上しますが、演算量(Gflops)が二乗に比例して増加します。
Q6. Classifier-Free GuidanceはDiTでどのように実装されていますか?
クラスラベルのドロップアウトで学習し、推論時に条件付き/無条件予測の重み付き和でサンプリングします。
Q7. DiT-XL/2のImageNet 256x256 FID-50Kスコアは?
2.27で、当時のclass-conditional生成SOTAを達成しました。
Q8. Dynamic DiT(ICLR 2025)の核心アイデアは?
Early Exitメカニズムにより、ノイズレベルに応じて使用するレイヤー数を動的に調整し、推論時間を40%削減します。