Skip to content

Split View: Diffusion Transformer(DiT) 아키텍처 분석: U-Net에서 Transformer로의 전환

|

Diffusion Transformer(DiT) 아키텍처 분석: U-Net에서 Transformer로의 전환

들어가며

2022년 말 발표된 "Scalable Diffusion Models with Transformers" (William Peebles & Saining Xie, 2023 ICCV)는 이미지 생성의 백본을 U-Net에서 Transformer로 전환하는 시발점이 되었습니다. 이 논문에서 제안한 DiT(Diffusion Transformer)는 이후 OpenAI의 SORA, Stability AI의 Stable Diffusion 3, 그리고 2025년 ICLR에서 발표된 Dynamic DiT까지 이어지는 핵심 아키텍처가 되었습니다.

배경: 왜 U-Net을 버리는가?

U-Net의 한계

기존 DDPM, LDM(Latent Diffusion Model)은 U-Net을 백본으로 사용합니다. U-Net은 이미지 생성에서 좋은 성능을 보이지만 근본적인 한계가 있습니다:

  • 스케일링 법칙 부재: 모델을 키워도 성능 향상이 예측 불가능
  • 아키텍처 복잡성: skip connection, 다양한 해상도의 feature map 관리
  • 연산 비효율: 고해상도에서 메모리/연산 비용이 급증
  • 다른 모달리티와의 통합 어려움: 텍스트, 비디오 등과의 통합이 비자연적

Transformer의 장점

반면 Transformer(ViT)는:

  • 검증된 스케일링 법칙: 파라미터 수에 비례한 성능 향상 (LLM에서 입증)
  • 단순한 아키텍처: Self-attention + FFN의 반복
  • 모달리티 불문: 텍스트, 이미지, 비디오 모두 동일 구조로 처리 가능
  • 하드웨어 최적화: GPU/TPU에서의 효율적인 병렬 처리

DiT 아키텍처 상세

전체 파이프라인

DiT는 Latent Diffusion 프레임워크 위에서 동작합니다:

  1. 이미지를 VAE 인코더로 잠재 공간(latent space)으로 변환
  2. 잠재 표현을 패치로 분할 (ViT와 동일)
  3. DiT 블록(Transformer)으로 노이즈 예측
  4. VAE 디코더로 이미지 복원

패치 임베딩

import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    """잠재 표현을 패치 시퀀스로 변환"""
    def __init__(self, patch_size=2, in_channels=4, embed_dim=1152):
        super().__init__()
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # x: (B, C, H, W) -> (B, N, D)
        # 예: (B, 4, 32, 32) -> (B, 256, 1152) (patch_size=2)
        x = self.proj(x)  # (B, D, H/p, W/p)
        x = x.flatten(2).transpose(1, 2)  # (B, N, D)
        return x

입력은 256x256 이미지를 VAE로 32x32 잠재 표현으로 변환한 후, patch_size=2로 분할하면 16x16=256개의 토큰이 됩니다.

조건화 메커니즘 — 4가지 변형

DiT는 타임스텝(t)과 클래스 레이블(c)을 주입하는 4가지 방식을 실험합니다:

1. In-context Conditioning

class InContextConditioning(nn.Module):
    """t와 c를 추가 토큰으로 시퀀스에 연결"""
    def __init__(self, embed_dim):
        super().__init__()
        self.t_embed = TimestepEmbedder(embed_dim)
        self.c_embed = LabelEmbedder(1000, embed_dim)

    def forward(self, x, t, c):
        t_token = self.t_embed(t).unsqueeze(1)  # (B, 1, D)
        c_token = self.c_embed(c).unsqueeze(1)  # (B, 1, D)
        x = torch.cat([t_token, c_token, x], dim=1)  # (B, N+2, D)
        return x

2. Cross-Attention

class CrossAttentionBlock(nn.Module):
    """t와 c를 cross-attention으로 주입"""
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.mlp = MLP(embed_dim)

    def forward(self, x, cond):
        x = x + self.self_attn(x, x, x)[0]
        x = x + self.cross_attn(x, cond, cond)[0]
        x = x + self.mlp(x)
        return x

3. Adaptive Layer Norm (adaLN)

class AdaLNBlock(nn.Module):
    """조건에 따라 LayerNorm의 scale/shift를 조절"""
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(embed_dim, 2 * embed_dim)
        )

    def forward(self, x, cond):
        shift, scale = self.adaLN_modulation(cond).chunk(2, dim=-1)
        x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
        x = x + self.attn(x, x, x)[0]
        return x

4. adaLN-Zero (최종 선택)

class DiTBlock(nn.Module):
    """adaLN-Zero: 초기화 시 잔차 연결을 0으로 시작"""
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(4 * embed_dim, embed_dim),
        )
        # 6개 파라미터: gamma1, beta1, alpha1, gamma2, beta2, alpha2
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(embed_dim, 6 * embed_dim)
        )
        # alpha를 0으로 초기화 -> 학습 초기에 identity function
        nn.init.zeros_(self.adaLN_modulation[-1].weight)
        nn.init.zeros_(self.adaLN_modulation[-1].bias)

    def forward(self, x, cond):
        # cond: (B, D) - 타임스텝 + 클래스 임베딩의 합
        shift_msa, scale_msa, gate_msa, \
        shift_mlp, scale_mlp, gate_mlp = \
            self.adaLN_modulation(cond).chunk(6, dim=-1)

        # Self-Attention with adaLN
        h = self.norm1(x)
        h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
        h = self.attn(h, h, h)[0]
        x = x + gate_msa.unsqueeze(1) * h  # gate=0으로 시작

        # FFN with adaLN
        h = self.norm2(x)
        h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
        h = self.mlp(h)
        x = x + gate_mlp.unsqueeze(1) * h  # gate=0으로 시작

        return x

adaLN-Zero의 핵심: gate 파라미터를 0으로 초기화하여 학습 초기에 각 DiT 블록이 identity function으로 동작합니다. 이는 깊은 네트워크의 학습 안정성을 크게 향상시킵니다.

전체 DiT 모델

class DiT(nn.Module):
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        embed_dim=1152,
        depth=28,
        num_heads=16,
        num_classes=1000,
    ):
        super().__init__()
        self.patch_embed = PatchEmbed(patch_size, in_channels, embed_dim)
        num_patches = (input_size // patch_size) ** 2  # 256

        # 위치 임베딩
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, embed_dim)
        )

        # 타임스텝 & 클래스 임베딩
        self.t_embedder = TimestepEmbedder(embed_dim)
        self.y_embedder = LabelEmbedder(num_classes, embed_dim)

        # DiT 블록
        self.blocks = nn.ModuleList([
            DiTBlock(embed_dim, num_heads) for _ in range(depth)
        ])

        # 최종 레이어
        self.final_layer = FinalLayer(embed_dim, patch_size, in_channels * 2)

    def forward(self, x, t, y):
        # 패치 임베딩 + 위치 임베딩
        x = self.patch_embed(x) + self.pos_embed

        # 조건 임베딩
        t_emb = self.t_embedder(t)  # (B, D)
        y_emb = self.y_embedder(y)  # (B, D)
        cond = t_emb + y_emb        # (B, D)

        # DiT 블록 통과
        for block in self.blocks:
            x = block(x, cond)

        # 노이즈와 분산 예측
        x = self.final_layer(x, cond)  # (B, N, 2*C*p*p)
        return x

모델 변형과 스케일링

모델depthembed_dimheads파라미터FID-50K
DiT-S/212384633M68.4
DiT-B/21276812130M43.5
DiT-L/224102416458M9.62
DiT-XL/228115216675M2.27

핵심 발견: DiT는 LLM과 동일한 스케일링 법칙을 따릅니다. 모델 크기를 키울수록 FID가 일관되게 개선됩니다.

패치 크기의 영향

모델patch_size토큰 수GflopsFID
DiT-XL/881611914.2
DiT-XL/44645255.11
DiT-XL/2225611212.27

패치 크기가 작을수록(토큰 수가 많을수록) 성능이 좋지만 연산량이 크게 증가합니다.

Classifier-Free Guidance 적용

def sample_with_cfg(model, z, class_labels, cfg_scale=4.0):
    """Classifier-Free Guidance로 샘플링"""
    # 조건부 + 무조건부 예측을 동시에
    z_combined = torch.cat([z, z], dim=0)
    y_combined = torch.cat([class_labels, torch.full_like(class_labels, 1000)])

    for t in reversed(range(1000)):
        t_batch = torch.full((z_combined.shape[0],), t, device=z.device)
        noise_pred = model(z_combined, t_batch, y_combined)

        # CFG 적용
        cond_pred, uncond_pred = noise_pred.chunk(2)
        guided_pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred)

        # DDPM step
        z = ddpm_step(z, guided_pred, t)

    return z

CFG scale=1.5에서 DiT-XL/2는 FID 2.27을 달성하여 당시 ImageNet 256x256에서 SOTA를 기록했습니다.

DiT의 후속 영향

SORA (OpenAI, 2024)

  • DiT를 비디오로 확장한 Spatial-Temporal DiT
  • 3D 패치 임베딩 (공간 + 시간)
  • 가변 해상도/길이 지원

Stable Diffusion 3 (Stability AI, 2024)

  • DiT + Flow Matching
  • MM-DiT: 텍스트와 이미지를 동일 Transformer에서 처리
  • 별도의 text encoder 없이 joint attention

Dynamic DiT (ICLR 2025)

  • Early Exit 메커니즘으로 추론 효율화
  • 노이즈가 큰 초기 스텝에서는 얕은 레이어, 디테일 생성 시 깊은 레이어 사용
  • 기존 DiT 대비 40% 추론 시간 절감

PyTorch 구현 실습

# DiT 공식 코드 클론
git clone https://github.com/facebookresearch/DiT.git
cd DiT

# 환경 설정
pip install torch torchvision diffusers accelerate

# 사전학습 모델 다운로드
python download.py DiT-XL/2

# 샘플 생성
python sample.py \
  --model DiT-XL/2 \
  --image-size 256 \
  --num-sampling-steps 250 \
  --seed 42 \
  --class-labels 207 360 387 974 88 979 417 279

커스텀 학습

from accelerate import Accelerator
from torch.utils.data import DataLoader

accelerator = Accelerator()

# 모델 초기화
model = DiT(
    input_size=32,  # VAE latent size
    patch_size=2,
    in_channels=4,
    embed_dim=1152,
    depth=28,
    num_heads=16,
    num_classes=1000,
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=0.0
)

model, optimizer, dataloader = accelerator.prepare(
    model, optimizer, dataloader
)

# 학습 루프
for epoch in range(epochs):
    for batch in dataloader:
        images, labels = batch
        # VAE 인코딩
        with torch.no_grad():
            latents = vae.encode(images).latent_dist.sample() * 0.18215

        # 랜덤 타임스텝
        t = torch.randint(0, 1000, (latents.shape[0],), device=latents.device)

        # 노이즈 추가
        noise = torch.randn_like(latents)
        noisy_latents = scheduler.add_noise(latents, noise, t)

        # 노이즈 예측
        noise_pred = model(noisy_latents, t, labels)
        loss = F.mse_loss(noise_pred, noise)

        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

정리

DiT는 이미지 생성에서 아키텍처 패러다임의 전환을 이끌었습니다:

  • U-Net의 복잡한 구조를 순수 Transformer로 대체
  • LLM과 동일한 스케일링 법칙이 이미지 생성에도 적용됨을 입증
  • adaLN-Zero로 깊은 Transformer의 안정적 학습 가능
  • SORA, SD3 등 후속 모델의 기반 아키텍처로 자리잡음

✅ 퀴즈: DiT 아키텍처 이해도 점검 (8문제)

Q1. DiT가 U-Net 대신 Transformer를 사용하는 주된 이유는?

검증된 스케일링 법칙, 단순한 아키텍처, 다른 모달리티와의 통합 용이성 때문입니다.

Q2. DiT에서 이미지는 어떻게 토큰 시퀀스로 변환되나요?

VAE로 잠재 표현(latent)으로 변환한 후, 패치 단위로 분할하여 선형 투영합니다.

Q3. adaLN-Zero의 "Zero"가 의미하는 것은?

gate 파라미터를 0으로 초기화하여 학습 초기에 각 블록이 identity function으로 동작하게 합니다.

Q4. DiT에서 실험한 4가지 조건화 방식 중 최종 선택된 것은?

adaLN-Zero입니다. In-context, Cross-Attention, adaLN, adaLN-Zero 중 가장 좋은 FID를 기록했습니다.

Q5. 패치 크기가 작을수록 어떤 트레이드오프가 있나요?

토큰 수가 증가하여 이미지 품질(FID)은 좋아지지만, 연산량(Gflops)이 제곱에 비례하여 증가합니다.

Q6. Classifier-Free Guidance는 DiT에서 어떻게 구현되나요?

클래스 레이블 드롭아웃으로 학습하고, 추론 시 조건부/무조건부 예측의 가중 합으로 샘플링합니다.

Q7. DiT-XL/2의 ImageNet 256x256 FID-50K 점수는?

2.27로, 당시 class-conditional 생성 SOTA를 달성했습니다.

Q8. Dynamic DiT(ICLR 2025)의 핵심 아이디어는?

Early Exit 메커니즘으로 노이즈 수준에 따라 사용하는 레이어 수를 동적으로 조절하여 추론 시간을 40% 절감합니다.

Diffusion Transformer (DiT) Architecture Analysis: The Shift from U-Net to Transformer

Introduction

Published in late 2022, "Scalable Diffusion Models with Transformers" (William Peebles & Saining Xie, ICCV 2023) marked a pivotal turning point in shifting the image generation backbone from U-Net to Transformer. The DiT (Diffusion Transformer) architecture proposed in this paper has since become the foundation for OpenAI's SORA, Stability AI's Stable Diffusion 3, and Dynamic DiT presented at ICLR 2025.

Background: Why Move Away from U-Net?

Limitations of U-Net

Traditional DDPM and LDM (Latent Diffusion Models) use U-Net as the backbone. While U-Net performs well in image generation, it has fundamental limitations:

  • No scaling laws: Performance gains are unpredictable when scaling up the model
  • Architectural complexity: Managing skip connections and feature maps at various resolutions
  • Computational inefficiency: Memory and compute costs surge at high resolutions
  • Difficult multi-modal integration: Unnatural integration with text, video, and other modalities

Advantages of Transformers

In contrast, Transformers (ViT) offer:

  • Proven scaling laws: Performance improves proportionally with parameter count (demonstrated in LLMs)
  • Simple architecture: Repetition of Self-attention + FFN blocks
  • Modality agnostic: Text, images, and video can all be processed with the same structure
  • Hardware optimization: Efficient parallel processing on GPUs/TPUs

DiT Architecture in Detail

Overall Pipeline

DiT operates on top of the Latent Diffusion framework:

  1. Encode images into latent space using a VAE encoder
  2. Split the latent representation into patches (same as ViT)
  3. Predict noise using DiT blocks (Transformer)
  4. Reconstruct images with the VAE decoder

Patch Embedding

import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    """Convert latent representations into a patch sequence"""
    def __init__(self, patch_size=2, in_channels=4, embed_dim=1152):
        super().__init__()
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # x: (B, C, H, W) -> (B, N, D)
        # e.g.: (B, 4, 32, 32) -> (B, 256, 1152) (patch_size=2)
        x = self.proj(x)  # (B, D, H/p, W/p)
        x = x.flatten(2).transpose(1, 2)  # (B, N, D)
        return x

The input is a 256x256 image converted to a 32x32 latent representation via VAE, then split with patch_size=2 to yield 16x16=256 tokens.

Conditioning Mechanisms — Four Variants

DiT experiments with four approaches for injecting the timestep (t) and class label (c):

1. In-context Conditioning

class InContextConditioning(nn.Module):
    """Concatenate t and c as additional tokens to the sequence"""
    def __init__(self, embed_dim):
        super().__init__()
        self.t_embed = TimestepEmbedder(embed_dim)
        self.c_embed = LabelEmbedder(1000, embed_dim)

    def forward(self, x, t, c):
        t_token = self.t_embed(t).unsqueeze(1)  # (B, 1, D)
        c_token = self.c_embed(c).unsqueeze(1)  # (B, 1, D)
        x = torch.cat([t_token, c_token, x], dim=1)  # (B, N+2, D)
        return x

2. Cross-Attention

class CrossAttentionBlock(nn.Module):
    """Inject t and c via cross-attention"""
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.mlp = MLP(embed_dim)

    def forward(self, x, cond):
        x = x + self.self_attn(x, x, x)[0]
        x = x + self.cross_attn(x, cond, cond)[0]
        x = x + self.mlp(x)
        return x

3. Adaptive Layer Norm (adaLN)

class AdaLNBlock(nn.Module):
    """Adjust LayerNorm scale/shift based on conditions"""
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(embed_dim, 2 * embed_dim)
        )

    def forward(self, x, cond):
        shift, scale = self.adaLN_modulation(cond).chunk(2, dim=-1)
        x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
        x = x + self.attn(x, x, x)[0]
        return x

4. adaLN-Zero (Final Choice)

class DiTBlock(nn.Module):
    """adaLN-Zero: Initialize residual connections to start at zero"""
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(4 * embed_dim, embed_dim),
        )
        # 6 parameters: gamma1, beta1, alpha1, gamma2, beta2, alpha2
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(embed_dim, 6 * embed_dim)
        )
        # Initialize alpha to 0 -> identity function at the start of training
        nn.init.zeros_(self.adaLN_modulation[-1].weight)
        nn.init.zeros_(self.adaLN_modulation[-1].bias)

    def forward(self, x, cond):
        # cond: (B, D) - sum of timestep + class embeddings
        shift_msa, scale_msa, gate_msa, \
        shift_mlp, scale_mlp, gate_mlp = \
            self.adaLN_modulation(cond).chunk(6, dim=-1)

        # Self-Attention with adaLN
        h = self.norm1(x)
        h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
        h = self.attn(h, h, h)[0]
        x = x + gate_msa.unsqueeze(1) * h  # gate starts at 0

        # FFN with adaLN
        h = self.norm2(x)
        h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
        h = self.mlp(h)
        x = x + gate_mlp.unsqueeze(1) * h  # gate starts at 0

        return x

The key insight of adaLN-Zero: By initializing gate parameters to zero, each DiT block acts as an identity function at the start of training. This significantly improves training stability for deep networks.

Full DiT Model

class DiT(nn.Module):
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        embed_dim=1152,
        depth=28,
        num_heads=16,
        num_classes=1000,
    ):
        super().__init__()
        self.patch_embed = PatchEmbed(patch_size, in_channels, embed_dim)
        num_patches = (input_size // patch_size) ** 2  # 256

        # Positional embedding
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, embed_dim)
        )

        # Timestep & class embeddings
        self.t_embedder = TimestepEmbedder(embed_dim)
        self.y_embedder = LabelEmbedder(num_classes, embed_dim)

        # DiT blocks
        self.blocks = nn.ModuleList([
            DiTBlock(embed_dim, num_heads) for _ in range(depth)
        ])

        # Final layer
        self.final_layer = FinalLayer(embed_dim, patch_size, in_channels * 2)

    def forward(self, x, t, y):
        # Patch embedding + positional embedding
        x = self.patch_embed(x) + self.pos_embed

        # Condition embedding
        t_emb = self.t_embedder(t)  # (B, D)
        y_emb = self.y_embedder(y)  # (B, D)
        cond = t_emb + y_emb        # (B, D)

        # Pass through DiT blocks
        for block in self.blocks:
            x = block(x, cond)

        # Predict noise and variance
        x = self.final_layer(x, cond)  # (B, N, 2*C*p*p)
        return x

Model Variants and Scaling

Modeldepthembed_dimheadsParametersFID-50K
DiT-S/212384633M68.4
DiT-B/21276812130M43.5
DiT-L/224102416458M9.62
DiT-XL/228115216675M2.27

Key finding: DiT follows the same scaling laws as LLMs. FID consistently improves as model size increases.

Impact of Patch Size

Modelpatch_sizeTokensGflopsFID
DiT-XL/881611914.2
DiT-XL/44645255.11
DiT-XL/2225611212.27

Smaller patch sizes (more tokens) yield better quality but significantly increase computation.

Classifier-Free Guidance

def sample_with_cfg(model, z, class_labels, cfg_scale=4.0):
    """Sampling with Classifier-Free Guidance"""
    # Conditional + unconditional predictions simultaneously
    z_combined = torch.cat([z, z], dim=0)
    y_combined = torch.cat([class_labels, torch.full_like(class_labels, 1000)])

    for t in reversed(range(1000)):
        t_batch = torch.full((z_combined.shape[0],), t, device=z.device)
        noise_pred = model(z_combined, t_batch, y_combined)

        # Apply CFG
        cond_pred, uncond_pred = noise_pred.chunk(2)
        guided_pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred)

        # DDPM step
        z = ddpm_step(z, guided_pred, t)

    return z

At CFG scale=1.5, DiT-XL/2 achieves an FID of 2.27, setting the SOTA on ImageNet 256x256 at the time.

Downstream Impact of DiT

SORA (OpenAI, 2024)

  • Extended DiT to video as Spatial-Temporal DiT
  • 3D patch embedding (spatial + temporal)
  • Variable resolution and duration support

Stable Diffusion 3 (Stability AI, 2024)

  • DiT + Flow Matching
  • MM-DiT: Processes text and images in the same Transformer
  • Joint attention without a separate text encoder

Dynamic DiT (ICLR 2025)

  • Early Exit mechanism for inference efficiency
  • Uses shallow layers for noisy early steps, deeper layers for detail generation
  • 40% inference time reduction compared to standard DiT

PyTorch Implementation Walkthrough

# Clone the official DiT code
git clone https://github.com/facebookresearch/DiT.git
cd DiT

# Set up environment
pip install torch torchvision diffusers accelerate

# Download pretrained model
python download.py DiT-XL/2

# Generate samples
python sample.py \
  --model DiT-XL/2 \
  --image-size 256 \
  --num-sampling-steps 250 \
  --seed 42 \
  --class-labels 207 360 387 974 88 979 417 279

Custom Training

from accelerate import Accelerator
from torch.utils.data import DataLoader

accelerator = Accelerator()

# Initialize model
model = DiT(
    input_size=32,  # VAE latent size
    patch_size=2,
    in_channels=4,
    embed_dim=1152,
    depth=28,
    num_heads=16,
    num_classes=1000,
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=0.0
)

model, optimizer, dataloader = accelerator.prepare(
    model, optimizer, dataloader
)

# Training loop
for epoch in range(epochs):
    for batch in dataloader:
        images, labels = batch
        # VAE encoding
        with torch.no_grad():
            latents = vae.encode(images).latent_dist.sample() * 0.18215

        # Random timestep
        t = torch.randint(0, 1000, (latents.shape[0],), device=latents.device)

        # Add noise
        noise = torch.randn_like(latents)
        noisy_latents = scheduler.add_noise(latents, noise, t)

        # Predict noise
        noise_pred = model(noisy_latents, t, labels)
        loss = F.mse_loss(noise_pred, noise)

        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

Summary

DiT led an architectural paradigm shift in image generation:

  • Replaced U-Net's complex structure with a pure Transformer
  • Demonstrated that the same scaling laws from LLMs apply to image generation
  • Enabled stable training of deep Transformers via adaLN-Zero
  • Established itself as the foundational architecture for subsequent models like SORA and SD3

Quiz: DiT Architecture Comprehension Check (8 Questions)

Q1. What is the main reason DiT uses Transformer instead of U-Net?

Proven scaling laws, simpler architecture, and easier integration with other modalities.

Q2. How are images converted into a token sequence in DiT?

Images are first converted to latent representations via VAE, then split into patches and linearly projected.

Q3. What does the "Zero" in adaLN-Zero mean?

Gate parameters are initialized to zero so each block acts as an identity function at the start of training.

Q4. Which of the four conditioning methods tested in DiT was ultimately chosen?

adaLN-Zero. Among In-context, Cross-Attention, adaLN, and adaLN-Zero, it achieved the best FID.

Q5. What is the trade-off with smaller patch sizes?

The number of tokens increases, improving image quality (FID), but computation (Gflops) grows quadratically.

Q6. How is Classifier-Free Guidance implemented in DiT?

Training uses class label dropout, and inference samples from a weighted sum of conditional and unconditional predictions.

Q7. What is DiT-XL/2's ImageNet 256x256 FID-50K score?

2.27, which set the class-conditional generation SOTA at the time.

Q8. What is the core idea behind Dynamic DiT (ICLR 2025)?

An Early Exit mechanism that dynamically adjusts the number of layers used based on noise level, reducing inference time by 40%.

Quiz

Q1: What is the main topic covered in "Diffusion Transformer (DiT) Architecture Analysis: The Shift from U-Net to Transformer"?

An analysis of the Scalable Diffusion Models with Transformers (DiT) paper. We cover the motivations behind transitioning from U-Net backbones to Transformers, adaLN-Zero conditioning, scaling laws, and the downstream impact on SORA and DALL-E 3.

Q2: What is Background: Why Move Away from U-Net?? Limitations of U-Net Traditional DDPM and LDM (Latent Diffusion Models) use U-Net as the backbone. While U-Net performs well in image generation, it has fundamental limitations: No scaling laws: Performance gains are unpredictable when scaling up the model Architectural complexit...

Q3: Describe the DiT Architecture in Detail. Overall Pipeline DiT operates on top of the Latent Diffusion framework: Encode images into latent space using a VAE encoder Split the latent representation into patches (same as ViT) Predict noise using DiT blocks (Transformer) Reconstruct images with the VAE decoder Patch Embedd...

Q4: What are the key aspects of Model Variants and Scaling? Key finding: DiT follows the same scaling laws as LLMs. FID consistently improves as model size increases. Impact of Patch Size Smaller patch sizes (more tokens) yield better quality but significantly increase computation.

Q5: How does Classifier-Free Guidance work? At CFG scale=1.5, DiT-XL/2 achieves an FID of 2.27, setting the SOTA on ImageNet 256x256 at the time.