Skip to content

필사 모드: Diffusion Transformer (DiT) Architecture Analysis: The Shift from U-Net to Transformer

English
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

Introduction

Published in late 2022, **"Scalable Diffusion Models with Transformers"** (William Peebles & Saining Xie, ICCV 2023) marked a pivotal turning point in shifting the image generation backbone from U-Net to Transformer. The DiT (Diffusion Transformer) architecture proposed in this paper has since become the foundation for OpenAI's SORA, Stability AI's Stable Diffusion 3, and Dynamic DiT presented at ICLR 2025.

Background: Why Move Away from U-Net?

Limitations of U-Net

Traditional DDPM and LDM (Latent Diffusion Models) use U-Net as the backbone. While U-Net performs well in image generation, it has fundamental limitations:

- **No scaling laws**: Performance gains are unpredictable when scaling up the model

- **Architectural complexity**: Managing skip connections and feature maps at various resolutions

- **Computational inefficiency**: Memory and compute costs surge at high resolutions

- **Difficult multi-modal integration**: Unnatural integration with text, video, and other modalities

Advantages of Transformers

In contrast, Transformers (ViT) offer:

- **Proven scaling laws**: Performance improves proportionally with parameter count (demonstrated in LLMs)

- **Simple architecture**: Repetition of Self-attention + FFN blocks

- **Modality agnostic**: Text, images, and video can all be processed with the same structure

- **Hardware optimization**: Efficient parallel processing on GPUs/TPUs

DiT Architecture in Detail

Overall Pipeline

DiT operates on top of the Latent Diffusion framework:

1. Encode images into latent space using a VAE encoder

2. Split the latent representation into patches (same as ViT)

3. Predict noise using **DiT blocks** (Transformer)

4. Reconstruct images with the VAE decoder

Patch Embedding

class PatchEmbed(nn.Module):

"""Convert latent representations into a patch sequence"""

def __init__(self, patch_size=2, in_channels=4, embed_dim=1152):

super().__init__()

self.proj = nn.Conv2d(

in_channels, embed_dim,

kernel_size=patch_size, stride=patch_size

)

def forward(self, x):

x: (B, C, H, W) -> (B, N, D)

e.g.: (B, 4, 32, 32) -> (B, 256, 1152) (patch_size=2)

x = self.proj(x) # (B, D, H/p, W/p)

x = x.flatten(2).transpose(1, 2) # (B, N, D)

return x

The input is a 256x256 image converted to a 32x32 latent representation via VAE, then split with patch_size=2 to yield 16x16=256 tokens.

Conditioning Mechanisms — Four Variants

DiT experiments with four approaches for injecting the timestep (t) and class label (c):

1. In-context Conditioning

class InContextConditioning(nn.Module):

"""Concatenate t and c as additional tokens to the sequence"""

def __init__(self, embed_dim):

super().__init__()

self.t_embed = TimestepEmbedder(embed_dim)

self.c_embed = LabelEmbedder(1000, embed_dim)

def forward(self, x, t, c):

t_token = self.t_embed(t).unsqueeze(1) # (B, 1, D)

c_token = self.c_embed(c).unsqueeze(1) # (B, 1, D)

x = torch.cat([t_token, c_token, x], dim=1) # (B, N+2, D)

return x

2. Cross-Attention

class CrossAttentionBlock(nn.Module):

"""Inject t and c via cross-attention"""

def __init__(self, embed_dim, num_heads):

super().__init__()

self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)

self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads)

self.mlp = MLP(embed_dim)

def forward(self, x, cond):

x = x + self.self_attn(x, x, x)[0]

x = x + self.cross_attn(x, cond, cond)[0]

x = x + self.mlp(x)

return x

3. Adaptive Layer Norm (adaLN)

class AdaLNBlock(nn.Module):

"""Adjust LayerNorm scale/shift based on conditions"""

def __init__(self, embed_dim, num_heads):

super().__init__()

self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False)

self.attn = nn.MultiheadAttention(embed_dim, num_heads)

self.adaLN_modulation = nn.Sequential(

nn.SiLU(),

nn.Linear(embed_dim, 2 * embed_dim)

)

def forward(self, x, cond):

shift, scale = self.adaLN_modulation(cond).chunk(2, dim=-1)

x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

x = x + self.attn(x, x, x)[0]

return x

4. adaLN-Zero (Final Choice)

class DiTBlock(nn.Module):

"""adaLN-Zero: Initialize residual connections to start at zero"""

def __init__(self, embed_dim, num_heads):

super().__init__()

self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=False)

self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=False)

self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

self.mlp = nn.Sequential(

nn.Linear(embed_dim, 4 * embed_dim),

nn.GELU(approximate="tanh"),

nn.Linear(4 * embed_dim, embed_dim),

)

6 parameters: gamma1, beta1, alpha1, gamma2, beta2, alpha2

self.adaLN_modulation = nn.Sequential(

nn.SiLU(),

nn.Linear(embed_dim, 6 * embed_dim)

)

Initialize alpha to 0 -> identity function at the start of training

nn.init.zeros_(self.adaLN_modulation[-1].weight)

nn.init.zeros_(self.adaLN_modulation[-1].bias)

def forward(self, x, cond):

cond: (B, D) - sum of timestep + class embeddings

shift_msa, scale_msa, gate_msa, \

shift_mlp, scale_mlp, gate_mlp = \

self.adaLN_modulation(cond).chunk(6, dim=-1)

Self-Attention with adaLN

h = self.norm1(x)

h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)

h = self.attn(h, h, h)[0]

x = x + gate_msa.unsqueeze(1) * h # gate starts at 0

FFN with adaLN

h = self.norm2(x)

h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)

h = self.mlp(h)

x = x + gate_mlp.unsqueeze(1) * h # gate starts at 0

return x

**The key insight of adaLN-Zero**: By initializing gate parameters to zero, each DiT block acts as an identity function at the start of training. This significantly improves training stability for deep networks.

Full DiT Model

class DiT(nn.Module):

def __init__(

self,

input_size=32,

patch_size=2,

in_channels=4,

embed_dim=1152,

depth=28,

num_heads=16,

num_classes=1000,

):

super().__init__()

self.patch_embed = PatchEmbed(patch_size, in_channels, embed_dim)

num_patches = (input_size // patch_size) ** 2 # 256

Positional embedding

self.pos_embed = nn.Parameter(

torch.zeros(1, num_patches, embed_dim)

)

Timestep & class embeddings

self.t_embedder = TimestepEmbedder(embed_dim)

self.y_embedder = LabelEmbedder(num_classes, embed_dim)

DiT blocks

self.blocks = nn.ModuleList([

DiTBlock(embed_dim, num_heads) for _ in range(depth)

])

Final layer

self.final_layer = FinalLayer(embed_dim, patch_size, in_channels * 2)

def forward(self, x, t, y):

Patch embedding + positional embedding

x = self.patch_embed(x) + self.pos_embed

Condition embedding

t_emb = self.t_embedder(t) # (B, D)

y_emb = self.y_embedder(y) # (B, D)

cond = t_emb + y_emb # (B, D)

Pass through DiT blocks

for block in self.blocks:

x = block(x, cond)

Predict noise and variance

x = self.final_layer(x, cond) # (B, N, 2*C*p*p)

return x

Model Variants and Scaling

| Model | depth | embed_dim | heads | Parameters | FID-50K |

| -------- | ----- | --------- | ----- | ---------- | -------- |

| DiT-S/2 | 12 | 384 | 6 | 33M | 68.4 |

| DiT-B/2 | 12 | 768 | 12 | 130M | 43.5 |

| DiT-L/2 | 24 | 1024 | 16 | 458M | 9.62 |

| DiT-XL/2 | 28 | 1152 | 16 | 675M | **2.27** |

Key finding: **DiT follows the same scaling laws as LLMs.** FID consistently improves as model size increases.

Impact of Patch Size

| Model | patch_size | Tokens | Gflops | FID |

| -------- | ---------- | ------ | ------ | ---- |

| DiT-XL/8 | 8 | 16 | 119 | 14.2 |

| DiT-XL/4 | 4 | 64 | 525 | 5.11 |

| DiT-XL/2 | 2 | 256 | 1121 | 2.27 |

Smaller patch sizes (more tokens) yield better quality but significantly increase computation.

Classifier-Free Guidance

def sample_with_cfg(model, z, class_labels, cfg_scale=4.0):

"""Sampling with Classifier-Free Guidance"""

Conditional + unconditional predictions simultaneously

z_combined = torch.cat([z, z], dim=0)

y_combined = torch.cat([class_labels, torch.full_like(class_labels, 1000)])

for t in reversed(range(1000)):

t_batch = torch.full((z_combined.shape[0],), t, device=z.device)

noise_pred = model(z_combined, t_batch, y_combined)

Apply CFG

cond_pred, uncond_pred = noise_pred.chunk(2)

guided_pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred)

DDPM step

z = ddpm_step(z, guided_pred, t)

return z

At CFG scale=1.5, DiT-XL/2 achieves an FID of **2.27**, setting the SOTA on ImageNet 256x256 at the time.

Downstream Impact of DiT

SORA (OpenAI, 2024)

- Extended DiT to video as **Spatial-Temporal DiT**

- 3D patch embedding (spatial + temporal)

- Variable resolution and duration support

Stable Diffusion 3 (Stability AI, 2024)

- DiT + Flow Matching

- MM-DiT: Processes text and images in the same Transformer

- Joint attention without a separate text encoder

Dynamic DiT (ICLR 2025)

- Early Exit mechanism for inference efficiency

- Uses shallow layers for noisy early steps, deeper layers for detail generation

- 40% inference time reduction compared to standard DiT

PyTorch Implementation Walkthrough

Clone the official DiT code

git clone https://github.com/facebookresearch/DiT.git

cd DiT

Set up environment

pip install torch torchvision diffusers accelerate

Download pretrained model

python download.py DiT-XL/2

Generate samples

python sample.py \

--model DiT-XL/2 \

--image-size 256 \

--num-sampling-steps 250 \

--seed 42 \

--class-labels 207 360 387 974 88 979 417 279

Custom Training

from accelerate import Accelerator

from torch.utils.data import DataLoader

accelerator = Accelerator()

Initialize model

model = DiT(

input_size=32, # VAE latent size

patch_size=2,

in_channels=4,

embed_dim=1152,

depth=28,

num_heads=16,

num_classes=1000,

)

optimizer = torch.optim.AdamW(

model.parameters(),

lr=1e-4,

weight_decay=0.0

)

model, optimizer, dataloader = accelerator.prepare(

model, optimizer, dataloader

)

Training loop

for epoch in range(epochs):

for batch in dataloader:

images, labels = batch

VAE encoding

with torch.no_grad():

latents = vae.encode(images).latent_dist.sample() * 0.18215

Random timestep

t = torch.randint(0, 1000, (latents.shape[0],), device=latents.device)

Add noise

noise = torch.randn_like(latents)

noisy_latents = scheduler.add_noise(latents, noise, t)

Predict noise

noise_pred = model(noisy_latents, t, labels)

loss = F.mse_loss(noise_pred, noise)

accelerator.backward(loss)

optimizer.step()

optimizer.zero_grad()

Summary

DiT led an **architectural paradigm shift** in image generation:

- Replaced U-Net's complex structure with a pure Transformer

- Demonstrated that the same **scaling laws** from LLMs apply to image generation

- Enabled stable training of deep Transformers via adaLN-Zero

- Established itself as the foundational architecture for subsequent models like SORA and SD3

**Q1. What is the main reason DiT uses Transformer instead of U-Net?**

Proven scaling laws, simpler architecture, and easier integration with other modalities.

**Q2. How are images converted into a token sequence in DiT?**

Images are first converted to latent representations via VAE, then split into patches and linearly projected.

**Q3. What does the "Zero" in adaLN-Zero mean?**

Gate parameters are initialized to zero so each block acts as an identity function at the start of training.

**Q4. Which of the four conditioning methods tested in DiT was ultimately chosen?**

adaLN-Zero. Among In-context, Cross-Attention, adaLN, and adaLN-Zero, it achieved the best FID.

**Q5. What is the trade-off with smaller patch sizes?**

The number of tokens increases, improving image quality (FID), but computation (Gflops) grows quadratically.

**Q6. How is Classifier-Free Guidance implemented in DiT?**

Training uses class label dropout, and inference samples from a weighted sum of conditional and unconditional predictions.

**Q7. What is DiT-XL/2's ImageNet 256x256 FID-50K score?**

2.27, which set the class-conditional generation SOTA at the time.

**Q8. What is the core idea behind Dynamic DiT (ICLR 2025)?**

An Early Exit mechanism that dynamically adjusts the number of layers used based on noise level, reducing inference time by 40%.

Quiz

Q1: What is the main topic covered in "Diffusion Transformer (DiT) Architecture Analysis: The

Shift from U-Net to Transformer"?

An analysis of the Scalable Diffusion Models with Transformers (DiT) paper. We cover the

motivations behind transitioning from U-Net backbones to Transformers, adaLN-Zero conditioning,

scaling laws, and the downstream impact on SORA and DALL-E 3.

Limitations of U-Net Traditional DDPM and LDM (Latent Diffusion Models) use U-Net as the backbone.

While U-Net performs well in image generation, it has fundamental limitations: No scaling laws:

Performance gains are unpredictable when scaling up the model Architectural complexit...

Overall Pipeline DiT operates on top of the Latent Diffusion framework: Encode images into latent

space using a VAE encoder Split the latent representation into patches (same as ViT) Predict noise

using DiT blocks (Transformer) Reconstruct images with the VAE decoder Patch Embedd...

Key finding: DiT follows the same scaling laws as LLMs. FID consistently improves as model size

increases. Impact of Patch Size Smaller patch sizes (more tokens) yield better quality but

significantly increase computation.

At CFG scale=1.5, DiT-XL/2 achieves an FID of 2.27, setting the SOTA on ImageNet 256x256 at the

time.

현재 단락 (1/246)

Published in late 2022, **"Scalable Diffusion Models with Transformers"** (William Peebles & Saining...

작성 글자: 0원문 글자: 11,434작성 단락: 0/246