- Published on
Knowledge Distillation Complete Guide: Model Compression and Lightweight Techniques
- Authors

- Name
- Youngju Kim
- @fjvbn20031
Introduction
As deep learning models grow more powerful, they also grow larger. Models like GPT-4 or Llama 3 405B require hundreds of gigabytes of memory, making them impossible to run on mobile devices or edge hardware. Knowledge Distillation and model compression techniques are the essential tools for making these models smaller and faster while preserving as much of their accuracy as possible.
This guide covers:
- Theoretical foundations of knowledge distillation with complete PyTorch implementations
- Multiple distillation paradigms: Response-based, Feature-based, Relation-based
- LLM distillation case studies (DistilBERT, TinyLLM, Distil Whisper)
- Structured and unstructured pruning
- Weight sharing and Neural Architecture Search (NAS)
1. Knowledge Distillation Fundamentals
1.1 The Teacher-Student Framework
Knowledge distillation was introduced by Hinton, Vinyals, and Dean in 2015. The core idea: transfer the "knowledge" of a large Teacher model into a small Student model.
Teacher Model (large, accurate)
│
│ Soft targets (probability distributions)
▼
Student Model (small, fast)
Rather than training with hard labels (one-hot vectors), the student learns from the Teacher's soft targets — the full softmax output distribution.
For example, classifying a cat image:
- Hard target: [0, 0, 1, 0, 0] (only the correct class is 1)
- Teacher soft target: [0.01, 0.05, 0.85, 0.07, 0.02]
The soft target encodes the information that "this image is a cat, but it has some resemblance to a tiger." This inter-class similarity information provides rich supervision for the student.
1.2 The Temperature Parameter
When softmax outputs are too confident ([0.001, 0.002, 0.997, ...]), they carry almost no more information than hard labels. The temperature parameter T makes distributions softer:
softmax_T(z_i) = exp(z_i / T) / sum_j(exp(z_j / T))
- When T is 1: standard softmax
- When T is greater than 1: flatter, more uniform distribution (more information)
- When T is less than 1: sharper, more peaked distribution
import torch
import torch.nn.functional as F
def temperature_softmax(logits, temperature=1.0):
"""Temperature-scaled softmax."""
return F.softmax(logits / temperature, dim=-1)
# Example
logits = torch.tensor([2.0, 1.0, 0.1, 0.5])
print("T=1:", temperature_softmax(logits, temperature=1).numpy().round(3))
# [0.596, 0.219, 0.090, 0.096]
print("T=4:", temperature_softmax(logits, temperature=4).numpy().round(3))
# [0.345, 0.262, 0.195, 0.216] — more uniform, more information
1.3 Hinton's KD Loss Function
The KD loss is a weighted sum of two terms:
KL Divergence term (soft target matching):
L_KD = T^2 * KLDiv(softmax_T(student_logits), softmax_T(teacher_logits))
The T^2 scaling keeps gradient magnitudes independent of the temperature value.
Cross-Entropy term (hard target learning):
L_CE = CrossEntropy(student_logits, true_labels)
Total loss:
L = alpha * L_CE + (1 - alpha) * L_KD
1.4 Complete PyTorch Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
class KnowledgeDistillationLoss(nn.Module):
"""
Hinton et al. (2015) Knowledge Distillation Loss.
L = alpha * CE(student, labels) + (1-alpha) * T^2 * KLDiv(student_soft, teacher_soft)
"""
def __init__(self, temperature=4.0, alpha=0.5):
super().__init__()
self.T = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
self.kl_loss = nn.KLDivLoss(reduction='batchmean')
def forward(self, student_logits, teacher_logits, labels):
# Hard target loss
loss_ce = self.ce_loss(student_logits, labels)
# Soft target loss (KL Divergence)
student_soft = F.log_softmax(student_logits / self.T, dim=-1)
teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
loss_kd = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)
return self.alpha * loss_ce + (1 - self.alpha) * loss_kd
def train_with_distillation(
teacher, student, train_loader,
num_epochs=10, temperature=4.0, alpha=0.5,
device='cuda'
):
"""Teacher-Student distillation training loop."""
teacher = teacher.to(device).eval() # Teacher is frozen
student = student.to(device)
# Freeze teacher parameters
for param in teacher.parameters():
param.requires_grad = False
criterion = KnowledgeDistillationLoss(temperature=temperature, alpha=alpha)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
for epoch in range(num_epochs):
student.train()
total_loss = 0.0
correct = 0
total = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
# Teacher inference (no gradient needed)
with torch.no_grad():
teacher_logits = teacher(images)
# Student inference
student_logits = student(images)
# KD loss computation
loss = criterion(student_logits, teacher_logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * images.size(0)
correct += (student_logits.argmax(1) == labels).sum().item()
total += images.size(0)
scheduler.step()
print(f"Epoch {epoch+1}/{num_epochs} | "
f"Loss: {total_loss/total:.4f} | "
f"Acc: {correct/total:.4f}")
return student
# Practical example
# Teacher: ResNet50 (25M params), Student: ResNet18 (11M params)
teacher = models.resnet50(weights='DEFAULT')
teacher.fc = nn.Linear(2048, 10)
student = models.resnet18(weights=None)
student.fc = nn.Linear(512, 10)
# Parameter count comparison
t_params = sum(p.numel() for p in teacher.parameters())
s_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {t_params:,} params") # ~25.6M
print(f"Student: {s_params:,} params") # ~11.2M
print(f"Compression ratio: {t_params/s_params:.1f}x")
2. Distillation Paradigms
2.1 Response-based Distillation (Logit Matching)
The most basic approach — the student mimics the Teacher's final outputs (logits). Hinton's original KD is the canonical example.
class ResponseBasedDistillation(nn.Module):
"""Response-based distillation using only final outputs."""
def __init__(self, temperature=4.0):
super().__init__()
self.T = temperature
def forward(self, student_logits, teacher_logits):
student_soft = F.log_softmax(student_logits / self.T, dim=-1)
teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.T ** 2)
2.2 Feature-based Distillation (Intermediate Layers)
Proposed in FitNets (Romero et al., 2015). The student learns to match the Teacher's intermediate feature maps, not just the final outputs.
Since the teacher and student may have different channel counts, a Regressor network projects the student features to match the teacher's dimensions.
class FeatureDistillationLoss(nn.Module):
"""Distillation via intermediate layer feature matching."""
def __init__(self, teacher_channels, student_channels):
super().__init__()
# Project student features to teacher feature space
self.regressor = nn.Sequential(
nn.Conv2d(student_channels, teacher_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(teacher_channels),
nn.ReLU(inplace=True),
)
def forward(self, student_feat, teacher_feat):
projected = self.regressor(student_feat)
return F.mse_loss(projected, teacher_feat.detach())
class HookBasedDistillation:
"""
Use forward hooks to extract intermediate layer features
without modifying the model's forward() method.
"""
def __init__(self, model, layer_names):
self.features = {}
self.hooks = []
for name, layer in model.named_modules():
if name in layer_names:
hook = layer.register_forward_hook(
self._make_hook(name)
)
self.hooks.append(hook)
def _make_hook(self, name):
def hook(module, input, output):
self.features[name] = output
return hook
def remove(self):
for hook in self.hooks:
hook.remove()
# Usage example
teacher = models.resnet50(weights='DEFAULT')
student = models.resnet18(weights=None)
# Teacher layer3: 1024 channels; Student layer3: 256 channels
teacher_hook = HookBasedDistillation(teacher, ['layer3'])
student_hook = HookBasedDistillation(student, ['layer3'])
feat_distill = FeatureDistillationLoss(
teacher_channels=1024,
student_channels=256
)
# During training
x = torch.randn(4, 3, 224, 224)
teacher_out = teacher(x)
student_out = student(x)
teacher_feat = teacher_hook.features['layer3']
student_feat = student_hook.features['layer3']
loss_feat = feat_distill(student_feat, teacher_feat)
2.3 Relation-based Distillation
RKD (Park et al., 2019). Rather than matching absolute values, the student learns to mimic the relational structure between samples in the teacher's embedding space.
class RelationalKnowledgeDistillation(nn.Module):
"""
Relational KD: preserve pairwise distance and angle relationships
between samples in the embedding space.
"""
def __init__(self, distance_weight=25.0, angle_weight=50.0):
super().__init__()
self.dist_w = distance_weight
self.angle_w = angle_weight
def pdist(self, e, squared=False, eps=1e-12):
"""Pairwise distance matrix computation."""
e_sq = (e ** 2).sum(dim=1)
prod = e @ e.t()
res = (e_sq.unsqueeze(1) + e_sq.unsqueeze(0) - 2 * prod).clamp(min=eps)
if not squared:
res = res.sqrt()
return res
def distance_loss(self, teacher_emb, student_emb):
"""Preserve pairwise distance relationships."""
t_d = self.pdist(teacher_emb)
t_d = t_d / (t_d.mean() + 1e-12) # normalize
s_d = self.pdist(student_emb)
s_d = s_d / (s_d.mean() + 1e-12)
return F.smooth_l1_loss(s_d, t_d.detach())
def angle_loss(self, teacher_emb, student_emb):
"""Preserve angular relationships between triplets."""
td = teacher_emb.unsqueeze(0) - teacher_emb.unsqueeze(1) # (N, N, D)
sd = student_emb.unsqueeze(0) - student_emb.unsqueeze(1)
td_norm = F.normalize(td.view(-1, td.size(-1)), dim=-1)
sd_norm = F.normalize(sd.view(-1, sd.size(-1)), dim=-1)
t_angle = (td_norm * td_norm.flip(0)).sum(dim=-1)
s_angle = (sd_norm * sd_norm.flip(0)).sum(dim=-1)
return F.smooth_l1_loss(s_angle, t_angle.detach())
def forward(self, teacher_emb, student_emb):
loss = self.dist_w * self.distance_loss(teacher_emb, student_emb)
loss += self.angle_w * self.angle_loss(teacher_emb, student_emb)
return loss
2.4 Attention Transfer
Zagoruyko and Komodakis (2017). Transfer attention maps — spatial activation patterns — from teacher to student.
class AttentionTransfer(nn.Module):
"""
Attention Transfer: transfer spatial attention patterns
derived from feature map activations.
"""
def attention_map(self, feat):
"""
Compute spatial attention map from feature maps.
Sums squared activations across channels and normalizes.
"""
return F.normalize(feat.pow(2).mean(1).view(feat.size(0), -1))
def forward(self, student_feats, teacher_feats):
"""
student_feats, teacher_feats: lists of feature maps at multiple layers.
Returns summed AT loss across all pairs.
"""
loss = 0.0
for s_feat, t_feat in zip(student_feats, teacher_feats):
s_attn = self.attention_map(s_feat)
t_attn = self.attention_map(t_feat)
loss += (s_attn - t_attn.detach()).pow(2).mean()
return loss
3. LLM Distillation
3.1 DistilBERT
DistilBERT (Sanh et al., 2019) compresses BERT-base (110M parameters) to 66M parameters.
Key techniques:
- Half the layers: 12 → 6
- Remove token-type embeddings
- Remove the pooler
- Triple loss: MLM + Distillation + Cosine embedding loss
from transformers import (
DistilBertModel,
DistilBertTokenizer
)
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistilBERTLoss(nn.Module):
"""
DistilBERT triple loss:
1. MLM CE loss (language modeling)
2. Soft-target KD loss (teacher logit matching)
3. Cosine embedding loss (hidden state similarity)
"""
def __init__(self, temperature=2.0, alpha=0.5, beta=0.1):
super().__init__()
self.T = temperature
self.alpha = alpha # MLM weight
self.beta = beta # Cosine loss weight
def forward(self, student_logits, teacher_logits,
student_hidden, teacher_hidden, mlm_labels):
# 1. MLM loss
loss_mlm = F.cross_entropy(
student_logits.view(-1, student_logits.size(-1)),
mlm_labels.view(-1),
ignore_index=-100
)
# 2. Soft KD loss
s_soft = F.log_softmax(student_logits / self.T, dim=-1)
t_soft = F.softmax(teacher_logits / self.T, dim=-1)
loss_kd = F.kl_div(s_soft, t_soft, reduction='batchmean') * (self.T ** 2)
# 3. Cosine embedding loss (hidden state alignment)
loss_cos = 1 - F.cosine_similarity(student_hidden, teacher_hidden, dim=-1).mean()
return (self.alpha * loss_mlm
+ (1 - self.alpha) * loss_kd
+ self.beta * loss_cos)
# Inference example
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
text = "Knowledge distillation is a model compression technique."
inputs = tokenizer(text, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
print(outputs.last_hidden_state.shape)
# torch.Size([1, 11, 768])
DistilBERT performance:
- 40% smaller than BERT-base
- 60% faster inference
- Retains 97% of BERT's GLUE performance
3.2 LLM Distillation Pattern
For large language models, the student learns from the teacher's next-token prediction distributions.
class LLMDistillationTrainer:
"""
LLM distillation trainer.
Transfers teacher's token distribution to student.
"""
def __init__(self, teacher, student, temperature=2.0, alpha=0.5):
self.teacher = teacher
self.student = student
self.T = temperature
self.alpha = alpha
def compute_loss(self, input_ids, attention_mask, labels):
# Teacher inference (no_grad)
with torch.no_grad():
teacher_out = self.teacher(
input_ids=input_ids,
attention_mask=attention_mask,
)
teacher_logits = teacher_out.logits # (B, seq_len, vocab_size)
# Student inference
student_out = self.student(
input_ids=input_ids,
attention_mask=attention_mask,
)
student_logits = student_out.logits
# 1. CE loss (hard target)
shift_logits = student_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_ce = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100
)
# 2. KD loss (soft target)
# Compute only on valid tokens
mask = (shift_labels != -100).float()
s_log_soft = F.log_softmax(
shift_logits / self.T, dim=-1
)
t_soft = F.softmax(
teacher_logits[..., :-1, :].contiguous() / self.T, dim=-1
)
# Per-token KL divergence
kl_per_token = F.kl_div(
s_log_soft, t_soft,
reduction='none'
).sum(-1) # (B, seq_len-1)
loss_kd = (kl_per_token * mask).sum() / mask.sum()
loss_kd = loss_kd * (self.T ** 2)
return self.alpha * loss_ce + (1 - self.alpha) * loss_kd
3.3 Distil Whisper
OpenAI's Whisper speech recognition model has a distilled variant that demonstrates aggressive LLM compression.
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration
)
# Distil-Whisper: distills Whisper-large-v2
# large-v2: 1550M params -> distil-large-v2: 756M params (2x faster, similar WER)
processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v2")
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v2")
# Only 2 decoder layers retained (original: 32)
# Encoder kept intact
print(f"Encoder layers: {len(model.model.encoder.layers)}")
print(f"Decoder layers: {len(model.model.decoder.layers)}")
4. Structured Pruning
Pruning removes unimportant parameters or structures from a model.
Structured pruning: Remove entire filters, heads, or layers → real speedup on all hardware Unstructured pruning: Set individual weights to zero → sparse matrices (needs specialized hardware)
4.1 Filter Pruning
Remove filters with small L1/L2 norms from CNN layers.
import torch
import torch.nn as nn
import numpy as np
def get_filter_importance(conv_layer, norm='l1'):
"""Compute the importance (norm) of each filter."""
weight = conv_layer.weight.data # (out_ch, in_ch, kH, kW)
if norm == 'l1':
return weight.abs().sum(dim=(1, 2, 3))
elif norm == 'l2':
return weight.pow(2).sum(dim=(1, 2, 3)).sqrt()
elif norm == 'gm':
return weight.view(weight.size(0), -1).norm(p=2, dim=1)
def prune_conv_layer(conv, keep_ratio=0.5):
"""
Filter pruning: remove filters with lowest importance.
Returns new Conv2d with reduced filter count.
"""
importance = get_filter_importance(conv)
n_keep = int(conv.out_channels * keep_ratio)
_, indices = importance.topk(n_keep)
indices, _ = indices.sort()
new_conv = nn.Conv2d(
conv.in_channels,
n_keep,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=conv.bias is not None
)
new_conv.weight.data = conv.weight.data[indices]
if conv.bias is not None:
new_conv.bias.data = conv.bias.data[indices]
return new_conv, indices
class StructuredPruner:
"""Structured pruning for ResNet-style models."""
def __init__(self, model, prune_ratio=0.5):
self.model = model
self.prune_ratio = prune_ratio
def global_threshold_prune(self, sparsity=0.5):
"""
Prune globally: remove the sparsity fraction of filters
with lowest importance across all layers.
"""
all_importances = []
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d):
imp = get_filter_importance(module)
all_importances.append(imp)
all_imp = torch.cat(all_importances)
threshold = all_imp.kthvalue(int(len(all_imp) * sparsity)).values.item()
masks = {}
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d):
imp = get_filter_importance(module)
masks[name] = imp >= threshold
return masks
4.2 Attention Head Pruning
Remove unimportant attention heads from transformer models.
class AttentionHeadPruner:
"""
Multi-Head Attention head pruning.
Michel et al. (2019): Are Sixteen Heads Really Better than One?
"""
def compute_head_importance(self, model, dataloader, device='cuda'):
"""Compute per-layer, per-head importance scores."""
head_importance = {}
model.eval()
for batch in dataloader:
inputs = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
if hasattr(outputs, 'attentions') and outputs.attentions:
for layer_idx, attn in enumerate(outputs.attentions):
# attn: (batch, heads, seq, seq)
# Importance: variance of attention weights (more focused = more important)
head_imp = attn.mean(0).var(-1).mean(-1) # (heads,)
key = f"layer_{layer_idx}"
if key not in head_importance:
head_importance[key] = head_imp
else:
head_importance[key] += head_imp
return head_importance
def prune_heads(self, model, heads_to_prune):
"""
Remove specified heads.
heads_to_prune: dict mapping layer_idx to list of head indices.
"""
model.prune_heads(heads_to_prune) # Built-in HuggingFace method
return model
# HuggingFace example
from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
# Remove heads 0, 3, 5 from layer 0 and heads 1, 2 from layer 6
heads_to_prune = {0: [0, 3, 5], 6: [1, 2]}
model.prune_heads(heads_to_prune)
total_params = sum(p.numel() for p in model.parameters())
print(f"Pruned BERT params: {total_params:,}")
4.3 Layer Pruning
Remove entire transformer layers that contribute least to model performance.
class LayerPruner:
"""Transformer layer pruning."""
def compute_layer_importance_by_gradient(
self, model, dataloader, device='cuda'
):
"""
Gradient-based layer importance.
Importance = mean |gradient * weight| across the layer.
"""
model.train()
layer_importance = {}
for batch in dataloader:
inputs = {k: v.to(device) for k, v in batch.items()}
outputs = model(**inputs)
loss = outputs.loss
loss.backward()
for i, layer in enumerate(model.encoder.layer):
grad_sum = 0.0
param_count = 0
for param in layer.parameters():
if param.grad is not None:
grad_sum += (param.grad * param.data).abs().sum().item()
param_count += param.numel()
key = f"layer_{i}"
imp = grad_sum / max(param_count, 1)
if key not in layer_importance:
layer_importance[key] = 0.0
layer_importance[key] += imp
model.zero_grad()
return layer_importance
def drop_layers(self, model, layers_to_drop):
"""Remove specified layers and rebuild the encoder."""
import copy
new_model = copy.deepcopy(model)
remaining = [
layer for i, layer in enumerate(new_model.encoder.layer)
if i not in layers_to_drop
]
new_model.encoder.layer = nn.ModuleList(remaining)
new_model.config.num_hidden_layers = len(remaining)
return new_model
4.4 PyTorch torch.nn.utils.prune
import torch.nn.utils.prune as prune
model = models.resnet18(weights='DEFAULT')
# Apply unstructured L1 pruning to a specific layer
conv = model.layer1[0].conv1
prune.l1_unstructured(conv, name='weight', amount=0.3)
# Sets 30% of weights to zero using a mask
print(f"Sparsity: {(conv.weight == 0).float().mean():.2f}")
# Structured pruning (filter-level)
prune.ln_structured(conv, name='weight', amount=0.3, n=2, dim=0)
# Make pruning permanent (remove mask, bake zeros into weights)
prune.remove(conv, 'weight')
# Apply to all Conv2d layers globally
parameters_to_prune = []
for module in model.modules():
if isinstance(module, nn.Conv2d):
parameters_to_prune.append((module, 'weight'))
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.4, # Prune 40% of all conv weights
)
# Check global sparsity
total_zero = sum(
(m.weight == 0).sum().item()
for m in model.modules() if isinstance(m, nn.Conv2d)
)
total_params = sum(
m.weight.numel()
for m in model.modules() if isinstance(m, nn.Conv2d)
)
print(f"Global sparsity: {total_zero/total_params:.2%}")
5. Unstructured Pruning
5.1 Magnitude-based Pruning
The simplest method: set weights with the smallest absolute values to zero.
class MagnitudePruner:
"""Magnitude-based unstructured pruning."""
def __init__(self, model, sparsity=0.5):
self.model = model
self.sparsity = sparsity
self.masks = {}
def compute_global_threshold(self):
"""Compute the global threshold across all weight tensors."""
all_weights = []
for name, param in self.model.named_parameters():
if 'weight' in name and param.dim() > 1:
all_weights.append(param.data.abs().view(-1))
all_weights = torch.cat(all_weights)
threshold = all_weights.kthvalue(
int(len(all_weights) * self.sparsity)
).values.item()
return threshold
def apply_pruning(self):
"""Create pruning masks using the global threshold."""
threshold = self.compute_global_threshold()
for name, param in self.model.named_parameters():
if 'weight' in name and param.dim() > 1:
mask = (param.data.abs() >= threshold).float()
self.masks[name] = mask
param.data *= mask
total_zeros = sum((m == 0).sum().item() for m in self.masks.values())
total_params = sum(m.numel() for m in self.masks.values())
print(f"Actual sparsity: {total_zeros/total_params:.2%}")
def apply_masks(self):
"""Re-apply masks after each gradient step to prevent dead weights from reviving."""
for name, param in self.model.named_parameters():
if name in self.masks:
param.data *= self.masks[name]
5.2 Gradual Magnitude Pruning
Pruning too aggressively from the start causes a sharp accuracy drop. A gradual schedule that slowly increases sparsity during training works much better.
class GradualMagnitudePruner:
"""
Gradually increase sparsity during training.
Zhu and Gupta (2017): To Prune, or Not to Prune.
"""
def __init__(self, model, initial_sparsity=0.0, final_sparsity=0.8,
begin_step=0, end_step=1000, frequency=100):
self.model = model
self.initial_sparsity = initial_sparsity
self.final_sparsity = final_sparsity
self.begin_step = begin_step
self.end_step = end_step
self.frequency = frequency
self.masks = {}
self._init_masks()
def _init_masks(self):
for name, param in self.model.named_parameters():
if 'weight' in name and param.dim() > 1:
self.masks[name] = torch.ones_like(param.data)
def compute_sparsity(self, step):
"""Compute target sparsity at current step using cubic schedule."""
if step < self.begin_step:
return self.initial_sparsity
if step > self.end_step:
return self.final_sparsity
pct_done = (step - self.begin_step) / (self.end_step - self.begin_step)
sparsity = (
self.final_sparsity
+ (self.initial_sparsity - self.final_sparsity)
* (1 - pct_done) ** 3
)
return sparsity
def step(self, global_step):
"""Update pruning masks at appropriate training steps."""
if (global_step % self.frequency != 0 or
global_step < self.begin_step or
global_step > self.end_step):
return
target_sparsity = self.compute_sparsity(global_step)
self._update_masks(target_sparsity)
def _update_masks(self, sparsity):
"""Update masks to match target sparsity."""
for name, param in self.model.named_parameters():
if name not in self.masks:
continue
n_prune = int(sparsity * param.data.numel())
if n_prune == 0:
continue
threshold = param.data.abs().view(-1).kthvalue(n_prune).values.item()
self.masks[name] = (param.data.abs() > threshold).float()
param.data *= self.masks[name]
6. Weight Sharing
6.1 Cross-Layer Parameter Sharing — ALBERT
ALBERT (Lan et al., 2019) drastically reduces BERT's parameters by reusing the same parameters across all transformer layers.
class ALBERTEncoder(nn.Module):
"""
ALBERT-style weight sharing.
A single transformer layer is executed N times.
"""
def __init__(self, hidden_size=768, num_heads=12,
intermediate_size=3072, num_layers=12):
super().__init__()
# Only ONE transformer layer is defined
self.shared_layer = nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=num_heads,
dim_feedforward=intermediate_size,
batch_first=True,
norm_first=True,
)
self.num_layers = num_layers
self.norm = nn.LayerNorm(hidden_size)
def forward(self, x, src_key_padding_mask=None):
# Execute the same layer num_layers times
for _ in range(self.num_layers):
x = self.shared_layer(x, src_key_padding_mask=src_key_padding_mask)
return self.norm(x)
# Parameter comparison
bert_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(768, 12, 3072, batch_first=True, norm_first=True),
num_layers=12
)
albert_encoder = ALBERTEncoder(num_layers=12)
bert_params = sum(p.numel() for p in bert_encoder.parameters())
albert_params = sum(p.numel() for p in albert_encoder.parameters())
print(f"BERT encoder: {bert_params:,}") # ~85M
print(f"ALBERT encoder: {albert_params:,}") # ~7M
print(f"Compression: {bert_params/albert_params:.1f}x")
6.2 Factorized Embeddings (ALBERT)
ALBERT also factorizes the embedding matrix: vocab_size x hidden_size → (vocab_size x embedding_size) + (embedding_size x hidden_size), where embedding_size is much smaller than hidden_size.
class FactorizedEmbedding(nn.Module):
"""
ALBERT-style factorized embedding.
vocab_size x H factored into (vocab_size x E) + (E x H)
where E << H.
"""
def __init__(self, vocab_size, embedding_size=128, hidden_size=768):
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, embedding_size)
self.embedding_projection = nn.Linear(embedding_size, hidden_size, bias=False)
def forward(self, input_ids):
embed = self.word_embeddings(input_ids) # (B, seq, E)
return self.embedding_projection(embed) # (B, seq, H)
# Parameter comparison (vocab_size=30000)
standard_embed = nn.Embedding(30000, 768) # 23.0M params
factorized_embed = FactorizedEmbedding(30000, 128, 768) # 3.97M params
s_params = sum(p.numel() for p in standard_embed.parameters())
f_params = sum(p.numel() for p in factorized_embed.parameters())
print(f"Standard: {s_params:,} -> Factorized: {f_params:,}")
print(f"Savings: {(s_params - f_params)/s_params:.1%}")
7. Neural Architecture Search (NAS)
7.1 Manual Design vs AutoML
Traditional CNNs were designed by researchers by hand (VGG, ResNet). NAS automates this process.
Three core components of NAS:
- Search Space: What operations and structures to explore
- Search Strategy: Reinforcement learning, evolutionary algorithms, gradient-based methods
- Performance Estimation: Quickly evaluate candidate architectures
7.2 DARTS (Differentiable Architecture Search)
Liu et al. (2019). Transforms architecture search into a continuous optimization problem.
import torch
import torch.nn as nn
import torch.nn.functional as F
# Candidate operations
PRIMITIVES = [
'none', # Zero connection
'skip_connect', # Identity
'sep_conv_3x3', # 3x3 Separable Conv
'sep_conv_5x5', # 5x5 Separable Conv
'dil_conv_3x3', # 3x3 Dilated Conv
'dil_conv_5x5', # 5x5 Dilated Conv
'avg_pool_3x3', # 3x3 Average Pool
'max_pool_3x3', # 3x3 Max Pool
]
class MixedOperation(nn.Module):
"""
DARTS: represent each edge as a weighted mixture of operations.
Architecture parameters (alpha) control each operation's weight.
"""
def __init__(self, C, stride, primitives):
super().__init__()
self.ops = nn.ModuleList([
self._build_op(primitive, C, stride)
for primitive in primitives
])
# Learnable architecture parameters
self.arch_params = nn.Parameter(
torch.randn(len(primitives)) * 1e-3
)
def _build_op(self, primitive, C, stride):
if primitive == 'none':
return nn.Sequential()
elif primitive == 'skip_connect':
return nn.Identity() if stride == 1 else nn.Sequential(
nn.AvgPool2d(stride, stride, padding=0),
nn.Conv2d(C, C, 1, bias=False),
nn.BatchNorm2d(C)
)
elif primitive == 'sep_conv_3x3':
return nn.Sequential(
nn.Conv2d(C, C, 3, stride=stride, padding=1, groups=C, bias=False),
nn.Conv2d(C, C, 1, bias=False),
nn.BatchNorm2d(C),
nn.ReLU(inplace=True),
)
elif primitive == 'avg_pool_3x3':
return nn.AvgPool2d(3, stride=stride, padding=1)
elif primitive == 'max_pool_3x3':
return nn.MaxPool2d(3, stride=stride, padding=1)
else:
return nn.Identity()
def forward(self, x):
weights = F.softmax(self.arch_params, dim=0)
results = []
for w, op in zip(weights, self.ops):
try:
result = op(x)
results.append(w * result)
except Exception:
pass
if results:
output = results[0]
for r in results[1:]:
if r.shape == output.shape:
output = output + r
return output
return x
def darts_discrete(cell):
"""
Convert continuous architecture to discrete by
selecting the highest-weight operation at each edge.
"""
for op in cell.ops:
weights = F.softmax(op.arch_params, dim=0)
best_op_idx = weights.argmax().item()
print(f" Selected: {PRIMITIVES[best_op_idx]} "
f"(weight: {weights[best_op_idx]:.3f})")
7.3 EfficientNet's NAS Process
EfficientNet-B0 is the base architecture found by an MnasNet-style NAS. The search optimized for accuracy under a mobile latency constraint. The result is a network built from MBConv blocks with squeeze-and-excitation.
class MBConvBlock(nn.Module):
"""
Mobile Inverted Bottleneck Convolution.
EfficientNet's core building block, found by NAS.
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, expand_ratio=6, se_ratio=0.25):
super().__init__()
self.use_residual = (stride == 1 and in_channels == out_channels)
hidden_dim = int(in_channels * expand_ratio)
layers = []
# Expansion phase
if expand_ratio != 1:
layers.extend([
nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
])
# Depthwise conv
layers.extend([
nn.Conv2d(hidden_dim, hidden_dim, kernel_size,
stride=stride,
padding=kernel_size // 2,
groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
])
# Squeeze-and-Excitation (key module discovered by NAS)
se_channels = max(1, int(in_channels * se_ratio))
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(hidden_dim, se_channels, 1),
nn.SiLU(),
nn.Conv2d(se_channels, hidden_dim, 1),
nn.Sigmoid(),
)
# Output projection
layers.extend([
nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
])
self.layers = nn.Sequential(*layers)
def forward(self, x):
out = self.layers(x)
if self.use_residual:
return x + out
return out
7.4 Once-for-All Network
Cai et al. (2020). Train one large network and then extract sub-networks tailored to different resource constraints — without retraining.
class OFALayer(nn.Module):
"""
Once-for-All elastic layer that supports multiple kernel sizes.
Training: randomly sample kernel size -> train sub-networks
Deployment: use only the required kernel size
"""
def __init__(self, channels, max_kernel=7):
super().__init__()
# Only one conv is stored — the largest kernel size
self.max_conv = nn.Conv2d(
channels, channels,
kernel_size=max_kernel,
padding=max_kernel // 2,
groups=channels, bias=False
)
self.bn = nn.BatchNorm2d(channels)
self.act = nn.ReLU(inplace=True)
self.active_kernel = max_kernel
def set_active_kernel(self, kernel_size):
"""Set the active kernel size for this layer."""
assert kernel_size <= self.max_conv.kernel_size[0]
self.active_kernel = kernel_size
def forward(self, x):
if self.active_kernel == self.max_conv.kernel_size[0]:
weight = self.max_conv.weight
else:
# Extract the center sub-kernel
center = self.max_conv.kernel_size[0] // 2
half = self.active_kernel // 2
weight = self.max_conv.weight[
:, :,
center - half: center + half + 1,
center - half: center + half + 1
]
padding = self.active_kernel // 2
out = F.conv2d(x, weight, padding=padding, groups=x.size(1))
return self.act(self.bn(out))
def progressive_shrinking_train(model, dataloader, kernels=[7, 5, 3]):
"""
Stage 1: train with max kernel (7)
Stage 2: randomly sample from [7, 5]
Stage 3: randomly sample from [7, 5, 3]
"""
for stage, active_kernels in enumerate(
[kernels[:1], kernels[:2], kernels]
):
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
print(f"Stage {stage+1}: kernels = {active_kernels}")
for epoch in range(5):
for images, labels in dataloader:
# Randomly sample kernel size each batch
k = active_kernels[torch.randint(len(active_kernels), (1,)).item()]
for module in model.modules():
if isinstance(module, OFALayer):
module.set_active_kernel(k)
outputs = model(images)
loss = F.cross_entropy(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
8. Integrated Model Compression Pipeline
In practice, multiple techniques are combined sequentially:
class ModelCompressionPipeline:
"""
Integrated pipeline: Distillation -> Pruning -> Quantization.
"""
def __init__(self, teacher, student, num_classes):
self.teacher = teacher
self.num_classes = num_classes
self.student = student
def step1_distillation(self, train_loader,
epochs=30, device='cuda'):
"""Step 1: Knowledge distillation."""
print("=== Step 1: Knowledge Distillation ===")
self.student = train_with_distillation(
self.teacher, self.student, train_loader,
num_epochs=epochs, temperature=4.0, alpha=0.5,
device=device
)
def step2_pruning(self, train_loader, sparsity=0.5,
finetune_epochs=10, device='cuda'):
"""Step 2: Gradual pruning + fine-tuning."""
print(f"=== Step 2: Pruning (target: {sparsity:.0%}) ===")
pruner = GradualMagnitudePruner(
self.student,
initial_sparsity=0.0,
final_sparsity=sparsity,
begin_step=0,
end_step=len(train_loader) * finetune_epochs,
frequency=100
)
optimizer = torch.optim.Adam(self.student.parameters(), lr=1e-4)
self.student.to(device)
step = 0
for epoch in range(finetune_epochs):
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
outputs = self.student(images)
loss = F.cross_entropy(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pruner.step(step)
pruner.apply_masks()
step += 1
def step3_quantization(self, calibration_loader, device='cpu'):
"""Step 3: Post-Training Static Quantization."""
print("=== Step 3: Quantization ===")
self.student.eval().to(device)
self.student.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
torch.ao.quantization.prepare(self.student, inplace=True)
with torch.no_grad():
for images, _ in calibration_loader:
self.student(images.to(device))
torch.ao.quantization.convert(self.student, inplace=True)
print("Quantization complete (INT8)")
return self.student
def compare_models(self, val_loader, device='cuda'):
"""Compare Teacher vs compressed Student."""
def eval_model(model, loader, dev):
model.eval().to(dev)
correct, total = 0, 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(dev), labels.to(dev)
preds = model(images).argmax(1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return correct / total
teacher_acc = eval_model(self.teacher, val_loader, device)
student_acc = eval_model(self.student, val_loader, 'cpu')
t_params = sum(p.numel() for p in self.teacher.parameters())
s_params = sum(p.numel() for p in self.student.parameters())
print(f"\n{'='*55}")
print(f"Teacher | Params: {t_params:>10,} | Acc: {teacher_acc:.4f}")
print(f"Student | Params: {s_params:>10,} | Acc: {student_acc:.4f}")
print(f"Compression: {t_params/s_params:.1f}x | "
f"Retained accuracy: {student_acc/teacher_acc:.1%}")
Summary
Knowledge distillation and model compression are the essential bridge between research-scale models and real-world deployment.
Key takeaways:
- Knowledge Distillation: Transfer inter-class relationship information via teacher's soft targets (probability distributions)
- Feature Distillation: Transfer intermediate representations, not just final outputs
- Relation Distillation: Preserve the structural relationships between samples
- LLM Distillation: DistilBERT, TinyLLaMA, and Distil Whisper demonstrate practical large-model compression
- Structured Pruning: Remove filters, heads, and layers for real hardware speedup
- Unstructured Pruning: Gradual sparsity schedules minimize accuracy loss
- Weight Sharing: ALBERT-style parameter reuse achieves up to 12x compression
- NAS: DARTS and Once-for-All automate architecture design for target hardware
The best practical approach is a sequential pipeline: distillation -> pruning -> quantization. Each step builds on the previous, allowing each compression technique to work on an already-compressed model.
References
- Hinton et al. (2015). Distilling the Knowledge in a Neural Network. https://arxiv.org/abs/1503.02531
- Romero et al. (2015). FitNets: Hints for Thin Deep Nets.
- Zagoruyko and Komodakis (2017). Paying More Attention to Attention.
- Park et al. (2019). Relational Knowledge Distillation.
- Sanh et al. (2019). DistilBERT, a distilled version of BERT. https://arxiv.org/abs/1910.01108
- Lan et al. (2019). ALBERT: A Lite BERT for Self-supervised Learning of Language Representations.
- Liu et al. (2019). DARTS: Differentiable Architecture Search.
- Cai et al. (2020). Once-for-All: Train One Network and Specialize it for Efficient Deployment.
- Zhu and Gupta (2017). To Prune, or Not to Prune: Exploring the Efficacy of Pruning for Model Compression.
- Tan and Le (2019). EfficientNet. https://arxiv.org/abs/1905.11946
- PyTorch Pruning Documentation: https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.html