- Published on
Diffusion Transformer(DiT) 아키텍처 분석: U-Net에서 Transformer로의 전환
- Authors
- Name
들어가며
2022년 말 발표된 "Scalable Diffusion Models with Transformers" (William Peebles & Saining Xie, 2023 ICCV)는 이미지 생성의 백본을 U-Net에서 Transformer로 전환하는 시발점이 되었습니다. 이 논문에서 제안한 DiT(Diffusion Transformer)는 이후 OpenAI의 SORA, Stability AI의 Stable Diffusion 3, 그리고 2025년 ICLR에서 발표된 Dynamic DiT까지 이어지는 핵심 아키텍처가 되었습니다.
배경: 왜 U-Net을 버리는가?
U-Net의 한계
기존 DDPM, LDM(Latent Diffusion Model)은 U-Net을 백본으로 사용합니다. U-Net은 이미지 생성에서 좋은 성능을 보이지만 근본적인 한계가 있습니다:
- 스케일링 법칙 부재: 모델을 키워도 성능 향상이 예측 불가능
- 아키텍처 복잡성: skip connection, 다양한 해상도의 feature map 관리
- 연산 비효율: 고해상도에서 메모리/연산 비용이 급증
- 다른 모달리티와의 통합 어려움: 텍스트, 비디오 등과의 통합이 비자연적
Transformer의 장점
반면 Transformer(ViT)는:
- 검증된 스케일링 법칙: 파라미터 수에 비례한 성능 향상 (LLM에서 입증)
- 단순한 아키텍처: Self-attention + FFN의 반복
- 모달리티 불문: 텍스트, 이미지, 비디오 모두 동일 구조로 처리 가능
- 하드웨어 최적화: GPU/TPU에서의 효율적인 병렬 처리
DiT 아키텍처 상세
전체 파이프라인
DiT는 Latent Diffusion 프레임워크 위에서 동작합니다:
- 이미지를 VAE 인코더로 잠재 공간(latent space)으로 변환
- 잠재 표현을 패치로 분할 (ViT와 동일)
- DiT 블록(Transformer)으로 노이즈 예측
- VAE 디코더로 이미지 복원
패치 임베딩
import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
"""잠재 표현을 패치 시퀀스로 변환"""
def __init__(self, patch_size=2, in_channels=4, embed_dim=1152):
super().__init__()
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: (B, C, H, W) -> (B, N, D)
# 예: (B, 4, 32, 32) -> (B, 256, 1152) (patch_size=2)
x = self.proj(x) # (B, D, H/p, W/p)
x = x.flatten(2).transpose(1, 2) # (B, N, D)
return x
입력은 256x256 이미지를 VAE로 32x32 잠재 표현으로 변환한 후, patch_size=2로 분할하면 16x16=256개의 토큰이 됩니다.
조건화 메커니즘 — 4가지 변형
DiT는 타임스텝(t)과 클래스 레이블(c)을 주입하는 4가지 방식을 실험합니다:
1. In-context Conditioning
class InContextConditioning(nn.Module):
"""t와 c를 추가 토큰으로 시퀀스에 연결"""
def __init__(self, embed_dim):
super().__init__()
self.t_embed = TimestepEmbedder(embed_dim)
self.c_embed = LabelEmbedder(1000, embed_dim)
def forward(self, x, t, c):
t_token = self.t_embed(t).unsqueeze(1) # (B, 1, D)
c_token = self.c_embed(c).unsqueeze(1) # (B, 1, D)
x = torch.cat([t_token, c_token, x], dim=1) # (B, N+2, D)
return x
2. Cross-Attention
class CrossAttentionBlock(nn.Module):
"""t와 c를 cross-attention으로 주입"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.mlp = MLP(embed_dim)
def forward(self, x, cond):
x = x + self.self_attn(x, x, x)[0]
x = x + self.cross_attn(x, cond, cond)[0]
x = x + self.mlp(x)
return x
3. Adaptive Layer Norm (adaLN)
class AdaLNBlock(nn.Module):
"""조건에 따라 LayerNorm의 scale/shift를 조절"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim, 2 * embed_dim)
)
def forward(self, x, cond):
shift, scale = self.adaLN_modulation(cond).chunk(2, dim=-1)
x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = x + self.attn(x, x, x)[0]
return x
4. adaLN-Zero (최종 선택)
class DiTBlock(nn.Module):
"""adaLN-Zero: 초기화 시 잔차 연결을 0으로 시작"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=False)
self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=False)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, 4 * embed_dim),
nn.GELU(approximate="tanh"),
nn.Linear(4 * embed_dim, embed_dim),
)
# 6개 파라미터: gamma1, beta1, alpha1, gamma2, beta2, alpha2
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim, 6 * embed_dim)
)
# alpha를 0으로 초기화 -> 학습 초기에 identity function
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)
def forward(self, x, cond):
# cond: (B, D) - 타임스텝 + 클래스 임베딩의 합
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(cond).chunk(6, dim=-1)
# Self-Attention with adaLN
h = self.norm1(x)
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
h = self.attn(h, h, h)[0]
x = x + gate_msa.unsqueeze(1) * h # gate=0으로 시작
# FFN with adaLN
h = self.norm2(x)
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
h = self.mlp(h)
x = x + gate_mlp.unsqueeze(1) * h # gate=0으로 시작
return x
adaLN-Zero의 핵심: gate 파라미터를 0으로 초기화하여 학습 초기에 각 DiT 블록이 identity function으로 동작합니다. 이는 깊은 네트워크의 학습 안정성을 크게 향상시킵니다.
전체 DiT 모델
class DiT(nn.Module):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
embed_dim=1152,
depth=28,
num_heads=16,
num_classes=1000,
):
super().__init__()
self.patch_embed = PatchEmbed(patch_size, in_channels, embed_dim)
num_patches = (input_size // patch_size) ** 2 # 256
# 위치 임베딩
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches, embed_dim)
)
# 타임스텝 & 클래스 임베딩
self.t_embedder = TimestepEmbedder(embed_dim)
self.y_embedder = LabelEmbedder(num_classes, embed_dim)
# DiT 블록
self.blocks = nn.ModuleList([
DiTBlock(embed_dim, num_heads) for _ in range(depth)
])
# 최종 레이어
self.final_layer = FinalLayer(embed_dim, patch_size, in_channels * 2)
def forward(self, x, t, y):
# 패치 임베딩 + 위치 임베딩
x = self.patch_embed(x) + self.pos_embed
# 조건 임베딩
t_emb = self.t_embedder(t) # (B, D)
y_emb = self.y_embedder(y) # (B, D)
cond = t_emb + y_emb # (B, D)
# DiT 블록 통과
for block in self.blocks:
x = block(x, cond)
# 노이즈와 분산 예측
x = self.final_layer(x, cond) # (B, N, 2*C*p*p)
return x
모델 변형과 스케일링
| 모델 | depth | embed_dim | heads | 파라미터 | FID-50K |
|---|---|---|---|---|---|
| DiT-S/2 | 12 | 384 | 6 | 33M | 68.4 |
| DiT-B/2 | 12 | 768 | 12 | 130M | 43.5 |
| DiT-L/2 | 24 | 1024 | 16 | 458M | 9.62 |
| DiT-XL/2 | 28 | 1152 | 16 | 675M | 2.27 |
핵심 발견: DiT는 LLM과 동일한 스케일링 법칙을 따릅니다. 모델 크기를 키울수록 FID가 일관되게 개선됩니다.
패치 크기의 영향
| 모델 | patch_size | 토큰 수 | Gflops | FID |
|---|---|---|---|---|
| DiT-XL/8 | 8 | 16 | 119 | 14.2 |
| DiT-XL/4 | 4 | 64 | 525 | 5.11 |
| DiT-XL/2 | 2 | 256 | 1121 | 2.27 |
패치 크기가 작을수록(토큰 수가 많을수록) 성능이 좋지만 연산량이 크게 증가합니다.
Classifier-Free Guidance 적용
def sample_with_cfg(model, z, class_labels, cfg_scale=4.0):
"""Classifier-Free Guidance로 샘플링"""
# 조건부 + 무조건부 예측을 동시에
z_combined = torch.cat([z, z], dim=0)
y_combined = torch.cat([class_labels, torch.full_like(class_labels, 1000)])
for t in reversed(range(1000)):
t_batch = torch.full((z_combined.shape[0],), t, device=z.device)
noise_pred = model(z_combined, t_batch, y_combined)
# CFG 적용
cond_pred, uncond_pred = noise_pred.chunk(2)
guided_pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred)
# DDPM step
z = ddpm_step(z, guided_pred, t)
return z
CFG scale=1.5에서 DiT-XL/2는 FID 2.27을 달성하여 당시 ImageNet 256x256에서 SOTA를 기록했습니다.
DiT의 후속 영향
SORA (OpenAI, 2024)
- DiT를 비디오로 확장한 Spatial-Temporal DiT
- 3D 패치 임베딩 (공간 + 시간)
- 가변 해상도/길이 지원
Stable Diffusion 3 (Stability AI, 2024)
- DiT + Flow Matching
- MM-DiT: 텍스트와 이미지를 동일 Transformer에서 처리
- 별도의 text encoder 없이 joint attention
Dynamic DiT (ICLR 2025)
- Early Exit 메커니즘으로 추론 효율화
- 노이즈가 큰 초기 스텝에서는 얕은 레이어, 디테일 생성 시 깊은 레이어 사용
- 기존 DiT 대비 40% 추론 시간 절감
PyTorch 구현 실습
# DiT 공식 코드 클론
git clone https://github.com/facebookresearch/DiT.git
cd DiT
# 환경 설정
pip install torch torchvision diffusers accelerate
# 사전학습 모델 다운로드
python download.py DiT-XL/2
# 샘플 생성
python sample.py \
--model DiT-XL/2 \
--image-size 256 \
--num-sampling-steps 250 \
--seed 42 \
--class-labels 207 360 387 974 88 979 417 279
커스텀 학습
from accelerate import Accelerator
from torch.utils.data import DataLoader
accelerator = Accelerator()
# 모델 초기화
model = DiT(
input_size=32, # VAE latent size
patch_size=2,
in_channels=4,
embed_dim=1152,
depth=28,
num_heads=16,
num_classes=1000,
)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-4,
weight_decay=0.0
)
model, optimizer, dataloader = accelerator.prepare(
model, optimizer, dataloader
)
# 학습 루프
for epoch in range(epochs):
for batch in dataloader:
images, labels = batch
# VAE 인코딩
with torch.no_grad():
latents = vae.encode(images).latent_dist.sample() * 0.18215
# 랜덤 타임스텝
t = torch.randint(0, 1000, (latents.shape[0],), device=latents.device)
# 노이즈 추가
noise = torch.randn_like(latents)
noisy_latents = scheduler.add_noise(latents, noise, t)
# 노이즈 예측
noise_pred = model(noisy_latents, t, labels)
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
정리
DiT는 이미지 생성에서 아키텍처 패러다임의 전환을 이끌었습니다:
- U-Net의 복잡한 구조를 순수 Transformer로 대체
- LLM과 동일한 스케일링 법칙이 이미지 생성에도 적용됨을 입증
- adaLN-Zero로 깊은 Transformer의 안정적 학습 가능
- SORA, SD3 등 후속 모델의 기반 아키텍처로 자리잡음
✅ 퀴즈: DiT 아키텍처 이해도 점검 (8문제)
Q1. DiT가 U-Net 대신 Transformer를 사용하는 주된 이유는?
검증된 스케일링 법칙, 단순한 아키텍처, 다른 모달리티와의 통합 용이성 때문입니다.
Q2. DiT에서 이미지는 어떻게 토큰 시퀀스로 변환되나요?
VAE로 잠재 표현(latent)으로 변환한 후, 패치 단위로 분할하여 선형 투영합니다.
Q3. adaLN-Zero의 "Zero"가 의미하는 것은?
gate 파라미터를 0으로 초기화하여 학습 초기에 각 블록이 identity function으로 동작하게 합니다.
Q4. DiT에서 실험한 4가지 조건화 방식 중 최종 선택된 것은?
adaLN-Zero입니다. In-context, Cross-Attention, adaLN, adaLN-Zero 중 가장 좋은 FID를 기록했습니다.
Q5. 패치 크기가 작을수록 어떤 트레이드오프가 있나요?
토큰 수가 증가하여 이미지 품질(FID)은 좋아지지만, 연산량(Gflops)이 제곱에 비례하여 증가합니다.
Q6. Classifier-Free Guidance는 DiT에서 어떻게 구현되나요?
클래스 레이블 드롭아웃으로 학습하고, 추론 시 조건부/무조건부 예측의 가중 합으로 샘플링합니다.
Q7. DiT-XL/2의 ImageNet 256x256 FID-50K 점수는?
2.27로, 당시 class-conditional 생성 SOTA를 달성했습니다.
Q8. Dynamic DiT(ICLR 2025)의 핵심 아이디어는?
Early Exit 메커니즘으로 노이즈 수준에 따라 사용하는 레이어 수를 동적으로 조절하여 추론 시간을 40% 절감합니다.