Skip to content

필사 모드: Diffusion Transformer(DiT)アーキテクチャ分析:U-NetからTransformerへの転換

日本語
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

はじめに

2022年末に発表された**「Scalable Diffusion Models with Transformers」**(William Peebles & Saining Xie、2023 ICCV)は、画像生成のバックボーンをU-NetからTransformerへ転換する起点となりました。この論文で提案されたDiT(Diffusion Transformer)は、その後OpenAIのSORA、Stability AIのStable Diffusion 3、そして2025年のICLRで発表されたDynamic DiTへと続く核心アーキテクチャとなりました。

背景:なぜU-Netから離れるのか?

U-Netの限界

従来のDDPM、LDM(Latent Diffusion Model)はU-Netをバックボーンとして使用しています。U-Netは画像生成で良好な性能を示しますが、根本的な限界があります:

- **スケーリング則の不在**: モデルを大きくしても性能向上が予測不可能

- **アーキテクチャの複雑さ**: skip connection、多様な解像度のfeature map管理

- **演算の非効率**: 高解像度でメモリ・演算コストが急増

- **他のモダリティとの統合の困難さ**: テキスト、動画などとの統合が不自然

Transformerの利点

一方、Transformer(ViT)は:

- **実証済みのスケーリング則**: パラメータ数に比例した性能向上(LLMで実証済み)

- **シンプルなアーキテクチャ**: Self-attention + FFNの繰り返し

- **モダリティ不問**: テキスト、画像、動画すべてを同一構造で処理可能

- **ハードウェア最適化**: GPU/TPUでの効率的な並列処理

DiTアーキテクチャの詳細

全体パイプライン

DiTはLatent Diffusionフレームワーク上で動作します:

1. 画像をVAEエンコーダで潜在空間(latent space)に変換

2. 潜在表現をパッチに分割(ViTと同様)

3. **DiTブロック**(Transformer)でノイズを予測

4. VAEデコーダで画像を復元

パッチ埋め込み

class PatchEmbed(nn.Module):

"""潜在表現をパッチシーケンスに変換"""

def __init__(self, patch_size=2, in_channels=4, embed_dim=1152):

super().__init__()

self.proj = nn.Conv2d(

in_channels, embed_dim,

kernel_size=patch_size, stride=patch_size

)

def forward(self, x):

x: (B, C, H, W) -> (B, N, D)

例: (B, 4, 32, 32) -> (B, 256, 1152) (patch_size=2)

x = self.proj(x) # (B, D, H/p, W/p)

x = x.flatten(2).transpose(1, 2) # (B, N, D)

return x

入力は256x256の画像をVAEで32x32の潜在表現に変換した後、patch_size=2で分割すると16x16=256個のトークンになります。

条件付けメカニズム — 4つのバリエーション

DiTはタイムステップ(t)とクラスラベル(c)を注入する4つの方式を実験しています:

1. In-context Conditioning

class InContextConditioning(nn.Module):

"""tとcを追加トークンとしてシーケンスに連結"""

def __init__(self, embed_dim):

super().__init__()

self.t_embed = TimestepEmbedder(embed_dim)

self.c_embed = LabelEmbedder(1000, embed_dim)

def forward(self, x, t, c):

t_token = self.t_embed(t).unsqueeze(1) # (B, 1, D)

c_token = self.c_embed(c).unsqueeze(1) # (B, 1, D)

x = torch.cat([t_token, c_token, x], dim=1) # (B, N+2, D)

return x

2. Cross-Attention

class CrossAttentionBlock(nn.Module):

"""tとcをcross-attentionで注入"""

def __init__(self, embed_dim, num_heads):

super().__init__()

self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)

self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads)

self.mlp = MLP(embed_dim)

def forward(self, x, cond):

x = x + self.self_attn(x, x, x)[0]

x = x + self.cross_attn(x, cond, cond)[0]

x = x + self.mlp(x)

return x

3. Adaptive Layer Norm(adaLN)

class AdaLNBlock(nn.Module):

"""条件に応じてLayerNormのscale/shiftを調整"""

def __init__(self, embed_dim, num_heads):

super().__init__()

self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False)

self.attn = nn.MultiheadAttention(embed_dim, num_heads)

self.adaLN_modulation = nn.Sequential(

nn.SiLU(),

nn.Linear(embed_dim, 2 * embed_dim)

)

def forward(self, x, cond):

shift, scale = self.adaLN_modulation(cond).chunk(2, dim=-1)

x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

x = x + self.attn(x, x, x)[0]

return x

4. adaLN-Zero(最終選択)

class DiTBlock(nn.Module):

"""adaLN-Zero: 初期化時に残差接続をゼロから開始"""

def __init__(self, embed_dim, num_heads):

super().__init__()

self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=False)

self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=False)

self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

self.mlp = nn.Sequential(

nn.Linear(embed_dim, 4 * embed_dim),

nn.GELU(approximate="tanh"),

nn.Linear(4 * embed_dim, embed_dim),

)

6つのパラメータ: gamma1, beta1, alpha1, gamma2, beta2, alpha2

self.adaLN_modulation = nn.Sequential(

nn.SiLU(),

nn.Linear(embed_dim, 6 * embed_dim)

)

alphaを0に初期化 -> 学習初期にidentity functionとして動作

nn.init.zeros_(self.adaLN_modulation[-1].weight)

nn.init.zeros_(self.adaLN_modulation[-1].bias)

def forward(self, x, cond):

cond: (B, D) - タイムステップ + クラス埋め込みの和

shift_msa, scale_msa, gate_msa, \

shift_mlp, scale_mlp, gate_mlp = \

self.adaLN_modulation(cond).chunk(6, dim=-1)

Self-Attention with adaLN

h = self.norm1(x)

h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)

h = self.attn(h, h, h)[0]

x = x + gate_msa.unsqueeze(1) * h # gateは0から開始

FFN with adaLN

h = self.norm2(x)

h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)

h = self.mlp(h)

x = x + gate_mlp.unsqueeze(1) * h # gateは0から開始

return x

**adaLN-Zeroの核心**: gateパラメータを0に初期化することで、学習初期に各DiTブロックがidentity functionとして動作します。これにより深いネットワークの学習安定性が大幅に向上します。

DiTモデル全体

class DiT(nn.Module):

def __init__(

self,

input_size=32,

patch_size=2,

in_channels=4,

embed_dim=1152,

depth=28,

num_heads=16,

num_classes=1000,

):

super().__init__()

self.patch_embed = PatchEmbed(patch_size, in_channels, embed_dim)

num_patches = (input_size // patch_size) ** 2 # 256

位置埋め込み

self.pos_embed = nn.Parameter(

torch.zeros(1, num_patches, embed_dim)

)

タイムステップ & クラス埋め込み

self.t_embedder = TimestepEmbedder(embed_dim)

self.y_embedder = LabelEmbedder(num_classes, embed_dim)

DiTブロック

self.blocks = nn.ModuleList([

DiTBlock(embed_dim, num_heads) for _ in range(depth)

])

最終レイヤー

self.final_layer = FinalLayer(embed_dim, patch_size, in_channels * 2)

def forward(self, x, t, y):

パッチ埋め込み + 位置埋め込み

x = self.patch_embed(x) + self.pos_embed

条件埋め込み

t_emb = self.t_embedder(t) # (B, D)

y_emb = self.y_embedder(y) # (B, D)

cond = t_emb + y_emb # (B, D)

DiTブロックを通過

for block in self.blocks:

x = block(x, cond)

ノイズと分散を予測

x = self.final_layer(x, cond) # (B, N, 2*C*p*p)

return x

モデルバリエーションとスケーリング

| モデル | depth | embed_dim | heads | パラメータ | FID-50K |

| -------- | ----- | --------- | ----- | ---------- | -------- |

| DiT-S/2 | 12 | 384 | 6 | 33M | 68.4 |

| DiT-B/2 | 12 | 768 | 12 | 130M | 43.5 |

| DiT-L/2 | 24 | 1024 | 16 | 458M | 9.62 |

| DiT-XL/2 | 28 | 1152 | 16 | 675M | **2.27** |

核心的発見:**DiTはLLMと同一のスケーリング則に従います。** モデルサイズを大きくするほどFIDが一貫して改善されます。

パッチサイズの影響

| モデル | patch_size | トークン数 | Gflops | FID |

| -------- | ---------- | ---------- | ------ | ---- |

| DiT-XL/8 | 8 | 16 | 119 | 14.2 |

| DiT-XL/4 | 4 | 64 | 525 | 5.11 |

| DiT-XL/2 | 2 | 256 | 1121 | 2.27 |

パッチサイズが小さいほど(トークン数が多いほど)品質は良くなりますが、演算量が大幅に増加します。

Classifier-Free Guidanceの適用

def sample_with_cfg(model, z, class_labels, cfg_scale=4.0):

"""Classifier-Free Guidanceでサンプリング"""

条件付き + 無条件予測を同時に

z_combined = torch.cat([z, z], dim=0)

y_combined = torch.cat([class_labels, torch.full_like(class_labels, 1000)])

for t in reversed(range(1000)):

t_batch = torch.full((z_combined.shape[0],), t, device=z.device)

noise_pred = model(z_combined, t_batch, y_combined)

CFG適用

cond_pred, uncond_pred = noise_pred.chunk(2)

guided_pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred)

DDPM step

z = ddpm_step(z, guided_pred, t)

return z

CFG scale=1.5でDiT-XL/2はFID **2.27**を達成し、当時のImageNet 256x256でSOTAを記録しました。

DiTの後続への影響

SORA(OpenAI、2024)

- DiTを動画に拡張した**Spatial-Temporal DiT**

- 3Dパッチ埋め込み(空間 + 時間)

- 可変解像度・長さに対応

Stable Diffusion 3(Stability AI、2024)

- DiT + Flow Matching

- MM-DiT:テキストと画像を同一Transformerで処理

- 別途のtext encoderなしでjoint attention

Dynamic DiT(ICLR 2025)

- Early Exitメカニズムで推論を効率化

- ノイズが大きい初期ステップでは浅いレイヤー、ディテール生成時に深いレイヤーを使用

- 従来DiT比40%の推論時間削減

PyTorch実装チュートリアル

DiT公式コードをクローン

git clone https://github.com/facebookresearch/DiT.git

cd DiT

環境設定

pip install torch torchvision diffusers accelerate

事前学習モデルのダウンロード

python download.py DiT-XL/2

サンプル生成

python sample.py \

--model DiT-XL/2 \

--image-size 256 \

--num-sampling-steps 250 \

--seed 42 \

--class-labels 207 360 387 974 88 979 417 279

カスタムトレーニング

from accelerate import Accelerator

from torch.utils.data import DataLoader

accelerator = Accelerator()

モデル初期化

model = DiT(

input_size=32, # VAE latent size

patch_size=2,

in_channels=4,

embed_dim=1152,

depth=28,

num_heads=16,

num_classes=1000,

)

optimizer = torch.optim.AdamW(

model.parameters(),

lr=1e-4,

weight_decay=0.0

)

model, optimizer, dataloader = accelerator.prepare(

model, optimizer, dataloader

)

学習ループ

for epoch in range(epochs):

for batch in dataloader:

images, labels = batch

VAEエンコーディング

with torch.no_grad():

latents = vae.encode(images).latent_dist.sample() * 0.18215

ランダムタイムステップ

t = torch.randint(0, 1000, (latents.shape[0],), device=latents.device)

ノイズ追加

noise = torch.randn_like(latents)

noisy_latents = scheduler.add_noise(latents, noise, t)

ノイズ予測

noise_pred = model(noisy_latents, t, labels)

loss = F.mse_loss(noise_pred, noise)

accelerator.backward(loss)

optimizer.step()

optimizer.zero_grad()

まとめ

DiTは画像生成における**アーキテクチャパラダイムの転換**を牽引しました:

- U-Netの複雑な構造を純粋なTransformerに置き換え

- LLMと同一の**スケーリング則**が画像生成にも適用されることを実証

- adaLN-Zeroにより深いTransformerの安定的な学習を実現

- SORA、SD3などの後続モデルの基盤アーキテクチャとして定着

**Q1. DiTがU-Netの代わりにTransformerを使用する主な理由は?**

実証済みのスケーリング則、シンプルなアーキテクチャ、他のモダリティとの統合の容易さのためです。

**Q2. DiTにおいて画像はどのようにトークンシーケンスに変換されますか?**

VAEで潜在表現(latent)に変換した後、パッチ単位で分割して線形投影します。

**Q3. adaLN-Zeroの「Zero」が意味するものは?**

gateパラメータを0に初期化することで、学習初期に各ブロックがidentity functionとして動作するようにします。

**Q4. DiTで実験された4つの条件付け方式のうち、最終的に選択されたものは?**

adaLN-Zeroです。In-context、Cross-Attention、adaLN、adaLN-Zeroの中で最も良いFIDを記録しました。

**Q5. パッチサイズが小さいほど、どのようなトレードオフがありますか?**

トークン数が増加して画像品質(FID)は向上しますが、演算量(Gflops)が二乗に比例して増加します。

**Q6. Classifier-Free GuidanceはDiTでどのように実装されていますか?**

クラスラベルのドロップアウトで学習し、推論時に条件付き/無条件予測の重み付き和でサンプリングします。

**Q7. DiT-XL/2のImageNet 256x256 FID-50Kスコアは?**

2.27で、当時のclass-conditional生成SOTAを達成しました。

**Q8. Dynamic DiT(ICLR 2025)の核心アイデアは?**

Early Exitメカニズムにより、ノイズレベルに応じて使用するレイヤー数を動的に調整し、推論時間を40%削減します。

クイズ

Q1: 「Diffusion

Transformer(DiT)アーキテクチャ分析:U-NetからTransformerへの転換」の主なトピックは何ですか?

Scalable Diffusion Models with

Transformers(DiT)論文を分析します。U-Netベースの拡散モデルの限界を超えてTransformerバックボーンへ転換した背景、adaLN-Zero条件付け、スケーリング則、SORA/DALL-E

3への影響まで解説します。

U-Netの限界 従来のDDPM、LDM(Latent Diffusion

Model)はU-Netをバックボーンとして使用しています。U-Netは画像生成で良好な性能を示しますが、根本的な限界があります:

スケーリング則の不在: モデルを大きくしても性能向上が予測不可能 アーキテクチャの複雑さ: skip

connection、多様な解像度のfeature map管理 演算の非効率: 高解像度でメモリ・演算コストが急増

他のモダリティとの統合の困難さ: テキスト、動画などとの統合が不自然 Transformerの利点

一方、Transformer(Vi...

全体パイプライン DiTはLatent Diffusionフレームワーク上で動作します:

画像をVAEエンコーダで潜在空間(latent space)に変換 潜在表現をパッチに分割(ViTと同様)

DiTブロック(Transformer)でノイズを予測 VAEデコーダで画像を復元 パッチ埋め込み

入力は256x256の画像をVAEで32x32の潜在表現に変換した後、patch_size=2で分割すると16x16=256個のトークンになります。

核心的発見:DiTはLLMと同一のスケーリング則に従います。

モデルサイズを大きくするほどFIDが一貫して改善されます。 パッチサイズの影響

パッチサイズが小さいほど(トークン数が多いほど)品質は良くなりますが、演算量が大幅に増加します。

CFG scale=1.5でDiT-XL/2はFID 2.27を達成し、当時のImageNet 256x256でSOTAを記録しました。

현재 단락 (1/249)

2022年末に発表された**「Scalable Diffusion Models with Transformers」**(William Peebles & Saining Xie、2023 ICC...

작성 글자: 0원문 글자: 9,309작성 단락: 0/249