- Authors
- Name
- 1. GPU Memory Composition Analysis
- 2. FP32 vs FP16 vs BF16 vs FP8 Numerical Representation Comparison
- 3. Mixed Precision Training Principles
- 4. PyTorch AMP Official Documentation Feature Analysis and Code Examples
- 5. Gradient Checkpointing (Activation Checkpointing) Analysis
- 6. Gradient Accumulation Technique
- 7. Memory Profiling with torch.cuda.memory_summary()
- 8. NVIDIA Transformer Engine and FP8 Training
- 9. Practical Case: Training Large Models on Limited GPUs
- 10. Conclusion
- References
1. GPU Memory Composition Analysis
When training deep learning models, GPU memory stores far more than just model parameters. The components occupying GPU memory during the entire training process can be broadly divided into four categories. Understanding these components precisely is essential for devising effective memory optimization strategies.
1.1 Model Parameters
Model parameters refer to the weights and biases of the neural network. They reside in GPU memory throughout the entire training process, and memory usage is determined by the number of parameters and the data type.
- FP32 basis: 4 bytes per parameter
- FP16/BF16 basis: 2 bytes per parameter
For example, storing a 7B (7 billion) parameter model in FP32 requires approximately 28GB, while FP16 requires approximately 14GB.
1.2 Gradients
Gradients are the partial derivatives of each parameter with respect to the loss function. They are computed during backpropagation, and since each parameter has a corresponding gradient value, they occupy the same amount of memory as the parameters.
- FP32 training: number of parameters x 4 bytes
- Mixed Precision training: number of parameters x 2 bytes (stored in FP16)
1.3 Optimizer States
Optimizer states often account for the largest portion of memory. While SGD maintains only momentum, the Adam/AdamW optimizer maintains two additional states for each parameter: First Moment (Mean) and Second Moment (Variance).
Optimizer states memory for Adam Optimizer with FP32 training:
Let the number of parameters be Φ:
- Master Weights (FP32): 4Φ bytes
- Momentum (FP32): 4Φ bytes
- Variance (FP32): 4Φ bytes
- Total: 12Φ bytes
Using Adam with a 7B model requires approximately 84GB for optimizer states alone. This is why optimizer state optimization is critical in large model training.
1.4 Activations
Activations are the outputs of each layer during the forward pass and must be stored for computing gradients during backpropagation. Activation memory is proportional to the following factors:
- Batch Size: larger sizes require more memory
- Sequence Length: especially important in Transformer models (O(n^2) for attention)
- Hidden Dimension: increases with larger models
- Number of Layers: increases with deeper models
In large models, activation memory frequently exceeds parameter memory.
1.5 Overall Memory Summary
When using Mixed Precision Training with Adam Optimizer, the total memory requirement for Φ parameters is as follows:
| Component | Data Type | Memory |
|---|---|---|
| Model Parameters (FP16 copy) | FP16 | 2Φ bytes |
| Gradients (FP16) | FP16 | 2Φ bytes |
| Master Weights | FP32 | 4Φ bytes |
| Optimizer Momentum | FP32 | 4Φ bytes |
| Optimizer Variance | FP32 | 4Φ bytes |
| Total | 16Φ bytes |
In addition, activation memory, temporary buffers, and memory fragmentation cause approximately 20-30% additional overhead in practice.
2. FP32 vs FP16 vs BF16 vs FP8 Numerical Representation Comparison
The floating-point formats supported by NVIDIA GPUs each have unique characteristics. To understand the trade-offs between training stability and performance, we compare the bit composition and representation ranges of each format.
2.1 Bit Composition
All IEEE 754-based floating-point numbers consist of three parts: Sign, Exponent, and Mantissa.
| Format | Sign | Exponent | Mantissa | Total Bits | Memory |
|---|---|---|---|---|---|
| FP32 | 1 | 8 | 23 | 32 | 4 bytes |
| FP16 | 1 | 5 | 10 | 16 | 2 bytes |
| BF16 | 1 | 8 | 7 | 16 | 2 bytes |
| FP8 (E4M3) | 1 | 4 | 3 | 8 | 1 byte |
| FP8 (E5M2) | 1 | 5 | 2 | 8 | 1 byte |
2.2 Dynamic Range and Precision
Comparing the representable ranges and precision of each format:
FP32 (Single Precision)
- Dynamic Range: ~1.2 x 10^-38 to 3.4 x 10^38
- Precision: approximately 7 decimal places
- The default data type for deep learning, providing the highest precision.
FP16 (Half Precision)
- Dynamic Range: ~6.1 x 10^-5 to 6.55 x 10^4
- Precision: approximately 3 decimal places
- The narrow dynamic range makes gradient underflow/overflow likely. Must always be used with Loss Scaling.
BF16 (Brain Floating Point 16)
- Dynamic Range: ~1.2 x 10^-38 to 3.4 x 10^38 (same as FP32)
- Precision: approximately 2 decimal places
- Has the same 8-bit exponent as FP32, so the dynamic range is identical. Although precision is lower, it can train stably without Loss Scaling, making it the emerging standard for large model training. Supported on NVIDIA Ampere (A100) and later.
FP8 E4M3
- Dynamic Range: ~plus/minus 448
- Precision: lowest
- Primarily used in the forward pass. Per-tensor scaling is essential to compensate for the narrow range.
FP8 E5M2
- Dynamic Range: ~plus/minus 57,344
- Precision: lower than E4M3 but wider range
- Primarily used in the backward pass (gradient computation). This accommodates the wide dynamic range of gradients.
2.3 Appropriate Use Scenarios for Each Format
FP32 → Optimizer States, Master Weights (requires precise cumulative operations)
BF16 → Forward/Backward operations (wide range, no Loss Scaling required)
FP16 → Forward/Backward operations (Tensor Core utilization, Loss Scaling required)
FP8 → Maximum performance on Hopper/Blackwell GPUs (using Transformer Engine)
3. Mixed Precision Training Principles
Mixed Precision Training is a technique proposed by NVIDIA in the 2017 paper "Mixed Precision Training" (Micikevicius et al.), which uses a mix of FP16 (or BF16) and FP32 during training. The core objective is to achieve computational speedup through Tensor Cores and memory savings while maintaining FP32-level training accuracy.
3.1 Three Core Techniques
According to NVIDIA official documentation, Mixed Precision Training consists of three core techniques:
(1) Master Weights (Maintaining FP32 Copies)
Even when performing operations in FP16, a separate FP32 master copy of the model parameters is maintained. During weight updates, gradients are applied to the FP32 master weights, which are then converted back to FP16 for the next forward pass. This is necessary because FP16's low precision causes small gradients to not be reflected in weight updates (the swamping problem).
Forward Pass: FP16 Weights → FP16 Activations → FP16 Loss
Backward Pass: FP16 Gradients computed
Weight Update: FP32 Master Weights += Learning Rate × FP32 Gradients
FP16 Weights = cast(FP32 Master Weights)
(2) Loss Scaling
FP16's dynamic range (~6.1 x 10^-5 to 6.55 x 10^4) may be insufficient to accommodate gradient values. In practice, gradient values smaller than 10^-10 are common, but FP16's minimum representable positive value is approximately 6 x 10^-8. This causes gradient underflow, where small gradients are treated as zero.
Loss Scaling solves this problem:
- Multiply the loss value computed in the forward pass by a scale factor.
- By the chain rule, all gradients computed via backpropagation are multiplied by the same scale factor.
- Before the weight update, divide the gradients by the scale factor to restore their original magnitude.
(3) Dynamic Loss Scaling
Dynamic Loss Scaling dynamically adjusts the scale factor during training instead of using a fixed one. NVIDIA's implementation works as follows:
- Set the initial scale factor large (e.g., 2^24 = 16,777,216).
- At each iteration, check if gradients contain inf/NaN.
- If no overflow occurs: maintain the current scale factor, and after N consecutive successes, double the scale factor (default N=2000).
- If overflow occurs: skip the weight update for that iteration and halve the scale factor.
This mechanism automatically adapts to changes in gradient distribution during training. As gradient magnitudes decrease in later stages of training, the scale factor naturally increases to prevent underflow.
3.2 Advantages of BF16 Training
BF16 has the same dynamic range as FP32, so Loss Scaling is not needed. This greatly simplifies implementation and fundamentally eliminates overflow/underflow-related issues. It was first adopted on Google's TPUs and has been supported at the hardware level since NVIDIA Ampere (A100) GPUs.
However, BF16 has only 7 mantissa bits compared to FP16's 10 bits, resulting in lower precision. In some models, BF16 training may cause convergence issues due to insufficient precision, so master weights must always be maintained in FP32.
4. PyTorch AMP Official Documentation Feature Analysis and Code Examples
PyTorch officially supports Automatic Mixed Precision (AMP) through the torch.amp module. AMP consists of two core components: torch.amp.autocast and torch.amp.GradScaler.
Note: Starting from PyTorch 2.x,
torch.cuda.amp.autocastandtorch.cuda.amp.GradScalerare deprecated. Usetorch.amp.autocast("cuda", ...)andtorch.amp.GradScaler("cuda", ...)instead.
4.1 torch.amp.autocast
autocast is used as a context manager or decorator and automatically executes operations within its scope at the appropriate precision. Rather than converting all operations to FP16 uniformly, it selects the optimal data type for each operation type.
- Executed in FP16: Conv, Linear, MatMul, etc. (operations that can utilize Tensor Cores)
- Maintained in FP32: Softmax, LayerNorm, Loss computation, etc. (operations where numerical stability is critical)
4.2 torch.amp.GradScaler
GradScaler automatically manages Loss Scaling. It abstracts the entire Dynamic Loss Scaling process (scale application, overflow checking, scale factor adjustment, weight update skipping) so users don't need to manage it manually.
4.3 Basic Usage Pattern
import torch
from torch.amp import autocast, GradScaler
model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler("cuda")
for epoch in range(num_epochs):
for batch in dataloader:
inputs, targets = batch
inputs = inputs.cuda()
targets = targets.cuda()
optimizer.zero_grad()
# autocast region: Execute Forward Pass in Mixed Precision
with autocast("cuda"):
outputs = model(inputs)
loss = loss_fn(outputs, targets)
# GradScaler: Scale the loss then execute Backward Pass
scaler.scale(loss).backward()
# GradScaler: Gradient Unscale → Overflow check → Optimizer Step
scaler.step(optimizer)
# Update Scale Factor
scaler.update()
4.4 Using with Gradient Clipping
When applying gradient clipping with Mixed Precision, you must explicitly call scaler.unscale_() before performing clipping:
scaler.scale(loss).backward()
# Perform Gradient Unscale first
scaler.unscale_(optimizer)
# Apply clipping to the restored original-magnitude gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Optimizer Step (internally checks for inf/NaN before proceeding)
scaler.step(optimizer)
scaler.update()
4.5 Usage with Multiple GPUs (DistributedDataParallel)
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(model, device_ids=[local_rank])
scaler = GradScaler("cuda")
with autocast("cuda"):
outputs = model(inputs)
loss = loss_fn(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
When using DDP and AMP together, GradScaler operates independently on each GPU. Since DDP's AllReduce is performed on scaled gradients, unscaling must occur after AllReduce. PyTorch's implementation handles this automatically.
4.6 Using BF16
When using BF16, Loss Scaling is unnecessary, so you only need autocast without GradScaler:
with autocast("cuda", dtype=torch.bfloat16):
outputs = model(inputs)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
This approach results in much more concise code, and since step skipping from Dynamic Loss Scaling doesn't occur, training can be more stable.
5. Gradient Checkpointing (Activation Checkpointing) Analysis
Gradient Checkpointing is a technique that exploits the trade-off between memory and computation time. Instead of storing all intermediate activations during the forward pass, it stores only some and recomputes the rest during the backward pass.
5.1 Principle
In a typical training process, all layer outputs (activations) from the forward pass are stored in memory because they're needed for gradient computation during backpropagation. Gradient Checkpointing modifies this process:
- Forward Pass: Only the inputs at designated checkpoint boundaries are saved; intermediate activations are not stored.
- Backward Pass: The forward pass for each segment is re-executed from the saved inputs to recompute activations, and then gradients are calculated.
This reduces activation memory from O(n) to O(sqrt(n)) (where n is the number of layers). In exchange, forward computation increases by approximately 33% (one additional forward pass).
5.2 PyTorch Implementation: torch.utils.checkpoint
PyTorch provides two APIs through the torch.utils.checkpoint module:
checkpoint function
Applies checkpointing to individual functions or layers:
from torch.utils.checkpoint import checkpoint
class TransformerBlock(nn.Module):
def __init__(self, ...):
super().__init__()
self.attention = MultiHeadAttention(...)
self.ffn = FeedForward(...)
self.norm1 = nn.LayerNorm(...)
self.norm2 = nn.LayerNorm(...)
def forward(self, x):
# Don't store intermediate activations of this block
# Recompute during backward
return checkpoint(self._forward, x, use_reentrant=False)
def _forward(self, x):
x = x + self.attention(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
checkpoint_sequential function
Applies checkpointing to groups of layers in a sequential structure:
from torch.utils.checkpoint import checkpoint_sequential
class DeepModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
*[TransformerBlock(...) for _ in range(24)]
)
def forward(self, x):
# Divide 24 layers into 4 groups (6 each) for checkpointing
segments = 4
return checkpoint_sequential(self.layers, segments, x,
use_reentrant=False)
5.3 Reentrant vs Non-reentrant Checkpointing
PyTorch provides two checkpointing modes:
- Reentrant (use_reentrant=True): The legacy approach. Re-executes the entire function during backward. Has some limitations (may not be compatible with certain Autograd features).
- Non-reentrant (use_reentrant=False): The improved approach. Recomputes only the necessary intermediate activations. Recommended by PyTorch official documentation and will become the default in the future.
5.4 Practical Application Strategies
Applying checkpointing to all layers incurs significant computational overhead. Effective strategies include:
- Apply only to Attention layers: Self-attention has O(n^2) activation memory, making it the largest consumer.
- Apply every N-th layer: For example, setting checkpoints every 2nd layer balances memory savings and computational overhead.
- Hugging Face Transformers: Can be enabled with a single line:
model.gradient_checkpointing_enable().
6. Gradient Accumulation Technique
Gradient Accumulation is a technique for overcoming physical batch size limitations. When GPU memory is limited and large batches cannot be used, multiple small micro-batches are processed sequentially, gradients are accumulated, and then a single weight update is performed.
6.1 Principle
Effective Batch Size = Micro-batch Size x Accumulation Steps
For example, if the micro-batch size is 8 and accumulation steps are 4, the effective batch size is 32.
The key insight is that in PyTorch, calling loss.backward() accumulates (sums) gradients in the .grad attribute. Since gradients are not reset until optimizer.zero_grad() is called, the results of multiple backward passes can naturally be accumulated.
6.2 Implementation
accumulation_steps = 4
optimizer.zero_grad()
for i, (inputs, targets) in enumerate(dataloader):
inputs, targets = inputs.cuda(), targets.cuda()
with autocast("cuda"):
outputs = model(inputs)
# Divide loss by accumulation steps to compute the average
loss = loss_fn(outputs, targets) / accumulation_steps
scaler.scale(loss).backward()
# Perform weight update every accumulation_steps
if (i + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Important: Do not forget to divide the loss by accumulation_steps. Since gradients are accumulated, failing to divide will effectively increase the learning rate by a factor of accumulation_steps.
6.3 Considerations
- Batch Normalization: BatchNorm computes statistics per micro-batch, which may differ from statistics over the full effective batch. In such cases, using GroupNorm or LayerNorm is preferable.
- Learning Rate Scheduling: When using step-based schedulers, the scheduler's call frequency should be adjusted to account for accumulation steps.
- Combination with DDP: When using gradient accumulation with DDP, AllReduce can be disabled during accumulation to reduce communication overhead. Use the
model.no_sync()context manager.
for i, (inputs, targets) in enumerate(dataloader):
# Skip AllReduce unless it's the last accumulation step
context = model.no_sync() if (i + 1) % accumulation_steps != 0 else nullcontext()
with context:
with autocast("cuda"):
outputs = model(inputs)
loss = loss_fn(outputs, targets) / accumulation_steps
scaler.scale(loss).backward()
if (i + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
7. Memory Profiling with torch.cuda.memory_summary()
The first step in memory optimization is to accurately understand the current memory usage. PyTorch provides various CUDA memory profiling tools.
7.1 torch.cuda.memory_summary()
This is the most basic yet useful tool. It provides detailed information about GPU memory allocation status:
import torch
model = MyLargeModel().cuda()
inputs = torch.randn(32, 3, 224, 224).cuda()
# Check memory state after forward pass
with autocast("cuda"):
outputs = model(inputs)
loss = loss_fn(outputs, targets)
print(torch.cuda.memory_summary(device=0, abbreviated=False))
Example output (partial):
|===========================================================================|
| PyTorch CUDA memory summary |
|===========================================================================|
| CUDA OOMs: 0 |
| cudaMallocs: 234 |
|---------------------------------------------------------------------------+
| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
|---------------------------------------------------------------------------+
| Allocated memory | 4096 MiB | 6234 MiB | 12345 MiB | 8249 MiB |
| Active memory | 3800 MiB | 5900 MiB | 11000 MiB | 7200 MiB |
| Requested memory | 3750 MiB | 5850 MiB | 10800 MiB | 7050 MiB |
| GPU reserved memory | 6400 MiB | 6400 MiB | 6400 MiB | 0 MiB |
|---------------------------------------------------------------------------+
Key metric interpretation:
- Allocated memory: Currently allocated memory (actual memory occupied by tensors)
- Active memory: Memory currently in use
- GPU reserved memory: Total memory reserved by PyTorch from CUDA (caching allocator)
- Peak Usage: Maximum memory usage during training (critical for determining OOM occurrence)
7.2 Individual Memory Query Functions
Useful for tracking memory usage at specific points within code:
# Currently allocated memory (bytes)
allocated = torch.cuda.memory_allocated(device=0)
# Currently reserved memory (bytes)
reserved = torch.cuda.memory_reserved(device=0)
# Maximum allocated memory since training started
max_allocated = torch.cuda.max_memory_allocated(device=0)
print(f"Allocated: {allocated / 1024**3:.2f} GB")
print(f"Reserved: {reserved / 1024**3:.2f} GB")
print(f"Peak: {max_allocated / 1024**3:.2f} GB")
# Reset peak statistics (for segment-by-segment measurement)
torch.cuda.reset_peak_memory_stats(device=0)
7.3 Detailed Analysis with Memory Snapshots
Memory Snapshots, available since PyTorch 2.x, record memory allocation/deallocation events in chronological order for visualization:
# Start memory recording
torch.cuda.memory._record_memory_history(max_entries=100000)
# Execute training code
train_one_epoch(model, dataloader, optimizer)
# Save snapshot
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
# Stop recording
torch.cuda.memory._record_memory_history(enabled=None)
Saved snapshots can be uploaded to PyTorch's official Memory Visualizer (https://pytorch.org/memory_viz) for visual analysis.
7.4 OOM Debugging Strategy
When Out of Memory occurs, diagnose the cause in the following order:
- Get an overview with
torch.cuda.memory_summary() - Check peak memory with
max_memory_allocated() - Use Memory Snapshots to identify which operations cause memory spikes
- Analyze which factor dominates: batch size, sequence length, or hidden dimension
- Apply appropriate optimization techniques (Mixed Precision, Gradient Checkpointing, Gradient Accumulation)
8. NVIDIA Transformer Engine and FP8 Training
NVIDIA Transformer Engine (TE) is a library that accelerates training and inference of Transformer models using FP8 precision on Hopper (H100) and later GPUs. FP8 training can achieve up to 2x throughput improvement compared to FP16/BF16.
8.1 Two FP8 Formats
The H100 GPU supports two FP8 formats at the hardware level:
- E4M3 (4-bit Exponent, 3-bit Mantissa): Range ~plus/minus 448. Relatively high precision. Used for forward pass activations and weights.
- E5M2 (5-bit Exponent, 2-bit Mantissa): Range ~plus/minus 57,344. Wide dynamic range. Used for backward pass gradients.
This separation strategy selects the format best suited to the numerical characteristics of each stage, minimizing accuracy loss.
8.2 Per-tensor Scaling
To compensate for FP8's narrow dynamic range, Transformer Engine applies per-tensor scaling. It maintains individual scale factors for each tensor, adjusting the tensor's value distribution to fit within FP8's representable range.
Transformer Engine offers several scaling strategies:
- Delayed Scaling: Determines the scale factor based on statistics from the previous iteration (default)
- Current Scaling (Just-in-time): Computes the scale factor immediately based on the current tensor's values
- Block Scaling: Divides the tensor into blocks and applies individual scale factors to each block
8.3 Transformer Engine Usage Example
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format
# FP8 training recipe configuration
fp8_recipe = DelayedScaling(
margin=0,
fp8_format=Format.HYBRID, # Forward: E4M3, Backward: E5M2
amax_history_len=1024,
amax_compute_algo="max",
)
# Use Transformer Engine layers
model = te.TransformerLayer(
hidden_size=4096,
ffn_hidden_size=16384,
num_attention_heads=32,
layer_number=1,
)
# FP8 training loop
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
outputs = model(inputs)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
8.4 Evolution in the Blackwell Generation
NVIDIA Blackwell (B200) GPUs add support for the following formats in addition to FP8:
- MXFP8: Microscaling FP8. Supports finer block-level scaling
- NVFP4: 4-bit Floating Point. Primarily used for inference, maximizing memory efficiency
Transformer Engine 2.x and later provides integrated support for these new formats through the Recipe module.
9. Practical Case: Training Large Models on Limited GPUs
Combining all the techniques covered so far, we outline practical strategies for training large models in limited GPU environments.
9.1 Scenario: Fine-tuning a 7B Model on a 24GB GPU
Assume fine-tuning a 7B parameter model on a single A10G (24GB VRAM).
Memory Requirement Analysis (FP32 basis):
| Component | Memory |
|---|---|
| Model Parameters (FP32) | 28 GB |
| Gradients (FP32) | 28 GB |
| Adam Optimizer States | 56 GB |
| Activations (batch=1) | ~2 GB |
| Total | ~114 GB |
This is absolutely impossible with FP32. Let's apply optimizations step by step.
9.2 Step-by-Step Optimization
Step 1: Mixed Precision Training (BF16)
from torch.amp import autocast
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf",
torch_dtype=torch.bfloat16)
| Component | Memory |
|---|---|
| Model Parameters (BF16) | 14 GB |
| Gradients (BF16) | 14 GB |
| Adam Master Weights (FP32) | 28 GB |
| Adam Momentum (FP32) | 28 GB |
| Adam Variance (FP32) | 28 GB |
| Total | ~112 GB |
Still insufficient due to optimizer states.
Step 2: Apply LoRA (Reduce Trainable Parameters)
LoRA (Low-Rank Adaptation) sets only about 0.1-1% of total parameters as trainable. Applying Rank=16 LoRA to a 7B model results in approximately 20M trainable parameters.
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=0.05,
)
model = get_peft_model(model, lora_config)
| Component | Memory |
|---|---|
| Full Model (BF16, frozen) | 14 GB |
| LoRA Parameters (BF16) | ~0.04 GB |
| LoRA Gradients (BF16) | ~0.04 GB |
| Adam States (FP32, LoRA only) | ~0.24 GB |
| Total | ~14.3 GB |
This fits comfortably in a 24GB GPU. Now there's room to increase batch size.
Step 3: Add Gradient Checkpointing
Save as much activation memory as possible in the remaining ~10GB to increase batch size:
model.gradient_checkpointing_enable()
Step 4: Ensure Effective Batch Size with Gradient Accumulation
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=4,
gradient_accumulation_steps=8, # Effective Batch Size = 32
bf16=True,
gradient_checkpointing=True,
learning_rate=2e-4,
num_train_epochs=3,
logging_steps=10,
)
9.3 Optimization Technique Combination Guide
| GPU VRAM | Model Scale | Recommended Combination |
|---|---|---|
| 8 GB | ~1B | BF16 + LoRA(r=8) + GC + GA |
| 16 GB | ~3B | BF16 + LoRA(r=16) + GC + GA |
| 24 GB | ~7B | BF16 + LoRA(r=16) + GC + GA |
| 40 GB | ~13B | BF16 + LoRA(r=32) + GC + GA |
| 80 GB | ~7B Full FT | BF16 + GC + GA |
| 8x80 GB | ~70B | BF16 + FSDP/DeepSpeed ZeRO-3 + GC |
GC = Gradient Checkpointing, GA = Gradient Accumulation, Full FT = Full Fine-tuning
9.4 Integrated Memory Profiling Script
You can use the following script to pre-check memory usage before training:
import torch
from torch.amp import autocast
def profile_memory(model, dummy_input, dtype=torch.bfloat16):
"""Profile GPU memory usage during training."""
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
model = model.cuda()
dummy_input = dummy_input.cuda()
# Forward Pass
with autocast("cuda", dtype=dtype):
outputs = model(dummy_input)
if isinstance(outputs, dict):
loss = outputs["loss"]
else:
loss = outputs.sum()
print("=== After Forward Pass ===")
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
# Backward Pass
loss.backward()
print("\n=== After Backward Pass ===")
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
print(f"Peak: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
print("\n=== Full Memory Summary ===")
print(torch.cuda.memory_summary(abbreviated=True))
# Usage example
# profile_memory(model, dummy_input)
10. Conclusion
GPU memory optimization is not about a single technique but rather a combination of multiple techniques. Here is a summary of the key points:
Understanding memory composition comes first: You must understand how much memory each of Parameters, Gradients, Optimizer States, and Activations consumes to decide where to optimize.
Mixed Precision is fundamental: Using BF16 allows you to halve memory and increase computation speed stably without Loss Scaling. On Ampere and later GPUs, use BF16 as the default.
Gradient Checkpointing targets activation memory: It significantly reduces activation memory at the cost of approximately 33% computational overhead. It is nearly essential for large model training.
Gradient Accumulation solves the batch size problem: It increases the effective batch size without additional memory burden.
FP8 is the next-generation standard: FP8 training is possible on Hopper/Blackwell GPUs through Transformer Engine, providing additional performance improvements over FP16/BF16.
Make profiling a habit: Regularly monitoring memory usage through
torch.cuda.memory_summary()and Memory Snapshots can prevent OOM issues proactively.
References
- NVIDIA Mixed Precision Training Documentation
- NVIDIA Mixed Precision Training User Guide (PDF)
- NVIDIA Automatic Mixed Precision for Deep Learning
- NVIDIA Mixed Precision Training of Deep Neural Networks (Blog)
- NVIDIA Floating-Point 8: Introduction to Efficient Lower-Precision Training (Blog)
- NVIDIA Transformer Engine Documentation
- NVIDIA Transformer Engine FP8 Primer
- NVIDIA Transformer Engine GitHub Repository
- PyTorch Automatic Mixed Precision (torch.amp) Documentation
- PyTorch AMP Examples
- PyTorch AMP Recipe Tutorial
- PyTorch Blog: What Every User Should Know About Mixed Precision Training
- PyTorch torch.utils.checkpoint Documentation
- PyTorch Blog: Current and New Activation Checkpointing Techniques
- PyTorch Understanding CUDA Memory Usage
- PyTorch torch.cuda.memory_summary Documentation
- Hugging Face: Methods and Tools for Efficient Training on a Single GPU
- Google Cloud Blog: Decoding High-Bandwidth Memory for Fine-tuning AI Models