- Published on
Deep Learning Debugging Complete Guide: From Diagnosing Training Failures to Performance Optimization
- Authors

- Name
- Youngju Kim
- @fjvbn20031
When training deep learning models, you will frequently encounter unexpected failures. Loss suddenly becoming NaN, models refusing to converge no matter how long you wait, or out-of-memory errors are experiences every deep learning developer has faced. This guide provides systematic methods for diagnosing and resolving all major issues that arise during deep learning training, complete with practical code examples.
1. Common Deep Learning Training Failure Patterns
Deep learning training failures can be broadly categorized into three types.
Loss Not Decreasing
Training starts but Loss barely decreases or stays close to its initial value. The most common causes are learning rate too low, bugs in model implementation, or issues in data preprocessing.
Loss Becoming NaN
Loss suddenly changes to NaN (Not a Number) or Inf (Infinity). This occurs due to numerical instability, typically when the learning rate is too high or the data contains outliers.
Training Loss Decreasing But Validation Loss Increasing
This is overfitting. The model is memorizing training data but failing to generalize to new data.
Checklist-Based Diagnostic Framework
def diagnose_training(model, train_loader, val_loader, optimizer, loss_fn, device):
"""
Quick diagnostic function to run before training begins
"""
print("=== Deep Learning Training Diagnostic Checklist ===\n")
# 1. Data validation
print("[1] Validating data...")
batch = next(iter(train_loader))
X, y = batch
print(f" Input shape: {X.shape}")
print(f" Label shape: {y.shape}")
print(f" Input range: [{X.min():.4f}, {X.max():.4f}]")
print(f" Input has NaN: {torch.isnan(X).any()}")
print(f" Input has Inf: {torch.isinf(X).any()}")
# 2. Model parameter validation
print("\n[2] Validating model parameters...")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" Total parameters: {total_params:,}")
print(f" Trainable parameters: {trainable_params:,}")
# 3. Forward pass test
print("\n[3] Testing forward pass...")
model.eval()
with torch.no_grad():
try:
output = model(X.to(device))
print(f" Output shape: {output.shape}")
print(f" Output has NaN: {torch.isnan(output).any()}")
loss = loss_fn(output, y.to(device))
print(f" Initial Loss: {loss.item():.4f}")
except Exception as e:
print(f" Forward pass failed: {e}")
# 4. Backward pass test
print("\n[4] Testing backward pass...")
model.train()
optimizer.zero_grad()
output = model(X.to(device))
loss = loss_fn(output, y.to(device))
loss.backward()
# Gradient validation
grad_norms = []
for name, param in model.named_parameters():
if param.grad is not None:
grad_norms.append((name, param.grad.norm().item()))
if grad_norms:
print(" Top 5 gradient norms by layer:")
for name, norm in sorted(grad_norms, key=lambda x: x[1], reverse=True)[:5]:
print(f" {name}: {norm:.6f}")
print("\nDiagnosis complete!")
2. Loss Problem Diagnosis
NaN Loss Causes and Solutions
NaN Loss is one of the most frustrating problems in deep learning. Multiple causes exist, each requiring a different approach.
Learning Rate Too High
The most common cause of NaN Loss. When the learning rate is too high, parameter update magnitudes become excessive and Loss explodes.
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
def find_learning_rate(model, train_loader, loss_fn, device,
start_lr=1e-7, end_lr=10, num_iter=100):
"""
Use LR Range Test to find the optimal learning rate range.
"""
optimizer = optim.SGD(model.parameters(), lr=start_lr)
lr_multiplier = (end_lr / start_lr) ** (1 / num_iter)
lrs = []
losses = []
best_loss = float('inf')
model.train()
data_iter = iter(train_loader)
for i in range(num_iter):
try:
X, y = next(data_iter)
except StopIteration:
data_iter = iter(train_loader)
X, y = next(data_iter)
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
output = model(X)
loss = loss_fn(output, y)
if torch.isnan(loss) or loss.item() > best_loss * 4:
print(f"Loss explosion detected at lr={optimizer.param_groups[0]['lr']:.2e}")
break
if loss.item() < best_loss:
best_loss = loss.item()
lrs.append(optimizer.param_groups[0]['lr'])
losses.append(loss.item())
loss.backward()
optimizer.step()
for pg in optimizer.param_groups:
pg['lr'] *= lr_multiplier
plt.figure(figsize=(10, 4))
plt.plot(lrs, losses)
plt.xscale('log')
plt.xlabel('Learning Rate')
plt.ylabel('Loss')
plt.title('LR Range Test')
plt.grid(True)
plt.savefig('lr_range_test.png')
plt.show()
return lrs, losses
def safe_training_step(model, X, y, optimizer, loss_fn, scaler=None):
"""
Safe training step that detects and skips NaN Loss
"""
optimizer.zero_grad()
if torch.isnan(X).any() or torch.isinf(X).any():
print("Warning: NaN/Inf in input, skipping step")
return None
if scaler is not None:
with torch.cuda.amp.autocast():
output = model(X)
loss = loss_fn(output, y)
if torch.isnan(loss) or torch.isinf(loss):
print(f"Warning: Loss is {loss.item()}, skipping step")
return None
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
else:
output = model(X)
loss = loss_fn(output, y)
if torch.isnan(loss) or torch.isinf(loss):
print(f"Warning: Loss is {loss.item()}, skipping step")
return None
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
return loss.item()
Preventing log(0) Computation
In cross-entropy loss or log-based loss functions, log(0) returns -Inf, causing NaN.
# Wrong: log(0) possible
def bad_cross_entropy(pred, target):
return -torch.sum(target * torch.log(pred))
# Correct: use eps for numerical stability
def safe_cross_entropy(pred, target, eps=1e-8):
pred = torch.clamp(pred, min=eps, max=1-eps)
return -torch.sum(target * torch.log(pred))
# Best: use PyTorch built-ins (internally applies log-sum-exp trick)
loss_fn = nn.CrossEntropyLoss() # numerically stable
log_softmax = nn.LogSoftmax(dim=1) # combined log + softmax
def numerically_stable_log_loss(logits, targets):
import torch.nn.functional as F
return F.cross_entropy(logits, targets)
Using torch.autograd.set_detect_anomaly
import torch
# Enable anomaly detection mode (development/debugging only)
# Slows performance, disable in production
with torch.autograd.detect_anomaly():
output = model(X)
loss = loss_fn(output, y)
loss.backward() # Prints exact location when NaN/Inf occurs
torch.autograd.set_detect_anomaly(True)
def train_with_anomaly_detection(model, loader, optimizer, loss_fn, device, epochs=5):
model.train()
for epoch in range(epochs):
for batch_idx, (X, y) in enumerate(loader):
X, y = X.to(device), y.to(device)
with torch.autograd.detect_anomaly():
optimizer.zero_grad()
output = model(X)
loss = loss_fn(output, y)
if torch.isnan(loss):
print(f"NaN Loss at epoch {epoch}, batch {batch_idx}")
print(f"Input stats: mean={X.mean():.4f}, std={X.std():.4f}")
print(f"Output stats: mean={output.mean():.4f}, std={output.std():.4f}")
break
loss.backward()
optimizer.step()
3. Gradient Problems
Diagnosing Vanishing Gradients
Vanishing gradients occur in deep networks when backpropagation causes gradients to become extremely small as they propagate to earlier layers.
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
def check_gradient_flow(model):
"""
Visualize gradient magnitudes per layer to diagnose vanishing/exploding gradients
"""
ave_grads = []
max_grads = []
layers = []
for name, param in model.named_parameters():
if param.requires_grad and param.grad is not None:
layers.append(name)
ave_grads.append(param.grad.abs().mean().item())
max_grads.append(param.grad.abs().max().item())
plt.figure(figsize=(12, 6))
plt.bar(range(len(ave_grads)), ave_grads, alpha=0.5, label='Mean Gradient')
plt.bar(range(len(max_grads)), max_grads, alpha=0.5, label='Max Gradient')
plt.xticks(range(len(layers)), layers, rotation=90)
plt.xlabel("Layer")
plt.ylabel("Gradient Magnitude")
plt.title("Gradient Flow by Layer")
plt.legend()
plt.yscale('log')
plt.tight_layout()
plt.savefig('gradient_flow.png')
for name, avg_grad in zip(layers, ave_grads):
if avg_grad < 1e-6:
print(f"Warning: Vanishing gradient possible in {name} (avg={avg_grad:.2e})")
return layers, ave_grads, max_grads
def register_gradient_hooks(model):
"""
Register gradient hooks for real-time monitoring
"""
gradient_stats = {}
def make_hook(name):
def hook(grad):
gradient_stats[name] = {
'mean': grad.abs().mean().item(),
'max': grad.abs().max().item(),
'std': grad.std().item(),
'has_nan': torch.isnan(grad).any().item(),
'has_inf': torch.isinf(grad).any().item()
}
if torch.isnan(grad).any():
print(f"NaN gradient detected: {name}")
return grad
return hook
hooks = []
for name, param in model.named_parameters():
if param.requires_grad:
hook = param.register_hook(make_hook(name))
hooks.append(hook)
return gradient_stats, hooks
# Fix vanishing gradients: He initialization + BatchNorm + Residual Connection
class ResidualBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.block = nn.Sequential(
nn.Linear(dim, dim),
nn.BatchNorm1d(dim),
nn.ReLU(),
nn.Linear(dim, dim),
nn.BatchNorm1d(dim)
)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.block(x) + x) # residual connection
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
Solving Exploding Gradients with Gradient Clipping
import torch
import torch.nn as nn
def train_with_gradient_clipping(model, loader, optimizer, loss_fn, device,
max_norm=1.0, epochs=10):
"""
Safe training loop with gradient clipping
"""
model.train()
history = {'train_loss': [], 'grad_norm': []}
for epoch in range(epochs):
epoch_loss = 0
epoch_grad_norms = []
for X, y in loader:
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
output = model(X)
loss = loss_fn(output, y)
loss.backward()
# Compute gradient norm before clipping
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
epoch_grad_norms.append(total_norm)
# Apply gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
optimizer.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / len(loader)
avg_grad_norm = sum(epoch_grad_norms) / len(epoch_grad_norms)
history['train_loss'].append(avg_loss)
history['grad_norm'].append(avg_grad_norm)
print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Grad Norm={avg_grad_norm:.4f}")
if avg_grad_norm > max_norm * 10:
print(f"Warning: Gradient norm very large ({avg_grad_norm:.4f}). Consider reducing learning rate.")
return history
4. Solving Overfitting
Diagnosing Overfitting
import matplotlib.pyplot as plt
import numpy as np
def plot_learning_curves(train_losses, val_losses, train_accs=None, val_accs=None):
"""
Diagnose overfitting using training/validation loss and accuracy curves
"""
fig, axes = plt.subplots(1, 2 if train_accs else 1, figsize=(14, 5))
if not isinstance(axes, np.ndarray):
axes = [axes]
axes[0].plot(train_losses, label='Train Loss', color='blue')
axes[0].plot(val_losses, label='Val Loss', color='red', linestyle='--')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training/Validation Loss')
axes[0].legend()
axes[0].grid(True)
min_val_idx = np.argmin(val_losses)
axes[0].axvline(x=min_val_idx, color='green', linestyle=':', label=f'Best epoch: {min_val_idx}')
axes[0].legend()
if train_accs and val_accs:
axes[1].plot(train_accs, label='Train Acc', color='blue')
axes[1].plot(val_accs, label='Val Acc', color='red', linestyle='--')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training/Validation Accuracy')
axes[1].legend()
axes[1].grid(True)
final_gap = val_losses[-1] - train_losses[-1]
print(f"Final overfitting gap (Val-Train Loss): {final_gap:.4f}")
if final_gap > 0.1:
print("Warning: Severe overfitting detected!")
plt.tight_layout()
plt.savefig('learning_curves.png')
plt.show()
Implementing Early Stopping
class EarlyStopping:
"""
Monitor validation Loss and stop training early when overfitting
"""
def __init__(self, patience=10, min_delta=0.001, restore_best=True, verbose=True):
self.patience = patience
self.min_delta = min_delta
self.restore_best = restore_best
self.verbose = verbose
self.best_loss = float('inf')
self.best_epoch = 0
self.counter = 0
self.best_weights = None
self.stopped_epoch = 0
def __call__(self, val_loss, model, epoch):
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.best_epoch = epoch
self.counter = 0
if self.restore_best:
import copy
self.best_weights = copy.deepcopy(model.state_dict())
if self.verbose:
print(f"Validation Loss improved: {val_loss:.6f} (epoch {epoch})")
return False
else:
self.counter += 1
if self.verbose:
print(f"EarlyStopping counter: {self.counter}/{self.patience}")
if self.counter >= self.patience:
self.stopped_epoch = epoch
if self.restore_best and self.best_weights:
model.load_state_dict(self.best_weights)
print(f"Best weights restored (epoch {self.best_epoch})")
return True
return False
def train_with_regularization(model, train_loader, val_loader,
optimizer, loss_fn, device, epochs=100):
"""
Training loop with various regularization techniques
"""
early_stopping = EarlyStopping(patience=15, min_delta=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5, verbose=True
)
train_losses = []
val_losses = []
for epoch in range(epochs):
model.train()
train_loss = 0
for X, y in train_loader:
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
output = model(X)
loss = loss_fn(output, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item()
model.eval()
val_loss = 0
with torch.no_grad():
for X, y in val_loader:
X, y = X.to(device), y.to(device)
output = model(X)
val_loss += loss_fn(output, y).item()
train_loss /= len(train_loader)
val_loss /= len(val_loader)
train_losses.append(train_loss)
val_losses.append(val_loss)
scheduler.step(val_loss)
print(f"Epoch {epoch+1}: Train={train_loss:.4f}, Val={val_loss:.4f}")
if early_stopping(val_loss, model, epoch):
print(f"Early stopping at epoch {epoch+1}")
break
return train_losses, val_losses
# Dropout + L2 Regularization example
class RegularizedModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate=0.3):
super().__init__()
self.network = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(p=dropout_rate),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(p=dropout_rate),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.network(x)
# L2 regularization via weight_decay in optimizer
optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-3,
weight_decay=1e-4
)
Data Augmentation Strategies
import torchvision.transforms as transforms
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomGrayscale(p=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def mixup_data(x, y, alpha=0.2, device='cuda'):
"""
Mixup augmentation: create new samples by linearly interpolating two samples
"""
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size(0)
index = torch.randperm(batch_size).to(device)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def mixup_criterion(criterion, pred, y_a, y_b, lam):
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
5. Training Speed Issues
Resolving Data Loading Bottlenecks
import torch
from torch.utils.data import DataLoader
import time
def profile_dataloader(dataset, batch_size=32, num_workers_list=[0, 2, 4, 8]):
"""
Compare data loading speed across different num_workers settings
"""
results = {}
for num_workers in num_workers_list:
loader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
prefetch_factor=2 if num_workers > 0 else None,
persistent_workers=True if num_workers > 0 else False
)
start = time.time()
for i, batch in enumerate(loader):
if i >= 10:
break
elapsed = time.time() - start
results[num_workers] = elapsed
print(f"num_workers={num_workers}: {elapsed:.3f}s (10 batches)")
best_workers = min(results, key=results.get)
print(f"\nOptimal num_workers: {best_workers}")
return results
def create_optimized_dataloader(dataset, batch_size, is_train=True):
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=is_train,
num_workers=4,
pin_memory=True,
prefetch_factor=2,
persistent_workers=True,
drop_last=is_train
)
Mixed Precision Training
import torch
from torch.cuda.amp import autocast, GradScaler
def train_mixed_precision(model, loader, optimizer, loss_fn, device, epochs=10):
"""
FP16 mixed precision training for 2-3x speedup
"""
scaler = GradScaler()
model.train()
for epoch in range(epochs):
for X, y in loader:
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(X)
loss = loss_fn(output, y)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
print(f"Epoch {epoch+1} complete, scaler scale: {scaler.get_scale()}")
Applying torch.compile
import torch
model = MyModel().to(device)
# Default compilation
compiled_model = torch.compile(model)
# Maximum performance mode (longer compile time)
compiled_model = torch.compile(model, mode='max-autotune')
# For frequently changing input sizes
compiled_model = torch.compile(model, dynamic=True)
def benchmark_model(model, inputs, n_iters=100):
# Warmup
for _ in range(10):
_ = model(inputs)
torch.cuda.synchronize()
start = time.time()
for _ in range(n_iters):
_ = model(inputs)
torch.cuda.synchronize()
elapsed = time.time() - start
return elapsed / n_iters
6. Out-of-Memory (OOM) Solutions
GPU Memory Analysis
import torch
import gc
def print_gpu_memory_summary(device=0):
"""
Print detailed GPU memory usage status
"""
if not torch.cuda.is_available():
print("CUDA is not available.")
return
print(f"=== GPU {device} Memory Summary ===")
print(f"Total memory: {torch.cuda.get_device_properties(device).total_memory / 1e9:.2f} GB")
print(f"Reserved memory: {torch.cuda.memory_reserved(device) / 1e9:.2f} GB")
print(f"Allocated memory: {torch.cuda.memory_allocated(device) / 1e9:.2f} GB")
print(f"Cached memory: {(torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)) / 1e9:.2f} GB")
print()
print(torch.cuda.memory_summary(device=device, abbreviated=False))
def clear_gpu_memory():
"""
Clear GPU memory cache
"""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(f"GPU memory after cleanup: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
Implementing Gradient Checkpointing
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
class MemoryEfficientModel(nn.Module):
"""
Memory-efficient model using Gradient Checkpointing
Trades computation time for memory savings through recomputation
"""
def __init__(self, layers):
super().__init__()
self.layers = nn.ModuleList(layers)
def forward(self, x):
# checkpoint_sequential: automatically applies checkpointing to sequential layers
# segments: number of chunks (more = more memory saved, slower)
return checkpoint_sequential(self.layers, segments=4, input=x)
def forward_with_manual_checkpoints(self, x):
x = self.layers[0](x)
for layer in self.layers[1:-1]:
x = checkpoint(layer, x)
x = self.layers[-1](x)
return x
# Enable Gradient Checkpointing in Transformers
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.gradient_checkpointing_enable() # Simple activation for Hugging Face models
Automatic Batch Size Search
def find_optimal_batch_size(model, loss_fn, device,
start_batch=8, max_batch=512):
"""
Finds the maximum batch size usable without OOM
"""
batch_size = start_batch
optimal_batch_size = start_batch
while batch_size <= max_batch:
try:
dummy_input = torch.randn(batch_size, 3, 224, 224).to(device)
dummy_target = torch.randint(0, 1000, (batch_size,)).to(device)
output = model(dummy_input)
loss = loss_fn(output, dummy_target)
loss.backward()
optimal_batch_size = batch_size
print(f"Batch size {batch_size}: Success")
batch_size *= 2
del dummy_input, dummy_target, output, loss
torch.cuda.empty_cache()
except RuntimeError as e:
if "out of memory" in str(e):
print(f"Batch size {batch_size}: OOM")
torch.cuda.empty_cache()
break
else:
raise e
print(f"\nRecommended batch size: {optimal_batch_size} (with safety margin: {optimal_batch_size // 2})")
return optimal_batch_size
7. Data Pipeline Debugging
Data Sample Visualization
import matplotlib.pyplot as plt
import numpy as np
import torch
from collections import Counter
def visualize_batch(loader, num_samples=16, class_names=None):
"""
Visualize data batch samples to verify preprocessing results
"""
X, y = next(iter(loader))
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
axes = axes.flatten()
for i in range(min(num_samples, len(X))):
img = X[i].numpy()
# Reverse ImageNet normalization
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img = std[:, None, None] * img + mean[:, None, None]
img = np.clip(img, 0, 1)
axes[i].imshow(img.transpose(1, 2, 0))
label = y[i].item()
title = class_names[label] if class_names else f"Label: {label}"
axes[i].set_title(title)
axes[i].axis('off')
plt.tight_layout()
plt.savefig('data_samples.png')
plt.show()
def check_label_distribution(dataset):
"""
Check label distribution to detect class imbalance
"""
labels = [dataset[i][1] for i in range(len(dataset))]
counter = Counter(labels)
classes = sorted(counter.keys())
counts = [counter[c] for c in classes]
total = sum(counts)
print("Label Distribution:")
for cls, count in zip(classes, counts):
pct = count / total * 100
bar = '#' * int(pct / 2)
print(f" Class {cls}: {count:5d} ({pct:.1f}%) {bar}")
max_count = max(counts)
min_count = min(counts)
imbalance_ratio = max_count / min_count
if imbalance_ratio > 10:
print(f"\nWarning: Severe class imbalance! (ratio: {imbalance_ratio:.1f}:1)")
print("Solution: Consider weighted sampling or class-weighted loss function.")
return counter
def create_weighted_sampler(dataset):
"""
Create weighted sampler to address class imbalance
"""
labels = [dataset[i][1] for i in range(len(dataset))]
class_counts = Counter(labels)
weights = [1.0 / class_counts[label] for label in labels]
weights = torch.DoubleTensor(weights)
sampler = torch.utils.data.WeightedRandomSampler(
weights=weights,
num_samples=len(weights),
replacement=True
)
return sampler
8. Model Architecture Debugging
Model Structure Analysis with torchinfo
from torchinfo import summary
import torch
import torch.nn as nn
def analyze_model(model, input_size):
"""
Analyze model structure and identify bottleneck layers
"""
model_stats = summary(
model,
input_size=input_size,
col_names=["input_size", "output_size", "num_params", "kernel_size",
"mult_adds"],
verbose=1
)
print("\nParameter distribution by layer:")
for name, module in model.named_modules():
num_params = sum(p.numel() for p in module.parameters(recurse=False))
if num_params > 0:
print(f" {name}: {num_params:,} parameters")
return model_stats
def monitor_activations(model, X):
"""
Monitor intermediate activation values to detect dead neurons
"""
activations = {}
def make_activation_hook(name):
def hook(module, input, output):
activations[name] = output.detach()
return hook
hooks = []
for name, module in model.named_modules():
if isinstance(module, (nn.ReLU, nn.GELU, nn.Tanh, nn.Sigmoid)):
hook = module.register_forward_hook(make_activation_hook(name))
hooks.append(hook)
with torch.no_grad():
model(X)
print("\nActivation statistics:")
for name, act in activations.items():
dead_neurons = (act == 0).float().mean().item()
print(f" {name}:")
print(f" mean: {act.mean():.4f}, std: {act.std():.4f}")
print(f" dead neuron ratio: {dead_neurons:.2%}")
if dead_neurons > 0.5:
print(f" Warning: {dead_neurons:.0%} of neurons are inactive!")
for hook in hooks:
hook.remove()
return activations
def visualize_weight_distribution(model):
"""
Visualize weight distributions to detect initialization issues
"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()
linear_layers = [(name, m) for name, m in model.named_modules()
if isinstance(m, (nn.Linear, nn.Conv2d))]
for i, (name, layer) in enumerate(linear_layers[:6]):
if i >= len(axes):
break
weight_data = layer.weight.data.cpu().numpy().flatten()
axes[i].hist(weight_data, bins=50, color='blue', alpha=0.7)
axes[i].set_title(f"{name}\n(mean={weight_data.mean():.4f}, std={weight_data.std():.4f})")
axes[i].set_xlabel('Weight Value')
axes[i].set_ylabel('Frequency')
axes[i].grid(True, alpha=0.3)
plt.suptitle("Weight Distribution by Layer")
plt.tight_layout()
plt.savefig('weight_distribution.png')
plt.show()
9. Training Monitoring
TensorBoard Usage
from torch.utils.tensorboard import SummaryWriter
import torch
import numpy as np
class TensorBoardLogger:
def __init__(self, log_dir='runs/experiment'):
self.writer = SummaryWriter(log_dir)
def log_scalars(self, metrics: dict, epoch: int):
for name, value in metrics.items():
self.writer.add_scalar(name, value, epoch)
def log_model_gradients(self, model, epoch: int):
for name, param in model.named_parameters():
if param.grad is not None:
self.writer.add_histogram(f'gradients/{name}', param.grad, epoch)
self.writer.add_histogram(f'weights/{name}', param.data, epoch)
def log_learning_rate(self, optimizer, epoch: int):
for i, pg in enumerate(optimizer.param_groups):
self.writer.add_scalar(f'lr/group_{i}', pg['lr'], epoch)
def close(self):
self.writer.close()
def train_with_tensorboard(model, train_loader, val_loader,
optimizer, loss_fn, device, epochs=50):
logger = TensorBoardLogger(log_dir='runs/debug_session')
for epoch in range(epochs):
model.train()
train_loss, train_correct = 0, 0
for X, y in train_loader:
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
output = model(X)
loss = loss_fn(output, y)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_correct += (output.argmax(1) == y).sum().item()
train_loss /= len(train_loader)
train_acc = train_correct / len(train_loader.dataset)
model.eval()
val_loss, val_correct = 0, 0
with torch.no_grad():
for X, y in val_loader:
X, y = X.to(device), y.to(device)
output = model(X)
val_loss += loss_fn(output, y).item()
val_correct += (output.argmax(1) == y).sum().item()
val_loss /= len(val_loader)
val_acc = val_correct / len(val_loader.dataset)
logger.log_scalars({
'Loss/Train': train_loss,
'Loss/Val': val_loss,
'Accuracy/Train': train_acc,
'Accuracy/Val': val_acc,
}, epoch)
logger.log_model_gradients(model, epoch)
logger.log_learning_rate(optimizer, epoch)
logger.close()
print("TensorBoard: run with 'tensorboard --logdir=runs'")
10. Reproducibility
import random
import numpy as np
import torch
import os
def set_seed(seed: int = 42):
"""
Fix all random number generators for complete reproducibility
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # Multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
print(f"All random generators initialized with seed {seed}")
def save_experiment_config(config: dict, save_path: str = 'experiment_config.json'):
"""
Record the complete experiment environment
"""
import json
import subprocess
full_config = config.copy()
full_config['environment'] = {
'python': subprocess.getoutput('python --version'),
'torch': torch.__version__,
'cuda': torch.version.cuda,
'cudnn': str(torch.backends.cudnn.version()),
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'
}
try:
full_config['git_hash'] = subprocess.getoutput('git rev-parse HEAD')
except Exception:
full_config['git_hash'] = 'unknown'
with open(save_path, 'w') as f:
json.dump(full_config, f, indent=2)
print(f"Experiment config saved: {save_path}")
return full_config
def test_reproducibility(model_fn, train_fn, seed=42, n_runs=3):
"""
Verify reproducibility by running multiple times with same seed
"""
results = []
for run in range(n_runs):
set_seed(seed)
model = model_fn()
loss = train_fn(model)
results.append(loss)
print(f"Run {run+1}: Final Loss = {loss:.6f}")
max_diff = max(results) - min(results)
print(f"\nMax difference: {max_diff:.8f}")
if max_diff < 1e-5:
print("Reproducibility check passed!")
else:
print("Warning: Reproducibility issues detected.")
return results
11. Distributed Training Debugging
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os
def setup_distributed(rank, world_size, backend='nccl'):
"""
Initialize distributed training environment
"""
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group(
backend=backend,
rank=rank,
world_size=world_size
)
torch.cuda.set_device(rank)
print(f"Process {rank}/{world_size} initialized")
def cleanup_distributed():
dist.destroy_process_group()
def debug_ddp_training(rank, world_size, model, dataset):
"""
DDP training debugging example
"""
setup_distributed(rank, world_size)
device = torch.device(f'cuda:{rank}')
model = model.to(device)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=True
)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=32,
sampler=sampler,
num_workers=4
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
for epoch in range(10):
# Critical: update sampler seed each epoch
sampler.set_epoch(epoch)
for X, y in loader:
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
output = model(X)
loss = loss_fn(output, y)
loss.backward()
optimizer.step()
# Only log from rank 0
if rank == 0:
print(f"Epoch {epoch+1} complete")
torch.save(model.module.state_dict(), f'checkpoint_epoch{epoch}.pt')
# Synchronize all ranks
dist.barrier()
cleanup_distributed()
12. MLflow and Experiment Management
import mlflow
import mlflow.pytorch
import torch
import optuna
def train_with_mlflow(model, train_loader, val_loader, optimizer,
loss_fn, device, params: dict):
"""
Experiment tracking and model version management with MLflow
"""
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("deep-learning-debug")
with mlflow.start_run():
mlflow.log_params(params)
best_val_loss = float('inf')
for epoch in range(params['epochs']):
model.train()
train_loss = 0
for X, y in train_loader:
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
output = model(X)
loss = loss_fn(output, y)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_loss /= len(train_loader)
model.eval()
val_loss = 0
with torch.no_grad():
for X, y in val_loader:
X, y = X.to(device), y.to(device)
output = model(X)
val_loss += loss_fn(output, y).item()
val_loss /= len(val_loader)
mlflow.log_metrics({
'train_loss': train_loss,
'val_loss': val_loss
}, step=epoch)
if val_loss < best_val_loss:
best_val_loss = val_loss
mlflow.pytorch.log_model(model, "best_model")
mlflow.log_metric("best_val_loss", best_val_loss)
return best_val_loss
def hyperparameter_optimization_with_optuna(model_fn, train_loader,
val_loader, device, n_trials=50):
"""
Hyperparameter optimization with Optuna
"""
def objective(trial):
lr = trial.suggest_float('lr', 1e-5, 1e-1, log=True)
weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-2, log=True)
dropout = trial.suggest_float('dropout', 0.0, 0.5)
batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128])
model = model_fn(dropout=dropout).to(device)
optimizer = torch.optim.AdamW(model.parameters(),
lr=lr, weight_decay=weight_decay)
loss_fn = torch.nn.CrossEntropyLoss()
val_loss = train_with_mlflow(
model, train_loader, val_loader, optimizer, loss_fn, device,
params={'lr': lr, 'weight_decay': weight_decay,
'dropout': dropout, 'batch_size': batch_size, 'epochs': 10}
)
return val_loss
study = optuna.create_study(
direction='minimize',
sampler=optuna.samplers.TPESampler(seed=42),
pruner=optuna.pruners.MedianPruner()
)
study.optimize(objective, n_trials=n_trials)
print("\nBest hyperparameters:")
for key, value in study.best_params.items():
print(f" {key}: {value}")
print(f"Best validation Loss: {study.best_value:.4f}")
return study.best_params
Conclusion: Deep Learning Debugging Workflow
A systematic approach is critical for deep learning debugging. Follow this order to diagnose problems:
-
Start with data: 80% of problems originate with data. Check for NaN, wrong labels, and normalization errors first.
-
Start small: Before full-batch training, test whether the model can overfit a single batch.
-
Check gradients: Verify that gradients flow correctly after the loss function.
-
Use monitoring tools: Choose one of TensorBoard, W&B, or MLflow and track all experiments.
-
Ensure reproducibility: Debugging without fixed seeds is extremely difficult. Always set seeds.
def minimum_debug_checklist(model, train_loader, device):
"""
Minimum checklist to verify before training starts
"""
print("Deep Learning Pre-Training Checklist")
print("=" * 50)
# 1. Single-batch overfitting test
print("[1] Single-batch overfitting test...")
X, y = next(iter(train_loader))
X, y = X.to(device), y.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
initial_loss = None
for step in range(100):
optimizer.zero_grad()
output = model(X)
loss = loss_fn(output, y)
if initial_loss is None:
initial_loss = loss.item()
loss.backward()
optimizer.step()
final_loss = loss.item()
overfit_ratio = initial_loss / final_loss if final_loss > 0 else float('inf')
if overfit_ratio > 10:
print(f" Pass: Loss decreased from {initial_loss:.4f} to {final_loss:.4f} (ratio: {overfit_ratio:.1f}x)")
else:
print(f" Warning: Cannot overfit a single batch (ratio: {overfit_ratio:.1f}x)")
print(" -> Check model capacity, learning rate, and data errors")
print("\nChecklist complete!")
By systematically applying the techniques covered in this guide, you can quickly diagnose and resolve most issues that arise during deep learning training. Debugging becomes faster with experience, but having the right tools and methodology is paramount.