- Authors
- Name
- Introduction
- Background: Why Move Away from U-Net?
- DiT Architecture in Detail
- Model Variants and Scaling
- Classifier-Free Guidance
- Downstream Impact of DiT
- PyTorch Implementation Walkthrough
- Summary
Introduction
Published in late 2022, "Scalable Diffusion Models with Transformers" (William Peebles & Saining Xie, ICCV 2023) marked a pivotal turning point in shifting the image generation backbone from U-Net to Transformer. The DiT (Diffusion Transformer) architecture proposed in this paper has since become the foundation for OpenAI's SORA, Stability AI's Stable Diffusion 3, and Dynamic DiT presented at ICLR 2025.
Background: Why Move Away from U-Net?
Limitations of U-Net
Traditional DDPM and LDM (Latent Diffusion Models) use U-Net as the backbone. While U-Net performs well in image generation, it has fundamental limitations:
- No scaling laws: Performance gains are unpredictable when scaling up the model
- Architectural complexity: Managing skip connections and feature maps at various resolutions
- Computational inefficiency: Memory and compute costs surge at high resolutions
- Difficult multi-modal integration: Unnatural integration with text, video, and other modalities
Advantages of Transformers
In contrast, Transformers (ViT) offer:
- Proven scaling laws: Performance improves proportionally with parameter count (demonstrated in LLMs)
- Simple architecture: Repetition of Self-attention + FFN blocks
- Modality agnostic: Text, images, and video can all be processed with the same structure
- Hardware optimization: Efficient parallel processing on GPUs/TPUs
DiT Architecture in Detail
Overall Pipeline
DiT operates on top of the Latent Diffusion framework:
- Encode images into latent space using a VAE encoder
- Split the latent representation into patches (same as ViT)
- Predict noise using DiT blocks (Transformer)
- Reconstruct images with the VAE decoder
Patch Embedding
import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
"""Convert latent representations into a patch sequence"""
def __init__(self, patch_size=2, in_channels=4, embed_dim=1152):
super().__init__()
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: (B, C, H, W) -> (B, N, D)
# e.g.: (B, 4, 32, 32) -> (B, 256, 1152) (patch_size=2)
x = self.proj(x) # (B, D, H/p, W/p)
x = x.flatten(2).transpose(1, 2) # (B, N, D)
return x
The input is a 256x256 image converted to a 32x32 latent representation via VAE, then split with patch_size=2 to yield 16x16=256 tokens.
Conditioning Mechanisms — Four Variants
DiT experiments with four approaches for injecting the timestep (t) and class label (c):
1. In-context Conditioning
class InContextConditioning(nn.Module):
"""Concatenate t and c as additional tokens to the sequence"""
def __init__(self, embed_dim):
super().__init__()
self.t_embed = TimestepEmbedder(embed_dim)
self.c_embed = LabelEmbedder(1000, embed_dim)
def forward(self, x, t, c):
t_token = self.t_embed(t).unsqueeze(1) # (B, 1, D)
c_token = self.c_embed(c).unsqueeze(1) # (B, 1, D)
x = torch.cat([t_token, c_token, x], dim=1) # (B, N+2, D)
return x
2. Cross-Attention
class CrossAttentionBlock(nn.Module):
"""Inject t and c via cross-attention"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.mlp = MLP(embed_dim)
def forward(self, x, cond):
x = x + self.self_attn(x, x, x)[0]
x = x + self.cross_attn(x, cond, cond)[0]
x = x + self.mlp(x)
return x
3. Adaptive Layer Norm (adaLN)
class AdaLNBlock(nn.Module):
"""Adjust LayerNorm scale/shift based on conditions"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim, 2 * embed_dim)
)
def forward(self, x, cond):
shift, scale = self.adaLN_modulation(cond).chunk(2, dim=-1)
x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = x + self.attn(x, x, x)[0]
return x
4. adaLN-Zero (Final Choice)
class DiTBlock(nn.Module):
"""adaLN-Zero: Initialize residual connections to start at zero"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=False)
self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=False)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, 4 * embed_dim),
nn.GELU(approximate="tanh"),
nn.Linear(4 * embed_dim, embed_dim),
)
# 6 parameters: gamma1, beta1, alpha1, gamma2, beta2, alpha2
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim, 6 * embed_dim)
)
# Initialize alpha to 0 -> identity function at the start of training
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)
def forward(self, x, cond):
# cond: (B, D) - sum of timestep + class embeddings
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(cond).chunk(6, dim=-1)
# Self-Attention with adaLN
h = self.norm1(x)
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
h = self.attn(h, h, h)[0]
x = x + gate_msa.unsqueeze(1) * h # gate starts at 0
# FFN with adaLN
h = self.norm2(x)
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
h = self.mlp(h)
x = x + gate_mlp.unsqueeze(1) * h # gate starts at 0
return x
The key insight of adaLN-Zero: By initializing gate parameters to zero, each DiT block acts as an identity function at the start of training. This significantly improves training stability for deep networks.
Full DiT Model
class DiT(nn.Module):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
embed_dim=1152,
depth=28,
num_heads=16,
num_classes=1000,
):
super().__init__()
self.patch_embed = PatchEmbed(patch_size, in_channels, embed_dim)
num_patches = (input_size // patch_size) ** 2 # 256
# Positional embedding
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches, embed_dim)
)
# Timestep & class embeddings
self.t_embedder = TimestepEmbedder(embed_dim)
self.y_embedder = LabelEmbedder(num_classes, embed_dim)
# DiT blocks
self.blocks = nn.ModuleList([
DiTBlock(embed_dim, num_heads) for _ in range(depth)
])
# Final layer
self.final_layer = FinalLayer(embed_dim, patch_size, in_channels * 2)
def forward(self, x, t, y):
# Patch embedding + positional embedding
x = self.patch_embed(x) + self.pos_embed
# Condition embedding
t_emb = self.t_embedder(t) # (B, D)
y_emb = self.y_embedder(y) # (B, D)
cond = t_emb + y_emb # (B, D)
# Pass through DiT blocks
for block in self.blocks:
x = block(x, cond)
# Predict noise and variance
x = self.final_layer(x, cond) # (B, N, 2*C*p*p)
return x
Model Variants and Scaling
| Model | depth | embed_dim | heads | Parameters | FID-50K |
|---|---|---|---|---|---|
| DiT-S/2 | 12 | 384 | 6 | 33M | 68.4 |
| DiT-B/2 | 12 | 768 | 12 | 130M | 43.5 |
| DiT-L/2 | 24 | 1024 | 16 | 458M | 9.62 |
| DiT-XL/2 | 28 | 1152 | 16 | 675M | 2.27 |
Key finding: DiT follows the same scaling laws as LLMs. FID consistently improves as model size increases.
Impact of Patch Size
| Model | patch_size | Tokens | Gflops | FID |
|---|---|---|---|---|
| DiT-XL/8 | 8 | 16 | 119 | 14.2 |
| DiT-XL/4 | 4 | 64 | 525 | 5.11 |
| DiT-XL/2 | 2 | 256 | 1121 | 2.27 |
Smaller patch sizes (more tokens) yield better quality but significantly increase computation.
Classifier-Free Guidance
def sample_with_cfg(model, z, class_labels, cfg_scale=4.0):
"""Sampling with Classifier-Free Guidance"""
# Conditional + unconditional predictions simultaneously
z_combined = torch.cat([z, z], dim=0)
y_combined = torch.cat([class_labels, torch.full_like(class_labels, 1000)])
for t in reversed(range(1000)):
t_batch = torch.full((z_combined.shape[0],), t, device=z.device)
noise_pred = model(z_combined, t_batch, y_combined)
# Apply CFG
cond_pred, uncond_pred = noise_pred.chunk(2)
guided_pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred)
# DDPM step
z = ddpm_step(z, guided_pred, t)
return z
At CFG scale=1.5, DiT-XL/2 achieves an FID of 2.27, setting the SOTA on ImageNet 256x256 at the time.
Downstream Impact of DiT
SORA (OpenAI, 2024)
- Extended DiT to video as Spatial-Temporal DiT
- 3D patch embedding (spatial + temporal)
- Variable resolution and duration support
Stable Diffusion 3 (Stability AI, 2024)
- DiT + Flow Matching
- MM-DiT: Processes text and images in the same Transformer
- Joint attention without a separate text encoder
Dynamic DiT (ICLR 2025)
- Early Exit mechanism for inference efficiency
- Uses shallow layers for noisy early steps, deeper layers for detail generation
- 40% inference time reduction compared to standard DiT
PyTorch Implementation Walkthrough
# Clone the official DiT code
git clone https://github.com/facebookresearch/DiT.git
cd DiT
# Set up environment
pip install torch torchvision diffusers accelerate
# Download pretrained model
python download.py DiT-XL/2
# Generate samples
python sample.py \
--model DiT-XL/2 \
--image-size 256 \
--num-sampling-steps 250 \
--seed 42 \
--class-labels 207 360 387 974 88 979 417 279
Custom Training
from accelerate import Accelerator
from torch.utils.data import DataLoader
accelerator = Accelerator()
# Initialize model
model = DiT(
input_size=32, # VAE latent size
patch_size=2,
in_channels=4,
embed_dim=1152,
depth=28,
num_heads=16,
num_classes=1000,
)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-4,
weight_decay=0.0
)
model, optimizer, dataloader = accelerator.prepare(
model, optimizer, dataloader
)
# Training loop
for epoch in range(epochs):
for batch in dataloader:
images, labels = batch
# VAE encoding
with torch.no_grad():
latents = vae.encode(images).latent_dist.sample() * 0.18215
# Random timestep
t = torch.randint(0, 1000, (latents.shape[0],), device=latents.device)
# Add noise
noise = torch.randn_like(latents)
noisy_latents = scheduler.add_noise(latents, noise, t)
# Predict noise
noise_pred = model(noisy_latents, t, labels)
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
Summary
DiT led an architectural paradigm shift in image generation:
- Replaced U-Net's complex structure with a pure Transformer
- Demonstrated that the same scaling laws from LLMs apply to image generation
- Enabled stable training of deep Transformers via adaLN-Zero
- Established itself as the foundational architecture for subsequent models like SORA and SD3
Quiz: DiT Architecture Comprehension Check (8 Questions)
Q1. What is the main reason DiT uses Transformer instead of U-Net?
Proven scaling laws, simpler architecture, and easier integration with other modalities.
Q2. How are images converted into a token sequence in DiT?
Images are first converted to latent representations via VAE, then split into patches and linearly projected.
Q3. What does the "Zero" in adaLN-Zero mean?
Gate parameters are initialized to zero so each block acts as an identity function at the start of training.
Q4. Which of the four conditioning methods tested in DiT was ultimately chosen?
adaLN-Zero. Among In-context, Cross-Attention, adaLN, and adaLN-Zero, it achieved the best FID.
Q5. What is the trade-off with smaller patch sizes?
The number of tokens increases, improving image quality (FID), but computation (Gflops) grows quadratically.
Q6. How is Classifier-Free Guidance implemented in DiT?
Training uses class label dropout, and inference samples from a weighted sum of conditional and unconditional predictions.
Q7. What is DiT-XL/2's ImageNet 256x256 FID-50K score?
2.27, which set the class-conditional generation SOTA at the time.
Q8. What is the core idea behind Dynamic DiT (ICLR 2025)?
An Early Exit mechanism that dynamically adjusts the number of layers used based on noise level, reducing inference time by 40%.