필사 모드: Knowledge Distillation Complete Guide: Model Compression and Lightweight Techniques
EnglishIntroduction
As deep learning models grow more powerful, they also grow larger. Models like GPT-4 or Llama 3 405B require hundreds of gigabytes of memory, making them impossible to run on mobile devices or edge hardware. **Knowledge Distillation** and **model compression** techniques are the essential tools for making these models smaller and faster while preserving as much of their accuracy as possible.
This guide covers:
- Theoretical foundations of knowledge distillation with complete PyTorch implementations
- Multiple distillation paradigms: Response-based, Feature-based, Relation-based
- LLM distillation case studies (DistilBERT, TinyLLM, Distil Whisper)
- Structured and unstructured pruning
- Weight sharing and Neural Architecture Search (NAS)
1. Knowledge Distillation Fundamentals
1.1 The Teacher-Student Framework
Knowledge distillation was introduced by Hinton, Vinyals, and Dean in 2015. The core idea: **transfer the "knowledge" of a large Teacher model into a small Student model**.
Teacher Model (large, accurate)
│
│ Soft targets (probability distributions)
▼
Student Model (small, fast)
Rather than training with hard labels (one-hot vectors), the student learns from the Teacher's **soft targets** — the full softmax output distribution.
For example, classifying a cat image:
- Hard target: [0, 0, 1, 0, 0] (only the correct class is 1)
- Teacher soft target: [0.01, 0.05, **0.85**, 0.07, 0.02]
The soft target encodes the information that "this image is a cat, but it has some resemblance to a tiger." This **inter-class similarity information** provides rich supervision for the student.
1.2 The Temperature Parameter
When softmax outputs are too confident ([0.001, 0.002, 0.997, ...]), they carry almost no more information than hard labels. The temperature parameter T makes distributions softer:
softmax_T(z_i) = exp(z_i / T) / sum_j(exp(z_j / T))
- When T is 1: standard softmax
- When T is greater than 1: flatter, more uniform distribution (more information)
- When T is less than 1: sharper, more peaked distribution
def temperature_softmax(logits, temperature=1.0):
"""Temperature-scaled softmax."""
return F.softmax(logits / temperature, dim=-1)
Example
logits = torch.tensor([2.0, 1.0, 0.1, 0.5])
print("T=1:", temperature_softmax(logits, temperature=1).numpy().round(3))
[0.596, 0.219, 0.090, 0.096]
print("T=4:", temperature_softmax(logits, temperature=4).numpy().round(3))
[0.345, 0.262, 0.195, 0.216] — more uniform, more information
1.3 Hinton's KD Loss Function
The KD loss is a weighted sum of two terms:
**KL Divergence term** (soft target matching):
L_KD = T^2 * KLDiv(softmax_T(student_logits), softmax_T(teacher_logits))
The T^2 scaling keeps gradient magnitudes independent of the temperature value.
**Cross-Entropy term** (hard target learning):
L_CE = CrossEntropy(student_logits, true_labels)
**Total loss**:
L = alpha * L_CE + (1 - alpha) * L_KD
1.4 Complete PyTorch Implementation
from torch.utils.data import DataLoader
class KnowledgeDistillationLoss(nn.Module):
"""
Hinton et al. (2015) Knowledge Distillation Loss.
L = alpha * CE(student, labels) + (1-alpha) * T^2 * KLDiv(student_soft, teacher_soft)
"""
def __init__(self, temperature=4.0, alpha=0.5):
super().__init__()
self.T = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
self.kl_loss = nn.KLDivLoss(reduction='batchmean')
def forward(self, student_logits, teacher_logits, labels):
Hard target loss
loss_ce = self.ce_loss(student_logits, labels)
Soft target loss (KL Divergence)
student_soft = F.log_softmax(student_logits / self.T, dim=-1)
teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
loss_kd = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)
return self.alpha * loss_ce + (1 - self.alpha) * loss_kd
def train_with_distillation(
teacher, student, train_loader,
num_epochs=10, temperature=4.0, alpha=0.5,
device='cuda'
):
"""Teacher-Student distillation training loop."""
teacher = teacher.to(device).eval() # Teacher is frozen
student = student.to(device)
Freeze teacher parameters
for param in teacher.parameters():
param.requires_grad = False
criterion = KnowledgeDistillationLoss(temperature=temperature, alpha=alpha)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
for epoch in range(num_epochs):
student.train()
total_loss = 0.0
correct = 0
total = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
Teacher inference (no gradient needed)
with torch.no_grad():
teacher_logits = teacher(images)
Student inference
student_logits = student(images)
KD loss computation
loss = criterion(student_logits, teacher_logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * images.size(0)
correct += (student_logits.argmax(1) == labels).sum().item()
total += images.size(0)
scheduler.step()
print(f"Epoch {epoch+1}/{num_epochs} | "
f"Loss: {total_loss/total:.4f} | "
f"Acc: {correct/total:.4f}")
return student
Practical example
Teacher: ResNet50 (25M params), Student: ResNet18 (11M params)
teacher = models.resnet50(weights='DEFAULT')
teacher.fc = nn.Linear(2048, 10)
student = models.resnet18(weights=None)
student.fc = nn.Linear(512, 10)
Parameter count comparison
t_params = sum(p.numel() for p in teacher.parameters())
s_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {t_params:,} params") # ~25.6M
print(f"Student: {s_params:,} params") # ~11.2M
print(f"Compression ratio: {t_params/s_params:.1f}x")
2. Distillation Paradigms
2.1 Response-based Distillation (Logit Matching)
The most basic approach — the student mimics the Teacher's **final outputs (logits)**. Hinton's original KD is the canonical example.
class ResponseBasedDistillation(nn.Module):
"""Response-based distillation using only final outputs."""
def __init__(self, temperature=4.0):
super().__init__()
self.T = temperature
def forward(self, student_logits, teacher_logits):
student_soft = F.log_softmax(student_logits / self.T, dim=-1)
teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.T ** 2)
2.2 Feature-based Distillation (Intermediate Layers)
Proposed in FitNets (Romero et al., 2015). The student learns to match the Teacher's **intermediate feature maps**, not just the final outputs.
Since the teacher and student may have different channel counts, a **Regressor** network projects the student features to match the teacher's dimensions.
class FeatureDistillationLoss(nn.Module):
"""Distillation via intermediate layer feature matching."""
def __init__(self, teacher_channels, student_channels):
super().__init__()
Project student features to teacher feature space
self.regressor = nn.Sequential(
nn.Conv2d(student_channels, teacher_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(teacher_channels),
nn.ReLU(inplace=True),
)
def forward(self, student_feat, teacher_feat):
projected = self.regressor(student_feat)
return F.mse_loss(projected, teacher_feat.detach())
class HookBasedDistillation:
"""
Use forward hooks to extract intermediate layer features
without modifying the model's forward() method.
"""
def __init__(self, model, layer_names):
self.features = {}
self.hooks = []
for name, layer in model.named_modules():
if name in layer_names:
hook = layer.register_forward_hook(
self._make_hook(name)
)
self.hooks.append(hook)
def _make_hook(self, name):
def hook(module, input, output):
self.features[name] = output
return hook
def remove(self):
for hook in self.hooks:
hook.remove()
Usage example
teacher = models.resnet50(weights='DEFAULT')
student = models.resnet18(weights=None)
Teacher layer3: 1024 channels; Student layer3: 256 channels
teacher_hook = HookBasedDistillation(teacher, ['layer3'])
student_hook = HookBasedDistillation(student, ['layer3'])
feat_distill = FeatureDistillationLoss(
teacher_channels=1024,
student_channels=256
)
During training
x = torch.randn(4, 3, 224, 224)
teacher_out = teacher(x)
student_out = student(x)
teacher_feat = teacher_hook.features['layer3']
student_feat = student_hook.features['layer3']
loss_feat = feat_distill(student_feat, teacher_feat)
2.3 Relation-based Distillation
RKD (Park et al., 2019). Rather than matching absolute values, the student learns to mimic the **relational structure** between samples in the teacher's embedding space.
class RelationalKnowledgeDistillation(nn.Module):
"""
Relational KD: preserve pairwise distance and angle relationships
between samples in the embedding space.
"""
def __init__(self, distance_weight=25.0, angle_weight=50.0):
super().__init__()
self.dist_w = distance_weight
self.angle_w = angle_weight
def pdist(self, e, squared=False, eps=1e-12):
"""Pairwise distance matrix computation."""
e_sq = (e ** 2).sum(dim=1)
prod = e @ e.t()
res = (e_sq.unsqueeze(1) + e_sq.unsqueeze(0) - 2 * prod).clamp(min=eps)
if not squared:
res = res.sqrt()
return res
def distance_loss(self, teacher_emb, student_emb):
"""Preserve pairwise distance relationships."""
t_d = self.pdist(teacher_emb)
t_d = t_d / (t_d.mean() + 1e-12) # normalize
s_d = self.pdist(student_emb)
s_d = s_d / (s_d.mean() + 1e-12)
return F.smooth_l1_loss(s_d, t_d.detach())
def angle_loss(self, teacher_emb, student_emb):
"""Preserve angular relationships between triplets."""
td = teacher_emb.unsqueeze(0) - teacher_emb.unsqueeze(1) # (N, N, D)
sd = student_emb.unsqueeze(0) - student_emb.unsqueeze(1)
td_norm = F.normalize(td.view(-1, td.size(-1)), dim=-1)
sd_norm = F.normalize(sd.view(-1, sd.size(-1)), dim=-1)
t_angle = (td_norm * td_norm.flip(0)).sum(dim=-1)
s_angle = (sd_norm * sd_norm.flip(0)).sum(dim=-1)
return F.smooth_l1_loss(s_angle, t_angle.detach())
def forward(self, teacher_emb, student_emb):
loss = self.dist_w * self.distance_loss(teacher_emb, student_emb)
loss += self.angle_w * self.angle_loss(teacher_emb, student_emb)
return loss
2.4 Attention Transfer
Zagoruyko and Komodakis (2017). Transfer attention maps — spatial activation patterns — from teacher to student.
class AttentionTransfer(nn.Module):
"""
Attention Transfer: transfer spatial attention patterns
derived from feature map activations.
"""
def attention_map(self, feat):
"""
Compute spatial attention map from feature maps.
Sums squared activations across channels and normalizes.
"""
return F.normalize(feat.pow(2).mean(1).view(feat.size(0), -1))
def forward(self, student_feats, teacher_feats):
"""
student_feats, teacher_feats: lists of feature maps at multiple layers.
Returns summed AT loss across all pairs.
"""
loss = 0.0
for s_feat, t_feat in zip(student_feats, teacher_feats):
s_attn = self.attention_map(s_feat)
t_attn = self.attention_map(t_feat)
loss += (s_attn - t_attn.detach()).pow(2).mean()
return loss
3. LLM Distillation
3.1 DistilBERT
DistilBERT (Sanh et al., 2019) compresses BERT-base (110M parameters) to 66M parameters.
Key techniques:
- **Half the layers**: 12 → 6
- **Remove token-type embeddings**
- **Remove the pooler**
- **Triple loss**: MLM + Distillation + Cosine embedding loss
from transformers import (
DistilBertModel,
DistilBertTokenizer
)
class DistilBERTLoss(nn.Module):
"""
DistilBERT triple loss:
1. MLM CE loss (language modeling)
2. Soft-target KD loss (teacher logit matching)
3. Cosine embedding loss (hidden state similarity)
"""
def __init__(self, temperature=2.0, alpha=0.5, beta=0.1):
super().__init__()
self.T = temperature
self.alpha = alpha # MLM weight
self.beta = beta # Cosine loss weight
def forward(self, student_logits, teacher_logits,
student_hidden, teacher_hidden, mlm_labels):
1. MLM loss
loss_mlm = F.cross_entropy(
student_logits.view(-1, student_logits.size(-1)),
mlm_labels.view(-1),
ignore_index=-100
)
2. Soft KD loss
s_soft = F.log_softmax(student_logits / self.T, dim=-1)
t_soft = F.softmax(teacher_logits / self.T, dim=-1)
loss_kd = F.kl_div(s_soft, t_soft, reduction='batchmean') * (self.T ** 2)
3. Cosine embedding loss (hidden state alignment)
loss_cos = 1 - F.cosine_similarity(student_hidden, teacher_hidden, dim=-1).mean()
return (self.alpha * loss_mlm
+ (1 - self.alpha) * loss_kd
+ self.beta * loss_cos)
Inference example
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
text = "Knowledge distillation is a model compression technique."
inputs = tokenizer(text, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
print(outputs.last_hidden_state.shape)
torch.Size([1, 11, 768])
DistilBERT performance:
- 40% smaller than BERT-base
- 60% faster inference
- Retains 97% of BERT's GLUE performance
3.2 LLM Distillation Pattern
For large language models, the student learns from the teacher's next-token prediction distributions.
class LLMDistillationTrainer:
"""
LLM distillation trainer.
Transfers teacher's token distribution to student.
"""
def __init__(self, teacher, student, temperature=2.0, alpha=0.5):
self.teacher = teacher
self.student = student
self.T = temperature
self.alpha = alpha
def compute_loss(self, input_ids, attention_mask, labels):
Teacher inference (no_grad)
with torch.no_grad():
teacher_out = self.teacher(
input_ids=input_ids,
attention_mask=attention_mask,
)
teacher_logits = teacher_out.logits # (B, seq_len, vocab_size)
Student inference
student_out = self.student(
input_ids=input_ids,
attention_mask=attention_mask,
)
student_logits = student_out.logits
1. CE loss (hard target)
shift_logits = student_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_ce = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100
)
2. KD loss (soft target)
Compute only on valid tokens
mask = (shift_labels != -100).float()
s_log_soft = F.log_softmax(
shift_logits / self.T, dim=-1
)
t_soft = F.softmax(
teacher_logits[..., :-1, :].contiguous() / self.T, dim=-1
)
Per-token KL divergence
kl_per_token = F.kl_div(
s_log_soft, t_soft,
reduction='none'
).sum(-1) # (B, seq_len-1)
loss_kd = (kl_per_token * mask).sum() / mask.sum()
loss_kd = loss_kd * (self.T ** 2)
return self.alpha * loss_ce + (1 - self.alpha) * loss_kd
3.3 Distil Whisper
OpenAI's Whisper speech recognition model has a distilled variant that demonstrates aggressive LLM compression.
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration
)
Distil-Whisper: distills Whisper-large-v2
large-v2: 1550M params -> distil-large-v2: 756M params (2x faster, similar WER)
processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v2")
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v2")
Only 2 decoder layers retained (original: 32)
Encoder kept intact
print(f"Encoder layers: {len(model.model.encoder.layers)}")
print(f"Decoder layers: {len(model.model.decoder.layers)}")
4. Structured Pruning
Pruning removes unimportant parameters or structures from a model.
**Structured pruning**: Remove entire filters, heads, or layers → **real speedup on all hardware**
**Unstructured pruning**: Set individual weights to zero → sparse matrices (needs specialized hardware)
4.1 Filter Pruning
Remove filters with small L1/L2 norms from CNN layers.
def get_filter_importance(conv_layer, norm='l1'):
"""Compute the importance (norm) of each filter."""
weight = conv_layer.weight.data # (out_ch, in_ch, kH, kW)
if norm == 'l1':
return weight.abs().sum(dim=(1, 2, 3))
elif norm == 'l2':
return weight.pow(2).sum(dim=(1, 2, 3)).sqrt()
elif norm == 'gm':
return weight.view(weight.size(0), -1).norm(p=2, dim=1)
def prune_conv_layer(conv, keep_ratio=0.5):
"""
Filter pruning: remove filters with lowest importance.
Returns new Conv2d with reduced filter count.
"""
importance = get_filter_importance(conv)
n_keep = int(conv.out_channels * keep_ratio)
_, indices = importance.topk(n_keep)
indices, _ = indices.sort()
new_conv = nn.Conv2d(
conv.in_channels,
n_keep,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=conv.bias is not None
)
new_conv.weight.data = conv.weight.data[indices]
if conv.bias is not None:
new_conv.bias.data = conv.bias.data[indices]
return new_conv, indices
class StructuredPruner:
"""Structured pruning for ResNet-style models."""
def __init__(self, model, prune_ratio=0.5):
self.model = model
self.prune_ratio = prune_ratio
def global_threshold_prune(self, sparsity=0.5):
"""
Prune globally: remove the sparsity fraction of filters
with lowest importance across all layers.
"""
all_importances = []
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d):
imp = get_filter_importance(module)
all_importances.append(imp)
all_imp = torch.cat(all_importances)
threshold = all_imp.kthvalue(int(len(all_imp) * sparsity)).values.item()
masks = {}
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d):
imp = get_filter_importance(module)
masks[name] = imp >= threshold
return masks
4.2 Attention Head Pruning
Remove unimportant attention heads from transformer models.
class AttentionHeadPruner:
"""
Multi-Head Attention head pruning.
Michel et al. (2019): Are Sixteen Heads Really Better than One?
"""
def compute_head_importance(self, model, dataloader, device='cuda'):
"""Compute per-layer, per-head importance scores."""
head_importance = {}
model.eval()
for batch in dataloader:
inputs = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
if hasattr(outputs, 'attentions') and outputs.attentions:
for layer_idx, attn in enumerate(outputs.attentions):
attn: (batch, heads, seq, seq)
Importance: variance of attention weights (more focused = more important)
head_imp = attn.mean(0).var(-1).mean(-1) # (heads,)
key = f"layer_{layer_idx}"
if key not in head_importance:
head_importance[key] = head_imp
else:
head_importance[key] += head_imp
return head_importance
def prune_heads(self, model, heads_to_prune):
"""
Remove specified heads.
heads_to_prune: dict mapping layer_idx to list of head indices.
"""
model.prune_heads(heads_to_prune) # Built-in HuggingFace method
return model
HuggingFace example
from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
Remove heads 0, 3, 5 from layer 0 and heads 1, 2 from layer 6
heads_to_prune = {0: [0, 3, 5], 6: [1, 2]}
model.prune_heads(heads_to_prune)
total_params = sum(p.numel() for p in model.parameters())
print(f"Pruned BERT params: {total_params:,}")
4.3 Layer Pruning
Remove entire transformer layers that contribute least to model performance.
class LayerPruner:
"""Transformer layer pruning."""
def compute_layer_importance_by_gradient(
self, model, dataloader, device='cuda'
):
"""
Gradient-based layer importance.
Importance = mean |gradient * weight| across the layer.
"""
model.train()
layer_importance = {}
for batch in dataloader:
inputs = {k: v.to(device) for k, v in batch.items()}
outputs = model(**inputs)
loss = outputs.loss
loss.backward()
for i, layer in enumerate(model.encoder.layer):
grad_sum = 0.0
param_count = 0
for param in layer.parameters():
if param.grad is not None:
grad_sum += (param.grad * param.data).abs().sum().item()
param_count += param.numel()
key = f"layer_{i}"
imp = grad_sum / max(param_count, 1)
if key not in layer_importance:
layer_importance[key] = 0.0
layer_importance[key] += imp
model.zero_grad()
return layer_importance
def drop_layers(self, model, layers_to_drop):
"""Remove specified layers and rebuild the encoder."""
new_model = copy.deepcopy(model)
remaining = [
layer for i, layer in enumerate(new_model.encoder.layer)
if i not in layers_to_drop
]
new_model.encoder.layer = nn.ModuleList(remaining)
new_model.config.num_hidden_layers = len(remaining)
return new_model
4.4 PyTorch torch.nn.utils.prune
model = models.resnet18(weights='DEFAULT')
Apply unstructured L1 pruning to a specific layer
conv = model.layer1[0].conv1
prune.l1_unstructured(conv, name='weight', amount=0.3)
Sets 30% of weights to zero using a mask
print(f"Sparsity: {(conv.weight == 0).float().mean():.2f}")
Structured pruning (filter-level)
prune.ln_structured(conv, name='weight', amount=0.3, n=2, dim=0)
Make pruning permanent (remove mask, bake zeros into weights)
prune.remove(conv, 'weight')
Apply to all Conv2d layers globally
parameters_to_prune = []
for module in model.modules():
if isinstance(module, nn.Conv2d):
parameters_to_prune.append((module, 'weight'))
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.4, # Prune 40% of all conv weights
)
Check global sparsity
total_zero = sum(
(m.weight == 0).sum().item()
for m in model.modules() if isinstance(m, nn.Conv2d)
)
total_params = sum(
m.weight.numel()
for m in model.modules() if isinstance(m, nn.Conv2d)
)
print(f"Global sparsity: {total_zero/total_params:.2%}")
5. Unstructured Pruning
5.1 Magnitude-based Pruning
The simplest method: set weights with the smallest absolute values to zero.
class MagnitudePruner:
"""Magnitude-based unstructured pruning."""
def __init__(self, model, sparsity=0.5):
self.model = model
self.sparsity = sparsity
self.masks = {}
def compute_global_threshold(self):
"""Compute the global threshold across all weight tensors."""
all_weights = []
for name, param in self.model.named_parameters():
if 'weight' in name and param.dim() > 1:
all_weights.append(param.data.abs().view(-1))
all_weights = torch.cat(all_weights)
threshold = all_weights.kthvalue(
int(len(all_weights) * self.sparsity)
).values.item()
return threshold
def apply_pruning(self):
"""Create pruning masks using the global threshold."""
threshold = self.compute_global_threshold()
for name, param in self.model.named_parameters():
if 'weight' in name and param.dim() > 1:
mask = (param.data.abs() >= threshold).float()
self.masks[name] = mask
param.data *= mask
total_zeros = sum((m == 0).sum().item() for m in self.masks.values())
total_params = sum(m.numel() for m in self.masks.values())
print(f"Actual sparsity: {total_zeros/total_params:.2%}")
def apply_masks(self):
"""Re-apply masks after each gradient step to prevent dead weights from reviving."""
for name, param in self.model.named_parameters():
if name in self.masks:
param.data *= self.masks[name]
5.2 Gradual Magnitude Pruning
Pruning too aggressively from the start causes a sharp accuracy drop. A gradual schedule that slowly increases sparsity during training works much better.
class GradualMagnitudePruner:
"""
Gradually increase sparsity during training.
Zhu and Gupta (2017): To Prune, or Not to Prune.
"""
def __init__(self, model, initial_sparsity=0.0, final_sparsity=0.8,
begin_step=0, end_step=1000, frequency=100):
self.model = model
self.initial_sparsity = initial_sparsity
self.final_sparsity = final_sparsity
self.begin_step = begin_step
self.end_step = end_step
self.frequency = frequency
self.masks = {}
self._init_masks()
def _init_masks(self):
for name, param in self.model.named_parameters():
if 'weight' in name and param.dim() > 1:
self.masks[name] = torch.ones_like(param.data)
def compute_sparsity(self, step):
"""Compute target sparsity at current step using cubic schedule."""
if step < self.begin_step:
return self.initial_sparsity
if step > self.end_step:
return self.final_sparsity
pct_done = (step - self.begin_step) / (self.end_step - self.begin_step)
sparsity = (
self.final_sparsity
+ (self.initial_sparsity - self.final_sparsity)
* (1 - pct_done) ** 3
)
return sparsity
def step(self, global_step):
"""Update pruning masks at appropriate training steps."""
if (global_step % self.frequency != 0 or
global_step < self.begin_step or
global_step > self.end_step):
return
target_sparsity = self.compute_sparsity(global_step)
self._update_masks(target_sparsity)
def _update_masks(self, sparsity):
"""Update masks to match target sparsity."""
for name, param in self.model.named_parameters():
if name not in self.masks:
continue
n_prune = int(sparsity * param.data.numel())
if n_prune == 0:
continue
threshold = param.data.abs().view(-1).kthvalue(n_prune).values.item()
self.masks[name] = (param.data.abs() > threshold).float()
param.data *= self.masks[name]
6. Weight Sharing
6.1 Cross-Layer Parameter Sharing — ALBERT
ALBERT (Lan et al., 2019) drastically reduces BERT's parameters by **reusing the same parameters across all transformer layers**.
class ALBERTEncoder(nn.Module):
"""
ALBERT-style weight sharing.
A single transformer layer is executed N times.
"""
def __init__(self, hidden_size=768, num_heads=12,
intermediate_size=3072, num_layers=12):
super().__init__()
Only ONE transformer layer is defined
self.shared_layer = nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=num_heads,
dim_feedforward=intermediate_size,
batch_first=True,
norm_first=True,
)
self.num_layers = num_layers
self.norm = nn.LayerNorm(hidden_size)
def forward(self, x, src_key_padding_mask=None):
Execute the same layer num_layers times
for _ in range(self.num_layers):
x = self.shared_layer(x, src_key_padding_mask=src_key_padding_mask)
return self.norm(x)
Parameter comparison
bert_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(768, 12, 3072, batch_first=True, norm_first=True),
num_layers=12
)
albert_encoder = ALBERTEncoder(num_layers=12)
bert_params = sum(p.numel() for p in bert_encoder.parameters())
albert_params = sum(p.numel() for p in albert_encoder.parameters())
print(f"BERT encoder: {bert_params:,}") # ~85M
print(f"ALBERT encoder: {albert_params:,}") # ~7M
print(f"Compression: {bert_params/albert_params:.1f}x")
6.2 Factorized Embeddings (ALBERT)
ALBERT also factorizes the embedding matrix: vocab_size x hidden_size → (vocab_size x embedding_size) + (embedding_size x hidden_size), where embedding_size is much smaller than hidden_size.
class FactorizedEmbedding(nn.Module):
"""
ALBERT-style factorized embedding.
vocab_size x H factored into (vocab_size x E) + (E x H)
where E << H.
"""
def __init__(self, vocab_size, embedding_size=128, hidden_size=768):
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, embedding_size)
self.embedding_projection = nn.Linear(embedding_size, hidden_size, bias=False)
def forward(self, input_ids):
embed = self.word_embeddings(input_ids) # (B, seq, E)
return self.embedding_projection(embed) # (B, seq, H)
Parameter comparison (vocab_size=30000)
standard_embed = nn.Embedding(30000, 768) # 23.0M params
factorized_embed = FactorizedEmbedding(30000, 128, 768) # 3.97M params
s_params = sum(p.numel() for p in standard_embed.parameters())
f_params = sum(p.numel() for p in factorized_embed.parameters())
print(f"Standard: {s_params:,} -> Factorized: {f_params:,}")
print(f"Savings: {(s_params - f_params)/s_params:.1%}")
7. Neural Architecture Search (NAS)
7.1 Manual Design vs AutoML
Traditional CNNs were designed by researchers by hand (VGG, ResNet). NAS automates this process.
Three core components of NAS:
1. **Search Space**: What operations and structures to explore
2. **Search Strategy**: Reinforcement learning, evolutionary algorithms, gradient-based methods
3. **Performance Estimation**: Quickly evaluate candidate architectures
7.2 DARTS (Differentiable Architecture Search)
Liu et al. (2019). Transforms architecture search into a continuous optimization problem.
Candidate operations
PRIMITIVES = [
'none', # Zero connection
'skip_connect', # Identity
'sep_conv_3x3', # 3x3 Separable Conv
'sep_conv_5x5', # 5x5 Separable Conv
'dil_conv_3x3', # 3x3 Dilated Conv
'dil_conv_5x5', # 5x5 Dilated Conv
'avg_pool_3x3', # 3x3 Average Pool
'max_pool_3x3', # 3x3 Max Pool
]
class MixedOperation(nn.Module):
"""
DARTS: represent each edge as a weighted mixture of operations.
Architecture parameters (alpha) control each operation's weight.
"""
def __init__(self, C, stride, primitives):
super().__init__()
self.ops = nn.ModuleList([
self._build_op(primitive, C, stride)
for primitive in primitives
])
Learnable architecture parameters
self.arch_params = nn.Parameter(
torch.randn(len(primitives)) * 1e-3
)
def _build_op(self, primitive, C, stride):
if primitive == 'none':
return nn.Sequential()
elif primitive == 'skip_connect':
return nn.Identity() if stride == 1 else nn.Sequential(
nn.AvgPool2d(stride, stride, padding=0),
nn.Conv2d(C, C, 1, bias=False),
nn.BatchNorm2d(C)
)
elif primitive == 'sep_conv_3x3':
return nn.Sequential(
nn.Conv2d(C, C, 3, stride=stride, padding=1, groups=C, bias=False),
nn.Conv2d(C, C, 1, bias=False),
nn.BatchNorm2d(C),
nn.ReLU(inplace=True),
)
elif primitive == 'avg_pool_3x3':
return nn.AvgPool2d(3, stride=stride, padding=1)
elif primitive == 'max_pool_3x3':
return nn.MaxPool2d(3, stride=stride, padding=1)
else:
return nn.Identity()
def forward(self, x):
weights = F.softmax(self.arch_params, dim=0)
results = []
for w, op in zip(weights, self.ops):
try:
result = op(x)
results.append(w * result)
except Exception:
pass
if results:
output = results[0]
for r in results[1:]:
if r.shape == output.shape:
output = output + r
return output
return x
def darts_discrete(cell):
"""
Convert continuous architecture to discrete by
selecting the highest-weight operation at each edge.
"""
for op in cell.ops:
weights = F.softmax(op.arch_params, dim=0)
best_op_idx = weights.argmax().item()
print(f" Selected: {PRIMITIVES[best_op_idx]} "
f"(weight: {weights[best_op_idx]:.3f})")
7.3 EfficientNet's NAS Process
EfficientNet-B0 is the base architecture found by an MnasNet-style NAS. The search optimized for accuracy under a mobile latency constraint. The result is a network built from **MBConv blocks** with squeeze-and-excitation.
class MBConvBlock(nn.Module):
"""
Mobile Inverted Bottleneck Convolution.
EfficientNet's core building block, found by NAS.
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, expand_ratio=6, se_ratio=0.25):
super().__init__()
self.use_residual = (stride == 1 and in_channels == out_channels)
hidden_dim = int(in_channels * expand_ratio)
layers = []
Expansion phase
if expand_ratio != 1:
layers.extend([
nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
])
Depthwise conv
layers.extend([
nn.Conv2d(hidden_dim, hidden_dim, kernel_size,
stride=stride,
padding=kernel_size // 2,
groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
])
Squeeze-and-Excitation (key module discovered by NAS)
se_channels = max(1, int(in_channels * se_ratio))
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(hidden_dim, se_channels, 1),
nn.SiLU(),
nn.Conv2d(se_channels, hidden_dim, 1),
nn.Sigmoid(),
)
Output projection
layers.extend([
nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
])
self.layers = nn.Sequential(*layers)
def forward(self, x):
out = self.layers(x)
if self.use_residual:
return x + out
return out
7.4 Once-for-All Network
Cai et al. (2020). Train one large network and then extract sub-networks tailored to different resource constraints — without retraining.
class OFALayer(nn.Module):
"""
Once-for-All elastic layer that supports multiple kernel sizes.
Training: randomly sample kernel size -> train sub-networks
Deployment: use only the required kernel size
"""
def __init__(self, channels, max_kernel=7):
super().__init__()
Only one conv is stored — the largest kernel size
self.max_conv = nn.Conv2d(
channels, channels,
kernel_size=max_kernel,
padding=max_kernel // 2,
groups=channels, bias=False
)
self.bn = nn.BatchNorm2d(channels)
self.act = nn.ReLU(inplace=True)
self.active_kernel = max_kernel
def set_active_kernel(self, kernel_size):
"""Set the active kernel size for this layer."""
assert kernel_size <= self.max_conv.kernel_size[0]
self.active_kernel = kernel_size
def forward(self, x):
if self.active_kernel == self.max_conv.kernel_size[0]:
weight = self.max_conv.weight
else:
Extract the center sub-kernel
center = self.max_conv.kernel_size[0] // 2
half = self.active_kernel // 2
weight = self.max_conv.weight[
:, :,
center - half: center + half + 1,
center - half: center + half + 1
]
padding = self.active_kernel // 2
out = F.conv2d(x, weight, padding=padding, groups=x.size(1))
return self.act(self.bn(out))
def progressive_shrinking_train(model, dataloader, kernels=[7, 5, 3]):
"""
Stage 1: train with max kernel (7)
Stage 2: randomly sample from [7, 5]
Stage 3: randomly sample from [7, 5, 3]
"""
for stage, active_kernels in enumerate(
[kernels[:1], kernels[:2], kernels]
):
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
print(f"Stage {stage+1}: kernels = {active_kernels}")
for epoch in range(5):
for images, labels in dataloader:
Randomly sample kernel size each batch
k = active_kernels[torch.randint(len(active_kernels), (1,)).item()]
for module in model.modules():
if isinstance(module, OFALayer):
module.set_active_kernel(k)
outputs = model(images)
loss = F.cross_entropy(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
8. Integrated Model Compression Pipeline
In practice, multiple techniques are combined sequentially:
class ModelCompressionPipeline:
"""
Integrated pipeline: Distillation -> Pruning -> Quantization.
"""
def __init__(self, teacher, student, num_classes):
self.teacher = teacher
self.num_classes = num_classes
self.student = student
def step1_distillation(self, train_loader,
epochs=30, device='cuda'):
"""Step 1: Knowledge distillation."""
print("=== Step 1: Knowledge Distillation ===")
self.student = train_with_distillation(
self.teacher, self.student, train_loader,
num_epochs=epochs, temperature=4.0, alpha=0.5,
device=device
)
def step2_pruning(self, train_loader, sparsity=0.5,
finetune_epochs=10, device='cuda'):
"""Step 2: Gradual pruning + fine-tuning."""
print(f"=== Step 2: Pruning (target: {sparsity:.0%}) ===")
pruner = GradualMagnitudePruner(
self.student,
initial_sparsity=0.0,
final_sparsity=sparsity,
begin_step=0,
end_step=len(train_loader) * finetune_epochs,
frequency=100
)
optimizer = torch.optim.Adam(self.student.parameters(), lr=1e-4)
self.student.to(device)
step = 0
for epoch in range(finetune_epochs):
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
outputs = self.student(images)
loss = F.cross_entropy(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pruner.step(step)
pruner.apply_masks()
step += 1
def step3_quantization(self, calibration_loader, device='cpu'):
"""Step 3: Post-Training Static Quantization."""
print("=== Step 3: Quantization ===")
self.student.eval().to(device)
self.student.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
torch.ao.quantization.prepare(self.student, inplace=True)
with torch.no_grad():
for images, _ in calibration_loader:
self.student(images.to(device))
torch.ao.quantization.convert(self.student, inplace=True)
print("Quantization complete (INT8)")
return self.student
def compare_models(self, val_loader, device='cuda'):
"""Compare Teacher vs compressed Student."""
def eval_model(model, loader, dev):
model.eval().to(dev)
correct, total = 0, 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(dev), labels.to(dev)
preds = model(images).argmax(1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return correct / total
teacher_acc = eval_model(self.teacher, val_loader, device)
student_acc = eval_model(self.student, val_loader, 'cpu')
t_params = sum(p.numel() for p in self.teacher.parameters())
s_params = sum(p.numel() for p in self.student.parameters())
print(f"\n{'='*55}")
print(f"Teacher | Params: {t_params:>10,} | Acc: {teacher_acc:.4f}")
print(f"Student | Params: {s_params:>10,} | Acc: {student_acc:.4f}")
print(f"Compression: {t_params/s_params:.1f}x | "
f"Retained accuracy: {student_acc/teacher_acc:.1%}")
Summary
Knowledge distillation and model compression are the essential bridge between research-scale models and real-world deployment.
Key takeaways:
1. **Knowledge Distillation**: Transfer inter-class relationship information via teacher's soft targets (probability distributions)
2. **Feature Distillation**: Transfer intermediate representations, not just final outputs
3. **Relation Distillation**: Preserve the structural relationships between samples
4. **LLM Distillation**: DistilBERT, TinyLLaMA, and Distil Whisper demonstrate practical large-model compression
5. **Structured Pruning**: Remove filters, heads, and layers for real hardware speedup
6. **Unstructured Pruning**: Gradual sparsity schedules minimize accuracy loss
7. **Weight Sharing**: ALBERT-style parameter reuse achieves up to 12x compression
8. **NAS**: DARTS and Once-for-All automate architecture design for target hardware
The best practical approach is a sequential pipeline: distillation -> pruning -> quantization. Each step builds on the previous, allowing each compression technique to work on an already-compressed model.
References
- Hinton et al. (2015). Distilling the Knowledge in a Neural Network. https://arxiv.org/abs/1503.02531
- Romero et al. (2015). FitNets: Hints for Thin Deep Nets.
- Zagoruyko and Komodakis (2017). Paying More Attention to Attention.
- Park et al. (2019). Relational Knowledge Distillation.
- Sanh et al. (2019). DistilBERT, a distilled version of BERT. https://arxiv.org/abs/1910.01108
- Lan et al. (2019). ALBERT: A Lite BERT for Self-supervised Learning of Language Representations.
- Liu et al. (2019). DARTS: Differentiable Architecture Search.
- Cai et al. (2020). Once-for-All: Train One Network and Specialize it for Efficient Deployment.
- Zhu and Gupta (2017). To Prune, or Not to Prune: Exploring the Efficacy of Pruning for Model Compression.
- Tan and Le (2019). EfficientNet. https://arxiv.org/abs/1905.11946
- PyTorch Pruning Documentation: https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.html
현재 단락 (1/894)
As deep learning models grow more powerful, they also grow larger. Models like GPT-4 or Llama 3 405B...