들어가며
2022년 말 발표된 **"Scalable Diffusion Models with Transformers"** (William Peebles & Saining Xie, 2023 ICCV)는 이미지 생성의 백본을 U-Net에서 Transformer로 전환하는 시발점이 되었습니다. 이 논문에서 제안한 DiT(Diffusion Transformer)는 이후 OpenAI의 SORA, Stability AI의 Stable Diffusion 3, 그리고 2025년 ICLR에서 발표된 Dynamic DiT까지 이어지는 핵심 아키텍처가 되었습니다.
배경: 왜 U-Net을 버리는가?
U-Net의 한계
기존 DDPM, LDM(Latent Diffusion Model)은 U-Net을 백본으로 사용합니다. U-Net은 이미지 생성에서 좋은 성능을 보이지만 근본적인 한계가 있습니다:
- **스케일링 법칙 부재**: 모델을 키워도 성능 향상이 예측 불가능
- **아키텍처 복잡성**: skip connection, 다양한 해상도의 feature map 관리
- **연산 비효율**: 고해상도에서 메모리/연산 비용이 급증
- **다른 모달리티와의 통합 어려움**: 텍스트, 비디오 등과의 통합이 비자연적
Transformer의 장점
반면 Transformer(ViT)는:
- **검증된 스케일링 법칙**: 파라미터 수에 비례한 성능 향상 (LLM에서 입증)
- **단순한 아키텍처**: Self-attention + FFN의 반복
- **모달리티 불문**: 텍스트, 이미지, 비디오 모두 동일 구조로 처리 가능
- **하드웨어 최적화**: GPU/TPU에서의 효율적인 병렬 처리
DiT 아키텍처 상세
전체 파이프라인
DiT는 Latent Diffusion 프레임워크 위에서 동작합니다:
1. 이미지를 VAE 인코더로 잠재 공간(latent space)으로 변환
2. 잠재 표현을 패치로 분할 (ViT와 동일)
3. **DiT 블록**(Transformer)으로 노이즈 예측
4. VAE 디코더로 이미지 복원
패치 임베딩
class PatchEmbed(nn.Module):
"""잠재 표현을 패치 시퀀스로 변환"""
def __init__(self, patch_size=2, in_channels=4, embed_dim=1152):
super().__init__()
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
x: (B, C, H, W) -> (B, N, D)
예: (B, 4, 32, 32) -> (B, 256, 1152) (patch_size=2)
x = self.proj(x) # (B, D, H/p, W/p)
x = x.flatten(2).transpose(1, 2) # (B, N, D)
return x
입력은 256x256 이미지를 VAE로 32x32 잠재 표현으로 변환한 후, patch_size=2로 분할하면 16x16=256개의 토큰이 됩니다.
조건화 메커니즘 — 4가지 변형
DiT는 타임스텝(t)과 클래스 레이블(c)을 주입하는 4가지 방식을 실험합니다:
1. In-context Conditioning
class InContextConditioning(nn.Module):
"""t와 c를 추가 토큰으로 시퀀스에 연결"""
def __init__(self, embed_dim):
super().__init__()
self.t_embed = TimestepEmbedder(embed_dim)
self.c_embed = LabelEmbedder(1000, embed_dim)
def forward(self, x, t, c):
t_token = self.t_embed(t).unsqueeze(1) # (B, 1, D)
c_token = self.c_embed(c).unsqueeze(1) # (B, 1, D)
x = torch.cat([t_token, c_token, x], dim=1) # (B, N+2, D)
return x
2. Cross-Attention
class CrossAttentionBlock(nn.Module):
"""t와 c를 cross-attention으로 주입"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.mlp = MLP(embed_dim)
def forward(self, x, cond):
x = x + self.self_attn(x, x, x)[0]
x = x + self.cross_attn(x, cond, cond)[0]
x = x + self.mlp(x)
return x
3. Adaptive Layer Norm (adaLN)
class AdaLNBlock(nn.Module):
"""조건에 따라 LayerNorm의 scale/shift를 조절"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim, 2 * embed_dim)
)
def forward(self, x, cond):
shift, scale = self.adaLN_modulation(cond).chunk(2, dim=-1)
x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = x + self.attn(x, x, x)[0]
return x
4. adaLN-Zero (최종 선택)
class DiTBlock(nn.Module):
"""adaLN-Zero: 초기화 시 잔차 연결을 0으로 시작"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=False)
self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=False)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, 4 * embed_dim),
nn.GELU(approximate="tanh"),
nn.Linear(4 * embed_dim, embed_dim),
)
6개 파라미터: gamma1, beta1, alpha1, gamma2, beta2, alpha2
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim, 6 * embed_dim)
)
alpha를 0으로 초기화 -> 학습 초기에 identity function
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)
def forward(self, x, cond):
cond: (B, D) - 타임스텝 + 클래스 임베딩의 합
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(cond).chunk(6, dim=-1)
Self-Attention with adaLN
h = self.norm1(x)
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
h = self.attn(h, h, h)[0]
x = x + gate_msa.unsqueeze(1) * h # gate=0으로 시작
FFN with adaLN
h = self.norm2(x)
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
h = self.mlp(h)
x = x + gate_mlp.unsqueeze(1) * h # gate=0으로 시작
return x
**adaLN-Zero의 핵심**: gate 파라미터를 0으로 초기화하여 학습 초기에 각 DiT 블록이 identity function으로 동작합니다. 이는 깊은 네트워크의 학습 안정성을 크게 향상시킵니다.
전체 DiT 모델
class DiT(nn.Module):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
embed_dim=1152,
depth=28,
num_heads=16,
num_classes=1000,
):
super().__init__()
self.patch_embed = PatchEmbed(patch_size, in_channels, embed_dim)
num_patches = (input_size // patch_size) ** 2 # 256
위치 임베딩
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches, embed_dim)
)
타임스텝 & 클래스 임베딩
self.t_embedder = TimestepEmbedder(embed_dim)
self.y_embedder = LabelEmbedder(num_classes, embed_dim)
DiT 블록
self.blocks = nn.ModuleList([
DiTBlock(embed_dim, num_heads) for _ in range(depth)
])
최종 레이어
self.final_layer = FinalLayer(embed_dim, patch_size, in_channels * 2)
def forward(self, x, t, y):
패치 임베딩 + 위치 임베딩
x = self.patch_embed(x) + self.pos_embed
조건 임베딩
t_emb = self.t_embedder(t) # (B, D)
y_emb = self.y_embedder(y) # (B, D)
cond = t_emb + y_emb # (B, D)
DiT 블록 통과
for block in self.blocks:
x = block(x, cond)
노이즈와 분산 예측
x = self.final_layer(x, cond) # (B, N, 2*C*p*p)
return x
모델 변형과 스케일링
| 모델 | depth | embed_dim | heads | 파라미터 | FID-50K |
| -------- | ----- | --------- | ----- | -------- | -------- |
| DiT-S/2 | 12 | 384 | 6 | 33M | 68.4 |
| DiT-B/2 | 12 | 768 | 12 | 130M | 43.5 |
| DiT-L/2 | 24 | 1024 | 16 | 458M | 9.62 |
| DiT-XL/2 | 28 | 1152 | 16 | 675M | **2.27** |
핵심 발견: **DiT는 LLM과 동일한 스케일링 법칙을 따릅니다.** 모델 크기를 키울수록 FID가 일관되게 개선됩니다.
패치 크기의 영향
| 모델 | patch_size | 토큰 수 | Gflops | FID |
| -------- | ---------- | ------- | ------ | ---- |
| DiT-XL/8 | 8 | 16 | 119 | 14.2 |
| DiT-XL/4 | 4 | 64 | 525 | 5.11 |
| DiT-XL/2 | 2 | 256 | 1121 | 2.27 |
패치 크기가 작을수록(토큰 수가 많을수록) 성능이 좋지만 연산량이 크게 증가합니다.
Classifier-Free Guidance 적용
def sample_with_cfg(model, z, class_labels, cfg_scale=4.0):
"""Classifier-Free Guidance로 샘플링"""
조건부 + 무조건부 예측을 동시에
z_combined = torch.cat([z, z], dim=0)
y_combined = torch.cat([class_labels, torch.full_like(class_labels, 1000)])
for t in reversed(range(1000)):
t_batch = torch.full((z_combined.shape[0],), t, device=z.device)
noise_pred = model(z_combined, t_batch, y_combined)
CFG 적용
cond_pred, uncond_pred = noise_pred.chunk(2)
guided_pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred)
DDPM step
z = ddpm_step(z, guided_pred, t)
return z
CFG scale=1.5에서 DiT-XL/2는 FID **2.27**을 달성하여 당시 ImageNet 256x256에서 SOTA를 기록했습니다.
DiT의 후속 영향
SORA (OpenAI, 2024)
- DiT를 비디오로 확장한 **Spatial-Temporal DiT**
- 3D 패치 임베딩 (공간 + 시간)
- 가변 해상도/길이 지원
Stable Diffusion 3 (Stability AI, 2024)
- DiT + Flow Matching
- MM-DiT: 텍스트와 이미지를 동일 Transformer에서 처리
- 별도의 text encoder 없이 joint attention
Dynamic DiT (ICLR 2025)
- Early Exit 메커니즘으로 추론 효율화
- 노이즈가 큰 초기 스텝에서는 얕은 레이어, 디테일 생성 시 깊은 레이어 사용
- 기존 DiT 대비 40% 추론 시간 절감
PyTorch 구현 실습
DiT 공식 코드 클론
git clone https://github.com/facebookresearch/DiT.git
cd DiT
환경 설정
pip install torch torchvision diffusers accelerate
사전학습 모델 다운로드
python download.py DiT-XL/2
샘플 생성
python sample.py \
--model DiT-XL/2 \
--image-size 256 \
--num-sampling-steps 250 \
--seed 42 \
--class-labels 207 360 387 974 88 979 417 279
커스텀 학습
from accelerate import Accelerator
from torch.utils.data import DataLoader
accelerator = Accelerator()
모델 초기화
model = DiT(
input_size=32, # VAE latent size
patch_size=2,
in_channels=4,
embed_dim=1152,
depth=28,
num_heads=16,
num_classes=1000,
)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-4,
weight_decay=0.0
)
model, optimizer, dataloader = accelerator.prepare(
model, optimizer, dataloader
)
학습 루프
for epoch in range(epochs):
for batch in dataloader:
images, labels = batch
VAE 인코딩
with torch.no_grad():
latents = vae.encode(images).latent_dist.sample() * 0.18215
랜덤 타임스텝
t = torch.randint(0, 1000, (latents.shape[0],), device=latents.device)
노이즈 추가
noise = torch.randn_like(latents)
noisy_latents = scheduler.add_noise(latents, noise, t)
노이즈 예측
noise_pred = model(noisy_latents, t, labels)
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
정리
DiT는 이미지 생성에서 **아키텍처 패러다임의 전환**을 이끌었습니다:
- U-Net의 복잡한 구조를 순수 Transformer로 대체
- LLM과 동일한 **스케일링 법칙**이 이미지 생성에도 적용됨을 입증
- adaLN-Zero로 깊은 Transformer의 안정적 학습 가능
- SORA, SD3 등 후속 모델의 기반 아키텍처로 자리잡음
**Q1. DiT가 U-Net 대신 Transformer를 사용하는 주된 이유는?**
검증된 스케일링 법칙, 단순한 아키텍처, 다른 모달리티와의 통합 용이성 때문입니다.
**Q2. DiT에서 이미지는 어떻게 토큰 시퀀스로 변환되나요?**
VAE로 잠재 표현(latent)으로 변환한 후, 패치 단위로 분할하여 선형 투영합니다.
**Q3. adaLN-Zero의 "Zero"가 의미하는 것은?**
gate 파라미터를 0으로 초기화하여 학습 초기에 각 블록이 identity function으로 동작하게 합니다.
**Q4. DiT에서 실험한 4가지 조건화 방식 중 최종 선택된 것은?**
adaLN-Zero입니다. In-context, Cross-Attention, adaLN, adaLN-Zero 중 가장 좋은 FID를 기록했습니다.
**Q5. 패치 크기가 작을수록 어떤 트레이드오프가 있나요?**
토큰 수가 증가하여 이미지 품질(FID)은 좋아지지만, 연산량(Gflops)이 제곱에 비례하여 증가합니다.
**Q6. Classifier-Free Guidance는 DiT에서 어떻게 구현되나요?**
클래스 레이블 드롭아웃으로 학습하고, 추론 시 조건부/무조건부 예측의 가중 합으로 샘플링합니다.
**Q7. DiT-XL/2의 ImageNet 256x256 FID-50K 점수는?**
2.27로, 당시 class-conditional 생성 SOTA를 달성했습니다.
**Q8. Dynamic DiT(ICLR 2025)의 핵심 아이디어는?**
Early Exit 메커니즘으로 노이즈 수준에 따라 사용하는 레이어 수를 동적으로 조절하여 추론 시간을 40% 절감합니다.
현재 단락 (1/230)
2022년 말 발표된 **"Scalable Diffusion Models with Transformers"** (William Peebles & Saining Xie, 2023 I...