Skip to content

Split View: 지식 증류(Knowledge Distillation) 완전 가이드: 모델 경량화와 압축 기법

|

지식 증류(Knowledge Distillation) 완전 가이드: 모델 경량화와 압축 기법

소개

딥러닝 모델이 강력해질수록 크기도 커집니다. GPT-4, Llama 3 405B 같은 대형 모델은 수백 GB의 메모리를 필요로 하며, 모바일 기기나 엣지 디바이스에서는 실행이 불가능합니다. **지식 증류(Knowledge Distillation)**와 모델 압축 기법은 큰 모델의 성능을 최대한 유지하면서 모델을 작고 빠르게 만드는 핵심 기술입니다.

이 가이드에서 다루는 내용:

  • 지식 증류의 이론적 배경과 완전한 PyTorch 구현
  • 다양한 증류 방법 (Response-based, Feature-based, Relation-based)
  • LLM 증류 사례 (DistilBERT, TinyLLM)
  • 구조적/비구조적 프루닝
  • 가중치 공유와 NAS

1. 지식 증류 기초

1.1 Teacher-Student 프레임워크

지식 증류는 2015년 Hinton, Vinyals, Dean이 제안한 기법으로, 큰 Teacher 모델의 "지식"을 작은 Student 모델로 전달하는 것이 핵심입니다.

Teacher Model (크고 정확)
      │ 소프트 타겟 (확률 분포) 전달
Student Model (작고 빠름)

단순히 정답 레이블(하드 타겟)로 학습하는 것과 달리, Teacher의 **소프트 타겟(softmax 출력 확률 분포)**을 사용합니다.

예를 들어, 고양이 이미지 분류:

  • 하드 타겟: [0, 0, 1, 0, 0] (정답만 1)
  • Teacher 소프트 타겟: [0.01, 0.05, 0.85, 0.07, 0.02]

소프트 타겟에는 "이 이미지는 고양이지만 호랑이와 약간 비슷하다"는 정보가 담겨 있습니다. 이런 클래스 간 유사성 정보가 Student 학습에 큰 도움이 됩니다.

1.2 온도(Temperature) 파라미터

소프트맥스 출력이 너무 confident하면 [0.001, 0.002, 0.997, ...]처럼 거의 하드 타겟과 같아져서 정보가 적습니다. 온도 파라미터 T로 분포를 부드럽게 만듭니다:

softmax_T(z_i) = exp(z_i / T) / sum_j(exp(z_j / T))
  • T가 1일 때: 일반 소프트맥스
  • T가 1보다 클 때: 분포가 더 균일해짐 (더 많은 정보 전달)
  • T가 1보다 작을 때: 분포가 더 날카로워짐
import torch
import torch.nn.functional as F

def temperature_softmax(logits, temperature=1.0):
    """온도로 조정된 소프트맥스."""
    return F.softmax(logits / temperature, dim=-1)

# 예시
logits = torch.tensor([2.0, 1.0, 0.1, 0.5])
print("T=1:", temperature_softmax(logits, T=1).numpy().round(3))
# [0.596, 0.219, 0.090, 0.096]
print("T=4:", temperature_softmax(logits, T=4).numpy().round(3))
# [0.345, 0.262, 0.195, 0.216] — 더 균일

1.3 Hinton의 KD 손실 함수

KD 손실은 두 항의 가중 합입니다:

KL Divergence 항 (소프트 타겟 매칭):

L_KD = T^2 * KLDiv(softmax_T(student_logits), softmax_T(teacher_logits))

T^2 스케일링: 그래디언트 크기를 T에 무관하게 유지하기 위함.

Cross-Entropy 항 (하드 타겟 학습):

L_CE = CrossEntropy(student_logits, true_labels)

총 손실:

L = alpha * L_CE + (1 - alpha) * L_KD

1.4 완전한 PyTorch 구현

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets


class KnowledgeDistillationLoss(nn.Module):
    """
    Hinton et al. (2015)의 지식 증류 손실.
    L = alpha * CE(student, labels) + (1-alpha) * T^2 * KLDiv(student_soft, teacher_soft)
    """
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.T = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, labels):
        # 하드 타겟 손실
        loss_ce = self.ce_loss(student_logits, labels)

        # 소프트 타겟 손실 (KL Divergence)
        student_soft = F.log_softmax(student_logits / self.T, dim=-1)
        teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
        loss_kd = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)

        return self.alpha * loss_ce + (1 - self.alpha) * loss_kd


def train_with_distillation(
    teacher, student, train_loader,
    num_epochs=10, temperature=4.0, alpha=0.5,
    device='cuda'
):
    """Teacher-Student 증류 학습 루프."""
    teacher = teacher.to(device).eval()  # Teacher는 frozen
    student = student.to(device)

    # Teacher 파라미터 동결
    for param in teacher.parameters():
        param.requires_grad = False

    criterion = KnowledgeDistillationLoss(temperature=temperature, alpha=alpha)
    optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    for epoch in range(num_epochs):
        student.train()
        total_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Teacher 추론 (그래디언트 불필요)
            with torch.no_grad():
                teacher_logits = teacher(images)

            # Student 추론
            student_logits = student(images)

            # KD 손실 계산
            loss = criterion(student_logits, teacher_logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * images.size(0)
            correct += (student_logits.argmax(1) == labels).sum().item()
            total += images.size(0)

        scheduler.step()
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Loss: {total_loss/total:.4f} | "
              f"Acc: {correct/total:.4f}")

    return student


# 실제 사용 예시
# Teacher: ResNet50 (25M params), Student: ResNet18 (11M params)
teacher = models.resnet50(weights='DEFAULT')
teacher.fc = nn.Linear(2048, 10)

student = models.resnet18(weights=None)
student.fc = nn.Linear(512, 10)

# 파라미터 수 비교
t_params = sum(p.numel() for p in teacher.parameters())
s_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {t_params:,} params")   # ~25.6M
print(f"Student: {s_params:,} params")   # ~11.2M
print(f"압축률: {t_params/s_params:.1f}x")

2. 다양한 증류 방법

2.1 Response-based 증류 (로짓 매칭)

가장 기본적인 방법으로, 앞서 설명한 Hinton KD가 대표적입니다. Student가 Teacher의 **최종 출력(로짓)**을 모방합니다.

class ResponseBasedDistillation(nn.Module):
    """최종 출력만 사용하는 Response-based 증류."""
    def __init__(self, temperature=4.0):
        super().__init__()
        self.T = temperature

    def forward(self, student_logits, teacher_logits):
        student_soft = F.log_softmax(student_logits / self.T, dim=-1)
        teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
        return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.T ** 2)

2.2 Feature-based 증류 (중간 레이어)

FitNets (Romero et al., 2015)에서 제안. Teacher의 중간 레이어 feature map을 Student가 따라하도록 학습합니다.

Teacher와 Student의 채널 수가 다를 수 있으므로 Regressor 네트워크로 차원을 맞춥니다.

class FeatureDistillationLoss(nn.Module):
    """중간 레이어 feature를 매칭하는 증류."""
    def __init__(self, teacher_channels, student_channels):
        super().__init__()
        # Student feature를 Teacher feature 공간으로 projection
        self.regressor = nn.Sequential(
            nn.Conv2d(student_channels, teacher_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(teacher_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, student_feat, teacher_feat):
        # Student feature를 Teacher 차원으로 변환
        projected = self.regressor(student_feat)
        # MSE 손실로 feature 맞추기
        return F.mse_loss(projected, teacher_feat.detach())


class HookBasedDistillation:
    """
    forward hook으로 중간 레이어 feature 추출.
    """
    def __init__(self, model, layer_names):
        self.features = {}
        self.hooks = []
        for name, layer in model.named_modules():
            if name in layer_names:
                hook = layer.register_forward_hook(
                    self._make_hook(name)
                )
                self.hooks.append(hook)

    def _make_hook(self, name):
        def hook(module, input, output):
            self.features[name] = output
        return hook

    def remove(self):
        for hook in self.hooks:
            hook.remove()


# 사용 예시
teacher = models.resnet50(weights='DEFAULT')
student = models.resnet18(weights=None)

# Teacher: layer3 출력 (1024채널), Student: layer3 출력 (256채널)
teacher_hook = HookBasedDistillation(teacher, ['layer3'])
student_hook = HookBasedDistillation(student, ['layer3'])

feat_distill = FeatureDistillationLoss(
    teacher_channels=1024,
    student_channels=256
)

# 학습 시
x = torch.randn(4, 3, 224, 224)
teacher_out = teacher(x)
student_out = student(x)

teacher_feat = teacher_hook.features['layer3']
student_feat = student_hook.features['layer3']

loss_feat = feat_distill(student_feat, teacher_feat)

2.3 Relation-based 증류 (상관관계 학습)

RKD (Park et al., 2019). 샘플들 사이의 관계를 Student가 모방하도록 합니다. 절대적인 값이 아니라 상대적인 구조를 학습합니다.

class RelationalKnowledgeDistillation(nn.Module):
    """
    Relational KD: 샘플 쌍의 거리 관계를 보존.
    Teacher의 embedding 공간 구조를 Student가 학습.
    """
    def __init__(self, distance_weight=25.0, angle_weight=50.0):
        super().__init__()
        self.dist_w = distance_weight
        self.angle_w = angle_weight

    def pdist(self, e, squared=False, eps=1e-12):
        """Pairwise distance computation."""
        e_sq = (e ** 2).sum(dim=1)
        prod = e @ e.t()
        res = (e_sq.unsqueeze(1) + e_sq.unsqueeze(0) - 2 * prod).clamp(min=eps)
        if not squared:
            res = res.sqrt()
        return res

    def distance_loss(self, teacher_emb, student_emb):
        """거리 관계 보존 손실."""
        t_d = self.pdist(teacher_emb)
        # 정규화
        t_d = t_d / (t_d.mean() + 1e-12)
        s_d = self.pdist(student_emb)
        s_d = s_d / (s_d.mean() + 1e-12)
        return F.smooth_l1_loss(s_d, t_d.detach())

    def angle_loss(self, teacher_emb, student_emb):
        """각도 관계 보존 손실."""
        # e_i - e_j 벡터 간 각도 관계
        td = teacher_emb.unsqueeze(0) - teacher_emb.unsqueeze(1)  # (N, N, D)
        sd = student_emb.unsqueeze(0) - student_emb.unsqueeze(1)

        # cosine similarity
        td_norm = F.normalize(td.view(-1, td.size(-1)), dim=-1)
        sd_norm = F.normalize(sd.view(-1, sd.size(-1)), dim=-1)

        t_angle = (td_norm * td_norm.flip(0)).sum(dim=-1)
        s_angle = (sd_norm * sd_norm.flip(0)).sum(dim=-1)

        return F.smooth_l1_loss(s_angle, t_angle.detach())

    def forward(self, teacher_emb, student_emb):
        loss = self.dist_w * self.distance_loss(teacher_emb, student_emb)
        loss += self.angle_w * self.angle_loss(teacher_emb, student_emb)
        return loss

2.4 Attention Transfer

Zagoruyko & Komodakis (2017). Attention map (채널별 activation 합)을 Teacher에서 Student로 전달합니다.

class AttentionTransfer(nn.Module):
    """Attention Transfer: activation map의 공간적 패턴 전달."""

    def attention_map(self, feat):
        """
        Feature map에서 attention map 계산.
        각 공간 위치에서 채널 방향으로 제곱합의 제곱근.
        """
        return F.normalize(feat.pow(2).mean(1).view(feat.size(0), -1))

    def forward(self, student_feats, teacher_feats):
        """
        student_feats, teacher_feats: list of feature maps
        각 쌍에 대해 AT 손실 계산.
        """
        loss = 0.0
        for s_feat, t_feat in zip(student_feats, teacher_feats):
            s_attn = self.attention_map(s_feat)
            t_attn = self.attention_map(t_feat)
            loss += (s_attn - t_attn.detach()).pow(2).mean()
        return loss

3. LLM 증류

3.1 DistilBERT

DistilBERT (Sanh et al., 2019)는 BERT-base (110M 파라미터)를 66M으로 압축한 모델입니다.

주요 기법:

  • 레이어 수 절반: 12개 → 6개
  • Token-type embedding 제거
  • Pooler 제거
  • Triple 손실: MLM + Distillation + Cosine embedding
from transformers import (
    DistilBertModel, BertModel,
    DistilBertTokenizer
)
import torch
import torch.nn as nn
import torch.nn.functional as F


class DistilBERTLoss(nn.Module):
    """
    DistilBERT 3중 손실:
    1. MLM CE 손실 (언어 모델링)
    2. Soft-target KD 손실 (Teacher logit 매칭)
    3. Cosine embedding 손실 (hidden state 유사도)
    """
    def __init__(self, temperature=2.0, alpha=0.5, beta=0.1):
        super().__init__()
        self.T = temperature
        self.alpha = alpha  # MLM 가중치
        self.beta = beta    # Cosine 손실 가중치

    def forward(self, student_logits, teacher_logits,
                student_hidden, teacher_hidden, mlm_labels):
        # 1. MLM 손실
        loss_mlm = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            mlm_labels.view(-1),
            ignore_index=-100
        )

        # 2. Soft KD 손실
        s_soft = F.log_softmax(student_logits / self.T, dim=-1)
        t_soft = F.softmax(teacher_logits / self.T, dim=-1)
        loss_kd = F.kl_div(s_soft, t_soft, reduction='batchmean') * (self.T ** 2)

        # 3. Cosine embedding 손실 (hidden state 유사도)
        loss_cos = 1 - F.cosine_similarity(student_hidden, teacher_hidden, dim=-1).mean()

        return (self.alpha * loss_mlm
                + (1 - self.alpha) * loss_kd
                + self.beta * loss_cos)


# 추론 예시
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')

text = "Knowledge distillation is a model compression technique."
inputs = tokenizer(text, return_tensors='pt')

with torch.no_grad():
    outputs = model(**inputs)

print(outputs.last_hidden_state.shape)
# torch.Size([1, 11, 768])

DistilBERT 성능:

  • BERT-base 대비 40% 크기 감소
  • 60% 빠른 추론 속도
  • GLUE 벤치마크 97% 성능 유지

3.2 TinyLLaMA 스타일 증류

LLM 증류에서는 Teacher의 다음 토큰 예측 분포를 Student가 학습합니다.

class LLMDistillationTrainer:
    """
    LLM 증류 학습기.
    Teacher의 로짓 분포를 Student에 전달.
    """
    def __init__(self, teacher, student, temperature=2.0, alpha=0.5):
        self.teacher = teacher
        self.student = student
        self.T = temperature
        self.alpha = alpha

    def compute_loss(self, input_ids, attention_mask, labels):
        # Teacher 추론 (no_grad)
        with torch.no_grad():
            teacher_out = self.teacher(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            teacher_logits = teacher_out.logits  # (B, seq_len, vocab_size)

        # Student 추론
        student_out = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        student_logits = student_out.logits

        # 1. CE 손실 (하드 타겟)
        shift_logits = student_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss_ce = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100
        )

        # 2. KD 손실 (소프트 타겟)
        # 유효한 토큰에 대해서만 계산
        mask = (shift_labels != -100).float()

        s_log_soft = F.log_softmax(
            shift_logits / self.T, dim=-1
        )
        t_soft = F.softmax(
            teacher_logits[..., :-1, :].contiguous() / self.T, dim=-1
        )

        # 토큰별 KL divergence
        kl_per_token = F.kl_div(
            s_log_soft, t_soft,
            reduction='none'
        ).sum(-1)  # (B, seq_len)

        loss_kd = (kl_per_token * mask).sum() / mask.sum()
        loss_kd = loss_kd * (self.T ** 2)

        return self.alpha * loss_ce + (1 - self.alpha) * loss_kd

3.3 Distil Whisper

OpenAI Whisper 음성 인식 모델의 증류판입니다.

from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration
)

# Distil-Whisper: Whisper-large-v2의 증류
# large-v2: 1550M params → distil-large-v2: 756M params (2x faster, same WER)
processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v2")
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v2")

# 6개 디코더 레이어만 유지 (원본 32개 → 2개)
# 인코더는 동일하게 유지
print(f"Encoder layers: {len(model.model.encoder.layers)}")
print(f"Decoder layers: {len(model.model.decoder.layers)}")

4. 구조적 프루닝 (Structured Pruning)

프루닝은 모델에서 중요하지 않은 파라미터나 구조를 제거하는 기법입니다.

구조적 프루닝: 필터, 헤드, 레이어 전체 단위로 제거 → 실제 속도 향상 비구조적 프루닝: 개별 가중치를 0으로 만들기 → 희소 행렬 (특수 하드웨어 필요)

4.1 필터 프루닝

CNN에서 L1/L2 norm이 작은 필터를 제거합니다.

import torch
import torch.nn as nn
import numpy as np


def get_filter_importance(conv_layer, norm='l1'):
    """각 필터의 중요도(norm) 계산."""
    weight = conv_layer.weight.data  # (out_ch, in_ch, kH, kW)
    if norm == 'l1':
        return weight.abs().sum(dim=(1, 2, 3))
    elif norm == 'l2':
        return weight.pow(2).sum(dim=(1, 2, 3)).sqrt()
    elif norm == 'gm':
        # Geometric Median 기반 (더 robust)
        return weight.view(weight.size(0), -1).norm(p=2, dim=1)


def prune_conv_layer(conv, keep_ratio=0.5):
    """
    필터 프루닝: 중요도 낮은 필터 제거.
    Returns: 새로운 Conv2d 레이어 (필터 수 감소)
    """
    importance = get_filter_importance(conv)
    n_keep = int(conv.out_channels * keep_ratio)
    # 중요도 상위 n_keep개 필터 선택
    _, indices = importance.topk(n_keep)
    indices, _ = indices.sort()

    # 새 conv 레이어 생성
    new_conv = nn.Conv2d(
        conv.in_channels,
        n_keep,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        bias=conv.bias is not None
    )
    new_conv.weight.data = conv.weight.data[indices]
    if conv.bias is not None:
        new_conv.bias.data = conv.bias.data[indices]

    return new_conv, indices


class StructuredPruner:
    """ResNet 스타일 모델의 구조적 프루닝."""

    def __init__(self, model, prune_ratio=0.5):
        self.model = model
        self.prune_ratio = prune_ratio

    def compute_layer_importance(self):
        """각 레이어 필터의 중요도 계산."""
        importance_dict = {}
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                importance = get_filter_importance(module)
                importance_dict[name] = importance
        return importance_dict

    def global_threshold_prune(self, sparsity=0.5):
        """
        전체 필터에 대해 전역 임계값으로 프루닝.
        가장 중요도 낮은 sparsity 비율의 필터 제거.
        """
        all_importances = []
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                imp = get_filter_importance(module)
                all_importances.append(imp)

        # 모든 필터 중요도를 하나로 합치기
        all_imp = torch.cat(all_importances)
        threshold = all_imp.kthvalue(int(len(all_imp) * sparsity)).values.item()

        # 임계값보다 낮은 필터 제거
        masks = {}
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                imp = get_filter_importance(module)
                masks[name] = imp >= threshold

        return masks

4.2 헤드 프루닝 (Attention Head Pruning)

트랜스포머에서 중요하지 않은 Attention Head를 제거합니다.

class AttentionHeadPruner:
    """
    Transformer의 Multi-Head Attention 헤드 프루닝.
    Michel et al. (2019): Are Sixteen Heads Really Better than One?
    """

    def compute_head_importance(self, model, dataloader, device='cuda'):
        """각 레이어, 각 헤드의 중요도 계산."""
        head_importance = {}

        model.eval()
        for batch in dataloader:
            inputs = {k: v.to(device) for k, v in batch.items()}

            # Attention 가중치 저장용 hook
            attention_weights = {}

            hooks = []
            for name, module in model.named_modules():
                if 'attention' in name.lower() and hasattr(module, 'num_heads'):
                    def hook_fn(m, inp, out, layer_name=name):
                        if hasattr(out, 'attentions') and out.attentions is not None:
                            attention_weights[layer_name] = out.attentions
                    hooks.append(module.register_forward_hook(hook_fn))

            with torch.no_grad():
                outputs = model(**inputs, output_attentions=True)

            # Head importance: attention 가중치의 분산 합
            if hasattr(outputs, 'attentions') and outputs.attentions:
                for layer_idx, attn in enumerate(outputs.attentions):
                    # attn: (batch, heads, seq, seq)
                    head_imp = attn.mean(0).var(-1).mean(-1)  # (heads,)
                    key = f"layer_{layer_idx}"
                    if key not in head_importance:
                        head_importance[key] = head_imp
                    else:
                        head_importance[key] += head_imp

            for hook in hooks:
                hook.remove()

        return head_importance

    def prune_heads(self, model, heads_to_prune):
        """
        지정된 head를 제거.
        heads_to_prune: {layer_idx: [head_indices]}
        """
        model.prune_heads(heads_to_prune)  # HuggingFace 모델 내장 기능
        return model


# HuggingFace transformers 활용
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

# layer 0의 head 0, 3, 5 제거
heads_to_prune = {0: [0, 3, 5], 6: [1, 2]}
model.prune_heads(heads_to_prune)

total_params = sum(p.numel() for p in model.parameters())
print(f"Pruned BERT params: {total_params:,}")

4.3 레이어 프루닝

성능에 기여가 적은 트랜스포머 레이어 전체를 제거합니다.

class LayerPruner:
    """트랜스포머 레이어 프루닝."""

    def compute_layer_importance_by_gradient(
        self, model, dataloader, device='cuda'
    ):
        """
        각 레이어에 대한 gradient 기반 중요도 계산.
        중요도 = |gradient * weight|의 평균
        """
        model.train()
        layer_importance = {}

        for batch in dataloader:
            inputs = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**inputs)
            loss = outputs.loss
            loss.backward()

            for i, layer in enumerate(model.encoder.layer):
                grad_sum = 0.0
                param_count = 0
                for param in layer.parameters():
                    if param.grad is not None:
                        grad_sum += (param.grad * param.data).abs().sum().item()
                        param_count += param.numel()

                key = f"layer_{i}"
                imp = grad_sum / max(param_count, 1)
                if key not in layer_importance:
                    layer_importance[key] = 0.0
                layer_importance[key] += imp

            model.zero_grad()

        return layer_importance

    def drop_layers(self, model, layers_to_drop):
        """지정된 레이어를 제거하고 남은 레이어로 재구성."""
        import copy
        new_model = copy.deepcopy(model)

        remaining = [
            layer for i, layer in enumerate(new_model.encoder.layer)
            if i not in layers_to_drop
        ]
        new_model.encoder.layer = nn.ModuleList(remaining)
        new_model.config.num_hidden_layers = len(remaining)

        return new_model

4.4 PyTorch torch.nn.utils.prune

import torch.nn.utils.prune as prune

model = models.resnet18(weights='DEFAULT')

# 특정 레이어에 비구조적 L1 프루닝 적용
conv = model.layer1[0].conv1
prune.l1_unstructured(conv, name='weight', amount=0.3)
# weight의 30%를 0으로 만듦 (mask 사용)

# 확인
print(f"Sparsity: {(conv.weight == 0).float().mean():.2f}")

# 구조적 프루닝 (필터 단위)
prune.ln_structured(conv, name='weight', amount=0.3, n=2, dim=0)

# 영구 적용 (mask 제거, 실제 weight 0화)
prune.remove(conv, 'weight')

# 여러 레이어에 반복 적용
parameters_to_prune = []
for module in model.modules():
    if isinstance(module, nn.Conv2d):
        parameters_to_prune.append((module, 'weight'))

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.4,  # 전체 파라미터의 40% 프루닝
)

# 전체 희소성 확인
total_zero = sum(
    (m.weight == 0).sum().item()
    for m in model.modules() if isinstance(m, nn.Conv2d)
)
total_params = sum(
    m.weight.numel()
    for m in model.modules() if isinstance(m, nn.Conv2d)
)
print(f"Global sparsity: {total_zero/total_params:.2%}")

5. 비구조적 프루닝 (Unstructured Pruning)

5.1 크기 기반 프루닝

절대값이 작은 가중치를 0으로 만드는 가장 단순한 방법입니다.

class MagnitudePruner:
    """크기 기반 비구조적 프루닝."""

    def __init__(self, model, sparsity=0.5):
        self.model = model
        self.sparsity = sparsity
        self.masks = {}

    def compute_global_threshold(self):
        """전체 가중치에 대한 전역 임계값 계산."""
        all_weights = []
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                all_weights.append(param.data.abs().view(-1))

        all_weights = torch.cat(all_weights)
        threshold = all_weights.kthvalue(
            int(len(all_weights) * self.sparsity)
        ).values.item()
        return threshold

    def apply_pruning(self):
        """전역 임계값으로 프루닝 마스크 생성."""
        threshold = self.compute_global_threshold()

        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                mask = (param.data.abs() >= threshold).float()
                self.masks[name] = mask
                param.data *= mask  # 임계값 이하 가중치를 0으로

        total_zeros = sum((m == 0).sum().item() for m in self.masks.values())
        total_params = sum(m.numel() for m in self.masks.values())
        print(f"실제 희소성: {total_zeros/total_params:.2%}")

    def apply_masks(self):
        """역전파 후 마스크 재적용 (기울기로 0이 복원되지 않도록)."""
        for name, param in self.model.named_parameters():
            if name in self.masks:
                param.data *= self.masks[name]

5.2 점진적 프루닝 (Gradual Magnitude Pruning)

처음부터 많이 프루닝하면 성능이 급락합니다. 학습 중에 서서히 희소성을 높이는 방법이 효과적입니다.

class GradualMagnitudePruner:
    """
    학습 중 점진적으로 희소성을 높이는 프루닝.
    Zhu & Gupta (2017): To Prune, or Not to Prune
    """
    def __init__(self, model, initial_sparsity=0.0, final_sparsity=0.8,
                 begin_step=0, end_step=1000, frequency=100):
        self.model = model
        self.initial_sparsity = initial_sparsity
        self.final_sparsity = final_sparsity
        self.begin_step = begin_step
        self.end_step = end_step
        self.frequency = frequency
        self.masks = {}
        self._init_masks()

    def _init_masks(self):
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                self.masks[name] = torch.ones_like(param.data)

    def compute_sparsity(self, step):
        """현재 스텝에서의 목표 희소성 계산 (cubic schedule)."""
        if step < self.begin_step:
            return self.initial_sparsity
        if step > self.end_step:
            return self.final_sparsity

        # Cubic decay schedule
        pct_done = (step - self.begin_step) / (self.end_step - self.begin_step)
        sparsity = (
            self.final_sparsity
            + (self.initial_sparsity - self.final_sparsity)
            * (1 - pct_done) ** 3
        )
        return sparsity

    def step(self, global_step):
        """학습 스텝에서 프루닝 업데이트."""
        if (global_step % self.frequency != 0 or
                global_step < self.begin_step or
                global_step > self.end_step):
            return

        target_sparsity = self.compute_sparsity(global_step)
        self._update_masks(target_sparsity)

    def _update_masks(self, sparsity):
        """현재 희소성 목표에 맞게 마스크 업데이트."""
        for name, param in self.model.named_parameters():
            if name not in self.masks:
                continue

            # 현재 활성 가중치만 고려 (이미 pruned된 것 제외)
            alive = param.data[self.masks[name].bool()]

            if len(alive) == 0:
                continue

            n_prune = int(sparsity * param.data.numel())
            if n_prune == 0:
                continue

            # 전체 중 하위 n_prune개 찾기
            threshold = param.data.abs().view(-1).kthvalue(n_prune).values.item()
            self.masks[name] = (param.data.abs() > threshold).float()
            param.data *= self.masks[name]

6. 가중치 공유 (Weight Sharing)

6.1 크로스 레이어 파라미터 공유 — ALBERT

ALBERT (Lan et al., 2019)는 BERT의 파라미터를 크게 줄이기 위해 모든 트랜스포머 레이어에서 동일한 파라미터를 반복 사용합니다.

class ALBERTEncoder(nn.Module):
    """
    ALBERT 스타일 가중치 공유.
    하나의 트랜스포머 레이어를 N번 반복 실행.
    """
    def __init__(self, hidden_size=768, num_heads=12,
                 intermediate_size=3072, num_layers=12):
        super().__init__()
        # 단 하나의 트랜스포머 레이어 정의
        self.shared_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=num_heads,
            dim_feedforward=intermediate_size,
            batch_first=True,
            norm_first=True,
        )
        self.num_layers = num_layers
        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, x, src_key_padding_mask=None):
        # 동일한 레이어를 num_layers번 반복
        for _ in range(self.num_layers):
            x = self.shared_layer(x, src_key_padding_mask=src_key_padding_mask)
        return self.norm(x)

    def count_parameters(self):
        # 실제 파라미터 수 (공유되므로 한 번만 카운트)
        return sum(p.numel() for p in self.parameters())


# 파라미터 비교
bert_encoder = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(768, 12, 3072, batch_first=True, norm_first=True),
    num_layers=12
)
albert_encoder = ALBERTEncoder(num_layers=12)

bert_params = sum(p.numel() for p in bert_encoder.parameters())
albert_params = albert_encoder.count_parameters()

print(f"BERT encoder: {bert_params:,}")     # ~85M
print(f"ALBERT encoder: {albert_params:,}") # ~7M
print(f"압축률: {bert_params/albert_params:.1f}x")  # ~12x

6.2 팩토라이제이션 (ALBERT 임베딩)

ALBERT는 임베딩 레이어도 분해합니다: vocab_size x hidden_size → vocab_size x embedding_size + embedding_size x hidden_size.

class FactorizedEmbedding(nn.Module):
    """
    ALBERT의 임베딩 팩토라이제이션.
    vocab_size x H → (vocab_size x E) + (E x H)
    E << H로 파라미터 대폭 감소.
    """
    def __init__(self, vocab_size, embedding_size=128, hidden_size=768):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, embedding_size)
        # E → H 선형 변환
        self.embedding_projection = nn.Linear(embedding_size, hidden_size, bias=False)

    def forward(self, input_ids):
        embed = self.word_embeddings(input_ids)        # (B, seq, E)
        return self.embedding_projection(embed)         # (B, seq, H)

# 파라미터 비교 (vocab_size=30000)
standard_embed = nn.Embedding(30000, 768)             # 23.0M params
factorized_embed = FactorizedEmbedding(30000, 128, 768)  # 3.97M params

s_params = sum(p.numel() for p in standard_embed.parameters())
f_params = sum(p.numel() for p in factorized_embed.parameters())
print(f"Standard: {s_params:,} → Factorized: {f_params:,}")
print(f"절감: {(s_params - f_params)/s_params:.1%}")

7. 신경망 구조 탐색 (NAS)

7.1 Manual Design vs AutoML

전통적으로 CNN 구조는 연구자가 수작업으로 설계했습니다 (VGG, ResNet). NAS는 이 과정을 자동화합니다.

NAS의 세 가지 핵심 요소:

  1. 검색 공간 (Search Space): 어떤 연산/구조를 탐색할지
  2. 검색 전략 (Search Strategy): 강화학습, 진화 알고리즘, 기울기 기반 등
  3. 성능 추정 (Performance Estimation): 후보 구조의 성능을 빠르게 평가

Liu et al. (2019). 구조 탐색을 연속적인 최적화 문제로 변환합니다.

import torch
import torch.nn as nn
import torch.nn.functional as F


# 후보 연산들
PRIMITIVES = [
    'none',           # 연결 없음
    'skip_connect',   # 항등 함수
    'sep_conv_3x3',   # 3x3 Separable Conv
    'sep_conv_5x5',   # 5x5 Separable Conv
    'dil_conv_3x3',   # 3x3 Dilated Conv
    'dil_conv_5x5',   # 5x5 Dilated Conv
    'avg_pool_3x3',   # 3x3 Average Pool
    'max_pool_3x3',   # 3x3 Max Pool
]


class MixedOperation(nn.Module):
    """
    DARTS: 여러 연산의 가중 합으로 엣지 표현.
    alpha (아키텍처 파라미터)가 각 연산의 가중치를 결정.
    """
    def __init__(self, C, stride, primitives):
        super().__init__()
        self.ops = nn.ModuleList([
            self._build_op(primitive, C, stride)
            for primitive in primitives
        ])
        # 아키텍처 파라미터 (학습 가능)
        self.arch_params = nn.Parameter(
            torch.randn(len(primitives)) * 1e-3
        )

    def _build_op(self, primitive, C, stride):
        if primitive == 'none':
            return nn.Sequential()  # Zero operation
        elif primitive == 'skip_connect':
            return nn.Identity() if stride == 1 else nn.Sequential(
                nn.AvgPool2d(stride, stride, padding=0),
                nn.Conv2d(C, C, 1, bias=False),
                nn.BatchNorm2d(C)
            )
        elif primitive == 'sep_conv_3x3':
            return nn.Sequential(
                nn.Conv2d(C, C, 3, stride=stride, padding=1, groups=C, bias=False),
                nn.Conv2d(C, C, 1, bias=False),
                nn.BatchNorm2d(C),
                nn.ReLU(inplace=True),
            )
        elif primitive == 'avg_pool_3x3':
            return nn.AvgPool2d(3, stride=stride, padding=1)
        elif primitive == 'max_pool_3x3':
            return nn.MaxPool2d(3, stride=stride, padding=1)
        else:
            return nn.Identity()

    def forward(self, x):
        # Softmax로 가중치 계산 후 각 연산의 가중 합
        weights = F.softmax(self.arch_params, dim=0)
        results = []
        for w, op in zip(weights, self.ops):
            try:
                result = op(x)
                if result.shape == x.shape or len(results) == 0:
                    results.append(w * result)
                else:
                    results.append(w * result)
            except Exception:
                pass

        if results:
            # 모든 연산 결과의 가중 합
            output = results[0]
            for r in results[1:]:
                if r.shape == output.shape:
                    output = output + r
            return output
        return x


class DARTSCell(nn.Module):
    """DARTS 셀: 여러 Mixed Operation으로 구성된 DAG."""
    def __init__(self, num_nodes, C, stride=1):
        super().__init__()
        self.num_nodes = num_nodes
        self.ops = nn.ModuleList()

        # 각 노드 쌍에 대해 Mixed Operation 생성
        for i in range(num_nodes):
            for j in range(i):
                self.ops.append(
                    MixedOperation(C, stride, PRIMITIVES)
                )

    def forward(self, *inputs):
        states = list(inputs)

        op_idx = 0
        for i in range(self.num_nodes):
            # 모든 이전 상태로부터 합산
            node_input = sum(
                self.ops[op_idx + j](states[j])
                for j in range(len(states))
            )
            op_idx += len(states)
            states.append(node_input)

        return states[-1]

    def get_arch_params(self):
        """아키텍처 파라미터 반환."""
        return [op.arch_params for op in self.ops]


def darts_discrete(cell):
    """
    연속적인 아키텍처를 이산 구조로 변환.
    각 엣지에서 가장 높은 가중치의 연산 선택.
    """
    for op in cell.ops:
        weights = F.softmax(op.arch_params, dim=0)
        best_op_idx = weights.argmax().item()
        print(f"  Selected: {PRIMITIVES[best_op_idx]} "
              f"(weight: {weights[best_op_idx]:.3f})")

7.3 EfficientNet의 NAS 프로세스

EfficientNet-B0는 MnasNet과 유사한 신경망 구조 탐색으로 찾은 기저 구조입니다.

# EfficientNet의 MBConv 블록 (NAS로 찾은 핵심 구조)
class MBConvBlock(nn.Module):
    """
    Mobile Inverted Bottleneck Convolution.
    EfficientNet의 기본 빌딩 블록.
    """
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, expand_ratio=6, se_ratio=0.25):
        super().__init__()
        self.use_residual = (stride == 1 and in_channels == out_channels)
        hidden_dim = int(in_channels * expand_ratio)

        layers = []
        # Expansion phase (1x1 conv)
        if expand_ratio != 1:
            layers.extend([
                nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
            ])

        # Depthwise conv
        layers.extend([
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size,
                      stride=stride,
                      padding=kernel_size // 2,
                      groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
        ])

        # Squeeze-and-Excitation (NAS가 발견한 중요 모듈)
        se_channels = max(1, int(in_channels * se_ratio))
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(hidden_dim, se_channels, 1),
            nn.SiLU(),
            nn.Conv2d(se_channels, hidden_dim, 1),
            nn.Sigmoid(),
        )

        # Output phase
        layers.extend([
            nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
        ])

        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv[:-2](x) if len(self.conv) > 2 else self.conv(x)
        # SE 모듈 적용
        out_se = self.se(out)
        out = out * out_se
        out = self.conv[-2:](out)

        if self.use_residual:
            return x + out
        return out

7.4 Once-for-All 네트워크

Cai et al. (2020). 하나의 초대형 네트워크를 학습한 후, 다양한 리소스 제약에 맞게 서브네트워크를 추출합니다.

class OFALayer(nn.Module):
    """
    Once-for-All: 다양한 커널 크기를 지원하는 탄력적 레이어.
    학습 시: 랜덤으로 커널 크기 선택 → 서브네트워크 학습
    배포 시: 특정 커널 크기만 사용
    """
    def __init__(self, channels, max_kernel=7):
        super().__init__()
        # 가장 큰 커널 크기의 conv 하나만 학습
        self.max_conv = nn.Conv2d(
            channels, channels,
            kernel_size=max_kernel,
            padding=max_kernel // 2,
            groups=channels, bias=False
        )
        self.bn = nn.BatchNorm2d(channels)
        self.act = nn.ReLU(inplace=True)
        self.active_kernel = max_kernel

    def set_active_kernel(self, kernel_size):
        """현재 활성 커널 크기 설정."""
        assert kernel_size <= self.max_conv.kernel_size[0]
        self.active_kernel = kernel_size

    def forward(self, x):
        if self.active_kernel == self.max_conv.kernel_size[0]:
            weight = self.max_conv.weight
        else:
            # 중앙에서 active_kernel 크기만큼 잘라내기
            center = self.max_conv.kernel_size[0] // 2
            half = self.active_kernel // 2
            weight = self.max_conv.weight[
                :, :,
                center - half: center + half + 1,
                center - half: center + half + 1
            ]

        padding = self.active_kernel // 2
        out = F.conv2d(x, weight, padding=padding, groups=x.size(1))
        return self.act(self.bn(out))


# Progressive shrinking 학습 (OFA의 핵심)
def progressive_shrinking_train(model, dataloader, kernels=[7, 5, 3]):
    """
    1단계: 최대 커널(7) 학습
    2단계: 7, 5 커널 랜덤 샘플링 학습
    3단계: 7, 5, 3 커널 랜덤 샘플링 학습
    """
    for stage, active_kernels in enumerate(
        [kernels[:1], kernels[:2], kernels]
    ):
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        print(f"Stage {stage+1}: kernels = {active_kernels}")

        for epoch in range(5):
            for images, labels in dataloader:
                # 랜덤으로 커널 크기 선택
                k = active_kernels[torch.randint(len(active_kernels), (1,)).item()]

                # 모든 레이어의 활성 커널 설정
                for module in model.modules():
                    if isinstance(module, OFALayer):
                        module.set_active_kernel(k)

                outputs = model(images)
                loss = F.cross_entropy(outputs, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

8. 통합 모델 압축 파이프라인

실제 배포 시에는 여러 기법을 조합합니다:

class ModelCompressionPipeline:
    """
    지식 증류 → 프루닝 → 양자화의 통합 파이프라인.
    """

    def __init__(self, teacher, student_arch, num_classes):
        self.teacher = teacher
        self.num_classes = num_classes

        # Student 모델 초기화
        self.student = student_arch

    def step1_distillation(self, train_loader, val_loader,
                           epochs=30, device='cuda'):
        """1단계: 지식 증류로 기본 학습."""
        print("=== Step 1: Knowledge Distillation ===")
        self.student = train_with_distillation(
            self.teacher, self.student, train_loader,
            num_epochs=epochs, temperature=4.0, alpha=0.5,
            device=device
        )

    def step2_pruning(self, train_loader, sparsity=0.5,
                      finetune_epochs=10, device='cuda'):
        """2단계: 점진적 프루닝 + 파인튜닝."""
        print(f"=== Step 2: Pruning (target sparsity: {sparsity:.0%}) ===")
        pruner = GradualMagnitudePruner(
            self.student,
            initial_sparsity=0.0,
            final_sparsity=sparsity,
            begin_step=0,
            end_step=len(train_loader) * finetune_epochs,
            frequency=100
        )

        optimizer = torch.optim.Adam(
            self.student.parameters(), lr=1e-4
        )
        self.student.to(device)
        step = 0

        for epoch in range(finetune_epochs):
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = self.student(images)
                loss = F.cross_entropy(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                pruner.step(step)
                pruner.apply_masks()  # 프루닝 마스크 유지
                step += 1

    def step3_quantization(self, calibration_loader, device='cpu'):
        """3단계: Post-Training Quantization."""
        print("=== Step 3: Quantization ===")
        self.student.eval().to(device)

        # Static quantization 준비
        self.student.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
        torch.ao.quantization.prepare(self.student, inplace=True)

        # 보정 데이터로 scale/zero_point 결정
        with torch.no_grad():
            for images, _ in calibration_loader:
                self.student(images.to(device))

        # 양자화 적용
        torch.ao.quantization.convert(self.student, inplace=True)
        print("Quantization complete (INT8)")
        return self.student

    def compare_models(self, val_loader, device='cuda'):
        """Teacher, Student, Pruned, Quantized 모델 성능 비교."""
        def eval_model(model, loader, dev):
            model.eval().to(dev)
            correct, total = 0, 0
            with torch.no_grad():
                for images, labels in loader:
                    images, labels = images.to(dev), labels.to(dev)
                    preds = model(images).argmax(1)
                    correct += (preds == labels).sum().item()
                    total += labels.size(0)
            return correct / total

        teacher_acc = eval_model(self.teacher, val_loader, device)
        student_acc = eval_model(self.student, val_loader, 'cpu')

        t_params = sum(p.numel() for p in self.teacher.parameters())
        s_params = sum(p.numel() for p in self.student.parameters())

        print(f"\n{'='*50}")
        print(f"Teacher  | Params: {t_params:>10,} | Acc: {teacher_acc:.4f}")
        print(f"Student  | Params: {s_params:>10,} | Acc: {student_acc:.4f}")
        print(f"압축률: {t_params/s_params:.1f}x | 성능 유지율: {student_acc/teacher_acc:.1%}")

결론

지식 증류와 모델 압축은 AI 모델을 현실 세계에서 사용 가능하게 만드는 핵심 기술입니다.

주요 요약:

  1. 지식 증류: Teacher의 소프트 타겟(확률 분포)을 통해 클래스 간 관계 정보 전달
  2. Feature 증류: 중간 레이어의 표현을 Student에 전달
  3. Relation 증류: 샘플 간 관계 구조 보존
  4. LLM 증류: DistilBERT, TinyLLM 등 실용적 대형 모델 압축
  5. 구조적 프루닝: 필터/헤드/레이어 단위 제거로 실제 속도 향상
  6. 비구조적 프루닝: 점진적 희소화로 정확도 손실 최소화
  7. 가중치 공유: ALBERT처럼 동일 파라미터를 반복 사용
  8. NAS: DARTS, OFA 등 자동화된 구조 탐색

실제 배포 시에는 증류 → 프루닝 → 양자화를 순차적으로 적용하는 파이프라인이 가장 효과적입니다.


참고 문헌

Knowledge Distillation Complete Guide: Model Compression and Lightweight Techniques

Introduction

As deep learning models grow more powerful, they also grow larger. Models like GPT-4 or Llama 3 405B require hundreds of gigabytes of memory, making them impossible to run on mobile devices or edge hardware. Knowledge Distillation and model compression techniques are the essential tools for making these models smaller and faster while preserving as much of their accuracy as possible.

This guide covers:

  • Theoretical foundations of knowledge distillation with complete PyTorch implementations
  • Multiple distillation paradigms: Response-based, Feature-based, Relation-based
  • LLM distillation case studies (DistilBERT, TinyLLM, Distil Whisper)
  • Structured and unstructured pruning
  • Weight sharing and Neural Architecture Search (NAS)

1. Knowledge Distillation Fundamentals

1.1 The Teacher-Student Framework

Knowledge distillation was introduced by Hinton, Vinyals, and Dean in 2015. The core idea: transfer the "knowledge" of a large Teacher model into a small Student model.

Teacher Model (large, accurate)
Soft targets (probability distributions)
Student Model (small, fast)

Rather than training with hard labels (one-hot vectors), the student learns from the Teacher's soft targets — the full softmax output distribution.

For example, classifying a cat image:

  • Hard target: [0, 0, 1, 0, 0] (only the correct class is 1)
  • Teacher soft target: [0.01, 0.05, 0.85, 0.07, 0.02]

The soft target encodes the information that "this image is a cat, but it has some resemblance to a tiger." This inter-class similarity information provides rich supervision for the student.

1.2 The Temperature Parameter

When softmax outputs are too confident ([0.001, 0.002, 0.997, ...]), they carry almost no more information than hard labels. The temperature parameter T makes distributions softer:

softmax_T(z_i) = exp(z_i / T) / sum_j(exp(z_j / T))
  • When T is 1: standard softmax
  • When T is greater than 1: flatter, more uniform distribution (more information)
  • When T is less than 1: sharper, more peaked distribution
import torch
import torch.nn.functional as F

def temperature_softmax(logits, temperature=1.0):
    """Temperature-scaled softmax."""
    return F.softmax(logits / temperature, dim=-1)

# Example
logits = torch.tensor([2.0, 1.0, 0.1, 0.5])
print("T=1:", temperature_softmax(logits, temperature=1).numpy().round(3))
# [0.596, 0.219, 0.090, 0.096]
print("T=4:", temperature_softmax(logits, temperature=4).numpy().round(3))
# [0.345, 0.262, 0.195, 0.216] — more uniform, more information

1.3 Hinton's KD Loss Function

The KD loss is a weighted sum of two terms:

KL Divergence term (soft target matching):

L_KD = T^2 * KLDiv(softmax_T(student_logits), softmax_T(teacher_logits))

The T^2 scaling keeps gradient magnitudes independent of the temperature value.

Cross-Entropy term (hard target learning):

L_CE = CrossEntropy(student_logits, true_labels)

Total loss:

L = alpha * L_CE + (1 - alpha) * L_KD

1.4 Complete PyTorch Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader
import torchvision.transforms as transforms


class KnowledgeDistillationLoss(nn.Module):
    """
    Hinton et al. (2015) Knowledge Distillation Loss.
    L = alpha * CE(student, labels) + (1-alpha) * T^2 * KLDiv(student_soft, teacher_soft)
    """
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.T = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, labels):
        # Hard target loss
        loss_ce = self.ce_loss(student_logits, labels)

        # Soft target loss (KL Divergence)
        student_soft = F.log_softmax(student_logits / self.T, dim=-1)
        teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
        loss_kd = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)

        return self.alpha * loss_ce + (1 - self.alpha) * loss_kd


def train_with_distillation(
    teacher, student, train_loader,
    num_epochs=10, temperature=4.0, alpha=0.5,
    device='cuda'
):
    """Teacher-Student distillation training loop."""
    teacher = teacher.to(device).eval()  # Teacher is frozen
    student = student.to(device)

    # Freeze teacher parameters
    for param in teacher.parameters():
        param.requires_grad = False

    criterion = KnowledgeDistillationLoss(temperature=temperature, alpha=alpha)
    optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    for epoch in range(num_epochs):
        student.train()
        total_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Teacher inference (no gradient needed)
            with torch.no_grad():
                teacher_logits = teacher(images)

            # Student inference
            student_logits = student(images)

            # KD loss computation
            loss = criterion(student_logits, teacher_logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * images.size(0)
            correct += (student_logits.argmax(1) == labels).sum().item()
            total += images.size(0)

        scheduler.step()
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Loss: {total_loss/total:.4f} | "
              f"Acc: {correct/total:.4f}")

    return student


# Practical example
# Teacher: ResNet50 (25M params), Student: ResNet18 (11M params)
teacher = models.resnet50(weights='DEFAULT')
teacher.fc = nn.Linear(2048, 10)

student = models.resnet18(weights=None)
student.fc = nn.Linear(512, 10)

# Parameter count comparison
t_params = sum(p.numel() for p in teacher.parameters())
s_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {t_params:,} params")   # ~25.6M
print(f"Student: {s_params:,} params")   # ~11.2M
print(f"Compression ratio: {t_params/s_params:.1f}x")

2. Distillation Paradigms

2.1 Response-based Distillation (Logit Matching)

The most basic approach — the student mimics the Teacher's final outputs (logits). Hinton's original KD is the canonical example.

class ResponseBasedDistillation(nn.Module):
    """Response-based distillation using only final outputs."""
    def __init__(self, temperature=4.0):
        super().__init__()
        self.T = temperature

    def forward(self, student_logits, teacher_logits):
        student_soft = F.log_softmax(student_logits / self.T, dim=-1)
        teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
        return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.T ** 2)

2.2 Feature-based Distillation (Intermediate Layers)

Proposed in FitNets (Romero et al., 2015). The student learns to match the Teacher's intermediate feature maps, not just the final outputs.

Since the teacher and student may have different channel counts, a Regressor network projects the student features to match the teacher's dimensions.

class FeatureDistillationLoss(nn.Module):
    """Distillation via intermediate layer feature matching."""
    def __init__(self, teacher_channels, student_channels):
        super().__init__()
        # Project student features to teacher feature space
        self.regressor = nn.Sequential(
            nn.Conv2d(student_channels, teacher_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(teacher_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, student_feat, teacher_feat):
        projected = self.regressor(student_feat)
        return F.mse_loss(projected, teacher_feat.detach())


class HookBasedDistillation:
    """
    Use forward hooks to extract intermediate layer features
    without modifying the model's forward() method.
    """
    def __init__(self, model, layer_names):
        self.features = {}
        self.hooks = []
        for name, layer in model.named_modules():
            if name in layer_names:
                hook = layer.register_forward_hook(
                    self._make_hook(name)
                )
                self.hooks.append(hook)

    def _make_hook(self, name):
        def hook(module, input, output):
            self.features[name] = output
        return hook

    def remove(self):
        for hook in self.hooks:
            hook.remove()


# Usage example
teacher = models.resnet50(weights='DEFAULT')
student = models.resnet18(weights=None)

# Teacher layer3: 1024 channels; Student layer3: 256 channels
teacher_hook = HookBasedDistillation(teacher, ['layer3'])
student_hook = HookBasedDistillation(student, ['layer3'])

feat_distill = FeatureDistillationLoss(
    teacher_channels=1024,
    student_channels=256
)

# During training
x = torch.randn(4, 3, 224, 224)
teacher_out = teacher(x)
student_out = student(x)

teacher_feat = teacher_hook.features['layer3']
student_feat = student_hook.features['layer3']

loss_feat = feat_distill(student_feat, teacher_feat)

2.3 Relation-based Distillation

RKD (Park et al., 2019). Rather than matching absolute values, the student learns to mimic the relational structure between samples in the teacher's embedding space.

class RelationalKnowledgeDistillation(nn.Module):
    """
    Relational KD: preserve pairwise distance and angle relationships
    between samples in the embedding space.
    """
    def __init__(self, distance_weight=25.0, angle_weight=50.0):
        super().__init__()
        self.dist_w = distance_weight
        self.angle_w = angle_weight

    def pdist(self, e, squared=False, eps=1e-12):
        """Pairwise distance matrix computation."""
        e_sq = (e ** 2).sum(dim=1)
        prod = e @ e.t()
        res = (e_sq.unsqueeze(1) + e_sq.unsqueeze(0) - 2 * prod).clamp(min=eps)
        if not squared:
            res = res.sqrt()
        return res

    def distance_loss(self, teacher_emb, student_emb):
        """Preserve pairwise distance relationships."""
        t_d = self.pdist(teacher_emb)
        t_d = t_d / (t_d.mean() + 1e-12)  # normalize
        s_d = self.pdist(student_emb)
        s_d = s_d / (s_d.mean() + 1e-12)
        return F.smooth_l1_loss(s_d, t_d.detach())

    def angle_loss(self, teacher_emb, student_emb):
        """Preserve angular relationships between triplets."""
        td = teacher_emb.unsqueeze(0) - teacher_emb.unsqueeze(1)  # (N, N, D)
        sd = student_emb.unsqueeze(0) - student_emb.unsqueeze(1)

        td_norm = F.normalize(td.view(-1, td.size(-1)), dim=-1)
        sd_norm = F.normalize(sd.view(-1, sd.size(-1)), dim=-1)

        t_angle = (td_norm * td_norm.flip(0)).sum(dim=-1)
        s_angle = (sd_norm * sd_norm.flip(0)).sum(dim=-1)

        return F.smooth_l1_loss(s_angle, t_angle.detach())

    def forward(self, teacher_emb, student_emb):
        loss = self.dist_w * self.distance_loss(teacher_emb, student_emb)
        loss += self.angle_w * self.angle_loss(teacher_emb, student_emb)
        return loss

2.4 Attention Transfer

Zagoruyko and Komodakis (2017). Transfer attention maps — spatial activation patterns — from teacher to student.

class AttentionTransfer(nn.Module):
    """
    Attention Transfer: transfer spatial attention patterns
    derived from feature map activations.
    """

    def attention_map(self, feat):
        """
        Compute spatial attention map from feature maps.
        Sums squared activations across channels and normalizes.
        """
        return F.normalize(feat.pow(2).mean(1).view(feat.size(0), -1))

    def forward(self, student_feats, teacher_feats):
        """
        student_feats, teacher_feats: lists of feature maps at multiple layers.
        Returns summed AT loss across all pairs.
        """
        loss = 0.0
        for s_feat, t_feat in zip(student_feats, teacher_feats):
            s_attn = self.attention_map(s_feat)
            t_attn = self.attention_map(t_feat)
            loss += (s_attn - t_attn.detach()).pow(2).mean()
        return loss

3. LLM Distillation

3.1 DistilBERT

DistilBERT (Sanh et al., 2019) compresses BERT-base (110M parameters) to 66M parameters.

Key techniques:

  • Half the layers: 12 → 6
  • Remove token-type embeddings
  • Remove the pooler
  • Triple loss: MLM + Distillation + Cosine embedding loss
from transformers import (
    DistilBertModel,
    DistilBertTokenizer
)
import torch
import torch.nn as nn
import torch.nn.functional as F


class DistilBERTLoss(nn.Module):
    """
    DistilBERT triple loss:
    1. MLM CE loss (language modeling)
    2. Soft-target KD loss (teacher logit matching)
    3. Cosine embedding loss (hidden state similarity)
    """
    def __init__(self, temperature=2.0, alpha=0.5, beta=0.1):
        super().__init__()
        self.T = temperature
        self.alpha = alpha  # MLM weight
        self.beta = beta    # Cosine loss weight

    def forward(self, student_logits, teacher_logits,
                student_hidden, teacher_hidden, mlm_labels):
        # 1. MLM loss
        loss_mlm = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            mlm_labels.view(-1),
            ignore_index=-100
        )

        # 2. Soft KD loss
        s_soft = F.log_softmax(student_logits / self.T, dim=-1)
        t_soft = F.softmax(teacher_logits / self.T, dim=-1)
        loss_kd = F.kl_div(s_soft, t_soft, reduction='batchmean') * (self.T ** 2)

        # 3. Cosine embedding loss (hidden state alignment)
        loss_cos = 1 - F.cosine_similarity(student_hidden, teacher_hidden, dim=-1).mean()

        return (self.alpha * loss_mlm
                + (1 - self.alpha) * loss_kd
                + self.beta * loss_cos)


# Inference example
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')

text = "Knowledge distillation is a model compression technique."
inputs = tokenizer(text, return_tensors='pt')

with torch.no_grad():
    outputs = model(**inputs)

print(outputs.last_hidden_state.shape)
# torch.Size([1, 11, 768])

DistilBERT performance:

  • 40% smaller than BERT-base
  • 60% faster inference
  • Retains 97% of BERT's GLUE performance

3.2 LLM Distillation Pattern

For large language models, the student learns from the teacher's next-token prediction distributions.

class LLMDistillationTrainer:
    """
    LLM distillation trainer.
    Transfers teacher's token distribution to student.
    """
    def __init__(self, teacher, student, temperature=2.0, alpha=0.5):
        self.teacher = teacher
        self.student = student
        self.T = temperature
        self.alpha = alpha

    def compute_loss(self, input_ids, attention_mask, labels):
        # Teacher inference (no_grad)
        with torch.no_grad():
            teacher_out = self.teacher(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            teacher_logits = teacher_out.logits  # (B, seq_len, vocab_size)

        # Student inference
        student_out = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        student_logits = student_out.logits

        # 1. CE loss (hard target)
        shift_logits = student_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss_ce = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100
        )

        # 2. KD loss (soft target)
        # Compute only on valid tokens
        mask = (shift_labels != -100).float()

        s_log_soft = F.log_softmax(
            shift_logits / self.T, dim=-1
        )
        t_soft = F.softmax(
            teacher_logits[..., :-1, :].contiguous() / self.T, dim=-1
        )

        # Per-token KL divergence
        kl_per_token = F.kl_div(
            s_log_soft, t_soft,
            reduction='none'
        ).sum(-1)  # (B, seq_len-1)

        loss_kd = (kl_per_token * mask).sum() / mask.sum()
        loss_kd = loss_kd * (self.T ** 2)

        return self.alpha * loss_ce + (1 - self.alpha) * loss_kd

3.3 Distil Whisper

OpenAI's Whisper speech recognition model has a distilled variant that demonstrates aggressive LLM compression.

from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration
)

# Distil-Whisper: distills Whisper-large-v2
# large-v2: 1550M params -> distil-large-v2: 756M params (2x faster, similar WER)
processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v2")
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v2")

# Only 2 decoder layers retained (original: 32)
# Encoder kept intact
print(f"Encoder layers: {len(model.model.encoder.layers)}")
print(f"Decoder layers: {len(model.model.decoder.layers)}")

4. Structured Pruning

Pruning removes unimportant parameters or structures from a model.

Structured pruning: Remove entire filters, heads, or layers → real speedup on all hardware Unstructured pruning: Set individual weights to zero → sparse matrices (needs specialized hardware)

4.1 Filter Pruning

Remove filters with small L1/L2 norms from CNN layers.

import torch
import torch.nn as nn
import numpy as np


def get_filter_importance(conv_layer, norm='l1'):
    """Compute the importance (norm) of each filter."""
    weight = conv_layer.weight.data  # (out_ch, in_ch, kH, kW)
    if norm == 'l1':
        return weight.abs().sum(dim=(1, 2, 3))
    elif norm == 'l2':
        return weight.pow(2).sum(dim=(1, 2, 3)).sqrt()
    elif norm == 'gm':
        return weight.view(weight.size(0), -1).norm(p=2, dim=1)


def prune_conv_layer(conv, keep_ratio=0.5):
    """
    Filter pruning: remove filters with lowest importance.
    Returns new Conv2d with reduced filter count.
    """
    importance = get_filter_importance(conv)
    n_keep = int(conv.out_channels * keep_ratio)
    _, indices = importance.topk(n_keep)
    indices, _ = indices.sort()

    new_conv = nn.Conv2d(
        conv.in_channels,
        n_keep,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        bias=conv.bias is not None
    )
    new_conv.weight.data = conv.weight.data[indices]
    if conv.bias is not None:
        new_conv.bias.data = conv.bias.data[indices]

    return new_conv, indices


class StructuredPruner:
    """Structured pruning for ResNet-style models."""

    def __init__(self, model, prune_ratio=0.5):
        self.model = model
        self.prune_ratio = prune_ratio

    def global_threshold_prune(self, sparsity=0.5):
        """
        Prune globally: remove the sparsity fraction of filters
        with lowest importance across all layers.
        """
        all_importances = []
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                imp = get_filter_importance(module)
                all_importances.append(imp)

        all_imp = torch.cat(all_importances)
        threshold = all_imp.kthvalue(int(len(all_imp) * sparsity)).values.item()

        masks = {}
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                imp = get_filter_importance(module)
                masks[name] = imp >= threshold

        return masks

4.2 Attention Head Pruning

Remove unimportant attention heads from transformer models.

class AttentionHeadPruner:
    """
    Multi-Head Attention head pruning.
    Michel et al. (2019): Are Sixteen Heads Really Better than One?
    """

    def compute_head_importance(self, model, dataloader, device='cuda'):
        """Compute per-layer, per-head importance scores."""
        head_importance = {}

        model.eval()
        for batch in dataloader:
            inputs = {k: v.to(device) for k, v in batch.items()}

            with torch.no_grad():
                outputs = model(**inputs, output_attentions=True)

            if hasattr(outputs, 'attentions') and outputs.attentions:
                for layer_idx, attn in enumerate(outputs.attentions):
                    # attn: (batch, heads, seq, seq)
                    # Importance: variance of attention weights (more focused = more important)
                    head_imp = attn.mean(0).var(-1).mean(-1)  # (heads,)
                    key = f"layer_{layer_idx}"
                    if key not in head_importance:
                        head_importance[key] = head_imp
                    else:
                        head_importance[key] += head_imp

        return head_importance

    def prune_heads(self, model, heads_to_prune):
        """
        Remove specified heads.
        heads_to_prune: dict mapping layer_idx to list of head indices.
        """
        model.prune_heads(heads_to_prune)  # Built-in HuggingFace method
        return model


# HuggingFace example
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

# Remove heads 0, 3, 5 from layer 0 and heads 1, 2 from layer 6
heads_to_prune = {0: [0, 3, 5], 6: [1, 2]}
model.prune_heads(heads_to_prune)

total_params = sum(p.numel() for p in model.parameters())
print(f"Pruned BERT params: {total_params:,}")

4.3 Layer Pruning

Remove entire transformer layers that contribute least to model performance.

class LayerPruner:
    """Transformer layer pruning."""

    def compute_layer_importance_by_gradient(
        self, model, dataloader, device='cuda'
    ):
        """
        Gradient-based layer importance.
        Importance = mean |gradient * weight| across the layer.
        """
        model.train()
        layer_importance = {}

        for batch in dataloader:
            inputs = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**inputs)
            loss = outputs.loss
            loss.backward()

            for i, layer in enumerate(model.encoder.layer):
                grad_sum = 0.0
                param_count = 0
                for param in layer.parameters():
                    if param.grad is not None:
                        grad_sum += (param.grad * param.data).abs().sum().item()
                        param_count += param.numel()

                key = f"layer_{i}"
                imp = grad_sum / max(param_count, 1)
                if key not in layer_importance:
                    layer_importance[key] = 0.0
                layer_importance[key] += imp

            model.zero_grad()

        return layer_importance

    def drop_layers(self, model, layers_to_drop):
        """Remove specified layers and rebuild the encoder."""
        import copy
        new_model = copy.deepcopy(model)

        remaining = [
            layer for i, layer in enumerate(new_model.encoder.layer)
            if i not in layers_to_drop
        ]
        new_model.encoder.layer = nn.ModuleList(remaining)
        new_model.config.num_hidden_layers = len(remaining)

        return new_model

4.4 PyTorch torch.nn.utils.prune

import torch.nn.utils.prune as prune

model = models.resnet18(weights='DEFAULT')

# Apply unstructured L1 pruning to a specific layer
conv = model.layer1[0].conv1
prune.l1_unstructured(conv, name='weight', amount=0.3)
# Sets 30% of weights to zero using a mask

print(f"Sparsity: {(conv.weight == 0).float().mean():.2f}")

# Structured pruning (filter-level)
prune.ln_structured(conv, name='weight', amount=0.3, n=2, dim=0)

# Make pruning permanent (remove mask, bake zeros into weights)
prune.remove(conv, 'weight')

# Apply to all Conv2d layers globally
parameters_to_prune = []
for module in model.modules():
    if isinstance(module, nn.Conv2d):
        parameters_to_prune.append((module, 'weight'))

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.4,  # Prune 40% of all conv weights
)

# Check global sparsity
total_zero = sum(
    (m.weight == 0).sum().item()
    for m in model.modules() if isinstance(m, nn.Conv2d)
)
total_params = sum(
    m.weight.numel()
    for m in model.modules() if isinstance(m, nn.Conv2d)
)
print(f"Global sparsity: {total_zero/total_params:.2%}")

5. Unstructured Pruning

5.1 Magnitude-based Pruning

The simplest method: set weights with the smallest absolute values to zero.

class MagnitudePruner:
    """Magnitude-based unstructured pruning."""

    def __init__(self, model, sparsity=0.5):
        self.model = model
        self.sparsity = sparsity
        self.masks = {}

    def compute_global_threshold(self):
        """Compute the global threshold across all weight tensors."""
        all_weights = []
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                all_weights.append(param.data.abs().view(-1))

        all_weights = torch.cat(all_weights)
        threshold = all_weights.kthvalue(
            int(len(all_weights) * self.sparsity)
        ).values.item()
        return threshold

    def apply_pruning(self):
        """Create pruning masks using the global threshold."""
        threshold = self.compute_global_threshold()

        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                mask = (param.data.abs() >= threshold).float()
                self.masks[name] = mask
                param.data *= mask

        total_zeros = sum((m == 0).sum().item() for m in self.masks.values())
        total_params = sum(m.numel() for m in self.masks.values())
        print(f"Actual sparsity: {total_zeros/total_params:.2%}")

    def apply_masks(self):
        """Re-apply masks after each gradient step to prevent dead weights from reviving."""
        for name, param in self.model.named_parameters():
            if name in self.masks:
                param.data *= self.masks[name]

5.2 Gradual Magnitude Pruning

Pruning too aggressively from the start causes a sharp accuracy drop. A gradual schedule that slowly increases sparsity during training works much better.

class GradualMagnitudePruner:
    """
    Gradually increase sparsity during training.
    Zhu and Gupta (2017): To Prune, or Not to Prune.
    """
    def __init__(self, model, initial_sparsity=0.0, final_sparsity=0.8,
                 begin_step=0, end_step=1000, frequency=100):
        self.model = model
        self.initial_sparsity = initial_sparsity
        self.final_sparsity = final_sparsity
        self.begin_step = begin_step
        self.end_step = end_step
        self.frequency = frequency
        self.masks = {}
        self._init_masks()

    def _init_masks(self):
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                self.masks[name] = torch.ones_like(param.data)

    def compute_sparsity(self, step):
        """Compute target sparsity at current step using cubic schedule."""
        if step < self.begin_step:
            return self.initial_sparsity
        if step > self.end_step:
            return self.final_sparsity

        pct_done = (step - self.begin_step) / (self.end_step - self.begin_step)
        sparsity = (
            self.final_sparsity
            + (self.initial_sparsity - self.final_sparsity)
            * (1 - pct_done) ** 3
        )
        return sparsity

    def step(self, global_step):
        """Update pruning masks at appropriate training steps."""
        if (global_step % self.frequency != 0 or
                global_step < self.begin_step or
                global_step > self.end_step):
            return

        target_sparsity = self.compute_sparsity(global_step)
        self._update_masks(target_sparsity)

    def _update_masks(self, sparsity):
        """Update masks to match target sparsity."""
        for name, param in self.model.named_parameters():
            if name not in self.masks:
                continue

            n_prune = int(sparsity * param.data.numel())
            if n_prune == 0:
                continue

            threshold = param.data.abs().view(-1).kthvalue(n_prune).values.item()
            self.masks[name] = (param.data.abs() > threshold).float()
            param.data *= self.masks[name]

6. Weight Sharing

6.1 Cross-Layer Parameter Sharing — ALBERT

ALBERT (Lan et al., 2019) drastically reduces BERT's parameters by reusing the same parameters across all transformer layers.

class ALBERTEncoder(nn.Module):
    """
    ALBERT-style weight sharing.
    A single transformer layer is executed N times.
    """
    def __init__(self, hidden_size=768, num_heads=12,
                 intermediate_size=3072, num_layers=12):
        super().__init__()
        # Only ONE transformer layer is defined
        self.shared_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=num_heads,
            dim_feedforward=intermediate_size,
            batch_first=True,
            norm_first=True,
        )
        self.num_layers = num_layers
        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, x, src_key_padding_mask=None):
        # Execute the same layer num_layers times
        for _ in range(self.num_layers):
            x = self.shared_layer(x, src_key_padding_mask=src_key_padding_mask)
        return self.norm(x)


# Parameter comparison
bert_encoder = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(768, 12, 3072, batch_first=True, norm_first=True),
    num_layers=12
)
albert_encoder = ALBERTEncoder(num_layers=12)

bert_params = sum(p.numel() for p in bert_encoder.parameters())
albert_params = sum(p.numel() for p in albert_encoder.parameters())

print(f"BERT encoder: {bert_params:,}")      # ~85M
print(f"ALBERT encoder: {albert_params:,}")  # ~7M
print(f"Compression: {bert_params/albert_params:.1f}x")

6.2 Factorized Embeddings (ALBERT)

ALBERT also factorizes the embedding matrix: vocab_size x hidden_size → (vocab_size x embedding_size) + (embedding_size x hidden_size), where embedding_size is much smaller than hidden_size.

class FactorizedEmbedding(nn.Module):
    """
    ALBERT-style factorized embedding.
    vocab_size x H factored into (vocab_size x E) + (E x H)
    where E << H.
    """
    def __init__(self, vocab_size, embedding_size=128, hidden_size=768):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, embedding_size)
        self.embedding_projection = nn.Linear(embedding_size, hidden_size, bias=False)

    def forward(self, input_ids):
        embed = self.word_embeddings(input_ids)        # (B, seq, E)
        return self.embedding_projection(embed)         # (B, seq, H)

# Parameter comparison (vocab_size=30000)
standard_embed = nn.Embedding(30000, 768)                  # 23.0M params
factorized_embed = FactorizedEmbedding(30000, 128, 768)    # 3.97M params

s_params = sum(p.numel() for p in standard_embed.parameters())
f_params = sum(p.numel() for p in factorized_embed.parameters())
print(f"Standard: {s_params:,} -> Factorized: {f_params:,}")
print(f"Savings: {(s_params - f_params)/s_params:.1%}")

7. Neural Architecture Search (NAS)

7.1 Manual Design vs AutoML

Traditional CNNs were designed by researchers by hand (VGG, ResNet). NAS automates this process.

Three core components of NAS:

  1. Search Space: What operations and structures to explore
  2. Search Strategy: Reinforcement learning, evolutionary algorithms, gradient-based methods
  3. Performance Estimation: Quickly evaluate candidate architectures

Liu et al. (2019). Transforms architecture search into a continuous optimization problem.

import torch
import torch.nn as nn
import torch.nn.functional as F


# Candidate operations
PRIMITIVES = [
    'none',           # Zero connection
    'skip_connect',   # Identity
    'sep_conv_3x3',   # 3x3 Separable Conv
    'sep_conv_5x5',   # 5x5 Separable Conv
    'dil_conv_3x3',   # 3x3 Dilated Conv
    'dil_conv_5x5',   # 5x5 Dilated Conv
    'avg_pool_3x3',   # 3x3 Average Pool
    'max_pool_3x3',   # 3x3 Max Pool
]


class MixedOperation(nn.Module):
    """
    DARTS: represent each edge as a weighted mixture of operations.
    Architecture parameters (alpha) control each operation's weight.
    """
    def __init__(self, C, stride, primitives):
        super().__init__()
        self.ops = nn.ModuleList([
            self._build_op(primitive, C, stride)
            for primitive in primitives
        ])
        # Learnable architecture parameters
        self.arch_params = nn.Parameter(
            torch.randn(len(primitives)) * 1e-3
        )

    def _build_op(self, primitive, C, stride):
        if primitive == 'none':
            return nn.Sequential()
        elif primitive == 'skip_connect':
            return nn.Identity() if stride == 1 else nn.Sequential(
                nn.AvgPool2d(stride, stride, padding=0),
                nn.Conv2d(C, C, 1, bias=False),
                nn.BatchNorm2d(C)
            )
        elif primitive == 'sep_conv_3x3':
            return nn.Sequential(
                nn.Conv2d(C, C, 3, stride=stride, padding=1, groups=C, bias=False),
                nn.Conv2d(C, C, 1, bias=False),
                nn.BatchNorm2d(C),
                nn.ReLU(inplace=True),
            )
        elif primitive == 'avg_pool_3x3':
            return nn.AvgPool2d(3, stride=stride, padding=1)
        elif primitive == 'max_pool_3x3':
            return nn.MaxPool2d(3, stride=stride, padding=1)
        else:
            return nn.Identity()

    def forward(self, x):
        weights = F.softmax(self.arch_params, dim=0)
        results = []
        for w, op in zip(weights, self.ops):
            try:
                result = op(x)
                results.append(w * result)
            except Exception:
                pass

        if results:
            output = results[0]
            for r in results[1:]:
                if r.shape == output.shape:
                    output = output + r
            return output
        return x


def darts_discrete(cell):
    """
    Convert continuous architecture to discrete by
    selecting the highest-weight operation at each edge.
    """
    for op in cell.ops:
        weights = F.softmax(op.arch_params, dim=0)
        best_op_idx = weights.argmax().item()
        print(f"  Selected: {PRIMITIVES[best_op_idx]} "
              f"(weight: {weights[best_op_idx]:.3f})")

7.3 EfficientNet's NAS Process

EfficientNet-B0 is the base architecture found by an MnasNet-style NAS. The search optimized for accuracy under a mobile latency constraint. The result is a network built from MBConv blocks with squeeze-and-excitation.

class MBConvBlock(nn.Module):
    """
    Mobile Inverted Bottleneck Convolution.
    EfficientNet's core building block, found by NAS.
    """
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, expand_ratio=6, se_ratio=0.25):
        super().__init__()
        self.use_residual = (stride == 1 and in_channels == out_channels)
        hidden_dim = int(in_channels * expand_ratio)

        layers = []
        # Expansion phase
        if expand_ratio != 1:
            layers.extend([
                nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
            ])

        # Depthwise conv
        layers.extend([
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size,
                      stride=stride,
                      padding=kernel_size // 2,
                      groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
        ])

        # Squeeze-and-Excitation (key module discovered by NAS)
        se_channels = max(1, int(in_channels * se_ratio))
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(hidden_dim, se_channels, 1),
            nn.SiLU(),
            nn.Conv2d(se_channels, hidden_dim, 1),
            nn.Sigmoid(),
        )

        # Output projection
        layers.extend([
            nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
        ])

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        out = self.layers(x)
        if self.use_residual:
            return x + out
        return out

7.4 Once-for-All Network

Cai et al. (2020). Train one large network and then extract sub-networks tailored to different resource constraints — without retraining.

class OFALayer(nn.Module):
    """
    Once-for-All elastic layer that supports multiple kernel sizes.
    Training: randomly sample kernel size -> train sub-networks
    Deployment: use only the required kernel size
    """
    def __init__(self, channels, max_kernel=7):
        super().__init__()
        # Only one conv is stored — the largest kernel size
        self.max_conv = nn.Conv2d(
            channels, channels,
            kernel_size=max_kernel,
            padding=max_kernel // 2,
            groups=channels, bias=False
        )
        self.bn = nn.BatchNorm2d(channels)
        self.act = nn.ReLU(inplace=True)
        self.active_kernel = max_kernel

    def set_active_kernel(self, kernel_size):
        """Set the active kernel size for this layer."""
        assert kernel_size <= self.max_conv.kernel_size[0]
        self.active_kernel = kernel_size

    def forward(self, x):
        if self.active_kernel == self.max_conv.kernel_size[0]:
            weight = self.max_conv.weight
        else:
            # Extract the center sub-kernel
            center = self.max_conv.kernel_size[0] // 2
            half = self.active_kernel // 2
            weight = self.max_conv.weight[
                :, :,
                center - half: center + half + 1,
                center - half: center + half + 1
            ]

        padding = self.active_kernel // 2
        out = F.conv2d(x, weight, padding=padding, groups=x.size(1))
        return self.act(self.bn(out))


def progressive_shrinking_train(model, dataloader, kernels=[7, 5, 3]):
    """
    Stage 1: train with max kernel (7)
    Stage 2: randomly sample from [7, 5]
    Stage 3: randomly sample from [7, 5, 3]
    """
    for stage, active_kernels in enumerate(
        [kernels[:1], kernels[:2], kernels]
    ):
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        print(f"Stage {stage+1}: kernels = {active_kernels}")

        for epoch in range(5):
            for images, labels in dataloader:
                # Randomly sample kernel size each batch
                k = active_kernels[torch.randint(len(active_kernels), (1,)).item()]

                for module in model.modules():
                    if isinstance(module, OFALayer):
                        module.set_active_kernel(k)

                outputs = model(images)
                loss = F.cross_entropy(outputs, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

8. Integrated Model Compression Pipeline

In practice, multiple techniques are combined sequentially:

class ModelCompressionPipeline:
    """
    Integrated pipeline: Distillation -> Pruning -> Quantization.
    """

    def __init__(self, teacher, student, num_classes):
        self.teacher = teacher
        self.num_classes = num_classes
        self.student = student

    def step1_distillation(self, train_loader,
                           epochs=30, device='cuda'):
        """Step 1: Knowledge distillation."""
        print("=== Step 1: Knowledge Distillation ===")
        self.student = train_with_distillation(
            self.teacher, self.student, train_loader,
            num_epochs=epochs, temperature=4.0, alpha=0.5,
            device=device
        )

    def step2_pruning(self, train_loader, sparsity=0.5,
                      finetune_epochs=10, device='cuda'):
        """Step 2: Gradual pruning + fine-tuning."""
        print(f"=== Step 2: Pruning (target: {sparsity:.0%}) ===")
        pruner = GradualMagnitudePruner(
            self.student,
            initial_sparsity=0.0,
            final_sparsity=sparsity,
            begin_step=0,
            end_step=len(train_loader) * finetune_epochs,
            frequency=100
        )

        optimizer = torch.optim.Adam(self.student.parameters(), lr=1e-4)
        self.student.to(device)
        step = 0

        for epoch in range(finetune_epochs):
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = self.student(images)
                loss = F.cross_entropy(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                pruner.step(step)
                pruner.apply_masks()
                step += 1

    def step3_quantization(self, calibration_loader, device='cpu'):
        """Step 3: Post-Training Static Quantization."""
        print("=== Step 3: Quantization ===")
        self.student.eval().to(device)

        self.student.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
        torch.ao.quantization.prepare(self.student, inplace=True)

        with torch.no_grad():
            for images, _ in calibration_loader:
                self.student(images.to(device))

        torch.ao.quantization.convert(self.student, inplace=True)
        print("Quantization complete (INT8)")
        return self.student

    def compare_models(self, val_loader, device='cuda'):
        """Compare Teacher vs compressed Student."""
        def eval_model(model, loader, dev):
            model.eval().to(dev)
            correct, total = 0, 0
            with torch.no_grad():
                for images, labels in loader:
                    images, labels = images.to(dev), labels.to(dev)
                    preds = model(images).argmax(1)
                    correct += (preds == labels).sum().item()
                    total += labels.size(0)
            return correct / total

        teacher_acc = eval_model(self.teacher, val_loader, device)
        student_acc = eval_model(self.student, val_loader, 'cpu')

        t_params = sum(p.numel() for p in self.teacher.parameters())
        s_params = sum(p.numel() for p in self.student.parameters())

        print(f"\n{'='*55}")
        print(f"Teacher  | Params: {t_params:>10,} | Acc: {teacher_acc:.4f}")
        print(f"Student  | Params: {s_params:>10,} | Acc: {student_acc:.4f}")
        print(f"Compression: {t_params/s_params:.1f}x | "
              f"Retained accuracy: {student_acc/teacher_acc:.1%}")

Summary

Knowledge distillation and model compression are the essential bridge between research-scale models and real-world deployment.

Key takeaways:

  1. Knowledge Distillation: Transfer inter-class relationship information via teacher's soft targets (probability distributions)
  2. Feature Distillation: Transfer intermediate representations, not just final outputs
  3. Relation Distillation: Preserve the structural relationships between samples
  4. LLM Distillation: DistilBERT, TinyLLaMA, and Distil Whisper demonstrate practical large-model compression
  5. Structured Pruning: Remove filters, heads, and layers for real hardware speedup
  6. Unstructured Pruning: Gradual sparsity schedules minimize accuracy loss
  7. Weight Sharing: ALBERT-style parameter reuse achieves up to 12x compression
  8. NAS: DARTS and Once-for-All automate architecture design for target hardware

The best practical approach is a sequential pipeline: distillation -> pruning -> quantization. Each step builds on the previous, allowing each compression technique to work on an already-compressed model.


References

  • Hinton et al. (2015). Distilling the Knowledge in a Neural Network. https://arxiv.org/abs/1503.02531
  • Romero et al. (2015). FitNets: Hints for Thin Deep Nets.
  • Zagoruyko and Komodakis (2017). Paying More Attention to Attention.
  • Park et al. (2019). Relational Knowledge Distillation.
  • Sanh et al. (2019). DistilBERT, a distilled version of BERT. https://arxiv.org/abs/1910.01108
  • Lan et al. (2019). ALBERT: A Lite BERT for Self-supervised Learning of Language Representations.
  • Liu et al. (2019). DARTS: Differentiable Architecture Search.
  • Cai et al. (2020). Once-for-All: Train One Network and Specialize it for Efficient Deployment.
  • Zhu and Gupta (2017). To Prune, or Not to Prune: Exploring the Efficacy of Pruning for Model Compression.
  • Tan and Le (2019). EfficientNet. https://arxiv.org/abs/1905.11946
  • PyTorch Pruning Documentation: https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.html