Skip to content
Published on

Diffusion Transformer (DiT) Architecture Analysis: The Shift from U-Net to Transformer

Authors
  • Name
    Twitter

Introduction

Published in late 2022, "Scalable Diffusion Models with Transformers" (William Peebles & Saining Xie, ICCV 2023) marked a pivotal turning point in shifting the image generation backbone from U-Net to Transformer. The DiT (Diffusion Transformer) architecture proposed in this paper has since become the foundation for OpenAI's SORA, Stability AI's Stable Diffusion 3, and Dynamic DiT presented at ICLR 2025.

Background: Why Move Away from U-Net?

Limitations of U-Net

Traditional DDPM and LDM (Latent Diffusion Models) use U-Net as the backbone. While U-Net performs well in image generation, it has fundamental limitations:

  • No scaling laws: Performance gains are unpredictable when scaling up the model
  • Architectural complexity: Managing skip connections and feature maps at various resolutions
  • Computational inefficiency: Memory and compute costs surge at high resolutions
  • Difficult multi-modal integration: Unnatural integration with text, video, and other modalities

Advantages of Transformers

In contrast, Transformers (ViT) offer:

  • Proven scaling laws: Performance improves proportionally with parameter count (demonstrated in LLMs)
  • Simple architecture: Repetition of Self-attention + FFN blocks
  • Modality agnostic: Text, images, and video can all be processed with the same structure
  • Hardware optimization: Efficient parallel processing on GPUs/TPUs

DiT Architecture in Detail

Overall Pipeline

DiT operates on top of the Latent Diffusion framework:

  1. Encode images into latent space using a VAE encoder
  2. Split the latent representation into patches (same as ViT)
  3. Predict noise using DiT blocks (Transformer)
  4. Reconstruct images with the VAE decoder

Patch Embedding

import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    """Convert latent representations into a patch sequence"""
    def __init__(self, patch_size=2, in_channels=4, embed_dim=1152):
        super().__init__()
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # x: (B, C, H, W) -> (B, N, D)
        # e.g.: (B, 4, 32, 32) -> (B, 256, 1152) (patch_size=2)
        x = self.proj(x)  # (B, D, H/p, W/p)
        x = x.flatten(2).transpose(1, 2)  # (B, N, D)
        return x

The input is a 256x256 image converted to a 32x32 latent representation via VAE, then split with patch_size=2 to yield 16x16=256 tokens.

Conditioning Mechanisms — Four Variants

DiT experiments with four approaches for injecting the timestep (t) and class label (c):

1. In-context Conditioning

class InContextConditioning(nn.Module):
    """Concatenate t and c as additional tokens to the sequence"""
    def __init__(self, embed_dim):
        super().__init__()
        self.t_embed = TimestepEmbedder(embed_dim)
        self.c_embed = LabelEmbedder(1000, embed_dim)

    def forward(self, x, t, c):
        t_token = self.t_embed(t).unsqueeze(1)  # (B, 1, D)
        c_token = self.c_embed(c).unsqueeze(1)  # (B, 1, D)
        x = torch.cat([t_token, c_token, x], dim=1)  # (B, N+2, D)
        return x

2. Cross-Attention

class CrossAttentionBlock(nn.Module):
    """Inject t and c via cross-attention"""
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.mlp = MLP(embed_dim)

    def forward(self, x, cond):
        x = x + self.self_attn(x, x, x)[0]
        x = x + self.cross_attn(x, cond, cond)[0]
        x = x + self.mlp(x)
        return x

3. Adaptive Layer Norm (adaLN)

class AdaLNBlock(nn.Module):
    """Adjust LayerNorm scale/shift based on conditions"""
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(embed_dim, 2 * embed_dim)
        )

    def forward(self, x, cond):
        shift, scale = self.adaLN_modulation(cond).chunk(2, dim=-1)
        x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
        x = x + self.attn(x, x, x)[0]
        return x

4. adaLN-Zero (Final Choice)

class DiTBlock(nn.Module):
    """adaLN-Zero: Initialize residual connections to start at zero"""
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(4 * embed_dim, embed_dim),
        )
        # 6 parameters: gamma1, beta1, alpha1, gamma2, beta2, alpha2
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(embed_dim, 6 * embed_dim)
        )
        # Initialize alpha to 0 -> identity function at the start of training
        nn.init.zeros_(self.adaLN_modulation[-1].weight)
        nn.init.zeros_(self.adaLN_modulation[-1].bias)

    def forward(self, x, cond):
        # cond: (B, D) - sum of timestep + class embeddings
        shift_msa, scale_msa, gate_msa, \
        shift_mlp, scale_mlp, gate_mlp = \
            self.adaLN_modulation(cond).chunk(6, dim=-1)

        # Self-Attention with adaLN
        h = self.norm1(x)
        h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
        h = self.attn(h, h, h)[0]
        x = x + gate_msa.unsqueeze(1) * h  # gate starts at 0

        # FFN with adaLN
        h = self.norm2(x)
        h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
        h = self.mlp(h)
        x = x + gate_mlp.unsqueeze(1) * h  # gate starts at 0

        return x

The key insight of adaLN-Zero: By initializing gate parameters to zero, each DiT block acts as an identity function at the start of training. This significantly improves training stability for deep networks.

Full DiT Model

class DiT(nn.Module):
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        embed_dim=1152,
        depth=28,
        num_heads=16,
        num_classes=1000,
    ):
        super().__init__()
        self.patch_embed = PatchEmbed(patch_size, in_channels, embed_dim)
        num_patches = (input_size // patch_size) ** 2  # 256

        # Positional embedding
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, embed_dim)
        )

        # Timestep & class embeddings
        self.t_embedder = TimestepEmbedder(embed_dim)
        self.y_embedder = LabelEmbedder(num_classes, embed_dim)

        # DiT blocks
        self.blocks = nn.ModuleList([
            DiTBlock(embed_dim, num_heads) for _ in range(depth)
        ])

        # Final layer
        self.final_layer = FinalLayer(embed_dim, patch_size, in_channels * 2)

    def forward(self, x, t, y):
        # Patch embedding + positional embedding
        x = self.patch_embed(x) + self.pos_embed

        # Condition embedding
        t_emb = self.t_embedder(t)  # (B, D)
        y_emb = self.y_embedder(y)  # (B, D)
        cond = t_emb + y_emb        # (B, D)

        # Pass through DiT blocks
        for block in self.blocks:
            x = block(x, cond)

        # Predict noise and variance
        x = self.final_layer(x, cond)  # (B, N, 2*C*p*p)
        return x

Model Variants and Scaling

Modeldepthembed_dimheadsParametersFID-50K
DiT-S/212384633M68.4
DiT-B/21276812130M43.5
DiT-L/224102416458M9.62
DiT-XL/228115216675M2.27

Key finding: DiT follows the same scaling laws as LLMs. FID consistently improves as model size increases.

Impact of Patch Size

Modelpatch_sizeTokensGflopsFID
DiT-XL/881611914.2
DiT-XL/44645255.11
DiT-XL/2225611212.27

Smaller patch sizes (more tokens) yield better quality but significantly increase computation.

Classifier-Free Guidance

def sample_with_cfg(model, z, class_labels, cfg_scale=4.0):
    """Sampling with Classifier-Free Guidance"""
    # Conditional + unconditional predictions simultaneously
    z_combined = torch.cat([z, z], dim=0)
    y_combined = torch.cat([class_labels, torch.full_like(class_labels, 1000)])

    for t in reversed(range(1000)):
        t_batch = torch.full((z_combined.shape[0],), t, device=z.device)
        noise_pred = model(z_combined, t_batch, y_combined)

        # Apply CFG
        cond_pred, uncond_pred = noise_pred.chunk(2)
        guided_pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred)

        # DDPM step
        z = ddpm_step(z, guided_pred, t)

    return z

At CFG scale=1.5, DiT-XL/2 achieves an FID of 2.27, setting the SOTA on ImageNet 256x256 at the time.

Downstream Impact of DiT

SORA (OpenAI, 2024)

  • Extended DiT to video as Spatial-Temporal DiT
  • 3D patch embedding (spatial + temporal)
  • Variable resolution and duration support

Stable Diffusion 3 (Stability AI, 2024)

  • DiT + Flow Matching
  • MM-DiT: Processes text and images in the same Transformer
  • Joint attention without a separate text encoder

Dynamic DiT (ICLR 2025)

  • Early Exit mechanism for inference efficiency
  • Uses shallow layers for noisy early steps, deeper layers for detail generation
  • 40% inference time reduction compared to standard DiT

PyTorch Implementation Walkthrough

# Clone the official DiT code
git clone https://github.com/facebookresearch/DiT.git
cd DiT

# Set up environment
pip install torch torchvision diffusers accelerate

# Download pretrained model
python download.py DiT-XL/2

# Generate samples
python sample.py \
  --model DiT-XL/2 \
  --image-size 256 \
  --num-sampling-steps 250 \
  --seed 42 \
  --class-labels 207 360 387 974 88 979 417 279

Custom Training

from accelerate import Accelerator
from torch.utils.data import DataLoader

accelerator = Accelerator()

# Initialize model
model = DiT(
    input_size=32,  # VAE latent size
    patch_size=2,
    in_channels=4,
    embed_dim=1152,
    depth=28,
    num_heads=16,
    num_classes=1000,
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=0.0
)

model, optimizer, dataloader = accelerator.prepare(
    model, optimizer, dataloader
)

# Training loop
for epoch in range(epochs):
    for batch in dataloader:
        images, labels = batch
        # VAE encoding
        with torch.no_grad():
            latents = vae.encode(images).latent_dist.sample() * 0.18215

        # Random timestep
        t = torch.randint(0, 1000, (latents.shape[0],), device=latents.device)

        # Add noise
        noise = torch.randn_like(latents)
        noisy_latents = scheduler.add_noise(latents, noise, t)

        # Predict noise
        noise_pred = model(noisy_latents, t, labels)
        loss = F.mse_loss(noise_pred, noise)

        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

Summary

DiT led an architectural paradigm shift in image generation:

  • Replaced U-Net's complex structure with a pure Transformer
  • Demonstrated that the same scaling laws from LLMs apply to image generation
  • Enabled stable training of deep Transformers via adaLN-Zero
  • Established itself as the foundational architecture for subsequent models like SORA and SD3

Quiz: DiT Architecture Comprehension Check (8 Questions)

Q1. What is the main reason DiT uses Transformer instead of U-Net?

Proven scaling laws, simpler architecture, and easier integration with other modalities.

Q2. How are images converted into a token sequence in DiT?

Images are first converted to latent representations via VAE, then split into patches and linearly projected.

Q3. What does the "Zero" in adaLN-Zero mean?

Gate parameters are initialized to zero so each block acts as an identity function at the start of training.

Q4. Which of the four conditioning methods tested in DiT was ultimately chosen?

adaLN-Zero. Among In-context, Cross-Attention, adaLN, and adaLN-Zero, it achieved the best FID.

Q5. What is the trade-off with smaller patch sizes?

The number of tokens increases, improving image quality (FID), but computation (Gflops) grows quadratically.

Q6. How is Classifier-Free Guidance implemented in DiT?

Training uses class label dropout, and inference samples from a weighted sum of conditional and unconditional predictions.

Q7. What is DiT-XL/2's ImageNet 256x256 FID-50K score?

2.27, which set the class-conditional generation SOTA at the time.

Q8. What is the core idea behind Dynamic DiT (ICLR 2025)?

An Early Exit mechanism that dynamically adjusts the number of layers used based on noise level, reducing inference time by 40%.