Skip to content

필사 모드: Diffusion Transformer(DiT) 아키텍처 분석: U-Net에서 Transformer로의 전환

한국어
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

들어가며

2022년 말 발표된 **"Scalable Diffusion Models with Transformers"** (William Peebles & Saining Xie, 2023 ICCV)는 이미지 생성의 백본을 U-Net에서 Transformer로 전환하는 시발점이 되었습니다. 이 논문에서 제안한 DiT(Diffusion Transformer)는 이후 OpenAI의 SORA, Stability AI의 Stable Diffusion 3, 그리고 2025년 ICLR에서 발표된 Dynamic DiT까지 이어지는 핵심 아키텍처가 되었습니다.

배경: 왜 U-Net을 버리는가?

U-Net의 한계

기존 DDPM, LDM(Latent Diffusion Model)은 U-Net을 백본으로 사용합니다. U-Net은 이미지 생성에서 좋은 성능을 보이지만 근본적인 한계가 있습니다:

- **스케일링 법칙 부재**: 모델을 키워도 성능 향상이 예측 불가능

- **아키텍처 복잡성**: skip connection, 다양한 해상도의 feature map 관리

- **연산 비효율**: 고해상도에서 메모리/연산 비용이 급증

- **다른 모달리티와의 통합 어려움**: 텍스트, 비디오 등과의 통합이 비자연적

Transformer의 장점

반면 Transformer(ViT)는:

- **검증된 스케일링 법칙**: 파라미터 수에 비례한 성능 향상 (LLM에서 입증)

- **단순한 아키텍처**: Self-attention + FFN의 반복

- **모달리티 불문**: 텍스트, 이미지, 비디오 모두 동일 구조로 처리 가능

- **하드웨어 최적화**: GPU/TPU에서의 효율적인 병렬 처리

DiT 아키텍처 상세

전체 파이프라인

DiT는 Latent Diffusion 프레임워크 위에서 동작합니다:

1. 이미지를 VAE 인코더로 잠재 공간(latent space)으로 변환

2. 잠재 표현을 패치로 분할 (ViT와 동일)

3. **DiT 블록**(Transformer)으로 노이즈 예측

4. VAE 디코더로 이미지 복원

패치 임베딩

class PatchEmbed(nn.Module):

"""잠재 표현을 패치 시퀀스로 변환"""

def __init__(self, patch_size=2, in_channels=4, embed_dim=1152):

super().__init__()

self.proj = nn.Conv2d(

in_channels, embed_dim,

kernel_size=patch_size, stride=patch_size

)

def forward(self, x):

x: (B, C, H, W) -> (B, N, D)

예: (B, 4, 32, 32) -> (B, 256, 1152) (patch_size=2)

x = self.proj(x) # (B, D, H/p, W/p)

x = x.flatten(2).transpose(1, 2) # (B, N, D)

return x

입력은 256x256 이미지를 VAE로 32x32 잠재 표현으로 변환한 후, patch_size=2로 분할하면 16x16=256개의 토큰이 됩니다.

조건화 메커니즘 — 4가지 변형

DiT는 타임스텝(t)과 클래스 레이블(c)을 주입하는 4가지 방식을 실험합니다:

1. In-context Conditioning

class InContextConditioning(nn.Module):

"""t와 c를 추가 토큰으로 시퀀스에 연결"""

def __init__(self, embed_dim):

super().__init__()

self.t_embed = TimestepEmbedder(embed_dim)

self.c_embed = LabelEmbedder(1000, embed_dim)

def forward(self, x, t, c):

t_token = self.t_embed(t).unsqueeze(1) # (B, 1, D)

c_token = self.c_embed(c).unsqueeze(1) # (B, 1, D)

x = torch.cat([t_token, c_token, x], dim=1) # (B, N+2, D)

return x

2. Cross-Attention

class CrossAttentionBlock(nn.Module):

"""t와 c를 cross-attention으로 주입"""

def __init__(self, embed_dim, num_heads):

super().__init__()

self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)

self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads)

self.mlp = MLP(embed_dim)

def forward(self, x, cond):

x = x + self.self_attn(x, x, x)[0]

x = x + self.cross_attn(x, cond, cond)[0]

x = x + self.mlp(x)

return x

3. Adaptive Layer Norm (adaLN)

class AdaLNBlock(nn.Module):

"""조건에 따라 LayerNorm의 scale/shift를 조절"""

def __init__(self, embed_dim, num_heads):

super().__init__()

self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False)

self.attn = nn.MultiheadAttention(embed_dim, num_heads)

self.adaLN_modulation = nn.Sequential(

nn.SiLU(),

nn.Linear(embed_dim, 2 * embed_dim)

)

def forward(self, x, cond):

shift, scale = self.adaLN_modulation(cond).chunk(2, dim=-1)

x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

x = x + self.attn(x, x, x)[0]

return x

4. adaLN-Zero (최종 선택)

class DiTBlock(nn.Module):

"""adaLN-Zero: 초기화 시 잔차 연결을 0으로 시작"""

def __init__(self, embed_dim, num_heads):

super().__init__()

self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=False)

self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=False)

self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

self.mlp = nn.Sequential(

nn.Linear(embed_dim, 4 * embed_dim),

nn.GELU(approximate="tanh"),

nn.Linear(4 * embed_dim, embed_dim),

)

6개 파라미터: gamma1, beta1, alpha1, gamma2, beta2, alpha2

self.adaLN_modulation = nn.Sequential(

nn.SiLU(),

nn.Linear(embed_dim, 6 * embed_dim)

)

alpha를 0으로 초기화 -> 학습 초기에 identity function

nn.init.zeros_(self.adaLN_modulation[-1].weight)

nn.init.zeros_(self.adaLN_modulation[-1].bias)

def forward(self, x, cond):

cond: (B, D) - 타임스텝 + 클래스 임베딩의 합

shift_msa, scale_msa, gate_msa, \

shift_mlp, scale_mlp, gate_mlp = \

self.adaLN_modulation(cond).chunk(6, dim=-1)

Self-Attention with adaLN

h = self.norm1(x)

h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)

h = self.attn(h, h, h)[0]

x = x + gate_msa.unsqueeze(1) * h # gate=0으로 시작

FFN with adaLN

h = self.norm2(x)

h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)

h = self.mlp(h)

x = x + gate_mlp.unsqueeze(1) * h # gate=0으로 시작

return x

**adaLN-Zero의 핵심**: gate 파라미터를 0으로 초기화하여 학습 초기에 각 DiT 블록이 identity function으로 동작합니다. 이는 깊은 네트워크의 학습 안정성을 크게 향상시킵니다.

전체 DiT 모델

class DiT(nn.Module):

def __init__(

self,

input_size=32,

patch_size=2,

in_channels=4,

embed_dim=1152,

depth=28,

num_heads=16,

num_classes=1000,

):

super().__init__()

self.patch_embed = PatchEmbed(patch_size, in_channels, embed_dim)

num_patches = (input_size // patch_size) ** 2 # 256

위치 임베딩

self.pos_embed = nn.Parameter(

torch.zeros(1, num_patches, embed_dim)

)

타임스텝 & 클래스 임베딩

self.t_embedder = TimestepEmbedder(embed_dim)

self.y_embedder = LabelEmbedder(num_classes, embed_dim)

DiT 블록

self.blocks = nn.ModuleList([

DiTBlock(embed_dim, num_heads) for _ in range(depth)

])

최종 레이어

self.final_layer = FinalLayer(embed_dim, patch_size, in_channels * 2)

def forward(self, x, t, y):

패치 임베딩 + 위치 임베딩

x = self.patch_embed(x) + self.pos_embed

조건 임베딩

t_emb = self.t_embedder(t) # (B, D)

y_emb = self.y_embedder(y) # (B, D)

cond = t_emb + y_emb # (B, D)

DiT 블록 통과

for block in self.blocks:

x = block(x, cond)

노이즈와 분산 예측

x = self.final_layer(x, cond) # (B, N, 2*C*p*p)

return x

모델 변형과 스케일링

| 모델 | depth | embed_dim | heads | 파라미터 | FID-50K |

| -------- | ----- | --------- | ----- | -------- | -------- |

| DiT-S/2 | 12 | 384 | 6 | 33M | 68.4 |

| DiT-B/2 | 12 | 768 | 12 | 130M | 43.5 |

| DiT-L/2 | 24 | 1024 | 16 | 458M | 9.62 |

| DiT-XL/2 | 28 | 1152 | 16 | 675M | **2.27** |

핵심 발견: **DiT는 LLM과 동일한 스케일링 법칙을 따릅니다.** 모델 크기를 키울수록 FID가 일관되게 개선됩니다.

패치 크기의 영향

| 모델 | patch_size | 토큰 수 | Gflops | FID |

| -------- | ---------- | ------- | ------ | ---- |

| DiT-XL/8 | 8 | 16 | 119 | 14.2 |

| DiT-XL/4 | 4 | 64 | 525 | 5.11 |

| DiT-XL/2 | 2 | 256 | 1121 | 2.27 |

패치 크기가 작을수록(토큰 수가 많을수록) 성능이 좋지만 연산량이 크게 증가합니다.

Classifier-Free Guidance 적용

def sample_with_cfg(model, z, class_labels, cfg_scale=4.0):

"""Classifier-Free Guidance로 샘플링"""

조건부 + 무조건부 예측을 동시에

z_combined = torch.cat([z, z], dim=0)

y_combined = torch.cat([class_labels, torch.full_like(class_labels, 1000)])

for t in reversed(range(1000)):

t_batch = torch.full((z_combined.shape[0],), t, device=z.device)

noise_pred = model(z_combined, t_batch, y_combined)

CFG 적용

cond_pred, uncond_pred = noise_pred.chunk(2)

guided_pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred)

DDPM step

z = ddpm_step(z, guided_pred, t)

return z

CFG scale=1.5에서 DiT-XL/2는 FID **2.27**을 달성하여 당시 ImageNet 256x256에서 SOTA를 기록했습니다.

DiT의 후속 영향

SORA (OpenAI, 2024)

- DiT를 비디오로 확장한 **Spatial-Temporal DiT**

- 3D 패치 임베딩 (공간 + 시간)

- 가변 해상도/길이 지원

Stable Diffusion 3 (Stability AI, 2024)

- DiT + Flow Matching

- MM-DiT: 텍스트와 이미지를 동일 Transformer에서 처리

- 별도의 text encoder 없이 joint attention

Dynamic DiT (ICLR 2025)

- Early Exit 메커니즘으로 추론 효율화

- 노이즈가 큰 초기 스텝에서는 얕은 레이어, 디테일 생성 시 깊은 레이어 사용

- 기존 DiT 대비 40% 추론 시간 절감

PyTorch 구현 실습

DiT 공식 코드 클론

git clone https://github.com/facebookresearch/DiT.git

cd DiT

환경 설정

pip install torch torchvision diffusers accelerate

사전학습 모델 다운로드

python download.py DiT-XL/2

샘플 생성

python sample.py \

--model DiT-XL/2 \

--image-size 256 \

--num-sampling-steps 250 \

--seed 42 \

--class-labels 207 360 387 974 88 979 417 279

커스텀 학습

from accelerate import Accelerator

from torch.utils.data import DataLoader

accelerator = Accelerator()

모델 초기화

model = DiT(

input_size=32, # VAE latent size

patch_size=2,

in_channels=4,

embed_dim=1152,

depth=28,

num_heads=16,

num_classes=1000,

)

optimizer = torch.optim.AdamW(

model.parameters(),

lr=1e-4,

weight_decay=0.0

)

model, optimizer, dataloader = accelerator.prepare(

model, optimizer, dataloader

)

학습 루프

for epoch in range(epochs):

for batch in dataloader:

images, labels = batch

VAE 인코딩

with torch.no_grad():

latents = vae.encode(images).latent_dist.sample() * 0.18215

랜덤 타임스텝

t = torch.randint(0, 1000, (latents.shape[0],), device=latents.device)

노이즈 추가

noise = torch.randn_like(latents)

noisy_latents = scheduler.add_noise(latents, noise, t)

노이즈 예측

noise_pred = model(noisy_latents, t, labels)

loss = F.mse_loss(noise_pred, noise)

accelerator.backward(loss)

optimizer.step()

optimizer.zero_grad()

정리

DiT는 이미지 생성에서 **아키텍처 패러다임의 전환**을 이끌었습니다:

- U-Net의 복잡한 구조를 순수 Transformer로 대체

- LLM과 동일한 **스케일링 법칙**이 이미지 생성에도 적용됨을 입증

- adaLN-Zero로 깊은 Transformer의 안정적 학습 가능

- SORA, SD3 등 후속 모델의 기반 아키텍처로 자리잡음

**Q1. DiT가 U-Net 대신 Transformer를 사용하는 주된 이유는?**

검증된 스케일링 법칙, 단순한 아키텍처, 다른 모달리티와의 통합 용이성 때문입니다.

**Q2. DiT에서 이미지는 어떻게 토큰 시퀀스로 변환되나요?**

VAE로 잠재 표현(latent)으로 변환한 후, 패치 단위로 분할하여 선형 투영합니다.

**Q3. adaLN-Zero의 "Zero"가 의미하는 것은?**

gate 파라미터를 0으로 초기화하여 학습 초기에 각 블록이 identity function으로 동작하게 합니다.

**Q4. DiT에서 실험한 4가지 조건화 방식 중 최종 선택된 것은?**

adaLN-Zero입니다. In-context, Cross-Attention, adaLN, adaLN-Zero 중 가장 좋은 FID를 기록했습니다.

**Q5. 패치 크기가 작을수록 어떤 트레이드오프가 있나요?**

토큰 수가 증가하여 이미지 품질(FID)은 좋아지지만, 연산량(Gflops)이 제곱에 비례하여 증가합니다.

**Q6. Classifier-Free Guidance는 DiT에서 어떻게 구현되나요?**

클래스 레이블 드롭아웃으로 학습하고, 추론 시 조건부/무조건부 예측의 가중 합으로 샘플링합니다.

**Q7. DiT-XL/2의 ImageNet 256x256 FID-50K 점수는?**

2.27로, 당시 class-conditional 생성 SOTA를 달성했습니다.

**Q8. Dynamic DiT(ICLR 2025)의 핵심 아이디어는?**

Early Exit 메커니즘으로 노이즈 수준에 따라 사용하는 레이어 수를 동적으로 조절하여 추론 시간을 40% 절감합니다.

현재 단락 (1/230)

2022년 말 발표된 **"Scalable Diffusion Models with Transformers"** (William Peebles & Saining Xie, 2023 I...

작성 글자: 0원문 글자: 8,100작성 단락: 0/230