Skip to content
Published on

Complete Guide to GPU Memory Optimization and Mixed Precision Training

Authors
  • Name
    Twitter

1. GPU Memory Composition Analysis

When training deep learning models, GPU memory stores far more than just model parameters. The components occupying GPU memory during the entire training process can be broadly divided into four categories. Understanding these components precisely is essential for devising effective memory optimization strategies.

1.1 Model Parameters

Model parameters refer to the weights and biases of the neural network. They reside in GPU memory throughout the entire training process, and memory usage is determined by the number of parameters and the data type.

  • FP32 basis: 4 bytes per parameter
  • FP16/BF16 basis: 2 bytes per parameter

For example, storing a 7B (7 billion) parameter model in FP32 requires approximately 28GB, while FP16 requires approximately 14GB.

1.2 Gradients

Gradients are the partial derivatives of each parameter with respect to the loss function. They are computed during backpropagation, and since each parameter has a corresponding gradient value, they occupy the same amount of memory as the parameters.

  • FP32 training: number of parameters x 4 bytes
  • Mixed Precision training: number of parameters x 2 bytes (stored in FP16)

1.3 Optimizer States

Optimizer states often account for the largest portion of memory. While SGD maintains only momentum, the Adam/AdamW optimizer maintains two additional states for each parameter: First Moment (Mean) and Second Moment (Variance).

Optimizer states memory for Adam Optimizer with FP32 training:

Let the number of parameters be Φ:
- Master Weights (FP32): 4Φ bytes
- Momentum (FP32): 4Φ bytes
- Variance (FP32): 4Φ bytes
- Total: 12Φ bytes

Using Adam with a 7B model requires approximately 84GB for optimizer states alone. This is why optimizer state optimization is critical in large model training.

1.4 Activations

Activations are the outputs of each layer during the forward pass and must be stored for computing gradients during backpropagation. Activation memory is proportional to the following factors:

  • Batch Size: larger sizes require more memory
  • Sequence Length: especially important in Transformer models (O(n^2) for attention)
  • Hidden Dimension: increases with larger models
  • Number of Layers: increases with deeper models

In large models, activation memory frequently exceeds parameter memory.

1.5 Overall Memory Summary

When using Mixed Precision Training with Adam Optimizer, the total memory requirement for Φ parameters is as follows:

ComponentData TypeMemory
Model Parameters (FP16 copy)FP162Φ bytes
Gradients (FP16)FP162Φ bytes
Master WeightsFP324Φ bytes
Optimizer MomentumFP324Φ bytes
Optimizer VarianceFP324Φ bytes
Total16Φ bytes

In addition, activation memory, temporary buffers, and memory fragmentation cause approximately 20-30% additional overhead in practice.


2. FP32 vs FP16 vs BF16 vs FP8 Numerical Representation Comparison

The floating-point formats supported by NVIDIA GPUs each have unique characteristics. To understand the trade-offs between training stability and performance, we compare the bit composition and representation ranges of each format.

2.1 Bit Composition

All IEEE 754-based floating-point numbers consist of three parts: Sign, Exponent, and Mantissa.

FormatSignExponentMantissaTotal BitsMemory
FP321823324 bytes
FP161510162 bytes
BF16187162 bytes
FP8 (E4M3)14381 byte
FP8 (E5M2)15281 byte

2.2 Dynamic Range and Precision

Comparing the representable ranges and precision of each format:

FP32 (Single Precision)

  • Dynamic Range: ~1.2 x 10^-38 to 3.4 x 10^38
  • Precision: approximately 7 decimal places
  • The default data type for deep learning, providing the highest precision.

FP16 (Half Precision)

  • Dynamic Range: ~6.1 x 10^-5 to 6.55 x 10^4
  • Precision: approximately 3 decimal places
  • The narrow dynamic range makes gradient underflow/overflow likely. Must always be used with Loss Scaling.

BF16 (Brain Floating Point 16)

  • Dynamic Range: ~1.2 x 10^-38 to 3.4 x 10^38 (same as FP32)
  • Precision: approximately 2 decimal places
  • Has the same 8-bit exponent as FP32, so the dynamic range is identical. Although precision is lower, it can train stably without Loss Scaling, making it the emerging standard for large model training. Supported on NVIDIA Ampere (A100) and later.

FP8 E4M3

  • Dynamic Range: ~plus/minus 448
  • Precision: lowest
  • Primarily used in the forward pass. Per-tensor scaling is essential to compensate for the narrow range.

FP8 E5M2

  • Dynamic Range: ~plus/minus 57,344
  • Precision: lower than E4M3 but wider range
  • Primarily used in the backward pass (gradient computation). This accommodates the wide dynamic range of gradients.

2.3 Appropriate Use Scenarios for Each Format

FP32Optimizer States, Master Weights (requires precise cumulative operations)
BF16Forward/Backward operations (wide range, no Loss Scaling required)
FP16Forward/Backward operations (Tensor Core utilization, Loss Scaling required)
FP8Maximum performance on Hopper/Blackwell GPUs (using Transformer Engine)

3. Mixed Precision Training Principles

Mixed Precision Training is a technique proposed by NVIDIA in the 2017 paper "Mixed Precision Training" (Micikevicius et al.), which uses a mix of FP16 (or BF16) and FP32 during training. The core objective is to achieve computational speedup through Tensor Cores and memory savings while maintaining FP32-level training accuracy.

3.1 Three Core Techniques

According to NVIDIA official documentation, Mixed Precision Training consists of three core techniques:

(1) Master Weights (Maintaining FP32 Copies)

Even when performing operations in FP16, a separate FP32 master copy of the model parameters is maintained. During weight updates, gradients are applied to the FP32 master weights, which are then converted back to FP16 for the next forward pass. This is necessary because FP16's low precision causes small gradients to not be reflected in weight updates (the swamping problem).

Forward Pass: FP16 WeightsFP16 ActivationsFP16 Loss
Backward Pass: FP16 Gradients computed
Weight Update: FP32 Master Weights += Learning Rate × FP32 Gradients
FP16 Weights = cast(FP32 Master Weights)

(2) Loss Scaling

FP16's dynamic range (~6.1 x 10^-5 to 6.55 x 10^4) may be insufficient to accommodate gradient values. In practice, gradient values smaller than 10^-10 are common, but FP16's minimum representable positive value is approximately 6 x 10^-8. This causes gradient underflow, where small gradients are treated as zero.

Loss Scaling solves this problem:

  1. Multiply the loss value computed in the forward pass by a scale factor.
  2. By the chain rule, all gradients computed via backpropagation are multiplied by the same scale factor.
  3. Before the weight update, divide the gradients by the scale factor to restore their original magnitude.

(3) Dynamic Loss Scaling

Dynamic Loss Scaling dynamically adjusts the scale factor during training instead of using a fixed one. NVIDIA's implementation works as follows:

  1. Set the initial scale factor large (e.g., 2^24 = 16,777,216).
  2. At each iteration, check if gradients contain inf/NaN.
  3. If no overflow occurs: maintain the current scale factor, and after N consecutive successes, double the scale factor (default N=2000).
  4. If overflow occurs: skip the weight update for that iteration and halve the scale factor.

This mechanism automatically adapts to changes in gradient distribution during training. As gradient magnitudes decrease in later stages of training, the scale factor naturally increases to prevent underflow.

3.2 Advantages of BF16 Training

BF16 has the same dynamic range as FP32, so Loss Scaling is not needed. This greatly simplifies implementation and fundamentally eliminates overflow/underflow-related issues. It was first adopted on Google's TPUs and has been supported at the hardware level since NVIDIA Ampere (A100) GPUs.

However, BF16 has only 7 mantissa bits compared to FP16's 10 bits, resulting in lower precision. In some models, BF16 training may cause convergence issues due to insufficient precision, so master weights must always be maintained in FP32.


4. PyTorch AMP Official Documentation Feature Analysis and Code Examples

PyTorch officially supports Automatic Mixed Precision (AMP) through the torch.amp module. AMP consists of two core components: torch.amp.autocast and torch.amp.GradScaler.

Note: Starting from PyTorch 2.x, torch.cuda.amp.autocast and torch.cuda.amp.GradScaler are deprecated. Use torch.amp.autocast("cuda", ...) and torch.amp.GradScaler("cuda", ...) instead.

4.1 torch.amp.autocast

autocast is used as a context manager or decorator and automatically executes operations within its scope at the appropriate precision. Rather than converting all operations to FP16 uniformly, it selects the optimal data type for each operation type.

  • Executed in FP16: Conv, Linear, MatMul, etc. (operations that can utilize Tensor Cores)
  • Maintained in FP32: Softmax, LayerNorm, Loss computation, etc. (operations where numerical stability is critical)

4.2 torch.amp.GradScaler

GradScaler automatically manages Loss Scaling. It abstracts the entire Dynamic Loss Scaling process (scale application, overflow checking, scale factor adjustment, weight update skipping) so users don't need to manage it manually.

4.3 Basic Usage Pattern

import torch
from torch.amp import autocast, GradScaler

model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler("cuda")

for epoch in range(num_epochs):
    for batch in dataloader:
        inputs, targets = batch
        inputs = inputs.cuda()
        targets = targets.cuda()

        optimizer.zero_grad()

        # autocast region: Execute Forward Pass in Mixed Precision
        with autocast("cuda"):
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

        # GradScaler: Scale the loss then execute Backward Pass
        scaler.scale(loss).backward()

        # GradScaler: Gradient Unscale → Overflow check → Optimizer Step
        scaler.step(optimizer)

        # Update Scale Factor
        scaler.update()

4.4 Using with Gradient Clipping

When applying gradient clipping with Mixed Precision, you must explicitly call scaler.unscale_() before performing clipping:

scaler.scale(loss).backward()

# Perform Gradient Unscale first
scaler.unscale_(optimizer)

# Apply clipping to the restored original-magnitude gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# Optimizer Step (internally checks for inf/NaN before proceeding)
scaler.step(optimizer)
scaler.update()

4.5 Usage with Multiple GPUs (DistributedDataParallel)

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

model = DDP(model, device_ids=[local_rank])
scaler = GradScaler("cuda")

with autocast("cuda"):
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

When using DDP and AMP together, GradScaler operates independently on each GPU. Since DDP's AllReduce is performed on scaled gradients, unscaling must occur after AllReduce. PyTorch's implementation handles this automatically.

4.6 Using BF16

When using BF16, Loss Scaling is unnecessary, so you only need autocast without GradScaler:

with autocast("cuda", dtype=torch.bfloat16):
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)

loss.backward()
optimizer.step()
optimizer.zero_grad()

This approach results in much more concise code, and since step skipping from Dynamic Loss Scaling doesn't occur, training can be more stable.


5. Gradient Checkpointing (Activation Checkpointing) Analysis

Gradient Checkpointing is a technique that exploits the trade-off between memory and computation time. Instead of storing all intermediate activations during the forward pass, it stores only some and recomputes the rest during the backward pass.

5.1 Principle

In a typical training process, all layer outputs (activations) from the forward pass are stored in memory because they're needed for gradient computation during backpropagation. Gradient Checkpointing modifies this process:

  1. Forward Pass: Only the inputs at designated checkpoint boundaries are saved; intermediate activations are not stored.
  2. Backward Pass: The forward pass for each segment is re-executed from the saved inputs to recompute activations, and then gradients are calculated.

This reduces activation memory from O(n) to O(sqrt(n)) (where n is the number of layers). In exchange, forward computation increases by approximately 33% (one additional forward pass).

5.2 PyTorch Implementation: torch.utils.checkpoint

PyTorch provides two APIs through the torch.utils.checkpoint module:

checkpoint function

Applies checkpointing to individual functions or layers:

from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def __init__(self, ...):
        super().__init__()
        self.attention = MultiHeadAttention(...)
        self.ffn = FeedForward(...)
        self.norm1 = nn.LayerNorm(...)
        self.norm2 = nn.LayerNorm(...)

    def forward(self, x):
        # Don't store intermediate activations of this block
        # Recompute during backward
        return checkpoint(self._forward, x, use_reentrant=False)

    def _forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

checkpoint_sequential function

Applies checkpointing to groups of layers in a sequential structure:

from torch.utils.checkpoint import checkpoint_sequential

class DeepModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            *[TransformerBlock(...) for _ in range(24)]
        )

    def forward(self, x):
        # Divide 24 layers into 4 groups (6 each) for checkpointing
        segments = 4
        return checkpoint_sequential(self.layers, segments, x,
                                     use_reentrant=False)

5.3 Reentrant vs Non-reentrant Checkpointing

PyTorch provides two checkpointing modes:

  • Reentrant (use_reentrant=True): The legacy approach. Re-executes the entire function during backward. Has some limitations (may not be compatible with certain Autograd features).
  • Non-reentrant (use_reentrant=False): The improved approach. Recomputes only the necessary intermediate activations. Recommended by PyTorch official documentation and will become the default in the future.

5.4 Practical Application Strategies

Applying checkpointing to all layers incurs significant computational overhead. Effective strategies include:

  • Apply only to Attention layers: Self-attention has O(n^2) activation memory, making it the largest consumer.
  • Apply every N-th layer: For example, setting checkpoints every 2nd layer balances memory savings and computational overhead.
  • Hugging Face Transformers: Can be enabled with a single line: model.gradient_checkpointing_enable().

6. Gradient Accumulation Technique

Gradient Accumulation is a technique for overcoming physical batch size limitations. When GPU memory is limited and large batches cannot be used, multiple small micro-batches are processed sequentially, gradients are accumulated, and then a single weight update is performed.

6.1 Principle

Effective Batch Size = Micro-batch Size x Accumulation Steps

For example, if the micro-batch size is 8 and accumulation steps are 4, the effective batch size is 32.

The key insight is that in PyTorch, calling loss.backward() accumulates (sums) gradients in the .grad attribute. Since gradients are not reset until optimizer.zero_grad() is called, the results of multiple backward passes can naturally be accumulated.

6.2 Implementation

accumulation_steps = 4
optimizer.zero_grad()

for i, (inputs, targets) in enumerate(dataloader):
    inputs, targets = inputs.cuda(), targets.cuda()

    with autocast("cuda"):
        outputs = model(inputs)
        # Divide loss by accumulation steps to compute the average
        loss = loss_fn(outputs, targets) / accumulation_steps

    scaler.scale(loss).backward()

    # Perform weight update every accumulation_steps
    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

Important: Do not forget to divide the loss by accumulation_steps. Since gradients are accumulated, failing to divide will effectively increase the learning rate by a factor of accumulation_steps.

6.3 Considerations

  • Batch Normalization: BatchNorm computes statistics per micro-batch, which may differ from statistics over the full effective batch. In such cases, using GroupNorm or LayerNorm is preferable.
  • Learning Rate Scheduling: When using step-based schedulers, the scheduler's call frequency should be adjusted to account for accumulation steps.
  • Combination with DDP: When using gradient accumulation with DDP, AllReduce can be disabled during accumulation to reduce communication overhead. Use the model.no_sync() context manager.
for i, (inputs, targets) in enumerate(dataloader):
    # Skip AllReduce unless it's the last accumulation step
    context = model.no_sync() if (i + 1) % accumulation_steps != 0 else nullcontext()
    with context:
        with autocast("cuda"):
            outputs = model(inputs)
            loss = loss_fn(outputs, targets) / accumulation_steps
        scaler.scale(loss).backward()

    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

7. Memory Profiling with torch.cuda.memory_summary()

The first step in memory optimization is to accurately understand the current memory usage. PyTorch provides various CUDA memory profiling tools.

7.1 torch.cuda.memory_summary()

This is the most basic yet useful tool. It provides detailed information about GPU memory allocation status:

import torch

model = MyLargeModel().cuda()
inputs = torch.randn(32, 3, 224, 224).cuda()

# Check memory state after forward pass
with autocast("cuda"):
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)

print(torch.cuda.memory_summary(device=0, abbreviated=False))

Example output (partial):

|===========================================================================|
|                  PyTorch CUDA memory summary                              |
|===========================================================================|
|            CUDA OOMs: 0                                                   |
|        cudaMallocs:   234                                                 |
|---------------------------------------------------------------------------+
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------+
| Allocated memory      |  4096 MiB  |  6234 MiB  | 12345 MiB  |  8249 MiB  |
| Active memory         |  3800 MiB  |  5900 MiB  | 11000 MiB  |  7200 MiB  |
| Requested memory      |  3750 MiB  |  5850 MiB  | 10800 MiB  |  7050 MiB  |
| GPU reserved memory   |  6400 MiB  |  6400 MiB  |  6400 MiB  |     0 MiB  |
|---------------------------------------------------------------------------+

Key metric interpretation:

  • Allocated memory: Currently allocated memory (actual memory occupied by tensors)
  • Active memory: Memory currently in use
  • GPU reserved memory: Total memory reserved by PyTorch from CUDA (caching allocator)
  • Peak Usage: Maximum memory usage during training (critical for determining OOM occurrence)

7.2 Individual Memory Query Functions

Useful for tracking memory usage at specific points within code:

# Currently allocated memory (bytes)
allocated = torch.cuda.memory_allocated(device=0)

# Currently reserved memory (bytes)
reserved = torch.cuda.memory_reserved(device=0)

# Maximum allocated memory since training started
max_allocated = torch.cuda.max_memory_allocated(device=0)

print(f"Allocated: {allocated / 1024**3:.2f} GB")
print(f"Reserved:  {reserved / 1024**3:.2f} GB")
print(f"Peak:      {max_allocated / 1024**3:.2f} GB")

# Reset peak statistics (for segment-by-segment measurement)
torch.cuda.reset_peak_memory_stats(device=0)

7.3 Detailed Analysis with Memory Snapshots

Memory Snapshots, available since PyTorch 2.x, record memory allocation/deallocation events in chronological order for visualization:

# Start memory recording
torch.cuda.memory._record_memory_history(max_entries=100000)

# Execute training code
train_one_epoch(model, dataloader, optimizer)

# Save snapshot
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

# Stop recording
torch.cuda.memory._record_memory_history(enabled=None)

Saved snapshots can be uploaded to PyTorch's official Memory Visualizer (https://pytorch.org/memory_viz) for visual analysis.

7.4 OOM Debugging Strategy

When Out of Memory occurs, diagnose the cause in the following order:

  1. Get an overview with torch.cuda.memory_summary()
  2. Check peak memory with max_memory_allocated()
  3. Use Memory Snapshots to identify which operations cause memory spikes
  4. Analyze which factor dominates: batch size, sequence length, or hidden dimension
  5. Apply appropriate optimization techniques (Mixed Precision, Gradient Checkpointing, Gradient Accumulation)

8. NVIDIA Transformer Engine and FP8 Training

NVIDIA Transformer Engine (TE) is a library that accelerates training and inference of Transformer models using FP8 precision on Hopper (H100) and later GPUs. FP8 training can achieve up to 2x throughput improvement compared to FP16/BF16.

8.1 Two FP8 Formats

The H100 GPU supports two FP8 formats at the hardware level:

  • E4M3 (4-bit Exponent, 3-bit Mantissa): Range ~plus/minus 448. Relatively high precision. Used for forward pass activations and weights.
  • E5M2 (5-bit Exponent, 2-bit Mantissa): Range ~plus/minus 57,344. Wide dynamic range. Used for backward pass gradients.

This separation strategy selects the format best suited to the numerical characteristics of each stage, minimizing accuracy loss.

8.2 Per-tensor Scaling

To compensate for FP8's narrow dynamic range, Transformer Engine applies per-tensor scaling. It maintains individual scale factors for each tensor, adjusting the tensor's value distribution to fit within FP8's representable range.

Transformer Engine offers several scaling strategies:

  • Delayed Scaling: Determines the scale factor based on statistics from the previous iteration (default)
  • Current Scaling (Just-in-time): Computes the scale factor immediately based on the current tensor's values
  • Block Scaling: Divides the tensor into blocks and applies individual scale factors to each block

8.3 Transformer Engine Usage Example

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format

# FP8 training recipe configuration
fp8_recipe = DelayedScaling(
    margin=0,
    fp8_format=Format.HYBRID,  # Forward: E4M3, Backward: E5M2
    amax_history_len=1024,
    amax_compute_algo="max",
)

# Use Transformer Engine layers
model = te.TransformerLayer(
    hidden_size=4096,
    ffn_hidden_size=16384,
    num_attention_heads=32,
    layer_number=1,
)

# FP8 training loop
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)

loss.backward()
optimizer.step()

8.4 Evolution in the Blackwell Generation

NVIDIA Blackwell (B200) GPUs add support for the following formats in addition to FP8:

  • MXFP8: Microscaling FP8. Supports finer block-level scaling
  • NVFP4: 4-bit Floating Point. Primarily used for inference, maximizing memory efficiency

Transformer Engine 2.x and later provides integrated support for these new formats through the Recipe module.


9. Practical Case: Training Large Models on Limited GPUs

Combining all the techniques covered so far, we outline practical strategies for training large models in limited GPU environments.

9.1 Scenario: Fine-tuning a 7B Model on a 24GB GPU

Assume fine-tuning a 7B parameter model on a single A10G (24GB VRAM).

Memory Requirement Analysis (FP32 basis):

ComponentMemory
Model Parameters (FP32)28 GB
Gradients (FP32)28 GB
Adam Optimizer States56 GB
Activations (batch=1)~2 GB
Total~114 GB

This is absolutely impossible with FP32. Let's apply optimizations step by step.

9.2 Step-by-Step Optimization

Step 1: Mixed Precision Training (BF16)

from torch.amp import autocast

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf",
                                              torch_dtype=torch.bfloat16)
ComponentMemory
Model Parameters (BF16)14 GB
Gradients (BF16)14 GB
Adam Master Weights (FP32)28 GB
Adam Momentum (FP32)28 GB
Adam Variance (FP32)28 GB
Total~112 GB

Still insufficient due to optimizer states.

Step 2: Apply LoRA (Reduce Trainable Parameters)

LoRA (Low-Rank Adaptation) sets only about 0.1-1% of total parameters as trainable. Applying Rank=16 LoRA to a 7B model results in approximately 20M trainable parameters.

from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.05,
)
model = get_peft_model(model, lora_config)
ComponentMemory
Full Model (BF16, frozen)14 GB
LoRA Parameters (BF16)~0.04 GB
LoRA Gradients (BF16)~0.04 GB
Adam States (FP32, LoRA only)~0.24 GB
Total~14.3 GB

This fits comfortably in a 24GB GPU. Now there's room to increase batch size.

Step 3: Add Gradient Checkpointing

Save as much activation memory as possible in the remaining ~10GB to increase batch size:

model.gradient_checkpointing_enable()

Step 4: Ensure Effective Batch Size with Gradient Accumulation

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,  # Effective Batch Size = 32
    bf16=True,
    gradient_checkpointing=True,
    learning_rate=2e-4,
    num_train_epochs=3,
    logging_steps=10,
)

9.3 Optimization Technique Combination Guide

GPU VRAMModel ScaleRecommended Combination
8 GB~1BBF16 + LoRA(r=8) + GC + GA
16 GB~3BBF16 + LoRA(r=16) + GC + GA
24 GB~7BBF16 + LoRA(r=16) + GC + GA
40 GB~13BBF16 + LoRA(r=32) + GC + GA
80 GB~7B Full FTBF16 + GC + GA
8x80 GB~70BBF16 + FSDP/DeepSpeed ZeRO-3 + GC

GC = Gradient Checkpointing, GA = Gradient Accumulation, Full FT = Full Fine-tuning

9.4 Integrated Memory Profiling Script

You can use the following script to pre-check memory usage before training:

import torch
from torch.amp import autocast

def profile_memory(model, dummy_input, dtype=torch.bfloat16):
    """Profile GPU memory usage during training."""
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()

    model = model.cuda()
    dummy_input = dummy_input.cuda()

    # Forward Pass
    with autocast("cuda", dtype=dtype):
        outputs = model(dummy_input)
        if isinstance(outputs, dict):
            loss = outputs["loss"]
        else:
            loss = outputs.sum()

    print("=== After Forward Pass ===")
    print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"Reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

    # Backward Pass
    loss.backward()

    print("\n=== After Backward Pass ===")
    print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"Reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    print(f"Peak:      {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")

    print("\n=== Full Memory Summary ===")
    print(torch.cuda.memory_summary(abbreviated=True))

# Usage example
# profile_memory(model, dummy_input)

10. Conclusion

GPU memory optimization is not about a single technique but rather a combination of multiple techniques. Here is a summary of the key points:

  1. Understanding memory composition comes first: You must understand how much memory each of Parameters, Gradients, Optimizer States, and Activations consumes to decide where to optimize.

  2. Mixed Precision is fundamental: Using BF16 allows you to halve memory and increase computation speed stably without Loss Scaling. On Ampere and later GPUs, use BF16 as the default.

  3. Gradient Checkpointing targets activation memory: It significantly reduces activation memory at the cost of approximately 33% computational overhead. It is nearly essential for large model training.

  4. Gradient Accumulation solves the batch size problem: It increases the effective batch size without additional memory burden.

  5. FP8 is the next-generation standard: FP8 training is possible on Hopper/Blackwell GPUs through Transformer Engine, providing additional performance improvements over FP16/BF16.

  6. Make profiling a habit: Regularly monitoring memory usage through torch.cuda.memory_summary() and Memory Snapshots can prevent OOM issues proactively.


References