Skip to content
Published on

Deep Learning Debugging Complete Guide: From Diagnosing Training Failures to Performance Optimization

Authors

When training deep learning models, you will frequently encounter unexpected failures. Loss suddenly becoming NaN, models refusing to converge no matter how long you wait, or out-of-memory errors are experiences every deep learning developer has faced. This guide provides systematic methods for diagnosing and resolving all major issues that arise during deep learning training, complete with practical code examples.

1. Common Deep Learning Training Failure Patterns

Deep learning training failures can be broadly categorized into three types.

Loss Not Decreasing

Training starts but Loss barely decreases or stays close to its initial value. The most common causes are learning rate too low, bugs in model implementation, or issues in data preprocessing.

Loss Becoming NaN

Loss suddenly changes to NaN (Not a Number) or Inf (Infinity). This occurs due to numerical instability, typically when the learning rate is too high or the data contains outliers.

Training Loss Decreasing But Validation Loss Increasing

This is overfitting. The model is memorizing training data but failing to generalize to new data.

Checklist-Based Diagnostic Framework

def diagnose_training(model, train_loader, val_loader, optimizer, loss_fn, device):
    """
    Quick diagnostic function to run before training begins
    """
    print("=== Deep Learning Training Diagnostic Checklist ===\n")

    # 1. Data validation
    print("[1] Validating data...")
    batch = next(iter(train_loader))
    X, y = batch
    print(f"  Input shape: {X.shape}")
    print(f"  Label shape: {y.shape}")
    print(f"  Input range: [{X.min():.4f}, {X.max():.4f}]")
    print(f"  Input has NaN: {torch.isnan(X).any()}")
    print(f"  Input has Inf: {torch.isinf(X).any()}")

    # 2. Model parameter validation
    print("\n[2] Validating model parameters...")
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")

    # 3. Forward pass test
    print("\n[3] Testing forward pass...")
    model.eval()
    with torch.no_grad():
        try:
            output = model(X.to(device))
            print(f"  Output shape: {output.shape}")
            print(f"  Output has NaN: {torch.isnan(output).any()}")
            loss = loss_fn(output, y.to(device))
            print(f"  Initial Loss: {loss.item():.4f}")
        except Exception as e:
            print(f"  Forward pass failed: {e}")

    # 4. Backward pass test
    print("\n[4] Testing backward pass...")
    model.train()
    optimizer.zero_grad()
    output = model(X.to(device))
    loss = loss_fn(output, y.to(device))
    loss.backward()

    # Gradient validation
    grad_norms = []
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norms.append((name, param.grad.norm().item()))

    if grad_norms:
        print("  Top 5 gradient norms by layer:")
        for name, norm in sorted(grad_norms, key=lambda x: x[1], reverse=True)[:5]:
            print(f"    {name}: {norm:.6f}")

    print("\nDiagnosis complete!")

2. Loss Problem Diagnosis

NaN Loss Causes and Solutions

NaN Loss is one of the most frustrating problems in deep learning. Multiple causes exist, each requiring a different approach.

Learning Rate Too High

The most common cause of NaN Loss. When the learning rate is too high, parameter update magnitudes become excessive and Loss explodes.

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt


def find_learning_rate(model, train_loader, loss_fn, device,
                        start_lr=1e-7, end_lr=10, num_iter=100):
    """
    Use LR Range Test to find the optimal learning rate range.
    """
    optimizer = optim.SGD(model.parameters(), lr=start_lr)
    lr_multiplier = (end_lr / start_lr) ** (1 / num_iter)

    lrs = []
    losses = []
    best_loss = float('inf')

    model.train()
    data_iter = iter(train_loader)

    for i in range(num_iter):
        try:
            X, y = next(data_iter)
        except StopIteration:
            data_iter = iter(train_loader)
            X, y = next(data_iter)

        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(X)
        loss = loss_fn(output, y)

        if torch.isnan(loss) or loss.item() > best_loss * 4:
            print(f"Loss explosion detected at lr={optimizer.param_groups[0]['lr']:.2e}")
            break

        if loss.item() < best_loss:
            best_loss = loss.item()

        lrs.append(optimizer.param_groups[0]['lr'])
        losses.append(loss.item())

        loss.backward()
        optimizer.step()

        for pg in optimizer.param_groups:
            pg['lr'] *= lr_multiplier

    plt.figure(figsize=(10, 4))
    plt.plot(lrs, losses)
    plt.xscale('log')
    plt.xlabel('Learning Rate')
    plt.ylabel('Loss')
    plt.title('LR Range Test')
    plt.grid(True)
    plt.savefig('lr_range_test.png')
    plt.show()

    return lrs, losses


def safe_training_step(model, X, y, optimizer, loss_fn, scaler=None):
    """
    Safe training step that detects and skips NaN Loss
    """
    optimizer.zero_grad()

    if torch.isnan(X).any() or torch.isinf(X).any():
        print("Warning: NaN/Inf in input, skipping step")
        return None

    if scaler is not None:
        with torch.cuda.amp.autocast():
            output = model(X)
            loss = loss_fn(output, y)

        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Warning: Loss is {loss.item()}, skipping step")
            return None

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
    else:
        output = model(X)
        loss = loss_fn(output, y)

        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Warning: Loss is {loss.item()}, skipping step")
            return None

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

    return loss.item()

Preventing log(0) Computation

In cross-entropy loss or log-based loss functions, log(0) returns -Inf, causing NaN.

# Wrong: log(0) possible
def bad_cross_entropy(pred, target):
    return -torch.sum(target * torch.log(pred))

# Correct: use eps for numerical stability
def safe_cross_entropy(pred, target, eps=1e-8):
    pred = torch.clamp(pred, min=eps, max=1-eps)
    return -torch.sum(target * torch.log(pred))

# Best: use PyTorch built-ins (internally applies log-sum-exp trick)
loss_fn = nn.CrossEntropyLoss()  # numerically stable
log_softmax = nn.LogSoftmax(dim=1)  # combined log + softmax

def numerically_stable_log_loss(logits, targets):
    import torch.nn.functional as F
    return F.cross_entropy(logits, targets)

Using torch.autograd.set_detect_anomaly

import torch

# Enable anomaly detection mode (development/debugging only)
# Slows performance, disable in production
with torch.autograd.detect_anomaly():
    output = model(X)
    loss = loss_fn(output, y)
    loss.backward()  # Prints exact location when NaN/Inf occurs

torch.autograd.set_detect_anomaly(True)

def train_with_anomaly_detection(model, loader, optimizer, loss_fn, device, epochs=5):
    model.train()
    for epoch in range(epochs):
        for batch_idx, (X, y) in enumerate(loader):
            X, y = X.to(device), y.to(device)

            with torch.autograd.detect_anomaly():
                optimizer.zero_grad()
                output = model(X)
                loss = loss_fn(output, y)

                if torch.isnan(loss):
                    print(f"NaN Loss at epoch {epoch}, batch {batch_idx}")
                    print(f"Input stats: mean={X.mean():.4f}, std={X.std():.4f}")
                    print(f"Output stats: mean={output.mean():.4f}, std={output.std():.4f}")
                    break

                loss.backward()
                optimizer.step()

3. Gradient Problems

Diagnosing Vanishing Gradients

Vanishing gradients occur in deep networks when backpropagation causes gradients to become extremely small as they propagate to earlier layers.

import torch
import torch.nn as nn
import matplotlib.pyplot as plt


def check_gradient_flow(model):
    """
    Visualize gradient magnitudes per layer to diagnose vanishing/exploding gradients
    """
    ave_grads = []
    max_grads = []
    layers = []

    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None:
            layers.append(name)
            ave_grads.append(param.grad.abs().mean().item())
            max_grads.append(param.grad.abs().max().item())

    plt.figure(figsize=(12, 6))
    plt.bar(range(len(ave_grads)), ave_grads, alpha=0.5, label='Mean Gradient')
    plt.bar(range(len(max_grads)), max_grads, alpha=0.5, label='Max Gradient')
    plt.xticks(range(len(layers)), layers, rotation=90)
    plt.xlabel("Layer")
    plt.ylabel("Gradient Magnitude")
    plt.title("Gradient Flow by Layer")
    plt.legend()
    plt.yscale('log')
    plt.tight_layout()
    plt.savefig('gradient_flow.png')

    for name, avg_grad in zip(layers, ave_grads):
        if avg_grad < 1e-6:
            print(f"Warning: Vanishing gradient possible in {name} (avg={avg_grad:.2e})")

    return layers, ave_grads, max_grads


def register_gradient_hooks(model):
    """
    Register gradient hooks for real-time monitoring
    """
    gradient_stats = {}

    def make_hook(name):
        def hook(grad):
            gradient_stats[name] = {
                'mean': grad.abs().mean().item(),
                'max': grad.abs().max().item(),
                'std': grad.std().item(),
                'has_nan': torch.isnan(grad).any().item(),
                'has_inf': torch.isinf(grad).any().item()
            }
            if torch.isnan(grad).any():
                print(f"NaN gradient detected: {name}")
            return grad
        return hook

    hooks = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            hook = param.register_hook(make_hook(name))
            hooks.append(hook)

    return gradient_stats, hooks


# Fix vanishing gradients: He initialization + BatchNorm + Residual Connection
class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(dim, dim),
            nn.BatchNorm1d(dim),
            nn.ReLU(),
            nn.Linear(dim, dim),
            nn.BatchNorm1d(dim)
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.block(x) + x)  # residual connection


def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

Solving Exploding Gradients with Gradient Clipping

import torch
import torch.nn as nn


def train_with_gradient_clipping(model, loader, optimizer, loss_fn, device,
                                   max_norm=1.0, epochs=10):
    """
    Safe training loop with gradient clipping
    """
    model.train()
    history = {'train_loss': [], 'grad_norm': []}

    for epoch in range(epochs):
        epoch_loss = 0
        epoch_grad_norms = []

        for X, y in loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(X)
            loss = loss_fn(output, y)
            loss.backward()

            # Compute gradient norm before clipping
            total_norm = 0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** 0.5
            epoch_grad_norms.append(total_norm)

            # Apply gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
            optimizer.step()
            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(loader)
        avg_grad_norm = sum(epoch_grad_norms) / len(epoch_grad_norms)

        history['train_loss'].append(avg_loss)
        history['grad_norm'].append(avg_grad_norm)

        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Grad Norm={avg_grad_norm:.4f}")

        if avg_grad_norm > max_norm * 10:
            print(f"Warning: Gradient norm very large ({avg_grad_norm:.4f}). Consider reducing learning rate.")

    return history

4. Solving Overfitting

Diagnosing Overfitting

import matplotlib.pyplot as plt
import numpy as np


def plot_learning_curves(train_losses, val_losses, train_accs=None, val_accs=None):
    """
    Diagnose overfitting using training/validation loss and accuracy curves
    """
    fig, axes = plt.subplots(1, 2 if train_accs else 1, figsize=(14, 5))

    if not isinstance(axes, np.ndarray):
        axes = [axes]

    axes[0].plot(train_losses, label='Train Loss', color='blue')
    axes[0].plot(val_losses, label='Val Loss', color='red', linestyle='--')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training/Validation Loss')
    axes[0].legend()
    axes[0].grid(True)

    min_val_idx = np.argmin(val_losses)
    axes[0].axvline(x=min_val_idx, color='green', linestyle=':', label=f'Best epoch: {min_val_idx}')
    axes[0].legend()

    if train_accs and val_accs:
        axes[1].plot(train_accs, label='Train Acc', color='blue')
        axes[1].plot(val_accs, label='Val Acc', color='red', linestyle='--')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Accuracy')
        axes[1].set_title('Training/Validation Accuracy')
        axes[1].legend()
        axes[1].grid(True)

    final_gap = val_losses[-1] - train_losses[-1]
    print(f"Final overfitting gap (Val-Train Loss): {final_gap:.4f}")
    if final_gap > 0.1:
        print("Warning: Severe overfitting detected!")

    plt.tight_layout()
    plt.savefig('learning_curves.png')
    plt.show()

Implementing Early Stopping

class EarlyStopping:
    """
    Monitor validation Loss and stop training early when overfitting
    """
    def __init__(self, patience=10, min_delta=0.001, restore_best=True, verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best = restore_best
        self.verbose = verbose

        self.best_loss = float('inf')
        self.best_epoch = 0
        self.counter = 0
        self.best_weights = None
        self.stopped_epoch = 0

    def __call__(self, val_loss, model, epoch):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.best_epoch = epoch
            self.counter = 0
            if self.restore_best:
                import copy
                self.best_weights = copy.deepcopy(model.state_dict())
            if self.verbose:
                print(f"Validation Loss improved: {val_loss:.6f} (epoch {epoch})")
            return False
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.stopped_epoch = epoch
                if self.restore_best and self.best_weights:
                    model.load_state_dict(self.best_weights)
                    print(f"Best weights restored (epoch {self.best_epoch})")
                return True
        return False


def train_with_regularization(model, train_loader, val_loader,
                               optimizer, loss_fn, device, epochs=100):
    """
    Training loop with various regularization techniques
    """
    early_stopping = EarlyStopping(patience=15, min_delta=0.001)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(X)
            loss = loss_fn(output, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item()

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y = X.to(device), y.to(device)
                output = model(X)
                val_loss += loss_fn(output, y).item()

        train_loss /= len(train_loader)
        val_loss /= len(val_loader)

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        scheduler.step(val_loss)
        print(f"Epoch {epoch+1}: Train={train_loss:.4f}, Val={val_loss:.4f}")

        if early_stopping(val_loss, model, epoch):
            print(f"Early stopping at epoch {epoch+1}")
            break

    return train_losses, val_losses


# Dropout + L2 Regularization example
class RegularizedModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate=0.3):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.network(x)

# L2 regularization via weight_decay in optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-3,
    weight_decay=1e-4
)

Data Augmentation Strategies

import torchvision.transforms as transforms


train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


def mixup_data(x, y, alpha=0.2, device='cuda'):
    """
    Mixup augmentation: create new samples by linearly interpolating two samples
    """
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]

    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

5. Training Speed Issues

Resolving Data Loading Bottlenecks

import torch
from torch.utils.data import DataLoader
import time


def profile_dataloader(dataset, batch_size=32, num_workers_list=[0, 2, 4, 8]):
    """
    Compare data loading speed across different num_workers settings
    """
    results = {}

    for num_workers in num_workers_list:
        loader = DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            prefetch_factor=2 if num_workers > 0 else None,
            persistent_workers=True if num_workers > 0 else False
        )

        start = time.time()
        for i, batch in enumerate(loader):
            if i >= 10:
                break
        elapsed = time.time() - start

        results[num_workers] = elapsed
        print(f"num_workers={num_workers}: {elapsed:.3f}s (10 batches)")

    best_workers = min(results, key=results.get)
    print(f"\nOptimal num_workers: {best_workers}")
    return results


def create_optimized_dataloader(dataset, batch_size, is_train=True):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=is_train,
        num_workers=4,
        pin_memory=True,
        prefetch_factor=2,
        persistent_workers=True,
        drop_last=is_train
    )

Mixed Precision Training

import torch
from torch.cuda.amp import autocast, GradScaler


def train_mixed_precision(model, loader, optimizer, loss_fn, device, epochs=10):
    """
    FP16 mixed precision training for 2-3x speedup
    """
    scaler = GradScaler()
    model.train()

    for epoch in range(epochs):
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()

            with autocast(device_type='cuda', dtype=torch.float16):
                output = model(X)
                loss = loss_fn(output, y)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

        print(f"Epoch {epoch+1} complete, scaler scale: {scaler.get_scale()}")

Applying torch.compile

import torch

model = MyModel().to(device)

# Default compilation
compiled_model = torch.compile(model)

# Maximum performance mode (longer compile time)
compiled_model = torch.compile(model, mode='max-autotune')

# For frequently changing input sizes
compiled_model = torch.compile(model, dynamic=True)


def benchmark_model(model, inputs, n_iters=100):
    # Warmup
    for _ in range(10):
        _ = model(inputs)

    torch.cuda.synchronize()
    start = time.time()
    for _ in range(n_iters):
        _ = model(inputs)
    torch.cuda.synchronize()
    elapsed = time.time() - start

    return elapsed / n_iters

6. Out-of-Memory (OOM) Solutions

GPU Memory Analysis

import torch
import gc


def print_gpu_memory_summary(device=0):
    """
    Print detailed GPU memory usage status
    """
    if not torch.cuda.is_available():
        print("CUDA is not available.")
        return

    print(f"=== GPU {device} Memory Summary ===")
    print(f"Total memory: {torch.cuda.get_device_properties(device).total_memory / 1e9:.2f} GB")
    print(f"Reserved memory: {torch.cuda.memory_reserved(device) / 1e9:.2f} GB")
    print(f"Allocated memory: {torch.cuda.memory_allocated(device) / 1e9:.2f} GB")
    print(f"Cached memory: {(torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)) / 1e9:.2f} GB")
    print()
    print(torch.cuda.memory_summary(device=device, abbreviated=False))


def clear_gpu_memory():
    """
    Clear GPU memory cache
    """
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    print(f"GPU memory after cleanup: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

Implementing Gradient Checkpointing

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint, checkpoint_sequential


class MemoryEfficientModel(nn.Module):
    """
    Memory-efficient model using Gradient Checkpointing
    Trades computation time for memory savings through recomputation
    """
    def __init__(self, layers):
        super().__init__()
        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        # checkpoint_sequential: automatically applies checkpointing to sequential layers
        # segments: number of chunks (more = more memory saved, slower)
        return checkpoint_sequential(self.layers, segments=4, input=x)

    def forward_with_manual_checkpoints(self, x):
        x = self.layers[0](x)
        for layer in self.layers[1:-1]:
            x = checkpoint(layer, x)
        x = self.layers[-1](x)
        return x


# Enable Gradient Checkpointing in Transformers
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("gpt2")
model.gradient_checkpointing_enable()  # Simple activation for Hugging Face models
def find_optimal_batch_size(model, loss_fn, device,
                              start_batch=8, max_batch=512):
    """
    Finds the maximum batch size usable without OOM
    """
    batch_size = start_batch
    optimal_batch_size = start_batch

    while batch_size <= max_batch:
        try:
            dummy_input = torch.randn(batch_size, 3, 224, 224).to(device)
            dummy_target = torch.randint(0, 1000, (batch_size,)).to(device)

            output = model(dummy_input)
            loss = loss_fn(output, dummy_target)
            loss.backward()

            optimal_batch_size = batch_size
            print(f"Batch size {batch_size}: Success")
            batch_size *= 2

            del dummy_input, dummy_target, output, loss
            torch.cuda.empty_cache()

        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"Batch size {batch_size}: OOM")
                torch.cuda.empty_cache()
                break
            else:
                raise e

    print(f"\nRecommended batch size: {optimal_batch_size} (with safety margin: {optimal_batch_size // 2})")
    return optimal_batch_size

7. Data Pipeline Debugging

Data Sample Visualization

import matplotlib.pyplot as plt
import numpy as np
import torch
from collections import Counter


def visualize_batch(loader, num_samples=16, class_names=None):
    """
    Visualize data batch samples to verify preprocessing results
    """
    X, y = next(iter(loader))

    fig, axes = plt.subplots(4, 4, figsize=(12, 12))
    axes = axes.flatten()

    for i in range(min(num_samples, len(X))):
        img = X[i].numpy()

        # Reverse ImageNet normalization
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = std[:, None, None] * img + mean[:, None, None]
        img = np.clip(img, 0, 1)

        axes[i].imshow(img.transpose(1, 2, 0))
        label = y[i].item()
        title = class_names[label] if class_names else f"Label: {label}"
        axes[i].set_title(title)
        axes[i].axis('off')

    plt.tight_layout()
    plt.savefig('data_samples.png')
    plt.show()


def check_label_distribution(dataset):
    """
    Check label distribution to detect class imbalance
    """
    labels = [dataset[i][1] for i in range(len(dataset))]
    counter = Counter(labels)

    classes = sorted(counter.keys())
    counts = [counter[c] for c in classes]
    total = sum(counts)

    print("Label Distribution:")
    for cls, count in zip(classes, counts):
        pct = count / total * 100
        bar = '#' * int(pct / 2)
        print(f"  Class {cls}: {count:5d} ({pct:.1f}%) {bar}")

    max_count = max(counts)
    min_count = min(counts)
    imbalance_ratio = max_count / min_count

    if imbalance_ratio > 10:
        print(f"\nWarning: Severe class imbalance! (ratio: {imbalance_ratio:.1f}:1)")
        print("Solution: Consider weighted sampling or class-weighted loss function.")

    return counter


def create_weighted_sampler(dataset):
    """
    Create weighted sampler to address class imbalance
    """
    labels = [dataset[i][1] for i in range(len(dataset))]
    class_counts = Counter(labels)

    weights = [1.0 / class_counts[label] for label in labels]
    weights = torch.DoubleTensor(weights)

    sampler = torch.utils.data.WeightedRandomSampler(
        weights=weights,
        num_samples=len(weights),
        replacement=True
    )
    return sampler

8. Model Architecture Debugging

Model Structure Analysis with torchinfo

from torchinfo import summary
import torch
import torch.nn as nn


def analyze_model(model, input_size):
    """
    Analyze model structure and identify bottleneck layers
    """
    model_stats = summary(
        model,
        input_size=input_size,
        col_names=["input_size", "output_size", "num_params", "kernel_size",
                   "mult_adds"],
        verbose=1
    )

    print("\nParameter distribution by layer:")
    for name, module in model.named_modules():
        num_params = sum(p.numel() for p in module.parameters(recurse=False))
        if num_params > 0:
            print(f"  {name}: {num_params:,} parameters")

    return model_stats


def monitor_activations(model, X):
    """
    Monitor intermediate activation values to detect dead neurons
    """
    activations = {}

    def make_activation_hook(name):
        def hook(module, input, output):
            activations[name] = output.detach()
        return hook

    hooks = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.ReLU, nn.GELU, nn.Tanh, nn.Sigmoid)):
            hook = module.register_forward_hook(make_activation_hook(name))
            hooks.append(hook)

    with torch.no_grad():
        model(X)

    print("\nActivation statistics:")
    for name, act in activations.items():
        dead_neurons = (act == 0).float().mean().item()
        print(f"  {name}:")
        print(f"    mean: {act.mean():.4f}, std: {act.std():.4f}")
        print(f"    dead neuron ratio: {dead_neurons:.2%}")
        if dead_neurons > 0.5:
            print(f"    Warning: {dead_neurons:.0%} of neurons are inactive!")

    for hook in hooks:
        hook.remove()

    return activations


def visualize_weight_distribution(model):
    """
    Visualize weight distributions to detect initialization issues
    """
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()

    linear_layers = [(name, m) for name, m in model.named_modules()
                     if isinstance(m, (nn.Linear, nn.Conv2d))]

    for i, (name, layer) in enumerate(linear_layers[:6]):
        if i >= len(axes):
            break
        weight_data = layer.weight.data.cpu().numpy().flatten()
        axes[i].hist(weight_data, bins=50, color='blue', alpha=0.7)
        axes[i].set_title(f"{name}\n(mean={weight_data.mean():.4f}, std={weight_data.std():.4f})")
        axes[i].set_xlabel('Weight Value')
        axes[i].set_ylabel('Frequency')
        axes[i].grid(True, alpha=0.3)

    plt.suptitle("Weight Distribution by Layer")
    plt.tight_layout()
    plt.savefig('weight_distribution.png')
    plt.show()

9. Training Monitoring

TensorBoard Usage

from torch.utils.tensorboard import SummaryWriter
import torch
import numpy as np


class TensorBoardLogger:
    def __init__(self, log_dir='runs/experiment'):
        self.writer = SummaryWriter(log_dir)

    def log_scalars(self, metrics: dict, epoch: int):
        for name, value in metrics.items():
            self.writer.add_scalar(name, value, epoch)

    def log_model_gradients(self, model, epoch: int):
        for name, param in model.named_parameters():
            if param.grad is not None:
                self.writer.add_histogram(f'gradients/{name}', param.grad, epoch)
                self.writer.add_histogram(f'weights/{name}', param.data, epoch)

    def log_learning_rate(self, optimizer, epoch: int):
        for i, pg in enumerate(optimizer.param_groups):
            self.writer.add_scalar(f'lr/group_{i}', pg['lr'], epoch)

    def close(self):
        self.writer.close()


def train_with_tensorboard(model, train_loader, val_loader,
                             optimizer, loss_fn, device, epochs=50):
    logger = TensorBoardLogger(log_dir='runs/debug_session')

    for epoch in range(epochs):
        model.train()
        train_loss, train_correct = 0, 0
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(X)
            loss = loss_fn(output, y)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_correct += (output.argmax(1) == y).sum().item()

        train_loss /= len(train_loader)
        train_acc = train_correct / len(train_loader.dataset)

        model.eval()
        val_loss, val_correct = 0, 0
        with torch.no_grad():
            for X, y in val_loader:
                X, y = X.to(device), y.to(device)
                output = model(X)
                val_loss += loss_fn(output, y).item()
                val_correct += (output.argmax(1) == y).sum().item()

        val_loss /= len(val_loader)
        val_acc = val_correct / len(val_loader.dataset)

        logger.log_scalars({
            'Loss/Train': train_loss,
            'Loss/Val': val_loss,
            'Accuracy/Train': train_acc,
            'Accuracy/Val': val_acc,
        }, epoch)
        logger.log_model_gradients(model, epoch)
        logger.log_learning_rate(optimizer, epoch)

    logger.close()
    print("TensorBoard: run with 'tensorboard --logdir=runs'")

10. Reproducibility

import random
import numpy as np
import torch
import os


def set_seed(seed: int = 42):
    """
    Fix all random number generators for complete reproducibility
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # Multi-GPU

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    torch.use_deterministic_algorithms(True)

    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

    print(f"All random generators initialized with seed {seed}")


def save_experiment_config(config: dict, save_path: str = 'experiment_config.json'):
    """
    Record the complete experiment environment
    """
    import json
    import subprocess

    full_config = config.copy()

    full_config['environment'] = {
        'python': subprocess.getoutput('python --version'),
        'torch': torch.__version__,
        'cuda': torch.version.cuda,
        'cudnn': str(torch.backends.cudnn.version()),
        'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'
    }

    try:
        full_config['git_hash'] = subprocess.getoutput('git rev-parse HEAD')
    except Exception:
        full_config['git_hash'] = 'unknown'

    with open(save_path, 'w') as f:
        json.dump(full_config, f, indent=2)

    print(f"Experiment config saved: {save_path}")
    return full_config


def test_reproducibility(model_fn, train_fn, seed=42, n_runs=3):
    """
    Verify reproducibility by running multiple times with same seed
    """
    results = []

    for run in range(n_runs):
        set_seed(seed)
        model = model_fn()
        loss = train_fn(model)
        results.append(loss)
        print(f"Run {run+1}: Final Loss = {loss:.6f}")

    max_diff = max(results) - min(results)
    print(f"\nMax difference: {max_diff:.8f}")

    if max_diff < 1e-5:
        print("Reproducibility check passed!")
    else:
        print("Warning: Reproducibility issues detected.")

    return results

11. Distributed Training Debugging

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os


def setup_distributed(rank, world_size, backend='nccl'):
    """
    Initialize distributed training environment
    """
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    dist.init_process_group(
        backend=backend,
        rank=rank,
        world_size=world_size
    )

    torch.cuda.set_device(rank)
    print(f"Process {rank}/{world_size} initialized")


def cleanup_distributed():
    dist.destroy_process_group()


def debug_ddp_training(rank, world_size, model, dataset):
    """
    DDP training debugging example
    """
    setup_distributed(rank, world_size)

    device = torch.device(f'cuda:{rank}')
    model = model.to(device)
    model = DDP(model, device_ids=[rank], find_unused_parameters=True)

    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=32,
        sampler=sampler,
        num_workers=4
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in range(10):
        # Critical: update sampler seed each epoch
        sampler.set_epoch(epoch)

        for X, y in loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(X)
            loss = loss_fn(output, y)
            loss.backward()
            optimizer.step()

        # Only log from rank 0
        if rank == 0:
            print(f"Epoch {epoch+1} complete")
            torch.save(model.module.state_dict(), f'checkpoint_epoch{epoch}.pt')

        # Synchronize all ranks
        dist.barrier()

    cleanup_distributed()

12. MLflow and Experiment Management

import mlflow
import mlflow.pytorch
import torch
import optuna


def train_with_mlflow(model, train_loader, val_loader, optimizer,
                       loss_fn, device, params: dict):
    """
    Experiment tracking and model version management with MLflow
    """
    mlflow.set_tracking_uri("http://localhost:5000")
    mlflow.set_experiment("deep-learning-debug")

    with mlflow.start_run():
        mlflow.log_params(params)
        best_val_loss = float('inf')

        for epoch in range(params['epochs']):
            model.train()
            train_loss = 0
            for X, y in train_loader:
                X, y = X.to(device), y.to(device)
                optimizer.zero_grad()
                output = model(X)
                loss = loss_fn(output, y)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            train_loss /= len(train_loader)

            model.eval()
            val_loss = 0
            with torch.no_grad():
                for X, y in val_loader:
                    X, y = X.to(device), y.to(device)
                    output = model(X)
                    val_loss += loss_fn(output, y).item()
            val_loss /= len(val_loader)

            mlflow.log_metrics({
                'train_loss': train_loss,
                'val_loss': val_loss
            }, step=epoch)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                mlflow.pytorch.log_model(model, "best_model")

        mlflow.log_metric("best_val_loss", best_val_loss)

    return best_val_loss


def hyperparameter_optimization_with_optuna(model_fn, train_loader,
                                              val_loader, device, n_trials=50):
    """
    Hyperparameter optimization with Optuna
    """
    def objective(trial):
        lr = trial.suggest_float('lr', 1e-5, 1e-1, log=True)
        weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-2, log=True)
        dropout = trial.suggest_float('dropout', 0.0, 0.5)
        batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128])

        model = model_fn(dropout=dropout).to(device)
        optimizer = torch.optim.AdamW(model.parameters(),
                                       lr=lr, weight_decay=weight_decay)
        loss_fn = torch.nn.CrossEntropyLoss()

        val_loss = train_with_mlflow(
            model, train_loader, val_loader, optimizer, loss_fn, device,
            params={'lr': lr, 'weight_decay': weight_decay,
                    'dropout': dropout, 'batch_size': batch_size, 'epochs': 10}
        )

        return val_loss

    study = optuna.create_study(
        direction='minimize',
        sampler=optuna.samplers.TPESampler(seed=42),
        pruner=optuna.pruners.MedianPruner()
    )

    study.optimize(objective, n_trials=n_trials)

    print("\nBest hyperparameters:")
    for key, value in study.best_params.items():
        print(f"  {key}: {value}")
    print(f"Best validation Loss: {study.best_value:.4f}")

    return study.best_params

Conclusion: Deep Learning Debugging Workflow

A systematic approach is critical for deep learning debugging. Follow this order to diagnose problems:

  1. Start with data: 80% of problems originate with data. Check for NaN, wrong labels, and normalization errors first.

  2. Start small: Before full-batch training, test whether the model can overfit a single batch.

  3. Check gradients: Verify that gradients flow correctly after the loss function.

  4. Use monitoring tools: Choose one of TensorBoard, W&B, or MLflow and track all experiments.

  5. Ensure reproducibility: Debugging without fixed seeds is extremely difficult. Always set seeds.

def minimum_debug_checklist(model, train_loader, device):
    """
    Minimum checklist to verify before training starts
    """
    print("Deep Learning Pre-Training Checklist")
    print("=" * 50)

    # 1. Single-batch overfitting test
    print("[1] Single-batch overfitting test...")
    X, y = next(iter(train_loader))
    X, y = X.to(device), y.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()

    initial_loss = None
    for step in range(100):
        optimizer.zero_grad()
        output = model(X)
        loss = loss_fn(output, y)
        if initial_loss is None:
            initial_loss = loss.item()
        loss.backward()
        optimizer.step()

    final_loss = loss.item()
    overfit_ratio = initial_loss / final_loss if final_loss > 0 else float('inf')

    if overfit_ratio > 10:
        print(f"  Pass: Loss decreased from {initial_loss:.4f} to {final_loss:.4f} (ratio: {overfit_ratio:.1f}x)")
    else:
        print(f"  Warning: Cannot overfit a single batch (ratio: {overfit_ratio:.1f}x)")
        print("  -> Check model capacity, learning rate, and data errors")

    print("\nChecklist complete!")

By systematically applying the techniques covered in this guide, you can quickly diagnose and resolve most issues that arise during deep learning training. Debugging becomes faster with experience, but having the right tools and methodology is paramount.