Skip to content

필사 모드: Knowledge Distillation Complete Guide: Model Compression and Lightweight Techniques

English
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

Introduction

As deep learning models grow more powerful, they also grow larger. Models like GPT-4 or Llama 3 405B require hundreds of gigabytes of memory, making them impossible to run on mobile devices or edge hardware. **Knowledge Distillation** and **model compression** techniques are the essential tools for making these models smaller and faster while preserving as much of their accuracy as possible.

This guide covers:

- Theoretical foundations of knowledge distillation with complete PyTorch implementations

- Multiple distillation paradigms: Response-based, Feature-based, Relation-based

- LLM distillation case studies (DistilBERT, TinyLLM, Distil Whisper)

- Structured and unstructured pruning

- Weight sharing and Neural Architecture Search (NAS)

1. Knowledge Distillation Fundamentals

1.1 The Teacher-Student Framework

Knowledge distillation was introduced by Hinton, Vinyals, and Dean in 2015. The core idea: **transfer the "knowledge" of a large Teacher model into a small Student model**.

Teacher Model (large, accurate)

│ Soft targets (probability distributions)

Student Model (small, fast)

Rather than training with hard labels (one-hot vectors), the student learns from the Teacher's **soft targets** — the full softmax output distribution.

For example, classifying a cat image:

- Hard target: [0, 0, 1, 0, 0] (only the correct class is 1)

- Teacher soft target: [0.01, 0.05, **0.85**, 0.07, 0.02]

The soft target encodes the information that "this image is a cat, but it has some resemblance to a tiger." This **inter-class similarity information** provides rich supervision for the student.

1.2 The Temperature Parameter

When softmax outputs are too confident ([0.001, 0.002, 0.997, ...]), they carry almost no more information than hard labels. The temperature parameter T makes distributions softer:

softmax_T(z_i) = exp(z_i / T) / sum_j(exp(z_j / T))

- When T is 1: standard softmax

- When T is greater than 1: flatter, more uniform distribution (more information)

- When T is less than 1: sharper, more peaked distribution

def temperature_softmax(logits, temperature=1.0):

"""Temperature-scaled softmax."""

return F.softmax(logits / temperature, dim=-1)

Example

logits = torch.tensor([2.0, 1.0, 0.1, 0.5])

print("T=1:", temperature_softmax(logits, temperature=1).numpy().round(3))

[0.596, 0.219, 0.090, 0.096]

print("T=4:", temperature_softmax(logits, temperature=4).numpy().round(3))

[0.345, 0.262, 0.195, 0.216] — more uniform, more information

1.3 Hinton's KD Loss Function

The KD loss is a weighted sum of two terms:

**KL Divergence term** (soft target matching):

L_KD = T^2 * KLDiv(softmax_T(student_logits), softmax_T(teacher_logits))

The T^2 scaling keeps gradient magnitudes independent of the temperature value.

**Cross-Entropy term** (hard target learning):

L_CE = CrossEntropy(student_logits, true_labels)

**Total loss**:

L = alpha * L_CE + (1 - alpha) * L_KD

1.4 Complete PyTorch Implementation

from torch.utils.data import DataLoader

class KnowledgeDistillationLoss(nn.Module):

"""

Hinton et al. (2015) Knowledge Distillation Loss.

L = alpha * CE(student, labels) + (1-alpha) * T^2 * KLDiv(student_soft, teacher_soft)

"""

def __init__(self, temperature=4.0, alpha=0.5):

super().__init__()

self.T = temperature

self.alpha = alpha

self.ce_loss = nn.CrossEntropyLoss()

self.kl_loss = nn.KLDivLoss(reduction='batchmean')

def forward(self, student_logits, teacher_logits, labels):

Hard target loss

loss_ce = self.ce_loss(student_logits, labels)

Soft target loss (KL Divergence)

student_soft = F.log_softmax(student_logits / self.T, dim=-1)

teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)

loss_kd = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)

return self.alpha * loss_ce + (1 - self.alpha) * loss_kd

def train_with_distillation(

teacher, student, train_loader,

num_epochs=10, temperature=4.0, alpha=0.5,

device='cuda'

):

"""Teacher-Student distillation training loop."""

teacher = teacher.to(device).eval() # Teacher is frozen

student = student.to(device)

Freeze teacher parameters

for param in teacher.parameters():

param.requires_grad = False

criterion = KnowledgeDistillationLoss(temperature=temperature, alpha=alpha)

optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

for epoch in range(num_epochs):

student.train()

total_loss = 0.0

correct = 0

total = 0

for images, labels in train_loader:

images, labels = images.to(device), labels.to(device)

Teacher inference (no gradient needed)

with torch.no_grad():

teacher_logits = teacher(images)

Student inference

student_logits = student(images)

KD loss computation

loss = criterion(student_logits, teacher_logits, labels)

optimizer.zero_grad()

loss.backward()

optimizer.step()

total_loss += loss.item() * images.size(0)

correct += (student_logits.argmax(1) == labels).sum().item()

total += images.size(0)

scheduler.step()

print(f"Epoch {epoch+1}/{num_epochs} | "

f"Loss: {total_loss/total:.4f} | "

f"Acc: {correct/total:.4f}")

return student

Practical example

Teacher: ResNet50 (25M params), Student: ResNet18 (11M params)

teacher = models.resnet50(weights='DEFAULT')

teacher.fc = nn.Linear(2048, 10)

student = models.resnet18(weights=None)

student.fc = nn.Linear(512, 10)

Parameter count comparison

t_params = sum(p.numel() for p in teacher.parameters())

s_params = sum(p.numel() for p in student.parameters())

print(f"Teacher: {t_params:,} params") # ~25.6M

print(f"Student: {s_params:,} params") # ~11.2M

print(f"Compression ratio: {t_params/s_params:.1f}x")

2. Distillation Paradigms

2.1 Response-based Distillation (Logit Matching)

The most basic approach — the student mimics the Teacher's **final outputs (logits)**. Hinton's original KD is the canonical example.

class ResponseBasedDistillation(nn.Module):

"""Response-based distillation using only final outputs."""

def __init__(self, temperature=4.0):

super().__init__()

self.T = temperature

def forward(self, student_logits, teacher_logits):

student_soft = F.log_softmax(student_logits / self.T, dim=-1)

teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)

return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.T ** 2)

2.2 Feature-based Distillation (Intermediate Layers)

Proposed in FitNets (Romero et al., 2015). The student learns to match the Teacher's **intermediate feature maps**, not just the final outputs.

Since the teacher and student may have different channel counts, a **Regressor** network projects the student features to match the teacher's dimensions.

class FeatureDistillationLoss(nn.Module):

"""Distillation via intermediate layer feature matching."""

def __init__(self, teacher_channels, student_channels):

super().__init__()

Project student features to teacher feature space

self.regressor = nn.Sequential(

nn.Conv2d(student_channels, teacher_channels, kernel_size=1, bias=False),

nn.BatchNorm2d(teacher_channels),

nn.ReLU(inplace=True),

)

def forward(self, student_feat, teacher_feat):

projected = self.regressor(student_feat)

return F.mse_loss(projected, teacher_feat.detach())

class HookBasedDistillation:

"""

Use forward hooks to extract intermediate layer features

without modifying the model's forward() method.

"""

def __init__(self, model, layer_names):

self.features = {}

self.hooks = []

for name, layer in model.named_modules():

if name in layer_names:

hook = layer.register_forward_hook(

self._make_hook(name)

)

self.hooks.append(hook)

def _make_hook(self, name):

def hook(module, input, output):

self.features[name] = output

return hook

def remove(self):

for hook in self.hooks:

hook.remove()

Usage example

teacher = models.resnet50(weights='DEFAULT')

student = models.resnet18(weights=None)

Teacher layer3: 1024 channels; Student layer3: 256 channels

teacher_hook = HookBasedDistillation(teacher, ['layer3'])

student_hook = HookBasedDistillation(student, ['layer3'])

feat_distill = FeatureDistillationLoss(

teacher_channels=1024,

student_channels=256

)

During training

x = torch.randn(4, 3, 224, 224)

teacher_out = teacher(x)

student_out = student(x)

teacher_feat = teacher_hook.features['layer3']

student_feat = student_hook.features['layer3']

loss_feat = feat_distill(student_feat, teacher_feat)

2.3 Relation-based Distillation

RKD (Park et al., 2019). Rather than matching absolute values, the student learns to mimic the **relational structure** between samples in the teacher's embedding space.

class RelationalKnowledgeDistillation(nn.Module):

"""

Relational KD: preserve pairwise distance and angle relationships

between samples in the embedding space.

"""

def __init__(self, distance_weight=25.0, angle_weight=50.0):

super().__init__()

self.dist_w = distance_weight

self.angle_w = angle_weight

def pdist(self, e, squared=False, eps=1e-12):

"""Pairwise distance matrix computation."""

e_sq = (e ** 2).sum(dim=1)

prod = e @ e.t()

res = (e_sq.unsqueeze(1) + e_sq.unsqueeze(0) - 2 * prod).clamp(min=eps)

if not squared:

res = res.sqrt()

return res

def distance_loss(self, teacher_emb, student_emb):

"""Preserve pairwise distance relationships."""

t_d = self.pdist(teacher_emb)

t_d = t_d / (t_d.mean() + 1e-12) # normalize

s_d = self.pdist(student_emb)

s_d = s_d / (s_d.mean() + 1e-12)

return F.smooth_l1_loss(s_d, t_d.detach())

def angle_loss(self, teacher_emb, student_emb):

"""Preserve angular relationships between triplets."""

td = teacher_emb.unsqueeze(0) - teacher_emb.unsqueeze(1) # (N, N, D)

sd = student_emb.unsqueeze(0) - student_emb.unsqueeze(1)

td_norm = F.normalize(td.view(-1, td.size(-1)), dim=-1)

sd_norm = F.normalize(sd.view(-1, sd.size(-1)), dim=-1)

t_angle = (td_norm * td_norm.flip(0)).sum(dim=-1)

s_angle = (sd_norm * sd_norm.flip(0)).sum(dim=-1)

return F.smooth_l1_loss(s_angle, t_angle.detach())

def forward(self, teacher_emb, student_emb):

loss = self.dist_w * self.distance_loss(teacher_emb, student_emb)

loss += self.angle_w * self.angle_loss(teacher_emb, student_emb)

return loss

2.4 Attention Transfer

Zagoruyko and Komodakis (2017). Transfer attention maps — spatial activation patterns — from teacher to student.

class AttentionTransfer(nn.Module):

"""

Attention Transfer: transfer spatial attention patterns

derived from feature map activations.

"""

def attention_map(self, feat):

"""

Compute spatial attention map from feature maps.

Sums squared activations across channels and normalizes.

"""

return F.normalize(feat.pow(2).mean(1).view(feat.size(0), -1))

def forward(self, student_feats, teacher_feats):

"""

student_feats, teacher_feats: lists of feature maps at multiple layers.

Returns summed AT loss across all pairs.

"""

loss = 0.0

for s_feat, t_feat in zip(student_feats, teacher_feats):

s_attn = self.attention_map(s_feat)

t_attn = self.attention_map(t_feat)

loss += (s_attn - t_attn.detach()).pow(2).mean()

return loss

3. LLM Distillation

3.1 DistilBERT

DistilBERT (Sanh et al., 2019) compresses BERT-base (110M parameters) to 66M parameters.

Key techniques:

- **Half the layers**: 12 → 6

- **Remove token-type embeddings**

- **Remove the pooler**

- **Triple loss**: MLM + Distillation + Cosine embedding loss

from transformers import (

DistilBertModel,

DistilBertTokenizer

)

class DistilBERTLoss(nn.Module):

"""

DistilBERT triple loss:

1. MLM CE loss (language modeling)

2. Soft-target KD loss (teacher logit matching)

3. Cosine embedding loss (hidden state similarity)

"""

def __init__(self, temperature=2.0, alpha=0.5, beta=0.1):

super().__init__()

self.T = temperature

self.alpha = alpha # MLM weight

self.beta = beta # Cosine loss weight

def forward(self, student_logits, teacher_logits,

student_hidden, teacher_hidden, mlm_labels):

1. MLM loss

loss_mlm = F.cross_entropy(

student_logits.view(-1, student_logits.size(-1)),

mlm_labels.view(-1),

ignore_index=-100

)

2. Soft KD loss

s_soft = F.log_softmax(student_logits / self.T, dim=-1)

t_soft = F.softmax(teacher_logits / self.T, dim=-1)

loss_kd = F.kl_div(s_soft, t_soft, reduction='batchmean') * (self.T ** 2)

3. Cosine embedding loss (hidden state alignment)

loss_cos = 1 - F.cosine_similarity(student_hidden, teacher_hidden, dim=-1).mean()

return (self.alpha * loss_mlm

+ (1 - self.alpha) * loss_kd

+ self.beta * loss_cos)

Inference example

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

model = DistilBertModel.from_pretrained('distilbert-base-uncased')

text = "Knowledge distillation is a model compression technique."

inputs = tokenizer(text, return_tensors='pt')

with torch.no_grad():

outputs = model(**inputs)

print(outputs.last_hidden_state.shape)

torch.Size([1, 11, 768])

DistilBERT performance:

- 40% smaller than BERT-base

- 60% faster inference

- Retains 97% of BERT's GLUE performance

3.2 LLM Distillation Pattern

For large language models, the student learns from the teacher's next-token prediction distributions.

class LLMDistillationTrainer:

"""

LLM distillation trainer.

Transfers teacher's token distribution to student.

"""

def __init__(self, teacher, student, temperature=2.0, alpha=0.5):

self.teacher = teacher

self.student = student

self.T = temperature

self.alpha = alpha

def compute_loss(self, input_ids, attention_mask, labels):

Teacher inference (no_grad)

with torch.no_grad():

teacher_out = self.teacher(

input_ids=input_ids,

attention_mask=attention_mask,

)

teacher_logits = teacher_out.logits # (B, seq_len, vocab_size)

Student inference

student_out = self.student(

input_ids=input_ids,

attention_mask=attention_mask,

)

student_logits = student_out.logits

1. CE loss (hard target)

shift_logits = student_logits[..., :-1, :].contiguous()

shift_labels = labels[..., 1:].contiguous()

loss_ce = F.cross_entropy(

shift_logits.view(-1, shift_logits.size(-1)),

shift_labels.view(-1),

ignore_index=-100

)

2. KD loss (soft target)

Compute only on valid tokens

mask = (shift_labels != -100).float()

s_log_soft = F.log_softmax(

shift_logits / self.T, dim=-1

)

t_soft = F.softmax(

teacher_logits[..., :-1, :].contiguous() / self.T, dim=-1

)

Per-token KL divergence

kl_per_token = F.kl_div(

s_log_soft, t_soft,

reduction='none'

).sum(-1) # (B, seq_len-1)

loss_kd = (kl_per_token * mask).sum() / mask.sum()

loss_kd = loss_kd * (self.T ** 2)

return self.alpha * loss_ce + (1 - self.alpha) * loss_kd

3.3 Distil Whisper

OpenAI's Whisper speech recognition model has a distilled variant that demonstrates aggressive LLM compression.

from transformers import (

WhisperProcessor,

WhisperForConditionalGeneration

)

Distil-Whisper: distills Whisper-large-v2

large-v2: 1550M params -> distil-large-v2: 756M params (2x faster, similar WER)

processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v2")

model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v2")

Only 2 decoder layers retained (original: 32)

Encoder kept intact

print(f"Encoder layers: {len(model.model.encoder.layers)}")

print(f"Decoder layers: {len(model.model.decoder.layers)}")

4. Structured Pruning

Pruning removes unimportant parameters or structures from a model.

**Structured pruning**: Remove entire filters, heads, or layers → **real speedup on all hardware**

**Unstructured pruning**: Set individual weights to zero → sparse matrices (needs specialized hardware)

4.1 Filter Pruning

Remove filters with small L1/L2 norms from CNN layers.

def get_filter_importance(conv_layer, norm='l1'):

"""Compute the importance (norm) of each filter."""

weight = conv_layer.weight.data # (out_ch, in_ch, kH, kW)

if norm == 'l1':

return weight.abs().sum(dim=(1, 2, 3))

elif norm == 'l2':

return weight.pow(2).sum(dim=(1, 2, 3)).sqrt()

elif norm == 'gm':

return weight.view(weight.size(0), -1).norm(p=2, dim=1)

def prune_conv_layer(conv, keep_ratio=0.5):

"""

Filter pruning: remove filters with lowest importance.

Returns new Conv2d with reduced filter count.

"""

importance = get_filter_importance(conv)

n_keep = int(conv.out_channels * keep_ratio)

_, indices = importance.topk(n_keep)

indices, _ = indices.sort()

new_conv = nn.Conv2d(

conv.in_channels,

n_keep,

kernel_size=conv.kernel_size,

stride=conv.stride,

padding=conv.padding,

bias=conv.bias is not None

)

new_conv.weight.data = conv.weight.data[indices]

if conv.bias is not None:

new_conv.bias.data = conv.bias.data[indices]

return new_conv, indices

class StructuredPruner:

"""Structured pruning for ResNet-style models."""

def __init__(self, model, prune_ratio=0.5):

self.model = model

self.prune_ratio = prune_ratio

def global_threshold_prune(self, sparsity=0.5):

"""

Prune globally: remove the sparsity fraction of filters

with lowest importance across all layers.

"""

all_importances = []

for name, module in self.model.named_modules():

if isinstance(module, nn.Conv2d):

imp = get_filter_importance(module)

all_importances.append(imp)

all_imp = torch.cat(all_importances)

threshold = all_imp.kthvalue(int(len(all_imp) * sparsity)).values.item()

masks = {}

for name, module in self.model.named_modules():

if isinstance(module, nn.Conv2d):

imp = get_filter_importance(module)

masks[name] = imp >= threshold

return masks

4.2 Attention Head Pruning

Remove unimportant attention heads from transformer models.

class AttentionHeadPruner:

"""

Multi-Head Attention head pruning.

Michel et al. (2019): Are Sixteen Heads Really Better than One?

"""

def compute_head_importance(self, model, dataloader, device='cuda'):

"""Compute per-layer, per-head importance scores."""

head_importance = {}

model.eval()

for batch in dataloader:

inputs = {k: v.to(device) for k, v in batch.items()}

with torch.no_grad():

outputs = model(**inputs, output_attentions=True)

if hasattr(outputs, 'attentions') and outputs.attentions:

for layer_idx, attn in enumerate(outputs.attentions):

attn: (batch, heads, seq, seq)

Importance: variance of attention weights (more focused = more important)

head_imp = attn.mean(0).var(-1).mean(-1) # (heads,)

key = f"layer_{layer_idx}"

if key not in head_importance:

head_importance[key] = head_imp

else:

head_importance[key] += head_imp

return head_importance

def prune_heads(self, model, heads_to_prune):

"""

Remove specified heads.

heads_to_prune: dict mapping layer_idx to list of head indices.

"""

model.prune_heads(heads_to_prune) # Built-in HuggingFace method

return model

HuggingFace example

from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

Remove heads 0, 3, 5 from layer 0 and heads 1, 2 from layer 6

heads_to_prune = {0: [0, 3, 5], 6: [1, 2]}

model.prune_heads(heads_to_prune)

total_params = sum(p.numel() for p in model.parameters())

print(f"Pruned BERT params: {total_params:,}")

4.3 Layer Pruning

Remove entire transformer layers that contribute least to model performance.

class LayerPruner:

"""Transformer layer pruning."""

def compute_layer_importance_by_gradient(

self, model, dataloader, device='cuda'

):

"""

Gradient-based layer importance.

Importance = mean |gradient * weight| across the layer.

"""

model.train()

layer_importance = {}

for batch in dataloader:

inputs = {k: v.to(device) for k, v in batch.items()}

outputs = model(**inputs)

loss = outputs.loss

loss.backward()

for i, layer in enumerate(model.encoder.layer):

grad_sum = 0.0

param_count = 0

for param in layer.parameters():

if param.grad is not None:

grad_sum += (param.grad * param.data).abs().sum().item()

param_count += param.numel()

key = f"layer_{i}"

imp = grad_sum / max(param_count, 1)

if key not in layer_importance:

layer_importance[key] = 0.0

layer_importance[key] += imp

model.zero_grad()

return layer_importance

def drop_layers(self, model, layers_to_drop):

"""Remove specified layers and rebuild the encoder."""

new_model = copy.deepcopy(model)

remaining = [

layer for i, layer in enumerate(new_model.encoder.layer)

if i not in layers_to_drop

]

new_model.encoder.layer = nn.ModuleList(remaining)

new_model.config.num_hidden_layers = len(remaining)

return new_model

4.4 PyTorch torch.nn.utils.prune

model = models.resnet18(weights='DEFAULT')

Apply unstructured L1 pruning to a specific layer

conv = model.layer1[0].conv1

prune.l1_unstructured(conv, name='weight', amount=0.3)

Sets 30% of weights to zero using a mask

print(f"Sparsity: {(conv.weight == 0).float().mean():.2f}")

Structured pruning (filter-level)

prune.ln_structured(conv, name='weight', amount=0.3, n=2, dim=0)

Make pruning permanent (remove mask, bake zeros into weights)

prune.remove(conv, 'weight')

Apply to all Conv2d layers globally

parameters_to_prune = []

for module in model.modules():

if isinstance(module, nn.Conv2d):

parameters_to_prune.append((module, 'weight'))

prune.global_unstructured(

parameters_to_prune,

pruning_method=prune.L1Unstructured,

amount=0.4, # Prune 40% of all conv weights

)

Check global sparsity

total_zero = sum(

(m.weight == 0).sum().item()

for m in model.modules() if isinstance(m, nn.Conv2d)

)

total_params = sum(

m.weight.numel()

for m in model.modules() if isinstance(m, nn.Conv2d)

)

print(f"Global sparsity: {total_zero/total_params:.2%}")

5. Unstructured Pruning

5.1 Magnitude-based Pruning

The simplest method: set weights with the smallest absolute values to zero.

class MagnitudePruner:

"""Magnitude-based unstructured pruning."""

def __init__(self, model, sparsity=0.5):

self.model = model

self.sparsity = sparsity

self.masks = {}

def compute_global_threshold(self):

"""Compute the global threshold across all weight tensors."""

all_weights = []

for name, param in self.model.named_parameters():

if 'weight' in name and param.dim() > 1:

all_weights.append(param.data.abs().view(-1))

all_weights = torch.cat(all_weights)

threshold = all_weights.kthvalue(

int(len(all_weights) * self.sparsity)

).values.item()

return threshold

def apply_pruning(self):

"""Create pruning masks using the global threshold."""

threshold = self.compute_global_threshold()

for name, param in self.model.named_parameters():

if 'weight' in name and param.dim() > 1:

mask = (param.data.abs() >= threshold).float()

self.masks[name] = mask

param.data *= mask

total_zeros = sum((m == 0).sum().item() for m in self.masks.values())

total_params = sum(m.numel() for m in self.masks.values())

print(f"Actual sparsity: {total_zeros/total_params:.2%}")

def apply_masks(self):

"""Re-apply masks after each gradient step to prevent dead weights from reviving."""

for name, param in self.model.named_parameters():

if name in self.masks:

param.data *= self.masks[name]

5.2 Gradual Magnitude Pruning

Pruning too aggressively from the start causes a sharp accuracy drop. A gradual schedule that slowly increases sparsity during training works much better.

class GradualMagnitudePruner:

"""

Gradually increase sparsity during training.

Zhu and Gupta (2017): To Prune, or Not to Prune.

"""

def __init__(self, model, initial_sparsity=0.0, final_sparsity=0.8,

begin_step=0, end_step=1000, frequency=100):

self.model = model

self.initial_sparsity = initial_sparsity

self.final_sparsity = final_sparsity

self.begin_step = begin_step

self.end_step = end_step

self.frequency = frequency

self.masks = {}

self._init_masks()

def _init_masks(self):

for name, param in self.model.named_parameters():

if 'weight' in name and param.dim() > 1:

self.masks[name] = torch.ones_like(param.data)

def compute_sparsity(self, step):

"""Compute target sparsity at current step using cubic schedule."""

if step < self.begin_step:

return self.initial_sparsity

if step > self.end_step:

return self.final_sparsity

pct_done = (step - self.begin_step) / (self.end_step - self.begin_step)

sparsity = (

self.final_sparsity

+ (self.initial_sparsity - self.final_sparsity)

* (1 - pct_done) ** 3

)

return sparsity

def step(self, global_step):

"""Update pruning masks at appropriate training steps."""

if (global_step % self.frequency != 0 or

global_step < self.begin_step or

global_step > self.end_step):

return

target_sparsity = self.compute_sparsity(global_step)

self._update_masks(target_sparsity)

def _update_masks(self, sparsity):

"""Update masks to match target sparsity."""

for name, param in self.model.named_parameters():

if name not in self.masks:

continue

n_prune = int(sparsity * param.data.numel())

if n_prune == 0:

continue

threshold = param.data.abs().view(-1).kthvalue(n_prune).values.item()

self.masks[name] = (param.data.abs() > threshold).float()

param.data *= self.masks[name]

6. Weight Sharing

6.1 Cross-Layer Parameter Sharing — ALBERT

ALBERT (Lan et al., 2019) drastically reduces BERT's parameters by **reusing the same parameters across all transformer layers**.

class ALBERTEncoder(nn.Module):

"""

ALBERT-style weight sharing.

A single transformer layer is executed N times.

"""

def __init__(self, hidden_size=768, num_heads=12,

intermediate_size=3072, num_layers=12):

super().__init__()

Only ONE transformer layer is defined

self.shared_layer = nn.TransformerEncoderLayer(

d_model=hidden_size,

nhead=num_heads,

dim_feedforward=intermediate_size,

batch_first=True,

norm_first=True,

)

self.num_layers = num_layers

self.norm = nn.LayerNorm(hidden_size)

def forward(self, x, src_key_padding_mask=None):

Execute the same layer num_layers times

for _ in range(self.num_layers):

x = self.shared_layer(x, src_key_padding_mask=src_key_padding_mask)

return self.norm(x)

Parameter comparison

bert_encoder = nn.TransformerEncoder(

nn.TransformerEncoderLayer(768, 12, 3072, batch_first=True, norm_first=True),

num_layers=12

)

albert_encoder = ALBERTEncoder(num_layers=12)

bert_params = sum(p.numel() for p in bert_encoder.parameters())

albert_params = sum(p.numel() for p in albert_encoder.parameters())

print(f"BERT encoder: {bert_params:,}") # ~85M

print(f"ALBERT encoder: {albert_params:,}") # ~7M

print(f"Compression: {bert_params/albert_params:.1f}x")

6.2 Factorized Embeddings (ALBERT)

ALBERT also factorizes the embedding matrix: vocab_size x hidden_size → (vocab_size x embedding_size) + (embedding_size x hidden_size), where embedding_size is much smaller than hidden_size.

class FactorizedEmbedding(nn.Module):

"""

ALBERT-style factorized embedding.

vocab_size x H factored into (vocab_size x E) + (E x H)

where E << H.

"""

def __init__(self, vocab_size, embedding_size=128, hidden_size=768):

super().__init__()

self.word_embeddings = nn.Embedding(vocab_size, embedding_size)

self.embedding_projection = nn.Linear(embedding_size, hidden_size, bias=False)

def forward(self, input_ids):

embed = self.word_embeddings(input_ids) # (B, seq, E)

return self.embedding_projection(embed) # (B, seq, H)

Parameter comparison (vocab_size=30000)

standard_embed = nn.Embedding(30000, 768) # 23.0M params

factorized_embed = FactorizedEmbedding(30000, 128, 768) # 3.97M params

s_params = sum(p.numel() for p in standard_embed.parameters())

f_params = sum(p.numel() for p in factorized_embed.parameters())

print(f"Standard: {s_params:,} -> Factorized: {f_params:,}")

print(f"Savings: {(s_params - f_params)/s_params:.1%}")

7. Neural Architecture Search (NAS)

7.1 Manual Design vs AutoML

Traditional CNNs were designed by researchers by hand (VGG, ResNet). NAS automates this process.

Three core components of NAS:

1. **Search Space**: What operations and structures to explore

2. **Search Strategy**: Reinforcement learning, evolutionary algorithms, gradient-based methods

3. **Performance Estimation**: Quickly evaluate candidate architectures

7.2 DARTS (Differentiable Architecture Search)

Liu et al. (2019). Transforms architecture search into a continuous optimization problem.

Candidate operations

PRIMITIVES = [

'none', # Zero connection

'skip_connect', # Identity

'sep_conv_3x3', # 3x3 Separable Conv

'sep_conv_5x5', # 5x5 Separable Conv

'dil_conv_3x3', # 3x3 Dilated Conv

'dil_conv_5x5', # 5x5 Dilated Conv

'avg_pool_3x3', # 3x3 Average Pool

'max_pool_3x3', # 3x3 Max Pool

]

class MixedOperation(nn.Module):

"""

DARTS: represent each edge as a weighted mixture of operations.

Architecture parameters (alpha) control each operation's weight.

"""

def __init__(self, C, stride, primitives):

super().__init__()

self.ops = nn.ModuleList([

self._build_op(primitive, C, stride)

for primitive in primitives

])

Learnable architecture parameters

self.arch_params = nn.Parameter(

torch.randn(len(primitives)) * 1e-3

)

def _build_op(self, primitive, C, stride):

if primitive == 'none':

return nn.Sequential()

elif primitive == 'skip_connect':

return nn.Identity() if stride == 1 else nn.Sequential(

nn.AvgPool2d(stride, stride, padding=0),

nn.Conv2d(C, C, 1, bias=False),

nn.BatchNorm2d(C)

)

elif primitive == 'sep_conv_3x3':

return nn.Sequential(

nn.Conv2d(C, C, 3, stride=stride, padding=1, groups=C, bias=False),

nn.Conv2d(C, C, 1, bias=False),

nn.BatchNorm2d(C),

nn.ReLU(inplace=True),

)

elif primitive == 'avg_pool_3x3':

return nn.AvgPool2d(3, stride=stride, padding=1)

elif primitive == 'max_pool_3x3':

return nn.MaxPool2d(3, stride=stride, padding=1)

else:

return nn.Identity()

def forward(self, x):

weights = F.softmax(self.arch_params, dim=0)

results = []

for w, op in zip(weights, self.ops):

try:

result = op(x)

results.append(w * result)

except Exception:

pass

if results:

output = results[0]

for r in results[1:]:

if r.shape == output.shape:

output = output + r

return output

return x

def darts_discrete(cell):

"""

Convert continuous architecture to discrete by

selecting the highest-weight operation at each edge.

"""

for op in cell.ops:

weights = F.softmax(op.arch_params, dim=0)

best_op_idx = weights.argmax().item()

print(f" Selected: {PRIMITIVES[best_op_idx]} "

f"(weight: {weights[best_op_idx]:.3f})")

7.3 EfficientNet's NAS Process

EfficientNet-B0 is the base architecture found by an MnasNet-style NAS. The search optimized for accuracy under a mobile latency constraint. The result is a network built from **MBConv blocks** with squeeze-and-excitation.

class MBConvBlock(nn.Module):

"""

Mobile Inverted Bottleneck Convolution.

EfficientNet's core building block, found by NAS.

"""

def __init__(self, in_channels, out_channels, kernel_size,

stride=1, expand_ratio=6, se_ratio=0.25):

super().__init__()

self.use_residual = (stride == 1 and in_channels == out_channels)

hidden_dim = int(in_channels * expand_ratio)

layers = []

Expansion phase

if expand_ratio != 1:

layers.extend([

nn.Conv2d(in_channels, hidden_dim, 1, bias=False),

nn.BatchNorm2d(hidden_dim),

nn.SiLU(),

])

Depthwise conv

layers.extend([

nn.Conv2d(hidden_dim, hidden_dim, kernel_size,

stride=stride,

padding=kernel_size // 2,

groups=hidden_dim, bias=False),

nn.BatchNorm2d(hidden_dim),

nn.SiLU(),

])

Squeeze-and-Excitation (key module discovered by NAS)

se_channels = max(1, int(in_channels * se_ratio))

self.se = nn.Sequential(

nn.AdaptiveAvgPool2d(1),

nn.Conv2d(hidden_dim, se_channels, 1),

nn.SiLU(),

nn.Conv2d(se_channels, hidden_dim, 1),

nn.Sigmoid(),

)

Output projection

layers.extend([

nn.Conv2d(hidden_dim, out_channels, 1, bias=False),

nn.BatchNorm2d(out_channels),

])

self.layers = nn.Sequential(*layers)

def forward(self, x):

out = self.layers(x)

if self.use_residual:

return x + out

return out

7.4 Once-for-All Network

Cai et al. (2020). Train one large network and then extract sub-networks tailored to different resource constraints — without retraining.

class OFALayer(nn.Module):

"""

Once-for-All elastic layer that supports multiple kernel sizes.

Training: randomly sample kernel size -> train sub-networks

Deployment: use only the required kernel size

"""

def __init__(self, channels, max_kernel=7):

super().__init__()

Only one conv is stored — the largest kernel size

self.max_conv = nn.Conv2d(

channels, channels,

kernel_size=max_kernel,

padding=max_kernel // 2,

groups=channels, bias=False

)

self.bn = nn.BatchNorm2d(channels)

self.act = nn.ReLU(inplace=True)

self.active_kernel = max_kernel

def set_active_kernel(self, kernel_size):

"""Set the active kernel size for this layer."""

assert kernel_size <= self.max_conv.kernel_size[0]

self.active_kernel = kernel_size

def forward(self, x):

if self.active_kernel == self.max_conv.kernel_size[0]:

weight = self.max_conv.weight

else:

Extract the center sub-kernel

center = self.max_conv.kernel_size[0] // 2

half = self.active_kernel // 2

weight = self.max_conv.weight[

:, :,

center - half: center + half + 1,

center - half: center + half + 1

]

padding = self.active_kernel // 2

out = F.conv2d(x, weight, padding=padding, groups=x.size(1))

return self.act(self.bn(out))

def progressive_shrinking_train(model, dataloader, kernels=[7, 5, 3]):

"""

Stage 1: train with max kernel (7)

Stage 2: randomly sample from [7, 5]

Stage 3: randomly sample from [7, 5, 3]

"""

for stage, active_kernels in enumerate(

[kernels[:1], kernels[:2], kernels]

):

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

print(f"Stage {stage+1}: kernels = {active_kernels}")

for epoch in range(5):

for images, labels in dataloader:

Randomly sample kernel size each batch

k = active_kernels[torch.randint(len(active_kernels), (1,)).item()]

for module in model.modules():

if isinstance(module, OFALayer):

module.set_active_kernel(k)

outputs = model(images)

loss = F.cross_entropy(outputs, labels)

optimizer.zero_grad()

loss.backward()

optimizer.step()

8. Integrated Model Compression Pipeline

In practice, multiple techniques are combined sequentially:

class ModelCompressionPipeline:

"""

Integrated pipeline: Distillation -> Pruning -> Quantization.

"""

def __init__(self, teacher, student, num_classes):

self.teacher = teacher

self.num_classes = num_classes

self.student = student

def step1_distillation(self, train_loader,

epochs=30, device='cuda'):

"""Step 1: Knowledge distillation."""

print("=== Step 1: Knowledge Distillation ===")

self.student = train_with_distillation(

self.teacher, self.student, train_loader,

num_epochs=epochs, temperature=4.0, alpha=0.5,

device=device

)

def step2_pruning(self, train_loader, sparsity=0.5,

finetune_epochs=10, device='cuda'):

"""Step 2: Gradual pruning + fine-tuning."""

print(f"=== Step 2: Pruning (target: {sparsity:.0%}) ===")

pruner = GradualMagnitudePruner(

self.student,

initial_sparsity=0.0,

final_sparsity=sparsity,

begin_step=0,

end_step=len(train_loader) * finetune_epochs,

frequency=100

)

optimizer = torch.optim.Adam(self.student.parameters(), lr=1e-4)

self.student.to(device)

step = 0

for epoch in range(finetune_epochs):

for images, labels in train_loader:

images, labels = images.to(device), labels.to(device)

outputs = self.student(images)

loss = F.cross_entropy(outputs, labels)

optimizer.zero_grad()

loss.backward()

optimizer.step()

pruner.step(step)

pruner.apply_masks()

step += 1

def step3_quantization(self, calibration_loader, device='cpu'):

"""Step 3: Post-Training Static Quantization."""

print("=== Step 3: Quantization ===")

self.student.eval().to(device)

self.student.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')

torch.ao.quantization.prepare(self.student, inplace=True)

with torch.no_grad():

for images, _ in calibration_loader:

self.student(images.to(device))

torch.ao.quantization.convert(self.student, inplace=True)

print("Quantization complete (INT8)")

return self.student

def compare_models(self, val_loader, device='cuda'):

"""Compare Teacher vs compressed Student."""

def eval_model(model, loader, dev):

model.eval().to(dev)

correct, total = 0, 0

with torch.no_grad():

for images, labels in loader:

images, labels = images.to(dev), labels.to(dev)

preds = model(images).argmax(1)

correct += (preds == labels).sum().item()

total += labels.size(0)

return correct / total

teacher_acc = eval_model(self.teacher, val_loader, device)

student_acc = eval_model(self.student, val_loader, 'cpu')

t_params = sum(p.numel() for p in self.teacher.parameters())

s_params = sum(p.numel() for p in self.student.parameters())

print(f"\n{'='*55}")

print(f"Teacher | Params: {t_params:>10,} | Acc: {teacher_acc:.4f}")

print(f"Student | Params: {s_params:>10,} | Acc: {student_acc:.4f}")

print(f"Compression: {t_params/s_params:.1f}x | "

f"Retained accuracy: {student_acc/teacher_acc:.1%}")

Summary

Knowledge distillation and model compression are the essential bridge between research-scale models and real-world deployment.

Key takeaways:

1. **Knowledge Distillation**: Transfer inter-class relationship information via teacher's soft targets (probability distributions)

2. **Feature Distillation**: Transfer intermediate representations, not just final outputs

3. **Relation Distillation**: Preserve the structural relationships between samples

4. **LLM Distillation**: DistilBERT, TinyLLaMA, and Distil Whisper demonstrate practical large-model compression

5. **Structured Pruning**: Remove filters, heads, and layers for real hardware speedup

6. **Unstructured Pruning**: Gradual sparsity schedules minimize accuracy loss

7. **Weight Sharing**: ALBERT-style parameter reuse achieves up to 12x compression

8. **NAS**: DARTS and Once-for-All automate architecture design for target hardware

The best practical approach is a sequential pipeline: distillation -> pruning -> quantization. Each step builds on the previous, allowing each compression technique to work on an already-compressed model.

References

- Hinton et al. (2015). Distilling the Knowledge in a Neural Network. https://arxiv.org/abs/1503.02531

- Romero et al. (2015). FitNets: Hints for Thin Deep Nets.

- Zagoruyko and Komodakis (2017). Paying More Attention to Attention.

- Park et al. (2019). Relational Knowledge Distillation.

- Sanh et al. (2019). DistilBERT, a distilled version of BERT. https://arxiv.org/abs/1910.01108

- Lan et al. (2019). ALBERT: A Lite BERT for Self-supervised Learning of Language Representations.

- Liu et al. (2019). DARTS: Differentiable Architecture Search.

- Cai et al. (2020). Once-for-All: Train One Network and Specialize it for Efficient Deployment.

- Zhu and Gupta (2017). To Prune, or Not to Prune: Exploring the Efficacy of Pruning for Model Compression.

- Tan and Le (2019). EfficientNet. https://arxiv.org/abs/1905.11946

- PyTorch Pruning Documentation: https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.html

현재 단락 (1/894)

As deep learning models grow more powerful, they also grow larger. Models like GPT-4 or Llama 3 405B...

작성 글자: 0원문 글자: 33,878작성 단락: 0/894