Skip to content
Published on

Knowledge Distillation Complete Guide: Model Compression and Lightweight Techniques

Authors

Introduction

As deep learning models grow more powerful, they also grow larger. Models like GPT-4 or Llama 3 405B require hundreds of gigabytes of memory, making them impossible to run on mobile devices or edge hardware. Knowledge Distillation and model compression techniques are the essential tools for making these models smaller and faster while preserving as much of their accuracy as possible.

This guide covers:

  • Theoretical foundations of knowledge distillation with complete PyTorch implementations
  • Multiple distillation paradigms: Response-based, Feature-based, Relation-based
  • LLM distillation case studies (DistilBERT, TinyLLM, Distil Whisper)
  • Structured and unstructured pruning
  • Weight sharing and Neural Architecture Search (NAS)

1. Knowledge Distillation Fundamentals

1.1 The Teacher-Student Framework

Knowledge distillation was introduced by Hinton, Vinyals, and Dean in 2015. The core idea: transfer the "knowledge" of a large Teacher model into a small Student model.

Teacher Model (large, accurate)
Soft targets (probability distributions)
Student Model (small, fast)

Rather than training with hard labels (one-hot vectors), the student learns from the Teacher's soft targets — the full softmax output distribution.

For example, classifying a cat image:

  • Hard target: [0, 0, 1, 0, 0] (only the correct class is 1)
  • Teacher soft target: [0.01, 0.05, 0.85, 0.07, 0.02]

The soft target encodes the information that "this image is a cat, but it has some resemblance to a tiger." This inter-class similarity information provides rich supervision for the student.

1.2 The Temperature Parameter

When softmax outputs are too confident ([0.001, 0.002, 0.997, ...]), they carry almost no more information than hard labels. The temperature parameter T makes distributions softer:

softmax_T(z_i) = exp(z_i / T) / sum_j(exp(z_j / T))
  • When T is 1: standard softmax
  • When T is greater than 1: flatter, more uniform distribution (more information)
  • When T is less than 1: sharper, more peaked distribution
import torch
import torch.nn.functional as F

def temperature_softmax(logits, temperature=1.0):
    """Temperature-scaled softmax."""
    return F.softmax(logits / temperature, dim=-1)

# Example
logits = torch.tensor([2.0, 1.0, 0.1, 0.5])
print("T=1:", temperature_softmax(logits, temperature=1).numpy().round(3))
# [0.596, 0.219, 0.090, 0.096]
print("T=4:", temperature_softmax(logits, temperature=4).numpy().round(3))
# [0.345, 0.262, 0.195, 0.216] — more uniform, more information

1.3 Hinton's KD Loss Function

The KD loss is a weighted sum of two terms:

KL Divergence term (soft target matching):

L_KD = T^2 * KLDiv(softmax_T(student_logits), softmax_T(teacher_logits))

The T^2 scaling keeps gradient magnitudes independent of the temperature value.

Cross-Entropy term (hard target learning):

L_CE = CrossEntropy(student_logits, true_labels)

Total loss:

L = alpha * L_CE + (1 - alpha) * L_KD

1.4 Complete PyTorch Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader
import torchvision.transforms as transforms


class KnowledgeDistillationLoss(nn.Module):
    """
    Hinton et al. (2015) Knowledge Distillation Loss.
    L = alpha * CE(student, labels) + (1-alpha) * T^2 * KLDiv(student_soft, teacher_soft)
    """
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.T = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, labels):
        # Hard target loss
        loss_ce = self.ce_loss(student_logits, labels)

        # Soft target loss (KL Divergence)
        student_soft = F.log_softmax(student_logits / self.T, dim=-1)
        teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
        loss_kd = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)

        return self.alpha * loss_ce + (1 - self.alpha) * loss_kd


def train_with_distillation(
    teacher, student, train_loader,
    num_epochs=10, temperature=4.0, alpha=0.5,
    device='cuda'
):
    """Teacher-Student distillation training loop."""
    teacher = teacher.to(device).eval()  # Teacher is frozen
    student = student.to(device)

    # Freeze teacher parameters
    for param in teacher.parameters():
        param.requires_grad = False

    criterion = KnowledgeDistillationLoss(temperature=temperature, alpha=alpha)
    optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    for epoch in range(num_epochs):
        student.train()
        total_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Teacher inference (no gradient needed)
            with torch.no_grad():
                teacher_logits = teacher(images)

            # Student inference
            student_logits = student(images)

            # KD loss computation
            loss = criterion(student_logits, teacher_logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * images.size(0)
            correct += (student_logits.argmax(1) == labels).sum().item()
            total += images.size(0)

        scheduler.step()
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Loss: {total_loss/total:.4f} | "
              f"Acc: {correct/total:.4f}")

    return student


# Practical example
# Teacher: ResNet50 (25M params), Student: ResNet18 (11M params)
teacher = models.resnet50(weights='DEFAULT')
teacher.fc = nn.Linear(2048, 10)

student = models.resnet18(weights=None)
student.fc = nn.Linear(512, 10)

# Parameter count comparison
t_params = sum(p.numel() for p in teacher.parameters())
s_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {t_params:,} params")   # ~25.6M
print(f"Student: {s_params:,} params")   # ~11.2M
print(f"Compression ratio: {t_params/s_params:.1f}x")

2. Distillation Paradigms

2.1 Response-based Distillation (Logit Matching)

The most basic approach — the student mimics the Teacher's final outputs (logits). Hinton's original KD is the canonical example.

class ResponseBasedDistillation(nn.Module):
    """Response-based distillation using only final outputs."""
    def __init__(self, temperature=4.0):
        super().__init__()
        self.T = temperature

    def forward(self, student_logits, teacher_logits):
        student_soft = F.log_softmax(student_logits / self.T, dim=-1)
        teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
        return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.T ** 2)

2.2 Feature-based Distillation (Intermediate Layers)

Proposed in FitNets (Romero et al., 2015). The student learns to match the Teacher's intermediate feature maps, not just the final outputs.

Since the teacher and student may have different channel counts, a Regressor network projects the student features to match the teacher's dimensions.

class FeatureDistillationLoss(nn.Module):
    """Distillation via intermediate layer feature matching."""
    def __init__(self, teacher_channels, student_channels):
        super().__init__()
        # Project student features to teacher feature space
        self.regressor = nn.Sequential(
            nn.Conv2d(student_channels, teacher_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(teacher_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, student_feat, teacher_feat):
        projected = self.regressor(student_feat)
        return F.mse_loss(projected, teacher_feat.detach())


class HookBasedDistillation:
    """
    Use forward hooks to extract intermediate layer features
    without modifying the model's forward() method.
    """
    def __init__(self, model, layer_names):
        self.features = {}
        self.hooks = []
        for name, layer in model.named_modules():
            if name in layer_names:
                hook = layer.register_forward_hook(
                    self._make_hook(name)
                )
                self.hooks.append(hook)

    def _make_hook(self, name):
        def hook(module, input, output):
            self.features[name] = output
        return hook

    def remove(self):
        for hook in self.hooks:
            hook.remove()


# Usage example
teacher = models.resnet50(weights='DEFAULT')
student = models.resnet18(weights=None)

# Teacher layer3: 1024 channels; Student layer3: 256 channels
teacher_hook = HookBasedDistillation(teacher, ['layer3'])
student_hook = HookBasedDistillation(student, ['layer3'])

feat_distill = FeatureDistillationLoss(
    teacher_channels=1024,
    student_channels=256
)

# During training
x = torch.randn(4, 3, 224, 224)
teacher_out = teacher(x)
student_out = student(x)

teacher_feat = teacher_hook.features['layer3']
student_feat = student_hook.features['layer3']

loss_feat = feat_distill(student_feat, teacher_feat)

2.3 Relation-based Distillation

RKD (Park et al., 2019). Rather than matching absolute values, the student learns to mimic the relational structure between samples in the teacher's embedding space.

class RelationalKnowledgeDistillation(nn.Module):
    """
    Relational KD: preserve pairwise distance and angle relationships
    between samples in the embedding space.
    """
    def __init__(self, distance_weight=25.0, angle_weight=50.0):
        super().__init__()
        self.dist_w = distance_weight
        self.angle_w = angle_weight

    def pdist(self, e, squared=False, eps=1e-12):
        """Pairwise distance matrix computation."""
        e_sq = (e ** 2).sum(dim=1)
        prod = e @ e.t()
        res = (e_sq.unsqueeze(1) + e_sq.unsqueeze(0) - 2 * prod).clamp(min=eps)
        if not squared:
            res = res.sqrt()
        return res

    def distance_loss(self, teacher_emb, student_emb):
        """Preserve pairwise distance relationships."""
        t_d = self.pdist(teacher_emb)
        t_d = t_d / (t_d.mean() + 1e-12)  # normalize
        s_d = self.pdist(student_emb)
        s_d = s_d / (s_d.mean() + 1e-12)
        return F.smooth_l1_loss(s_d, t_d.detach())

    def angle_loss(self, teacher_emb, student_emb):
        """Preserve angular relationships between triplets."""
        td = teacher_emb.unsqueeze(0) - teacher_emb.unsqueeze(1)  # (N, N, D)
        sd = student_emb.unsqueeze(0) - student_emb.unsqueeze(1)

        td_norm = F.normalize(td.view(-1, td.size(-1)), dim=-1)
        sd_norm = F.normalize(sd.view(-1, sd.size(-1)), dim=-1)

        t_angle = (td_norm * td_norm.flip(0)).sum(dim=-1)
        s_angle = (sd_norm * sd_norm.flip(0)).sum(dim=-1)

        return F.smooth_l1_loss(s_angle, t_angle.detach())

    def forward(self, teacher_emb, student_emb):
        loss = self.dist_w * self.distance_loss(teacher_emb, student_emb)
        loss += self.angle_w * self.angle_loss(teacher_emb, student_emb)
        return loss

2.4 Attention Transfer

Zagoruyko and Komodakis (2017). Transfer attention maps — spatial activation patterns — from teacher to student.

class AttentionTransfer(nn.Module):
    """
    Attention Transfer: transfer spatial attention patterns
    derived from feature map activations.
    """

    def attention_map(self, feat):
        """
        Compute spatial attention map from feature maps.
        Sums squared activations across channels and normalizes.
        """
        return F.normalize(feat.pow(2).mean(1).view(feat.size(0), -1))

    def forward(self, student_feats, teacher_feats):
        """
        student_feats, teacher_feats: lists of feature maps at multiple layers.
        Returns summed AT loss across all pairs.
        """
        loss = 0.0
        for s_feat, t_feat in zip(student_feats, teacher_feats):
            s_attn = self.attention_map(s_feat)
            t_attn = self.attention_map(t_feat)
            loss += (s_attn - t_attn.detach()).pow(2).mean()
        return loss

3. LLM Distillation

3.1 DistilBERT

DistilBERT (Sanh et al., 2019) compresses BERT-base (110M parameters) to 66M parameters.

Key techniques:

  • Half the layers: 12 → 6
  • Remove token-type embeddings
  • Remove the pooler
  • Triple loss: MLM + Distillation + Cosine embedding loss
from transformers import (
    DistilBertModel,
    DistilBertTokenizer
)
import torch
import torch.nn as nn
import torch.nn.functional as F


class DistilBERTLoss(nn.Module):
    """
    DistilBERT triple loss:
    1. MLM CE loss (language modeling)
    2. Soft-target KD loss (teacher logit matching)
    3. Cosine embedding loss (hidden state similarity)
    """
    def __init__(self, temperature=2.0, alpha=0.5, beta=0.1):
        super().__init__()
        self.T = temperature
        self.alpha = alpha  # MLM weight
        self.beta = beta    # Cosine loss weight

    def forward(self, student_logits, teacher_logits,
                student_hidden, teacher_hidden, mlm_labels):
        # 1. MLM loss
        loss_mlm = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            mlm_labels.view(-1),
            ignore_index=-100
        )

        # 2. Soft KD loss
        s_soft = F.log_softmax(student_logits / self.T, dim=-1)
        t_soft = F.softmax(teacher_logits / self.T, dim=-1)
        loss_kd = F.kl_div(s_soft, t_soft, reduction='batchmean') * (self.T ** 2)

        # 3. Cosine embedding loss (hidden state alignment)
        loss_cos = 1 - F.cosine_similarity(student_hidden, teacher_hidden, dim=-1).mean()

        return (self.alpha * loss_mlm
                + (1 - self.alpha) * loss_kd
                + self.beta * loss_cos)


# Inference example
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')

text = "Knowledge distillation is a model compression technique."
inputs = tokenizer(text, return_tensors='pt')

with torch.no_grad():
    outputs = model(**inputs)

print(outputs.last_hidden_state.shape)
# torch.Size([1, 11, 768])

DistilBERT performance:

  • 40% smaller than BERT-base
  • 60% faster inference
  • Retains 97% of BERT's GLUE performance

3.2 LLM Distillation Pattern

For large language models, the student learns from the teacher's next-token prediction distributions.

class LLMDistillationTrainer:
    """
    LLM distillation trainer.
    Transfers teacher's token distribution to student.
    """
    def __init__(self, teacher, student, temperature=2.0, alpha=0.5):
        self.teacher = teacher
        self.student = student
        self.T = temperature
        self.alpha = alpha

    def compute_loss(self, input_ids, attention_mask, labels):
        # Teacher inference (no_grad)
        with torch.no_grad():
            teacher_out = self.teacher(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            teacher_logits = teacher_out.logits  # (B, seq_len, vocab_size)

        # Student inference
        student_out = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        student_logits = student_out.logits

        # 1. CE loss (hard target)
        shift_logits = student_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss_ce = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100
        )

        # 2. KD loss (soft target)
        # Compute only on valid tokens
        mask = (shift_labels != -100).float()

        s_log_soft = F.log_softmax(
            shift_logits / self.T, dim=-1
        )
        t_soft = F.softmax(
            teacher_logits[..., :-1, :].contiguous() / self.T, dim=-1
        )

        # Per-token KL divergence
        kl_per_token = F.kl_div(
            s_log_soft, t_soft,
            reduction='none'
        ).sum(-1)  # (B, seq_len-1)

        loss_kd = (kl_per_token * mask).sum() / mask.sum()
        loss_kd = loss_kd * (self.T ** 2)

        return self.alpha * loss_ce + (1 - self.alpha) * loss_kd

3.3 Distil Whisper

OpenAI's Whisper speech recognition model has a distilled variant that demonstrates aggressive LLM compression.

from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration
)

# Distil-Whisper: distills Whisper-large-v2
# large-v2: 1550M params -> distil-large-v2: 756M params (2x faster, similar WER)
processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v2")
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v2")

# Only 2 decoder layers retained (original: 32)
# Encoder kept intact
print(f"Encoder layers: {len(model.model.encoder.layers)}")
print(f"Decoder layers: {len(model.model.decoder.layers)}")

4. Structured Pruning

Pruning removes unimportant parameters or structures from a model.

Structured pruning: Remove entire filters, heads, or layers → real speedup on all hardware Unstructured pruning: Set individual weights to zero → sparse matrices (needs specialized hardware)

4.1 Filter Pruning

Remove filters with small L1/L2 norms from CNN layers.

import torch
import torch.nn as nn
import numpy as np


def get_filter_importance(conv_layer, norm='l1'):
    """Compute the importance (norm) of each filter."""
    weight = conv_layer.weight.data  # (out_ch, in_ch, kH, kW)
    if norm == 'l1':
        return weight.abs().sum(dim=(1, 2, 3))
    elif norm == 'l2':
        return weight.pow(2).sum(dim=(1, 2, 3)).sqrt()
    elif norm == 'gm':
        return weight.view(weight.size(0), -1).norm(p=2, dim=1)


def prune_conv_layer(conv, keep_ratio=0.5):
    """
    Filter pruning: remove filters with lowest importance.
    Returns new Conv2d with reduced filter count.
    """
    importance = get_filter_importance(conv)
    n_keep = int(conv.out_channels * keep_ratio)
    _, indices = importance.topk(n_keep)
    indices, _ = indices.sort()

    new_conv = nn.Conv2d(
        conv.in_channels,
        n_keep,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        bias=conv.bias is not None
    )
    new_conv.weight.data = conv.weight.data[indices]
    if conv.bias is not None:
        new_conv.bias.data = conv.bias.data[indices]

    return new_conv, indices


class StructuredPruner:
    """Structured pruning for ResNet-style models."""

    def __init__(self, model, prune_ratio=0.5):
        self.model = model
        self.prune_ratio = prune_ratio

    def global_threshold_prune(self, sparsity=0.5):
        """
        Prune globally: remove the sparsity fraction of filters
        with lowest importance across all layers.
        """
        all_importances = []
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                imp = get_filter_importance(module)
                all_importances.append(imp)

        all_imp = torch.cat(all_importances)
        threshold = all_imp.kthvalue(int(len(all_imp) * sparsity)).values.item()

        masks = {}
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                imp = get_filter_importance(module)
                masks[name] = imp >= threshold

        return masks

4.2 Attention Head Pruning

Remove unimportant attention heads from transformer models.

class AttentionHeadPruner:
    """
    Multi-Head Attention head pruning.
    Michel et al. (2019): Are Sixteen Heads Really Better than One?
    """

    def compute_head_importance(self, model, dataloader, device='cuda'):
        """Compute per-layer, per-head importance scores."""
        head_importance = {}

        model.eval()
        for batch in dataloader:
            inputs = {k: v.to(device) for k, v in batch.items()}

            with torch.no_grad():
                outputs = model(**inputs, output_attentions=True)

            if hasattr(outputs, 'attentions') and outputs.attentions:
                for layer_idx, attn in enumerate(outputs.attentions):
                    # attn: (batch, heads, seq, seq)
                    # Importance: variance of attention weights (more focused = more important)
                    head_imp = attn.mean(0).var(-1).mean(-1)  # (heads,)
                    key = f"layer_{layer_idx}"
                    if key not in head_importance:
                        head_importance[key] = head_imp
                    else:
                        head_importance[key] += head_imp

        return head_importance

    def prune_heads(self, model, heads_to_prune):
        """
        Remove specified heads.
        heads_to_prune: dict mapping layer_idx to list of head indices.
        """
        model.prune_heads(heads_to_prune)  # Built-in HuggingFace method
        return model


# HuggingFace example
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

# Remove heads 0, 3, 5 from layer 0 and heads 1, 2 from layer 6
heads_to_prune = {0: [0, 3, 5], 6: [1, 2]}
model.prune_heads(heads_to_prune)

total_params = sum(p.numel() for p in model.parameters())
print(f"Pruned BERT params: {total_params:,}")

4.3 Layer Pruning

Remove entire transformer layers that contribute least to model performance.

class LayerPruner:
    """Transformer layer pruning."""

    def compute_layer_importance_by_gradient(
        self, model, dataloader, device='cuda'
    ):
        """
        Gradient-based layer importance.
        Importance = mean |gradient * weight| across the layer.
        """
        model.train()
        layer_importance = {}

        for batch in dataloader:
            inputs = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**inputs)
            loss = outputs.loss
            loss.backward()

            for i, layer in enumerate(model.encoder.layer):
                grad_sum = 0.0
                param_count = 0
                for param in layer.parameters():
                    if param.grad is not None:
                        grad_sum += (param.grad * param.data).abs().sum().item()
                        param_count += param.numel()

                key = f"layer_{i}"
                imp = grad_sum / max(param_count, 1)
                if key not in layer_importance:
                    layer_importance[key] = 0.0
                layer_importance[key] += imp

            model.zero_grad()

        return layer_importance

    def drop_layers(self, model, layers_to_drop):
        """Remove specified layers and rebuild the encoder."""
        import copy
        new_model = copy.deepcopy(model)

        remaining = [
            layer for i, layer in enumerate(new_model.encoder.layer)
            if i not in layers_to_drop
        ]
        new_model.encoder.layer = nn.ModuleList(remaining)
        new_model.config.num_hidden_layers = len(remaining)

        return new_model

4.4 PyTorch torch.nn.utils.prune

import torch.nn.utils.prune as prune

model = models.resnet18(weights='DEFAULT')

# Apply unstructured L1 pruning to a specific layer
conv = model.layer1[0].conv1
prune.l1_unstructured(conv, name='weight', amount=0.3)
# Sets 30% of weights to zero using a mask

print(f"Sparsity: {(conv.weight == 0).float().mean():.2f}")

# Structured pruning (filter-level)
prune.ln_structured(conv, name='weight', amount=0.3, n=2, dim=0)

# Make pruning permanent (remove mask, bake zeros into weights)
prune.remove(conv, 'weight')

# Apply to all Conv2d layers globally
parameters_to_prune = []
for module in model.modules():
    if isinstance(module, nn.Conv2d):
        parameters_to_prune.append((module, 'weight'))

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.4,  # Prune 40% of all conv weights
)

# Check global sparsity
total_zero = sum(
    (m.weight == 0).sum().item()
    for m in model.modules() if isinstance(m, nn.Conv2d)
)
total_params = sum(
    m.weight.numel()
    for m in model.modules() if isinstance(m, nn.Conv2d)
)
print(f"Global sparsity: {total_zero/total_params:.2%}")

5. Unstructured Pruning

5.1 Magnitude-based Pruning

The simplest method: set weights with the smallest absolute values to zero.

class MagnitudePruner:
    """Magnitude-based unstructured pruning."""

    def __init__(self, model, sparsity=0.5):
        self.model = model
        self.sparsity = sparsity
        self.masks = {}

    def compute_global_threshold(self):
        """Compute the global threshold across all weight tensors."""
        all_weights = []
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                all_weights.append(param.data.abs().view(-1))

        all_weights = torch.cat(all_weights)
        threshold = all_weights.kthvalue(
            int(len(all_weights) * self.sparsity)
        ).values.item()
        return threshold

    def apply_pruning(self):
        """Create pruning masks using the global threshold."""
        threshold = self.compute_global_threshold()

        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                mask = (param.data.abs() >= threshold).float()
                self.masks[name] = mask
                param.data *= mask

        total_zeros = sum((m == 0).sum().item() for m in self.masks.values())
        total_params = sum(m.numel() for m in self.masks.values())
        print(f"Actual sparsity: {total_zeros/total_params:.2%}")

    def apply_masks(self):
        """Re-apply masks after each gradient step to prevent dead weights from reviving."""
        for name, param in self.model.named_parameters():
            if name in self.masks:
                param.data *= self.masks[name]

5.2 Gradual Magnitude Pruning

Pruning too aggressively from the start causes a sharp accuracy drop. A gradual schedule that slowly increases sparsity during training works much better.

class GradualMagnitudePruner:
    """
    Gradually increase sparsity during training.
    Zhu and Gupta (2017): To Prune, or Not to Prune.
    """
    def __init__(self, model, initial_sparsity=0.0, final_sparsity=0.8,
                 begin_step=0, end_step=1000, frequency=100):
        self.model = model
        self.initial_sparsity = initial_sparsity
        self.final_sparsity = final_sparsity
        self.begin_step = begin_step
        self.end_step = end_step
        self.frequency = frequency
        self.masks = {}
        self._init_masks()

    def _init_masks(self):
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                self.masks[name] = torch.ones_like(param.data)

    def compute_sparsity(self, step):
        """Compute target sparsity at current step using cubic schedule."""
        if step < self.begin_step:
            return self.initial_sparsity
        if step > self.end_step:
            return self.final_sparsity

        pct_done = (step - self.begin_step) / (self.end_step - self.begin_step)
        sparsity = (
            self.final_sparsity
            + (self.initial_sparsity - self.final_sparsity)
            * (1 - pct_done) ** 3
        )
        return sparsity

    def step(self, global_step):
        """Update pruning masks at appropriate training steps."""
        if (global_step % self.frequency != 0 or
                global_step < self.begin_step or
                global_step > self.end_step):
            return

        target_sparsity = self.compute_sparsity(global_step)
        self._update_masks(target_sparsity)

    def _update_masks(self, sparsity):
        """Update masks to match target sparsity."""
        for name, param in self.model.named_parameters():
            if name not in self.masks:
                continue

            n_prune = int(sparsity * param.data.numel())
            if n_prune == 0:
                continue

            threshold = param.data.abs().view(-1).kthvalue(n_prune).values.item()
            self.masks[name] = (param.data.abs() > threshold).float()
            param.data *= self.masks[name]

6. Weight Sharing

6.1 Cross-Layer Parameter Sharing — ALBERT

ALBERT (Lan et al., 2019) drastically reduces BERT's parameters by reusing the same parameters across all transformer layers.

class ALBERTEncoder(nn.Module):
    """
    ALBERT-style weight sharing.
    A single transformer layer is executed N times.
    """
    def __init__(self, hidden_size=768, num_heads=12,
                 intermediate_size=3072, num_layers=12):
        super().__init__()
        # Only ONE transformer layer is defined
        self.shared_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=num_heads,
            dim_feedforward=intermediate_size,
            batch_first=True,
            norm_first=True,
        )
        self.num_layers = num_layers
        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, x, src_key_padding_mask=None):
        # Execute the same layer num_layers times
        for _ in range(self.num_layers):
            x = self.shared_layer(x, src_key_padding_mask=src_key_padding_mask)
        return self.norm(x)


# Parameter comparison
bert_encoder = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(768, 12, 3072, batch_first=True, norm_first=True),
    num_layers=12
)
albert_encoder = ALBERTEncoder(num_layers=12)

bert_params = sum(p.numel() for p in bert_encoder.parameters())
albert_params = sum(p.numel() for p in albert_encoder.parameters())

print(f"BERT encoder: {bert_params:,}")      # ~85M
print(f"ALBERT encoder: {albert_params:,}")  # ~7M
print(f"Compression: {bert_params/albert_params:.1f}x")

6.2 Factorized Embeddings (ALBERT)

ALBERT also factorizes the embedding matrix: vocab_size x hidden_size → (vocab_size x embedding_size) + (embedding_size x hidden_size), where embedding_size is much smaller than hidden_size.

class FactorizedEmbedding(nn.Module):
    """
    ALBERT-style factorized embedding.
    vocab_size x H factored into (vocab_size x E) + (E x H)
    where E << H.
    """
    def __init__(self, vocab_size, embedding_size=128, hidden_size=768):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, embedding_size)
        self.embedding_projection = nn.Linear(embedding_size, hidden_size, bias=False)

    def forward(self, input_ids):
        embed = self.word_embeddings(input_ids)        # (B, seq, E)
        return self.embedding_projection(embed)         # (B, seq, H)

# Parameter comparison (vocab_size=30000)
standard_embed = nn.Embedding(30000, 768)                  # 23.0M params
factorized_embed = FactorizedEmbedding(30000, 128, 768)    # 3.97M params

s_params = sum(p.numel() for p in standard_embed.parameters())
f_params = sum(p.numel() for p in factorized_embed.parameters())
print(f"Standard: {s_params:,} -> Factorized: {f_params:,}")
print(f"Savings: {(s_params - f_params)/s_params:.1%}")

7. Neural Architecture Search (NAS)

7.1 Manual Design vs AutoML

Traditional CNNs were designed by researchers by hand (VGG, ResNet). NAS automates this process.

Three core components of NAS:

  1. Search Space: What operations and structures to explore
  2. Search Strategy: Reinforcement learning, evolutionary algorithms, gradient-based methods
  3. Performance Estimation: Quickly evaluate candidate architectures

Liu et al. (2019). Transforms architecture search into a continuous optimization problem.

import torch
import torch.nn as nn
import torch.nn.functional as F


# Candidate operations
PRIMITIVES = [
    'none',           # Zero connection
    'skip_connect',   # Identity
    'sep_conv_3x3',   # 3x3 Separable Conv
    'sep_conv_5x5',   # 5x5 Separable Conv
    'dil_conv_3x3',   # 3x3 Dilated Conv
    'dil_conv_5x5',   # 5x5 Dilated Conv
    'avg_pool_3x3',   # 3x3 Average Pool
    'max_pool_3x3',   # 3x3 Max Pool
]


class MixedOperation(nn.Module):
    """
    DARTS: represent each edge as a weighted mixture of operations.
    Architecture parameters (alpha) control each operation's weight.
    """
    def __init__(self, C, stride, primitives):
        super().__init__()
        self.ops = nn.ModuleList([
            self._build_op(primitive, C, stride)
            for primitive in primitives
        ])
        # Learnable architecture parameters
        self.arch_params = nn.Parameter(
            torch.randn(len(primitives)) * 1e-3
        )

    def _build_op(self, primitive, C, stride):
        if primitive == 'none':
            return nn.Sequential()
        elif primitive == 'skip_connect':
            return nn.Identity() if stride == 1 else nn.Sequential(
                nn.AvgPool2d(stride, stride, padding=0),
                nn.Conv2d(C, C, 1, bias=False),
                nn.BatchNorm2d(C)
            )
        elif primitive == 'sep_conv_3x3':
            return nn.Sequential(
                nn.Conv2d(C, C, 3, stride=stride, padding=1, groups=C, bias=False),
                nn.Conv2d(C, C, 1, bias=False),
                nn.BatchNorm2d(C),
                nn.ReLU(inplace=True),
            )
        elif primitive == 'avg_pool_3x3':
            return nn.AvgPool2d(3, stride=stride, padding=1)
        elif primitive == 'max_pool_3x3':
            return nn.MaxPool2d(3, stride=stride, padding=1)
        else:
            return nn.Identity()

    def forward(self, x):
        weights = F.softmax(self.arch_params, dim=0)
        results = []
        for w, op in zip(weights, self.ops):
            try:
                result = op(x)
                results.append(w * result)
            except Exception:
                pass

        if results:
            output = results[0]
            for r in results[1:]:
                if r.shape == output.shape:
                    output = output + r
            return output
        return x


def darts_discrete(cell):
    """
    Convert continuous architecture to discrete by
    selecting the highest-weight operation at each edge.
    """
    for op in cell.ops:
        weights = F.softmax(op.arch_params, dim=0)
        best_op_idx = weights.argmax().item()
        print(f"  Selected: {PRIMITIVES[best_op_idx]} "
              f"(weight: {weights[best_op_idx]:.3f})")

7.3 EfficientNet's NAS Process

EfficientNet-B0 is the base architecture found by an MnasNet-style NAS. The search optimized for accuracy under a mobile latency constraint. The result is a network built from MBConv blocks with squeeze-and-excitation.

class MBConvBlock(nn.Module):
    """
    Mobile Inverted Bottleneck Convolution.
    EfficientNet's core building block, found by NAS.
    """
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, expand_ratio=6, se_ratio=0.25):
        super().__init__()
        self.use_residual = (stride == 1 and in_channels == out_channels)
        hidden_dim = int(in_channels * expand_ratio)

        layers = []
        # Expansion phase
        if expand_ratio != 1:
            layers.extend([
                nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
            ])

        # Depthwise conv
        layers.extend([
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size,
                      stride=stride,
                      padding=kernel_size // 2,
                      groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
        ])

        # Squeeze-and-Excitation (key module discovered by NAS)
        se_channels = max(1, int(in_channels * se_ratio))
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(hidden_dim, se_channels, 1),
            nn.SiLU(),
            nn.Conv2d(se_channels, hidden_dim, 1),
            nn.Sigmoid(),
        )

        # Output projection
        layers.extend([
            nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
        ])

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        out = self.layers(x)
        if self.use_residual:
            return x + out
        return out

7.4 Once-for-All Network

Cai et al. (2020). Train one large network and then extract sub-networks tailored to different resource constraints — without retraining.

class OFALayer(nn.Module):
    """
    Once-for-All elastic layer that supports multiple kernel sizes.
    Training: randomly sample kernel size -> train sub-networks
    Deployment: use only the required kernel size
    """
    def __init__(self, channels, max_kernel=7):
        super().__init__()
        # Only one conv is stored — the largest kernel size
        self.max_conv = nn.Conv2d(
            channels, channels,
            kernel_size=max_kernel,
            padding=max_kernel // 2,
            groups=channels, bias=False
        )
        self.bn = nn.BatchNorm2d(channels)
        self.act = nn.ReLU(inplace=True)
        self.active_kernel = max_kernel

    def set_active_kernel(self, kernel_size):
        """Set the active kernel size for this layer."""
        assert kernel_size <= self.max_conv.kernel_size[0]
        self.active_kernel = kernel_size

    def forward(self, x):
        if self.active_kernel == self.max_conv.kernel_size[0]:
            weight = self.max_conv.weight
        else:
            # Extract the center sub-kernel
            center = self.max_conv.kernel_size[0] // 2
            half = self.active_kernel // 2
            weight = self.max_conv.weight[
                :, :,
                center - half: center + half + 1,
                center - half: center + half + 1
            ]

        padding = self.active_kernel // 2
        out = F.conv2d(x, weight, padding=padding, groups=x.size(1))
        return self.act(self.bn(out))


def progressive_shrinking_train(model, dataloader, kernels=[7, 5, 3]):
    """
    Stage 1: train with max kernel (7)
    Stage 2: randomly sample from [7, 5]
    Stage 3: randomly sample from [7, 5, 3]
    """
    for stage, active_kernels in enumerate(
        [kernels[:1], kernels[:2], kernels]
    ):
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        print(f"Stage {stage+1}: kernels = {active_kernels}")

        for epoch in range(5):
            for images, labels in dataloader:
                # Randomly sample kernel size each batch
                k = active_kernels[torch.randint(len(active_kernels), (1,)).item()]

                for module in model.modules():
                    if isinstance(module, OFALayer):
                        module.set_active_kernel(k)

                outputs = model(images)
                loss = F.cross_entropy(outputs, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

8. Integrated Model Compression Pipeline

In practice, multiple techniques are combined sequentially:

class ModelCompressionPipeline:
    """
    Integrated pipeline: Distillation -> Pruning -> Quantization.
    """

    def __init__(self, teacher, student, num_classes):
        self.teacher = teacher
        self.num_classes = num_classes
        self.student = student

    def step1_distillation(self, train_loader,
                           epochs=30, device='cuda'):
        """Step 1: Knowledge distillation."""
        print("=== Step 1: Knowledge Distillation ===")
        self.student = train_with_distillation(
            self.teacher, self.student, train_loader,
            num_epochs=epochs, temperature=4.0, alpha=0.5,
            device=device
        )

    def step2_pruning(self, train_loader, sparsity=0.5,
                      finetune_epochs=10, device='cuda'):
        """Step 2: Gradual pruning + fine-tuning."""
        print(f"=== Step 2: Pruning (target: {sparsity:.0%}) ===")
        pruner = GradualMagnitudePruner(
            self.student,
            initial_sparsity=0.0,
            final_sparsity=sparsity,
            begin_step=0,
            end_step=len(train_loader) * finetune_epochs,
            frequency=100
        )

        optimizer = torch.optim.Adam(self.student.parameters(), lr=1e-4)
        self.student.to(device)
        step = 0

        for epoch in range(finetune_epochs):
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = self.student(images)
                loss = F.cross_entropy(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                pruner.step(step)
                pruner.apply_masks()
                step += 1

    def step3_quantization(self, calibration_loader, device='cpu'):
        """Step 3: Post-Training Static Quantization."""
        print("=== Step 3: Quantization ===")
        self.student.eval().to(device)

        self.student.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
        torch.ao.quantization.prepare(self.student, inplace=True)

        with torch.no_grad():
            for images, _ in calibration_loader:
                self.student(images.to(device))

        torch.ao.quantization.convert(self.student, inplace=True)
        print("Quantization complete (INT8)")
        return self.student

    def compare_models(self, val_loader, device='cuda'):
        """Compare Teacher vs compressed Student."""
        def eval_model(model, loader, dev):
            model.eval().to(dev)
            correct, total = 0, 0
            with torch.no_grad():
                for images, labels in loader:
                    images, labels = images.to(dev), labels.to(dev)
                    preds = model(images).argmax(1)
                    correct += (preds == labels).sum().item()
                    total += labels.size(0)
            return correct / total

        teacher_acc = eval_model(self.teacher, val_loader, device)
        student_acc = eval_model(self.student, val_loader, 'cpu')

        t_params = sum(p.numel() for p in self.teacher.parameters())
        s_params = sum(p.numel() for p in self.student.parameters())

        print(f"\n{'='*55}")
        print(f"Teacher  | Params: {t_params:>10,} | Acc: {teacher_acc:.4f}")
        print(f"Student  | Params: {s_params:>10,} | Acc: {student_acc:.4f}")
        print(f"Compression: {t_params/s_params:.1f}x | "
              f"Retained accuracy: {student_acc/teacher_acc:.1%}")

Summary

Knowledge distillation and model compression are the essential bridge between research-scale models and real-world deployment.

Key takeaways:

  1. Knowledge Distillation: Transfer inter-class relationship information via teacher's soft targets (probability distributions)
  2. Feature Distillation: Transfer intermediate representations, not just final outputs
  3. Relation Distillation: Preserve the structural relationships between samples
  4. LLM Distillation: DistilBERT, TinyLLaMA, and Distil Whisper demonstrate practical large-model compression
  5. Structured Pruning: Remove filters, heads, and layers for real hardware speedup
  6. Unstructured Pruning: Gradual sparsity schedules minimize accuracy loss
  7. Weight Sharing: ALBERT-style parameter reuse achieves up to 12x compression
  8. NAS: DARTS and Once-for-All automate architecture design for target hardware

The best practical approach is a sequential pipeline: distillation -> pruning -> quantization. Each step builds on the previous, allowing each compression technique to work on an already-compressed model.


References

  • Hinton et al. (2015). Distilling the Knowledge in a Neural Network. https://arxiv.org/abs/1503.02531
  • Romero et al. (2015). FitNets: Hints for Thin Deep Nets.
  • Zagoruyko and Komodakis (2017). Paying More Attention to Attention.
  • Park et al. (2019). Relational Knowledge Distillation.
  • Sanh et al. (2019). DistilBERT, a distilled version of BERT. https://arxiv.org/abs/1910.01108
  • Lan et al. (2019). ALBERT: A Lite BERT for Self-supervised Learning of Language Representations.
  • Liu et al. (2019). DARTS: Differentiable Architecture Search.
  • Cai et al. (2020). Once-for-All: Train One Network and Specialize it for Efficient Deployment.
  • Zhu and Gupta (2017). To Prune, or Not to Prune: Exploring the Efficacy of Pruning for Model Compression.
  • Tan and Le (2019). EfficientNet. https://arxiv.org/abs/1905.11946
  • PyTorch Pruning Documentation: https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.html