Skip to content
Published on

지식 증류(Knowledge Distillation) 완전 가이드: 모델 경량화와 압축 기법

Authors

소개

딥러닝 모델이 강력해질수록 크기도 커집니다. GPT-4, Llama 3 405B 같은 대형 모델은 수백 GB의 메모리를 필요로 하며, 모바일 기기나 엣지 디바이스에서는 실행이 불가능합니다. **지식 증류(Knowledge Distillation)**와 모델 압축 기법은 큰 모델의 성능을 최대한 유지하면서 모델을 작고 빠르게 만드는 핵심 기술입니다.

이 가이드에서 다루는 내용:

  • 지식 증류의 이론적 배경과 완전한 PyTorch 구현
  • 다양한 증류 방법 (Response-based, Feature-based, Relation-based)
  • LLM 증류 사례 (DistilBERT, TinyLLM)
  • 구조적/비구조적 프루닝
  • 가중치 공유와 NAS

1. 지식 증류 기초

1.1 Teacher-Student 프레임워크

지식 증류는 2015년 Hinton, Vinyals, Dean이 제안한 기법으로, 큰 Teacher 모델의 "지식"을 작은 Student 모델로 전달하는 것이 핵심입니다.

Teacher Model (크고 정확)
      │ 소프트 타겟 (확률 분포) 전달
Student Model (작고 빠름)

단순히 정답 레이블(하드 타겟)로 학습하는 것과 달리, Teacher의 **소프트 타겟(softmax 출력 확률 분포)**을 사용합니다.

예를 들어, 고양이 이미지 분류:

  • 하드 타겟: [0, 0, 1, 0, 0] (정답만 1)
  • Teacher 소프트 타겟: [0.01, 0.05, 0.85, 0.07, 0.02]

소프트 타겟에는 "이 이미지는 고양이지만 호랑이와 약간 비슷하다"는 정보가 담겨 있습니다. 이런 클래스 간 유사성 정보가 Student 학습에 큰 도움이 됩니다.

1.2 온도(Temperature) 파라미터

소프트맥스 출력이 너무 confident하면 [0.001, 0.002, 0.997, ...]처럼 거의 하드 타겟과 같아져서 정보가 적습니다. 온도 파라미터 T로 분포를 부드럽게 만듭니다:

softmax_T(z_i) = exp(z_i / T) / sum_j(exp(z_j / T))
  • T가 1일 때: 일반 소프트맥스
  • T가 1보다 클 때: 분포가 더 균일해짐 (더 많은 정보 전달)
  • T가 1보다 작을 때: 분포가 더 날카로워짐
import torch
import torch.nn.functional as F

def temperature_softmax(logits, temperature=1.0):
    """온도로 조정된 소프트맥스."""
    return F.softmax(logits / temperature, dim=-1)

# 예시
logits = torch.tensor([2.0, 1.0, 0.1, 0.5])
print("T=1:", temperature_softmax(logits, T=1).numpy().round(3))
# [0.596, 0.219, 0.090, 0.096]
print("T=4:", temperature_softmax(logits, T=4).numpy().round(3))
# [0.345, 0.262, 0.195, 0.216] — 더 균일

1.3 Hinton의 KD 손실 함수

KD 손실은 두 항의 가중 합입니다:

KL Divergence 항 (소프트 타겟 매칭):

L_KD = T^2 * KLDiv(softmax_T(student_logits), softmax_T(teacher_logits))

T^2 스케일링: 그래디언트 크기를 T에 무관하게 유지하기 위함.

Cross-Entropy 항 (하드 타겟 학습):

L_CE = CrossEntropy(student_logits, true_labels)

총 손실:

L = alpha * L_CE + (1 - alpha) * L_KD

1.4 완전한 PyTorch 구현

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets


class KnowledgeDistillationLoss(nn.Module):
    """
    Hinton et al. (2015)의 지식 증류 손실.
    L = alpha * CE(student, labels) + (1-alpha) * T^2 * KLDiv(student_soft, teacher_soft)
    """
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.T = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, labels):
        # 하드 타겟 손실
        loss_ce = self.ce_loss(student_logits, labels)

        # 소프트 타겟 손실 (KL Divergence)
        student_soft = F.log_softmax(student_logits / self.T, dim=-1)
        teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
        loss_kd = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)

        return self.alpha * loss_ce + (1 - self.alpha) * loss_kd


def train_with_distillation(
    teacher, student, train_loader,
    num_epochs=10, temperature=4.0, alpha=0.5,
    device='cuda'
):
    """Teacher-Student 증류 학습 루프."""
    teacher = teacher.to(device).eval()  # Teacher는 frozen
    student = student.to(device)

    # Teacher 파라미터 동결
    for param in teacher.parameters():
        param.requires_grad = False

    criterion = KnowledgeDistillationLoss(temperature=temperature, alpha=alpha)
    optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    for epoch in range(num_epochs):
        student.train()
        total_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Teacher 추론 (그래디언트 불필요)
            with torch.no_grad():
                teacher_logits = teacher(images)

            # Student 추론
            student_logits = student(images)

            # KD 손실 계산
            loss = criterion(student_logits, teacher_logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * images.size(0)
            correct += (student_logits.argmax(1) == labels).sum().item()
            total += images.size(0)

        scheduler.step()
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Loss: {total_loss/total:.4f} | "
              f"Acc: {correct/total:.4f}")

    return student


# 실제 사용 예시
# Teacher: ResNet50 (25M params), Student: ResNet18 (11M params)
teacher = models.resnet50(weights='DEFAULT')
teacher.fc = nn.Linear(2048, 10)

student = models.resnet18(weights=None)
student.fc = nn.Linear(512, 10)

# 파라미터 수 비교
t_params = sum(p.numel() for p in teacher.parameters())
s_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {t_params:,} params")   # ~25.6M
print(f"Student: {s_params:,} params")   # ~11.2M
print(f"압축률: {t_params/s_params:.1f}x")

2. 다양한 증류 방법

2.1 Response-based 증류 (로짓 매칭)

가장 기본적인 방법으로, 앞서 설명한 Hinton KD가 대표적입니다. Student가 Teacher의 **최종 출력(로짓)**을 모방합니다.

class ResponseBasedDistillation(nn.Module):
    """최종 출력만 사용하는 Response-based 증류."""
    def __init__(self, temperature=4.0):
        super().__init__()
        self.T = temperature

    def forward(self, student_logits, teacher_logits):
        student_soft = F.log_softmax(student_logits / self.T, dim=-1)
        teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
        return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.T ** 2)

2.2 Feature-based 증류 (중간 레이어)

FitNets (Romero et al., 2015)에서 제안. Teacher의 중간 레이어 feature map을 Student가 따라하도록 학습합니다.

Teacher와 Student의 채널 수가 다를 수 있으므로 Regressor 네트워크로 차원을 맞춥니다.

class FeatureDistillationLoss(nn.Module):
    """중간 레이어 feature를 매칭하는 증류."""
    def __init__(self, teacher_channels, student_channels):
        super().__init__()
        # Student feature를 Teacher feature 공간으로 projection
        self.regressor = nn.Sequential(
            nn.Conv2d(student_channels, teacher_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(teacher_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, student_feat, teacher_feat):
        # Student feature를 Teacher 차원으로 변환
        projected = self.regressor(student_feat)
        # MSE 손실로 feature 맞추기
        return F.mse_loss(projected, teacher_feat.detach())


class HookBasedDistillation:
    """
    forward hook으로 중간 레이어 feature 추출.
    """
    def __init__(self, model, layer_names):
        self.features = {}
        self.hooks = []
        for name, layer in model.named_modules():
            if name in layer_names:
                hook = layer.register_forward_hook(
                    self._make_hook(name)
                )
                self.hooks.append(hook)

    def _make_hook(self, name):
        def hook(module, input, output):
            self.features[name] = output
        return hook

    def remove(self):
        for hook in self.hooks:
            hook.remove()


# 사용 예시
teacher = models.resnet50(weights='DEFAULT')
student = models.resnet18(weights=None)

# Teacher: layer3 출력 (1024채널), Student: layer3 출력 (256채널)
teacher_hook = HookBasedDistillation(teacher, ['layer3'])
student_hook = HookBasedDistillation(student, ['layer3'])

feat_distill = FeatureDistillationLoss(
    teacher_channels=1024,
    student_channels=256
)

# 학습 시
x = torch.randn(4, 3, 224, 224)
teacher_out = teacher(x)
student_out = student(x)

teacher_feat = teacher_hook.features['layer3']
student_feat = student_hook.features['layer3']

loss_feat = feat_distill(student_feat, teacher_feat)

2.3 Relation-based 증류 (상관관계 학습)

RKD (Park et al., 2019). 샘플들 사이의 관계를 Student가 모방하도록 합니다. 절대적인 값이 아니라 상대적인 구조를 학습합니다.

class RelationalKnowledgeDistillation(nn.Module):
    """
    Relational KD: 샘플 쌍의 거리 관계를 보존.
    Teacher의 embedding 공간 구조를 Student가 학습.
    """
    def __init__(self, distance_weight=25.0, angle_weight=50.0):
        super().__init__()
        self.dist_w = distance_weight
        self.angle_w = angle_weight

    def pdist(self, e, squared=False, eps=1e-12):
        """Pairwise distance computation."""
        e_sq = (e ** 2).sum(dim=1)
        prod = e @ e.t()
        res = (e_sq.unsqueeze(1) + e_sq.unsqueeze(0) - 2 * prod).clamp(min=eps)
        if not squared:
            res = res.sqrt()
        return res

    def distance_loss(self, teacher_emb, student_emb):
        """거리 관계 보존 손실."""
        t_d = self.pdist(teacher_emb)
        # 정규화
        t_d = t_d / (t_d.mean() + 1e-12)
        s_d = self.pdist(student_emb)
        s_d = s_d / (s_d.mean() + 1e-12)
        return F.smooth_l1_loss(s_d, t_d.detach())

    def angle_loss(self, teacher_emb, student_emb):
        """각도 관계 보존 손실."""
        # e_i - e_j 벡터 간 각도 관계
        td = teacher_emb.unsqueeze(0) - teacher_emb.unsqueeze(1)  # (N, N, D)
        sd = student_emb.unsqueeze(0) - student_emb.unsqueeze(1)

        # cosine similarity
        td_norm = F.normalize(td.view(-1, td.size(-1)), dim=-1)
        sd_norm = F.normalize(sd.view(-1, sd.size(-1)), dim=-1)

        t_angle = (td_norm * td_norm.flip(0)).sum(dim=-1)
        s_angle = (sd_norm * sd_norm.flip(0)).sum(dim=-1)

        return F.smooth_l1_loss(s_angle, t_angle.detach())

    def forward(self, teacher_emb, student_emb):
        loss = self.dist_w * self.distance_loss(teacher_emb, student_emb)
        loss += self.angle_w * self.angle_loss(teacher_emb, student_emb)
        return loss

2.4 Attention Transfer

Zagoruyko & Komodakis (2017). Attention map (채널별 activation 합)을 Teacher에서 Student로 전달합니다.

class AttentionTransfer(nn.Module):
    """Attention Transfer: activation map의 공간적 패턴 전달."""

    def attention_map(self, feat):
        """
        Feature map에서 attention map 계산.
        각 공간 위치에서 채널 방향으로 제곱합의 제곱근.
        """
        return F.normalize(feat.pow(2).mean(1).view(feat.size(0), -1))

    def forward(self, student_feats, teacher_feats):
        """
        student_feats, teacher_feats: list of feature maps
        각 쌍에 대해 AT 손실 계산.
        """
        loss = 0.0
        for s_feat, t_feat in zip(student_feats, teacher_feats):
            s_attn = self.attention_map(s_feat)
            t_attn = self.attention_map(t_feat)
            loss += (s_attn - t_attn.detach()).pow(2).mean()
        return loss

3. LLM 증류

3.1 DistilBERT

DistilBERT (Sanh et al., 2019)는 BERT-base (110M 파라미터)를 66M으로 압축한 모델입니다.

주요 기법:

  • 레이어 수 절반: 12개 → 6개
  • Token-type embedding 제거
  • Pooler 제거
  • Triple 손실: MLM + Distillation + Cosine embedding
from transformers import (
    DistilBertModel, BertModel,
    DistilBertTokenizer
)
import torch
import torch.nn as nn
import torch.nn.functional as F


class DistilBERTLoss(nn.Module):
    """
    DistilBERT 3중 손실:
    1. MLM CE 손실 (언어 모델링)
    2. Soft-target KD 손실 (Teacher logit 매칭)
    3. Cosine embedding 손실 (hidden state 유사도)
    """
    def __init__(self, temperature=2.0, alpha=0.5, beta=0.1):
        super().__init__()
        self.T = temperature
        self.alpha = alpha  # MLM 가중치
        self.beta = beta    # Cosine 손실 가중치

    def forward(self, student_logits, teacher_logits,
                student_hidden, teacher_hidden, mlm_labels):
        # 1. MLM 손실
        loss_mlm = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            mlm_labels.view(-1),
            ignore_index=-100
        )

        # 2. Soft KD 손실
        s_soft = F.log_softmax(student_logits / self.T, dim=-1)
        t_soft = F.softmax(teacher_logits / self.T, dim=-1)
        loss_kd = F.kl_div(s_soft, t_soft, reduction='batchmean') * (self.T ** 2)

        # 3. Cosine embedding 손실 (hidden state 유사도)
        loss_cos = 1 - F.cosine_similarity(student_hidden, teacher_hidden, dim=-1).mean()

        return (self.alpha * loss_mlm
                + (1 - self.alpha) * loss_kd
                + self.beta * loss_cos)


# 추론 예시
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')

text = "Knowledge distillation is a model compression technique."
inputs = tokenizer(text, return_tensors='pt')

with torch.no_grad():
    outputs = model(**inputs)

print(outputs.last_hidden_state.shape)
# torch.Size([1, 11, 768])

DistilBERT 성능:

  • BERT-base 대비 40% 크기 감소
  • 60% 빠른 추론 속도
  • GLUE 벤치마크 97% 성능 유지

3.2 TinyLLaMA 스타일 증류

LLM 증류에서는 Teacher의 다음 토큰 예측 분포를 Student가 학습합니다.

class LLMDistillationTrainer:
    """
    LLM 증류 학습기.
    Teacher의 로짓 분포를 Student에 전달.
    """
    def __init__(self, teacher, student, temperature=2.0, alpha=0.5):
        self.teacher = teacher
        self.student = student
        self.T = temperature
        self.alpha = alpha

    def compute_loss(self, input_ids, attention_mask, labels):
        # Teacher 추론 (no_grad)
        with torch.no_grad():
            teacher_out = self.teacher(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            teacher_logits = teacher_out.logits  # (B, seq_len, vocab_size)

        # Student 추론
        student_out = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        student_logits = student_out.logits

        # 1. CE 손실 (하드 타겟)
        shift_logits = student_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss_ce = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100
        )

        # 2. KD 손실 (소프트 타겟)
        # 유효한 토큰에 대해서만 계산
        mask = (shift_labels != -100).float()

        s_log_soft = F.log_softmax(
            shift_logits / self.T, dim=-1
        )
        t_soft = F.softmax(
            teacher_logits[..., :-1, :].contiguous() / self.T, dim=-1
        )

        # 토큰별 KL divergence
        kl_per_token = F.kl_div(
            s_log_soft, t_soft,
            reduction='none'
        ).sum(-1)  # (B, seq_len)

        loss_kd = (kl_per_token * mask).sum() / mask.sum()
        loss_kd = loss_kd * (self.T ** 2)

        return self.alpha * loss_ce + (1 - self.alpha) * loss_kd

3.3 Distil Whisper

OpenAI Whisper 음성 인식 모델의 증류판입니다.

from transformers import (
    WhisperProcessor, WhisperForConditionalGeneration
)

# Distil-Whisper: Whisper-large-v2의 증류
# large-v2: 1550M params → distil-large-v2: 756M params (2x faster, same WER)
processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v2")
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v2")

# 6개 디코더 레이어만 유지 (원본 32개 → 2개)
# 인코더는 동일하게 유지
print(f"Encoder layers: {len(model.model.encoder.layers)}")
print(f"Decoder layers: {len(model.model.decoder.layers)}")

4. 구조적 프루닝 (Structured Pruning)

프루닝은 모델에서 중요하지 않은 파라미터나 구조를 제거하는 기법입니다.

구조적 프루닝: 필터, 헤드, 레이어 전체 단위로 제거 → 실제 속도 향상 비구조적 프루닝: 개별 가중치를 0으로 만들기 → 희소 행렬 (특수 하드웨어 필요)

4.1 필터 프루닝

CNN에서 L1/L2 norm이 작은 필터를 제거합니다.

import torch
import torch.nn as nn
import numpy as np


def get_filter_importance(conv_layer, norm='l1'):
    """각 필터의 중요도(norm) 계산."""
    weight = conv_layer.weight.data  # (out_ch, in_ch, kH, kW)
    if norm == 'l1':
        return weight.abs().sum(dim=(1, 2, 3))
    elif norm == 'l2':
        return weight.pow(2).sum(dim=(1, 2, 3)).sqrt()
    elif norm == 'gm':
        # Geometric Median 기반 (더 robust)
        return weight.view(weight.size(0), -1).norm(p=2, dim=1)


def prune_conv_layer(conv, keep_ratio=0.5):
    """
    필터 프루닝: 중요도 낮은 필터 제거.
    Returns: 새로운 Conv2d 레이어 (필터 수 감소)
    """
    importance = get_filter_importance(conv)
    n_keep = int(conv.out_channels * keep_ratio)
    # 중요도 상위 n_keep개 필터 선택
    _, indices = importance.topk(n_keep)
    indices, _ = indices.sort()

    # 새 conv 레이어 생성
    new_conv = nn.Conv2d(
        conv.in_channels,
        n_keep,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        bias=conv.bias is not None
    )
    new_conv.weight.data = conv.weight.data[indices]
    if conv.bias is not None:
        new_conv.bias.data = conv.bias.data[indices]

    return new_conv, indices


class StructuredPruner:
    """ResNet 스타일 모델의 구조적 프루닝."""

    def __init__(self, model, prune_ratio=0.5):
        self.model = model
        self.prune_ratio = prune_ratio

    def compute_layer_importance(self):
        """각 레이어 필터의 중요도 계산."""
        importance_dict = {}
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                importance = get_filter_importance(module)
                importance_dict[name] = importance
        return importance_dict

    def global_threshold_prune(self, sparsity=0.5):
        """
        전체 필터에 대해 전역 임계값으로 프루닝.
        가장 중요도 낮은 sparsity 비율의 필터 제거.
        """
        all_importances = []
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                imp = get_filter_importance(module)
                all_importances.append(imp)

        # 모든 필터 중요도를 하나로 합치기
        all_imp = torch.cat(all_importances)
        threshold = all_imp.kthvalue(int(len(all_imp) * sparsity)).values.item()

        # 임계값보다 낮은 필터 제거
        masks = {}
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                imp = get_filter_importance(module)
                masks[name] = imp >= threshold

        return masks

4.2 헤드 프루닝 (Attention Head Pruning)

트랜스포머에서 중요하지 않은 Attention Head를 제거합니다.

class AttentionHeadPruner:
    """
    Transformer의 Multi-Head Attention 헤드 프루닝.
    Michel et al. (2019): Are Sixteen Heads Really Better than One?
    """

    def compute_head_importance(self, model, dataloader, device='cuda'):
        """각 레이어, 각 헤드의 중요도 계산."""
        head_importance = {}

        model.eval()
        for batch in dataloader:
            inputs = {k: v.to(device) for k, v in batch.items()}

            # Attention 가중치 저장용 hook
            attention_weights = {}

            hooks = []
            for name, module in model.named_modules():
                if 'attention' in name.lower() and hasattr(module, 'num_heads'):
                    def hook_fn(m, inp, out, layer_name=name):
                        if hasattr(out, 'attentions') and out.attentions is not None:
                            attention_weights[layer_name] = out.attentions
                    hooks.append(module.register_forward_hook(hook_fn))

            with torch.no_grad():
                outputs = model(**inputs, output_attentions=True)

            # Head importance: attention 가중치의 분산 합
            if hasattr(outputs, 'attentions') and outputs.attentions:
                for layer_idx, attn in enumerate(outputs.attentions):
                    # attn: (batch, heads, seq, seq)
                    head_imp = attn.mean(0).var(-1).mean(-1)  # (heads,)
                    key = f"layer_{layer_idx}"
                    if key not in head_importance:
                        head_importance[key] = head_imp
                    else:
                        head_importance[key] += head_imp

            for hook in hooks:
                hook.remove()

        return head_importance

    def prune_heads(self, model, heads_to_prune):
        """
        지정된 head를 제거.
        heads_to_prune: {layer_idx: [head_indices]}
        """
        model.prune_heads(heads_to_prune)  # HuggingFace 모델 내장 기능
        return model


# HuggingFace transformers 활용
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

# layer 0의 head 0, 3, 5 제거
heads_to_prune = {0: [0, 3, 5], 6: [1, 2]}
model.prune_heads(heads_to_prune)

total_params = sum(p.numel() for p in model.parameters())
print(f"Pruned BERT params: {total_params:,}")

4.3 레이어 프루닝

성능에 기여가 적은 트랜스포머 레이어 전체를 제거합니다.

class LayerPruner:
    """트랜스포머 레이어 프루닝."""

    def compute_layer_importance_by_gradient(
        self, model, dataloader, device='cuda'
    ):
        """
        각 레이어에 대한 gradient 기반 중요도 계산.
        중요도 = |gradient * weight|의 평균
        """
        model.train()
        layer_importance = {}

        for batch in dataloader:
            inputs = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**inputs)
            loss = outputs.loss
            loss.backward()

            for i, layer in enumerate(model.encoder.layer):
                grad_sum = 0.0
                param_count = 0
                for param in layer.parameters():
                    if param.grad is not None:
                        grad_sum += (param.grad * param.data).abs().sum().item()
                        param_count += param.numel()

                key = f"layer_{i}"
                imp = grad_sum / max(param_count, 1)
                if key not in layer_importance:
                    layer_importance[key] = 0.0
                layer_importance[key] += imp

            model.zero_grad()

        return layer_importance

    def drop_layers(self, model, layers_to_drop):
        """지정된 레이어를 제거하고 남은 레이어로 재구성."""
        import copy
        new_model = copy.deepcopy(model)

        remaining = [
            layer for i, layer in enumerate(new_model.encoder.layer)
            if i not in layers_to_drop
        ]
        new_model.encoder.layer = nn.ModuleList(remaining)
        new_model.config.num_hidden_layers = len(remaining)

        return new_model

4.4 PyTorch torch.nn.utils.prune

import torch.nn.utils.prune as prune

model = models.resnet18(weights='DEFAULT')

# 특정 레이어에 비구조적 L1 프루닝 적용
conv = model.layer1[0].conv1
prune.l1_unstructured(conv, name='weight', amount=0.3)
# weight의 30%를 0으로 만듦 (mask 사용)

# 확인
print(f"Sparsity: {(conv.weight == 0).float().mean():.2f}")

# 구조적 프루닝 (필터 단위)
prune.ln_structured(conv, name='weight', amount=0.3, n=2, dim=0)

# 영구 적용 (mask 제거, 실제 weight 0화)
prune.remove(conv, 'weight')

# 여러 레이어에 반복 적용
parameters_to_prune = []
for module in model.modules():
    if isinstance(module, nn.Conv2d):
        parameters_to_prune.append((module, 'weight'))

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.4,  # 전체 파라미터의 40% 프루닝
)

# 전체 희소성 확인
total_zero = sum(
    (m.weight == 0).sum().item()
    for m in model.modules() if isinstance(m, nn.Conv2d)
)
total_params = sum(
    m.weight.numel()
    for m in model.modules() if isinstance(m, nn.Conv2d)
)
print(f"Global sparsity: {total_zero/total_params:.2%}")

5. 비구조적 프루닝 (Unstructured Pruning)

5.1 크기 기반 프루닝

절대값이 작은 가중치를 0으로 만드는 가장 단순한 방법입니다.

class MagnitudePruner:
    """크기 기반 비구조적 프루닝."""

    def __init__(self, model, sparsity=0.5):
        self.model = model
        self.sparsity = sparsity
        self.masks = {}

    def compute_global_threshold(self):
        """전체 가중치에 대한 전역 임계값 계산."""
        all_weights = []
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                all_weights.append(param.data.abs().view(-1))

        all_weights = torch.cat(all_weights)
        threshold = all_weights.kthvalue(
            int(len(all_weights) * self.sparsity)
        ).values.item()
        return threshold

    def apply_pruning(self):
        """전역 임계값으로 프루닝 마스크 생성."""
        threshold = self.compute_global_threshold()

        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                mask = (param.data.abs() >= threshold).float()
                self.masks[name] = mask
                param.data *= mask  # 임계값 이하 가중치를 0으로

        total_zeros = sum((m == 0).sum().item() for m in self.masks.values())
        total_params = sum(m.numel() for m in self.masks.values())
        print(f"실제 희소성: {total_zeros/total_params:.2%}")

    def apply_masks(self):
        """역전파 후 마스크 재적용 (기울기로 0이 복원되지 않도록)."""
        for name, param in self.model.named_parameters():
            if name in self.masks:
                param.data *= self.masks[name]

5.2 점진적 프루닝 (Gradual Magnitude Pruning)

처음부터 많이 프루닝하면 성능이 급락합니다. 학습 중에 서서히 희소성을 높이는 방법이 효과적입니다.

class GradualMagnitudePruner:
    """
    학습 중 점진적으로 희소성을 높이는 프루닝.
    Zhu & Gupta (2017): To Prune, or Not to Prune
    """
    def __init__(self, model, initial_sparsity=0.0, final_sparsity=0.8,
                 begin_step=0, end_step=1000, frequency=100):
        self.model = model
        self.initial_sparsity = initial_sparsity
        self.final_sparsity = final_sparsity
        self.begin_step = begin_step
        self.end_step = end_step
        self.frequency = frequency
        self.masks = {}
        self._init_masks()

    def _init_masks(self):
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                self.masks[name] = torch.ones_like(param.data)

    def compute_sparsity(self, step):
        """현재 스텝에서의 목표 희소성 계산 (cubic schedule)."""
        if step < self.begin_step:
            return self.initial_sparsity
        if step > self.end_step:
            return self.final_sparsity

        # Cubic decay schedule
        pct_done = (step - self.begin_step) / (self.end_step - self.begin_step)
        sparsity = (
            self.final_sparsity
            + (self.initial_sparsity - self.final_sparsity)
            * (1 - pct_done) ** 3
        )
        return sparsity

    def step(self, global_step):
        """학습 스텝에서 프루닝 업데이트."""
        if (global_step % self.frequency != 0 or
                global_step < self.begin_step or
                global_step > self.end_step):
            return

        target_sparsity = self.compute_sparsity(global_step)
        self._update_masks(target_sparsity)

    def _update_masks(self, sparsity):
        """현재 희소성 목표에 맞게 마스크 업데이트."""
        for name, param in self.model.named_parameters():
            if name not in self.masks:
                continue

            # 현재 활성 가중치만 고려 (이미 pruned된 것 제외)
            alive = param.data[self.masks[name].bool()]

            if len(alive) == 0:
                continue

            n_prune = int(sparsity * param.data.numel())
            if n_prune == 0:
                continue

            # 전체 중 하위 n_prune개 찾기
            threshold = param.data.abs().view(-1).kthvalue(n_prune).values.item()
            self.masks[name] = (param.data.abs() > threshold).float()
            param.data *= self.masks[name]

6. 가중치 공유 (Weight Sharing)

6.1 크로스 레이어 파라미터 공유 — ALBERT

ALBERT (Lan et al., 2019)는 BERT의 파라미터를 크게 줄이기 위해 모든 트랜스포머 레이어에서 동일한 파라미터를 반복 사용합니다.

class ALBERTEncoder(nn.Module):
    """
    ALBERT 스타일 가중치 공유.
    하나의 트랜스포머 레이어를 N번 반복 실행.
    """
    def __init__(self, hidden_size=768, num_heads=12,
                 intermediate_size=3072, num_layers=12):
        super().__init__()
        # 단 하나의 트랜스포머 레이어 정의
        self.shared_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=num_heads,
            dim_feedforward=intermediate_size,
            batch_first=True,
            norm_first=True,
        )
        self.num_layers = num_layers
        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, x, src_key_padding_mask=None):
        # 동일한 레이어를 num_layers번 반복
        for _ in range(self.num_layers):
            x = self.shared_layer(x, src_key_padding_mask=src_key_padding_mask)
        return self.norm(x)

    def count_parameters(self):
        # 실제 파라미터 수 (공유되므로 한 번만 카운트)
        return sum(p.numel() for p in self.parameters())


# 파라미터 비교
bert_encoder = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(768, 12, 3072, batch_first=True, norm_first=True),
    num_layers=12
)
albert_encoder = ALBERTEncoder(num_layers=12)

bert_params = sum(p.numel() for p in bert_encoder.parameters())
albert_params = albert_encoder.count_parameters()

print(f"BERT encoder: {bert_params:,}")     # ~85M
print(f"ALBERT encoder: {albert_params:,}") # ~7M
print(f"압축률: {bert_params/albert_params:.1f}x")  # ~12x

6.2 팩토라이제이션 (ALBERT 임베딩)

ALBERT는 임베딩 레이어도 분해합니다: vocab_size x hidden_size → vocab_size x embedding_size + embedding_size x hidden_size.

class FactorizedEmbedding(nn.Module):
    """
    ALBERT의 임베딩 팩토라이제이션.
    vocab_size x H → (vocab_size x E) + (E x H)
    E << H로 파라미터 대폭 감소.
    """
    def __init__(self, vocab_size, embedding_size=128, hidden_size=768):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, embedding_size)
        # E → H 선형 변환
        self.embedding_projection = nn.Linear(embedding_size, hidden_size, bias=False)

    def forward(self, input_ids):
        embed = self.word_embeddings(input_ids)        # (B, seq, E)
        return self.embedding_projection(embed)         # (B, seq, H)

# 파라미터 비교 (vocab_size=30000)
standard_embed = nn.Embedding(30000, 768)             # 23.0M params
factorized_embed = FactorizedEmbedding(30000, 128, 768)  # 3.97M params

s_params = sum(p.numel() for p in standard_embed.parameters())
f_params = sum(p.numel() for p in factorized_embed.parameters())
print(f"Standard: {s_params:,} → Factorized: {f_params:,}")
print(f"절감: {(s_params - f_params)/s_params:.1%}")

7. 신경망 구조 탐색 (NAS)

7.1 Manual Design vs AutoML

전통적으로 CNN 구조는 연구자가 수작업으로 설계했습니다 (VGG, ResNet). NAS는 이 과정을 자동화합니다.

NAS의 세 가지 핵심 요소:

  1. 검색 공간 (Search Space): 어떤 연산/구조를 탐색할지
  2. 검색 전략 (Search Strategy): 강화학습, 진화 알고리즘, 기울기 기반 등
  3. 성능 추정 (Performance Estimation): 후보 구조의 성능을 빠르게 평가

Liu et al. (2019). 구조 탐색을 연속적인 최적화 문제로 변환합니다.

import torch
import torch.nn as nn
import torch.nn.functional as F


# 후보 연산들
PRIMITIVES = [
    'none',           # 연결 없음
    'skip_connect',   # 항등 함수
    'sep_conv_3x3',   # 3x3 Separable Conv
    'sep_conv_5x5',   # 5x5 Separable Conv
    'dil_conv_3x3',   # 3x3 Dilated Conv
    'dil_conv_5x5',   # 5x5 Dilated Conv
    'avg_pool_3x3',   # 3x3 Average Pool
    'max_pool_3x3',   # 3x3 Max Pool
]


class MixedOperation(nn.Module):
    """
    DARTS: 여러 연산의 가중 합으로 엣지 표현.
    alpha (아키텍처 파라미터)가 각 연산의 가중치를 결정.
    """
    def __init__(self, C, stride, primitives):
        super().__init__()
        self.ops = nn.ModuleList([
            self._build_op(primitive, C, stride)
            for primitive in primitives
        ])
        # 아키텍처 파라미터 (학습 가능)
        self.arch_params = nn.Parameter(
            torch.randn(len(primitives)) * 1e-3
        )

    def _build_op(self, primitive, C, stride):
        if primitive == 'none':
            return nn.Sequential()  # Zero operation
        elif primitive == 'skip_connect':
            return nn.Identity() if stride == 1 else nn.Sequential(
                nn.AvgPool2d(stride, stride, padding=0),
                nn.Conv2d(C, C, 1, bias=False),
                nn.BatchNorm2d(C)
            )
        elif primitive == 'sep_conv_3x3':
            return nn.Sequential(
                nn.Conv2d(C, C, 3, stride=stride, padding=1, groups=C, bias=False),
                nn.Conv2d(C, C, 1, bias=False),
                nn.BatchNorm2d(C),
                nn.ReLU(inplace=True),
            )
        elif primitive == 'avg_pool_3x3':
            return nn.AvgPool2d(3, stride=stride, padding=1)
        elif primitive == 'max_pool_3x3':
            return nn.MaxPool2d(3, stride=stride, padding=1)
        else:
            return nn.Identity()

    def forward(self, x):
        # Softmax로 가중치 계산 후 각 연산의 가중 합
        weights = F.softmax(self.arch_params, dim=0)
        results = []
        for w, op in zip(weights, self.ops):
            try:
                result = op(x)
                if result.shape == x.shape or len(results) == 0:
                    results.append(w * result)
                else:
                    results.append(w * result)
            except Exception:
                pass

        if results:
            # 모든 연산 결과의 가중 합
            output = results[0]
            for r in results[1:]:
                if r.shape == output.shape:
                    output = output + r
            return output
        return x


class DARTSCell(nn.Module):
    """DARTS 셀: 여러 Mixed Operation으로 구성된 DAG."""
    def __init__(self, num_nodes, C, stride=1):
        super().__init__()
        self.num_nodes = num_nodes
        self.ops = nn.ModuleList()

        # 각 노드 쌍에 대해 Mixed Operation 생성
        for i in range(num_nodes):
            for j in range(i):
                self.ops.append(
                    MixedOperation(C, stride, PRIMITIVES)
                )

    def forward(self, *inputs):
        states = list(inputs)

        op_idx = 0
        for i in range(self.num_nodes):
            # 모든 이전 상태로부터 합산
            node_input = sum(
                self.ops[op_idx + j](states[j])
                for j in range(len(states))
            )
            op_idx += len(states)
            states.append(node_input)

        return states[-1]

    def get_arch_params(self):
        """아키텍처 파라미터 반환."""
        return [op.arch_params for op in self.ops]


def darts_discrete(cell):
    """
    연속적인 아키텍처를 이산 구조로 변환.
    각 엣지에서 가장 높은 가중치의 연산 선택.
    """
    for op in cell.ops:
        weights = F.softmax(op.arch_params, dim=0)
        best_op_idx = weights.argmax().item()
        print(f"  Selected: {PRIMITIVES[best_op_idx]} "
              f"(weight: {weights[best_op_idx]:.3f})")

7.3 EfficientNet의 NAS 프로세스

EfficientNet-B0는 MnasNet과 유사한 신경망 구조 탐색으로 찾은 기저 구조입니다.

# EfficientNet의 MBConv 블록 (NAS로 찾은 핵심 구조)
class MBConvBlock(nn.Module):
    """
    Mobile Inverted Bottleneck Convolution.
    EfficientNet의 기본 빌딩 블록.
    """
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, expand_ratio=6, se_ratio=0.25):
        super().__init__()
        self.use_residual = (stride == 1 and in_channels == out_channels)
        hidden_dim = int(in_channels * expand_ratio)

        layers = []
        # Expansion phase (1x1 conv)
        if expand_ratio != 1:
            layers.extend([
                nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
            ])

        # Depthwise conv
        layers.extend([
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size,
                      stride=stride,
                      padding=kernel_size // 2,
                      groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
        ])

        # Squeeze-and-Excitation (NAS가 발견한 중요 모듈)
        se_channels = max(1, int(in_channels * se_ratio))
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(hidden_dim, se_channels, 1),
            nn.SiLU(),
            nn.Conv2d(se_channels, hidden_dim, 1),
            nn.Sigmoid(),
        )

        # Output phase
        layers.extend([
            nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
        ])

        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv[:-2](x) if len(self.conv) > 2 else self.conv(x)
        # SE 모듈 적용
        out_se = self.se(out)
        out = out * out_se
        out = self.conv[-2:](out)

        if self.use_residual:
            return x + out
        return out

7.4 Once-for-All 네트워크

Cai et al. (2020). 하나의 초대형 네트워크를 학습한 후, 다양한 리소스 제약에 맞게 서브네트워크를 추출합니다.

class OFALayer(nn.Module):
    """
    Once-for-All: 다양한 커널 크기를 지원하는 탄력적 레이어.
    학습 시: 랜덤으로 커널 크기 선택 → 서브네트워크 학습
    배포 시: 특정 커널 크기만 사용
    """
    def __init__(self, channels, max_kernel=7):
        super().__init__()
        # 가장 큰 커널 크기의 conv 하나만 학습
        self.max_conv = nn.Conv2d(
            channels, channels,
            kernel_size=max_kernel,
            padding=max_kernel // 2,
            groups=channels, bias=False
        )
        self.bn = nn.BatchNorm2d(channels)
        self.act = nn.ReLU(inplace=True)
        self.active_kernel = max_kernel

    def set_active_kernel(self, kernel_size):
        """현재 활성 커널 크기 설정."""
        assert kernel_size <= self.max_conv.kernel_size[0]
        self.active_kernel = kernel_size

    def forward(self, x):
        if self.active_kernel == self.max_conv.kernel_size[0]:
            weight = self.max_conv.weight
        else:
            # 중앙에서 active_kernel 크기만큼 잘라내기
            center = self.max_conv.kernel_size[0] // 2
            half = self.active_kernel // 2
            weight = self.max_conv.weight[
                :, :,
                center - half: center + half + 1,
                center - half: center + half + 1
            ]

        padding = self.active_kernel // 2
        out = F.conv2d(x, weight, padding=padding, groups=x.size(1))
        return self.act(self.bn(out))


# Progressive shrinking 학습 (OFA의 핵심)
def progressive_shrinking_train(model, dataloader, kernels=[7, 5, 3]):
    """
    1단계: 최대 커널(7) 학습
    2단계: 7, 5 커널 랜덤 샘플링 학습
    3단계: 7, 5, 3 커널 랜덤 샘플링 학습
    """
    for stage, active_kernels in enumerate(
        [kernels[:1], kernels[:2], kernels]
    ):
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        print(f"Stage {stage+1}: kernels = {active_kernels}")

        for epoch in range(5):
            for images, labels in dataloader:
                # 랜덤으로 커널 크기 선택
                k = active_kernels[torch.randint(len(active_kernels), (1,)).item()]

                # 모든 레이어의 활성 커널 설정
                for module in model.modules():
                    if isinstance(module, OFALayer):
                        module.set_active_kernel(k)

                outputs = model(images)
                loss = F.cross_entropy(outputs, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

8. 통합 모델 압축 파이프라인

실제 배포 시에는 여러 기법을 조합합니다:

class ModelCompressionPipeline:
    """
    지식 증류 → 프루닝 → 양자화의 통합 파이프라인.
    """

    def __init__(self, teacher, student_arch, num_classes):
        self.teacher = teacher
        self.num_classes = num_classes

        # Student 모델 초기화
        self.student = student_arch

    def step1_distillation(self, train_loader, val_loader,
                           epochs=30, device='cuda'):
        """1단계: 지식 증류로 기본 학습."""
        print("=== Step 1: Knowledge Distillation ===")
        self.student = train_with_distillation(
            self.teacher, self.student, train_loader,
            num_epochs=epochs, temperature=4.0, alpha=0.5,
            device=device
        )

    def step2_pruning(self, train_loader, sparsity=0.5,
                      finetune_epochs=10, device='cuda'):
        """2단계: 점진적 프루닝 + 파인튜닝."""
        print(f"=== Step 2: Pruning (target sparsity: {sparsity:.0%}) ===")
        pruner = GradualMagnitudePruner(
            self.student,
            initial_sparsity=0.0,
            final_sparsity=sparsity,
            begin_step=0,
            end_step=len(train_loader) * finetune_epochs,
            frequency=100
        )

        optimizer = torch.optim.Adam(
            self.student.parameters(), lr=1e-4
        )
        self.student.to(device)
        step = 0

        for epoch in range(finetune_epochs):
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = self.student(images)
                loss = F.cross_entropy(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                pruner.step(step)
                pruner.apply_masks()  # 프루닝 마스크 유지
                step += 1

    def step3_quantization(self, calibration_loader, device='cpu'):
        """3단계: Post-Training Quantization."""
        print("=== Step 3: Quantization ===")
        self.student.eval().to(device)

        # Static quantization 준비
        self.student.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
        torch.ao.quantization.prepare(self.student, inplace=True)

        # 보정 데이터로 scale/zero_point 결정
        with torch.no_grad():
            for images, _ in calibration_loader:
                self.student(images.to(device))

        # 양자화 적용
        torch.ao.quantization.convert(self.student, inplace=True)
        print("Quantization complete (INT8)")
        return self.student

    def compare_models(self, val_loader, device='cuda'):
        """Teacher, Student, Pruned, Quantized 모델 성능 비교."""
        def eval_model(model, loader, dev):
            model.eval().to(dev)
            correct, total = 0, 0
            with torch.no_grad():
                for images, labels in loader:
                    images, labels = images.to(dev), labels.to(dev)
                    preds = model(images).argmax(1)
                    correct += (preds == labels).sum().item()
                    total += labels.size(0)
            return correct / total

        teacher_acc = eval_model(self.teacher, val_loader, device)
        student_acc = eval_model(self.student, val_loader, 'cpu')

        t_params = sum(p.numel() for p in self.teacher.parameters())
        s_params = sum(p.numel() for p in self.student.parameters())

        print(f"\n{'='*50}")
        print(f"Teacher  | Params: {t_params:>10,} | Acc: {teacher_acc:.4f}")
        print(f"Student  | Params: {s_params:>10,} | Acc: {student_acc:.4f}")
        print(f"압축률: {t_params/s_params:.1f}x | 성능 유지율: {student_acc/teacher_acc:.1%}")

결론

지식 증류와 모델 압축은 AI 모델을 현실 세계에서 사용 가능하게 만드는 핵심 기술입니다.

주요 요약:

  1. 지식 증류: Teacher의 소프트 타겟(확률 분포)을 통해 클래스 간 관계 정보 전달
  2. Feature 증류: 중간 레이어의 표현을 Student에 전달
  3. Relation 증류: 샘플 간 관계 구조 보존
  4. LLM 증류: DistilBERT, TinyLLM 등 실용적 대형 모델 압축
  5. 구조적 프루닝: 필터/헤드/레이어 단위 제거로 실제 속도 향상
  6. 비구조적 프루닝: 점진적 희소화로 정확도 손실 최소화
  7. 가중치 공유: ALBERT처럼 동일 파라미터를 반복 사용
  8. NAS: DARTS, OFA 등 자동화된 구조 탐색

실제 배포 시에는 증류 → 프루닝 → 양자화를 순차적으로 적용하는 파이프라인이 가장 효과적입니다.


참고 문헌