- Authors

- Name
- Youngju Kim
- @fjvbn20031
소개
딥러닝 모델이 강력해질수록 크기도 커집니다. GPT-4, Llama 3 405B 같은 대형 모델은 수백 GB의 메모리를 필요로 하며, 모바일 기기나 엣지 디바이스에서는 실행이 불가능합니다. **지식 증류(Knowledge Distillation)**와 모델 압축 기법은 큰 모델의 성능을 최대한 유지하면서 모델을 작고 빠르게 만드는 핵심 기술입니다.
이 가이드에서 다루는 내용:
- 지식 증류의 이론적 배경과 완전한 PyTorch 구현
- 다양한 증류 방법 (Response-based, Feature-based, Relation-based)
- LLM 증류 사례 (DistilBERT, TinyLLM)
- 구조적/비구조적 프루닝
- 가중치 공유와 NAS
1. 지식 증류 기초
1.1 Teacher-Student 프레임워크
지식 증류는 2015년 Hinton, Vinyals, Dean이 제안한 기법으로, 큰 Teacher 모델의 "지식"을 작은 Student 모델로 전달하는 것이 핵심입니다.
Teacher Model (크고 정확)
│
│ 소프트 타겟 (확률 분포) 전달
▼
Student Model (작고 빠름)
단순히 정답 레이블(하드 타겟)로 학습하는 것과 달리, Teacher의 **소프트 타겟(softmax 출력 확률 분포)**을 사용합니다.
예를 들어, 고양이 이미지 분류:
- 하드 타겟: [0, 0, 1, 0, 0] (정답만 1)
- Teacher 소프트 타겟: [0.01, 0.05, 0.85, 0.07, 0.02]
소프트 타겟에는 "이 이미지는 고양이지만 호랑이와 약간 비슷하다"는 정보가 담겨 있습니다. 이런 클래스 간 유사성 정보가 Student 학습에 큰 도움이 됩니다.
1.2 온도(Temperature) 파라미터
소프트맥스 출력이 너무 confident하면 [0.001, 0.002, 0.997, ...]처럼 거의 하드 타겟과 같아져서 정보가 적습니다. 온도 파라미터 T로 분포를 부드럽게 만듭니다:
softmax_T(z_i) = exp(z_i / T) / sum_j(exp(z_j / T))
- T가 1일 때: 일반 소프트맥스
- T가 1보다 클 때: 분포가 더 균일해짐 (더 많은 정보 전달)
- T가 1보다 작을 때: 분포가 더 날카로워짐
import torch
import torch.nn.functional as F
def temperature_softmax(logits, temperature=1.0):
"""온도로 조정된 소프트맥스."""
return F.softmax(logits / temperature, dim=-1)
# 예시
logits = torch.tensor([2.0, 1.0, 0.1, 0.5])
print("T=1:", temperature_softmax(logits, T=1).numpy().round(3))
# [0.596, 0.219, 0.090, 0.096]
print("T=4:", temperature_softmax(logits, T=4).numpy().round(3))
# [0.345, 0.262, 0.195, 0.216] — 더 균일
1.3 Hinton의 KD 손실 함수
KD 손실은 두 항의 가중 합입니다:
KL Divergence 항 (소프트 타겟 매칭):
L_KD = T^2 * KLDiv(softmax_T(student_logits), softmax_T(teacher_logits))
T^2 스케일링: 그래디언트 크기를 T에 무관하게 유지하기 위함.
Cross-Entropy 항 (하드 타겟 학습):
L_CE = CrossEntropy(student_logits, true_labels)
총 손실:
L = alpha * L_CE + (1 - alpha) * L_KD
1.4 완전한 PyTorch 구현
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
class KnowledgeDistillationLoss(nn.Module):
"""
Hinton et al. (2015)의 지식 증류 손실.
L = alpha * CE(student, labels) + (1-alpha) * T^2 * KLDiv(student_soft, teacher_soft)
"""
def __init__(self, temperature=4.0, alpha=0.5):
super().__init__()
self.T = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
self.kl_loss = nn.KLDivLoss(reduction='batchmean')
def forward(self, student_logits, teacher_logits, labels):
# 하드 타겟 손실
loss_ce = self.ce_loss(student_logits, labels)
# 소프트 타겟 손실 (KL Divergence)
student_soft = F.log_softmax(student_logits / self.T, dim=-1)
teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
loss_kd = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)
return self.alpha * loss_ce + (1 - self.alpha) * loss_kd
def train_with_distillation(
teacher, student, train_loader,
num_epochs=10, temperature=4.0, alpha=0.5,
device='cuda'
):
"""Teacher-Student 증류 학습 루프."""
teacher = teacher.to(device).eval() # Teacher는 frozen
student = student.to(device)
# Teacher 파라미터 동결
for param in teacher.parameters():
param.requires_grad = False
criterion = KnowledgeDistillationLoss(temperature=temperature, alpha=alpha)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
for epoch in range(num_epochs):
student.train()
total_loss = 0.0
correct = 0
total = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
# Teacher 추론 (그래디언트 불필요)
with torch.no_grad():
teacher_logits = teacher(images)
# Student 추론
student_logits = student(images)
# KD 손실 계산
loss = criterion(student_logits, teacher_logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * images.size(0)
correct += (student_logits.argmax(1) == labels).sum().item()
total += images.size(0)
scheduler.step()
print(f"Epoch {epoch+1}/{num_epochs} | "
f"Loss: {total_loss/total:.4f} | "
f"Acc: {correct/total:.4f}")
return student
# 실제 사용 예시
# Teacher: ResNet50 (25M params), Student: ResNet18 (11M params)
teacher = models.resnet50(weights='DEFAULT')
teacher.fc = nn.Linear(2048, 10)
student = models.resnet18(weights=None)
student.fc = nn.Linear(512, 10)
# 파라미터 수 비교
t_params = sum(p.numel() for p in teacher.parameters())
s_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {t_params:,} params") # ~25.6M
print(f"Student: {s_params:,} params") # ~11.2M
print(f"압축률: {t_params/s_params:.1f}x")
2. 다양한 증류 방법
2.1 Response-based 증류 (로짓 매칭)
가장 기본적인 방법으로, 앞서 설명한 Hinton KD가 대표적입니다. Student가 Teacher의 **최종 출력(로짓)**을 모방합니다.
class ResponseBasedDistillation(nn.Module):
"""최종 출력만 사용하는 Response-based 증류."""
def __init__(self, temperature=4.0):
super().__init__()
self.T = temperature
def forward(self, student_logits, teacher_logits):
student_soft = F.log_softmax(student_logits / self.T, dim=-1)
teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.T ** 2)
2.2 Feature-based 증류 (중간 레이어)
FitNets (Romero et al., 2015)에서 제안. Teacher의 중간 레이어 feature map을 Student가 따라하도록 학습합니다.
Teacher와 Student의 채널 수가 다를 수 있으므로 Regressor 네트워크로 차원을 맞춥니다.
class FeatureDistillationLoss(nn.Module):
"""중간 레이어 feature를 매칭하는 증류."""
def __init__(self, teacher_channels, student_channels):
super().__init__()
# Student feature를 Teacher feature 공간으로 projection
self.regressor = nn.Sequential(
nn.Conv2d(student_channels, teacher_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(teacher_channels),
nn.ReLU(inplace=True),
)
def forward(self, student_feat, teacher_feat):
# Student feature를 Teacher 차원으로 변환
projected = self.regressor(student_feat)
# MSE 손실로 feature 맞추기
return F.mse_loss(projected, teacher_feat.detach())
class HookBasedDistillation:
"""
forward hook으로 중간 레이어 feature 추출.
"""
def __init__(self, model, layer_names):
self.features = {}
self.hooks = []
for name, layer in model.named_modules():
if name in layer_names:
hook = layer.register_forward_hook(
self._make_hook(name)
)
self.hooks.append(hook)
def _make_hook(self, name):
def hook(module, input, output):
self.features[name] = output
return hook
def remove(self):
for hook in self.hooks:
hook.remove()
# 사용 예시
teacher = models.resnet50(weights='DEFAULT')
student = models.resnet18(weights=None)
# Teacher: layer3 출력 (1024채널), Student: layer3 출력 (256채널)
teacher_hook = HookBasedDistillation(teacher, ['layer3'])
student_hook = HookBasedDistillation(student, ['layer3'])
feat_distill = FeatureDistillationLoss(
teacher_channels=1024,
student_channels=256
)
# 학습 시
x = torch.randn(4, 3, 224, 224)
teacher_out = teacher(x)
student_out = student(x)
teacher_feat = teacher_hook.features['layer3']
student_feat = student_hook.features['layer3']
loss_feat = feat_distill(student_feat, teacher_feat)
2.3 Relation-based 증류 (상관관계 학습)
RKD (Park et al., 2019). 샘플들 사이의 관계를 Student가 모방하도록 합니다. 절대적인 값이 아니라 상대적인 구조를 학습합니다.
class RelationalKnowledgeDistillation(nn.Module):
"""
Relational KD: 샘플 쌍의 거리 관계를 보존.
Teacher의 embedding 공간 구조를 Student가 학습.
"""
def __init__(self, distance_weight=25.0, angle_weight=50.0):
super().__init__()
self.dist_w = distance_weight
self.angle_w = angle_weight
def pdist(self, e, squared=False, eps=1e-12):
"""Pairwise distance computation."""
e_sq = (e ** 2).sum(dim=1)
prod = e @ e.t()
res = (e_sq.unsqueeze(1) + e_sq.unsqueeze(0) - 2 * prod).clamp(min=eps)
if not squared:
res = res.sqrt()
return res
def distance_loss(self, teacher_emb, student_emb):
"""거리 관계 보존 손실."""
t_d = self.pdist(teacher_emb)
# 정규화
t_d = t_d / (t_d.mean() + 1e-12)
s_d = self.pdist(student_emb)
s_d = s_d / (s_d.mean() + 1e-12)
return F.smooth_l1_loss(s_d, t_d.detach())
def angle_loss(self, teacher_emb, student_emb):
"""각도 관계 보존 손실."""
# e_i - e_j 벡터 간 각도 관계
td = teacher_emb.unsqueeze(0) - teacher_emb.unsqueeze(1) # (N, N, D)
sd = student_emb.unsqueeze(0) - student_emb.unsqueeze(1)
# cosine similarity
td_norm = F.normalize(td.view(-1, td.size(-1)), dim=-1)
sd_norm = F.normalize(sd.view(-1, sd.size(-1)), dim=-1)
t_angle = (td_norm * td_norm.flip(0)).sum(dim=-1)
s_angle = (sd_norm * sd_norm.flip(0)).sum(dim=-1)
return F.smooth_l1_loss(s_angle, t_angle.detach())
def forward(self, teacher_emb, student_emb):
loss = self.dist_w * self.distance_loss(teacher_emb, student_emb)
loss += self.angle_w * self.angle_loss(teacher_emb, student_emb)
return loss
2.4 Attention Transfer
Zagoruyko & Komodakis (2017). Attention map (채널별 activation 합)을 Teacher에서 Student로 전달합니다.
class AttentionTransfer(nn.Module):
"""Attention Transfer: activation map의 공간적 패턴 전달."""
def attention_map(self, feat):
"""
Feature map에서 attention map 계산.
각 공간 위치에서 채널 방향으로 제곱합의 제곱근.
"""
return F.normalize(feat.pow(2).mean(1).view(feat.size(0), -1))
def forward(self, student_feats, teacher_feats):
"""
student_feats, teacher_feats: list of feature maps
각 쌍에 대해 AT 손실 계산.
"""
loss = 0.0
for s_feat, t_feat in zip(student_feats, teacher_feats):
s_attn = self.attention_map(s_feat)
t_attn = self.attention_map(t_feat)
loss += (s_attn - t_attn.detach()).pow(2).mean()
return loss
3. LLM 증류
3.1 DistilBERT
DistilBERT (Sanh et al., 2019)는 BERT-base (110M 파라미터)를 66M으로 압축한 모델입니다.
주요 기법:
- 레이어 수 절반: 12개 → 6개
- Token-type embedding 제거
- Pooler 제거
- Triple 손실: MLM + Distillation + Cosine embedding
from transformers import (
DistilBertModel, BertModel,
DistilBertTokenizer
)
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistilBERTLoss(nn.Module):
"""
DistilBERT 3중 손실:
1. MLM CE 손실 (언어 모델링)
2. Soft-target KD 손실 (Teacher logit 매칭)
3. Cosine embedding 손실 (hidden state 유사도)
"""
def __init__(self, temperature=2.0, alpha=0.5, beta=0.1):
super().__init__()
self.T = temperature
self.alpha = alpha # MLM 가중치
self.beta = beta # Cosine 손실 가중치
def forward(self, student_logits, teacher_logits,
student_hidden, teacher_hidden, mlm_labels):
# 1. MLM 손실
loss_mlm = F.cross_entropy(
student_logits.view(-1, student_logits.size(-1)),
mlm_labels.view(-1),
ignore_index=-100
)
# 2. Soft KD 손실
s_soft = F.log_softmax(student_logits / self.T, dim=-1)
t_soft = F.softmax(teacher_logits / self.T, dim=-1)
loss_kd = F.kl_div(s_soft, t_soft, reduction='batchmean') * (self.T ** 2)
# 3. Cosine embedding 손실 (hidden state 유사도)
loss_cos = 1 - F.cosine_similarity(student_hidden, teacher_hidden, dim=-1).mean()
return (self.alpha * loss_mlm
+ (1 - self.alpha) * loss_kd
+ self.beta * loss_cos)
# 추론 예시
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
text = "Knowledge distillation is a model compression technique."
inputs = tokenizer(text, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
print(outputs.last_hidden_state.shape)
# torch.Size([1, 11, 768])
DistilBERT 성능:
- BERT-base 대비 40% 크기 감소
- 60% 빠른 추론 속도
- GLUE 벤치마크 97% 성능 유지
3.2 TinyLLaMA 스타일 증류
LLM 증류에서는 Teacher의 다음 토큰 예측 분포를 Student가 학습합니다.
class LLMDistillationTrainer:
"""
LLM 증류 학습기.
Teacher의 로짓 분포를 Student에 전달.
"""
def __init__(self, teacher, student, temperature=2.0, alpha=0.5):
self.teacher = teacher
self.student = student
self.T = temperature
self.alpha = alpha
def compute_loss(self, input_ids, attention_mask, labels):
# Teacher 추론 (no_grad)
with torch.no_grad():
teacher_out = self.teacher(
input_ids=input_ids,
attention_mask=attention_mask,
)
teacher_logits = teacher_out.logits # (B, seq_len, vocab_size)
# Student 추론
student_out = self.student(
input_ids=input_ids,
attention_mask=attention_mask,
)
student_logits = student_out.logits
# 1. CE 손실 (하드 타겟)
shift_logits = student_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_ce = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100
)
# 2. KD 손실 (소프트 타겟)
# 유효한 토큰에 대해서만 계산
mask = (shift_labels != -100).float()
s_log_soft = F.log_softmax(
shift_logits / self.T, dim=-1
)
t_soft = F.softmax(
teacher_logits[..., :-1, :].contiguous() / self.T, dim=-1
)
# 토큰별 KL divergence
kl_per_token = F.kl_div(
s_log_soft, t_soft,
reduction='none'
).sum(-1) # (B, seq_len)
loss_kd = (kl_per_token * mask).sum() / mask.sum()
loss_kd = loss_kd * (self.T ** 2)
return self.alpha * loss_ce + (1 - self.alpha) * loss_kd
3.3 Distil Whisper
OpenAI Whisper 음성 인식 모델의 증류판입니다.
from transformers import (
WhisperProcessor, WhisperForConditionalGeneration
)
# Distil-Whisper: Whisper-large-v2의 증류
# large-v2: 1550M params → distil-large-v2: 756M params (2x faster, same WER)
processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v2")
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v2")
# 6개 디코더 레이어만 유지 (원본 32개 → 2개)
# 인코더는 동일하게 유지
print(f"Encoder layers: {len(model.model.encoder.layers)}")
print(f"Decoder layers: {len(model.model.decoder.layers)}")
4. 구조적 프루닝 (Structured Pruning)
프루닝은 모델에서 중요하지 않은 파라미터나 구조를 제거하는 기법입니다.
구조적 프루닝: 필터, 헤드, 레이어 전체 단위로 제거 → 실제 속도 향상 비구조적 프루닝: 개별 가중치를 0으로 만들기 → 희소 행렬 (특수 하드웨어 필요)
4.1 필터 프루닝
CNN에서 L1/L2 norm이 작은 필터를 제거합니다.
import torch
import torch.nn as nn
import numpy as np
def get_filter_importance(conv_layer, norm='l1'):
"""각 필터의 중요도(norm) 계산."""
weight = conv_layer.weight.data # (out_ch, in_ch, kH, kW)
if norm == 'l1':
return weight.abs().sum(dim=(1, 2, 3))
elif norm == 'l2':
return weight.pow(2).sum(dim=(1, 2, 3)).sqrt()
elif norm == 'gm':
# Geometric Median 기반 (더 robust)
return weight.view(weight.size(0), -1).norm(p=2, dim=1)
def prune_conv_layer(conv, keep_ratio=0.5):
"""
필터 프루닝: 중요도 낮은 필터 제거.
Returns: 새로운 Conv2d 레이어 (필터 수 감소)
"""
importance = get_filter_importance(conv)
n_keep = int(conv.out_channels * keep_ratio)
# 중요도 상위 n_keep개 필터 선택
_, indices = importance.topk(n_keep)
indices, _ = indices.sort()
# 새 conv 레이어 생성
new_conv = nn.Conv2d(
conv.in_channels,
n_keep,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=conv.bias is not None
)
new_conv.weight.data = conv.weight.data[indices]
if conv.bias is not None:
new_conv.bias.data = conv.bias.data[indices]
return new_conv, indices
class StructuredPruner:
"""ResNet 스타일 모델의 구조적 프루닝."""
def __init__(self, model, prune_ratio=0.5):
self.model = model
self.prune_ratio = prune_ratio
def compute_layer_importance(self):
"""각 레이어 필터의 중요도 계산."""
importance_dict = {}
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d):
importance = get_filter_importance(module)
importance_dict[name] = importance
return importance_dict
def global_threshold_prune(self, sparsity=0.5):
"""
전체 필터에 대해 전역 임계값으로 프루닝.
가장 중요도 낮은 sparsity 비율의 필터 제거.
"""
all_importances = []
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d):
imp = get_filter_importance(module)
all_importances.append(imp)
# 모든 필터 중요도를 하나로 합치기
all_imp = torch.cat(all_importances)
threshold = all_imp.kthvalue(int(len(all_imp) * sparsity)).values.item()
# 임계값보다 낮은 필터 제거
masks = {}
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d):
imp = get_filter_importance(module)
masks[name] = imp >= threshold
return masks
4.2 헤드 프루닝 (Attention Head Pruning)
트랜스포머에서 중요하지 않은 Attention Head를 제거합니다.
class AttentionHeadPruner:
"""
Transformer의 Multi-Head Attention 헤드 프루닝.
Michel et al. (2019): Are Sixteen Heads Really Better than One?
"""
def compute_head_importance(self, model, dataloader, device='cuda'):
"""각 레이어, 각 헤드의 중요도 계산."""
head_importance = {}
model.eval()
for batch in dataloader:
inputs = {k: v.to(device) for k, v in batch.items()}
# Attention 가중치 저장용 hook
attention_weights = {}
hooks = []
for name, module in model.named_modules():
if 'attention' in name.lower() and hasattr(module, 'num_heads'):
def hook_fn(m, inp, out, layer_name=name):
if hasattr(out, 'attentions') and out.attentions is not None:
attention_weights[layer_name] = out.attentions
hooks.append(module.register_forward_hook(hook_fn))
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
# Head importance: attention 가중치의 분산 합
if hasattr(outputs, 'attentions') and outputs.attentions:
for layer_idx, attn in enumerate(outputs.attentions):
# attn: (batch, heads, seq, seq)
head_imp = attn.mean(0).var(-1).mean(-1) # (heads,)
key = f"layer_{layer_idx}"
if key not in head_importance:
head_importance[key] = head_imp
else:
head_importance[key] += head_imp
for hook in hooks:
hook.remove()
return head_importance
def prune_heads(self, model, heads_to_prune):
"""
지정된 head를 제거.
heads_to_prune: {layer_idx: [head_indices]}
"""
model.prune_heads(heads_to_prune) # HuggingFace 모델 내장 기능
return model
# HuggingFace transformers 활용
from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
# layer 0의 head 0, 3, 5 제거
heads_to_prune = {0: [0, 3, 5], 6: [1, 2]}
model.prune_heads(heads_to_prune)
total_params = sum(p.numel() for p in model.parameters())
print(f"Pruned BERT params: {total_params:,}")
4.3 레이어 프루닝
성능에 기여가 적은 트랜스포머 레이어 전체를 제거합니다.
class LayerPruner:
"""트랜스포머 레이어 프루닝."""
def compute_layer_importance_by_gradient(
self, model, dataloader, device='cuda'
):
"""
각 레이어에 대한 gradient 기반 중요도 계산.
중요도 = |gradient * weight|의 평균
"""
model.train()
layer_importance = {}
for batch in dataloader:
inputs = {k: v.to(device) for k, v in batch.items()}
outputs = model(**inputs)
loss = outputs.loss
loss.backward()
for i, layer in enumerate(model.encoder.layer):
grad_sum = 0.0
param_count = 0
for param in layer.parameters():
if param.grad is not None:
grad_sum += (param.grad * param.data).abs().sum().item()
param_count += param.numel()
key = f"layer_{i}"
imp = grad_sum / max(param_count, 1)
if key not in layer_importance:
layer_importance[key] = 0.0
layer_importance[key] += imp
model.zero_grad()
return layer_importance
def drop_layers(self, model, layers_to_drop):
"""지정된 레이어를 제거하고 남은 레이어로 재구성."""
import copy
new_model = copy.deepcopy(model)
remaining = [
layer for i, layer in enumerate(new_model.encoder.layer)
if i not in layers_to_drop
]
new_model.encoder.layer = nn.ModuleList(remaining)
new_model.config.num_hidden_layers = len(remaining)
return new_model
4.4 PyTorch torch.nn.utils.prune
import torch.nn.utils.prune as prune
model = models.resnet18(weights='DEFAULT')
# 특정 레이어에 비구조적 L1 프루닝 적용
conv = model.layer1[0].conv1
prune.l1_unstructured(conv, name='weight', amount=0.3)
# weight의 30%를 0으로 만듦 (mask 사용)
# 확인
print(f"Sparsity: {(conv.weight == 0).float().mean():.2f}")
# 구조적 프루닝 (필터 단위)
prune.ln_structured(conv, name='weight', amount=0.3, n=2, dim=0)
# 영구 적용 (mask 제거, 실제 weight 0화)
prune.remove(conv, 'weight')
# 여러 레이어에 반복 적용
parameters_to_prune = []
for module in model.modules():
if isinstance(module, nn.Conv2d):
parameters_to_prune.append((module, 'weight'))
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.4, # 전체 파라미터의 40% 프루닝
)
# 전체 희소성 확인
total_zero = sum(
(m.weight == 0).sum().item()
for m in model.modules() if isinstance(m, nn.Conv2d)
)
total_params = sum(
m.weight.numel()
for m in model.modules() if isinstance(m, nn.Conv2d)
)
print(f"Global sparsity: {total_zero/total_params:.2%}")
5. 비구조적 프루닝 (Unstructured Pruning)
5.1 크기 기반 프루닝
절대값이 작은 가중치를 0으로 만드는 가장 단순한 방법입니다.
class MagnitudePruner:
"""크기 기반 비구조적 프루닝."""
def __init__(self, model, sparsity=0.5):
self.model = model
self.sparsity = sparsity
self.masks = {}
def compute_global_threshold(self):
"""전체 가중치에 대한 전역 임계값 계산."""
all_weights = []
for name, param in self.model.named_parameters():
if 'weight' in name and param.dim() > 1:
all_weights.append(param.data.abs().view(-1))
all_weights = torch.cat(all_weights)
threshold = all_weights.kthvalue(
int(len(all_weights) * self.sparsity)
).values.item()
return threshold
def apply_pruning(self):
"""전역 임계값으로 프루닝 마스크 생성."""
threshold = self.compute_global_threshold()
for name, param in self.model.named_parameters():
if 'weight' in name and param.dim() > 1:
mask = (param.data.abs() >= threshold).float()
self.masks[name] = mask
param.data *= mask # 임계값 이하 가중치를 0으로
total_zeros = sum((m == 0).sum().item() for m in self.masks.values())
total_params = sum(m.numel() for m in self.masks.values())
print(f"실제 희소성: {total_zeros/total_params:.2%}")
def apply_masks(self):
"""역전파 후 마스크 재적용 (기울기로 0이 복원되지 않도록)."""
for name, param in self.model.named_parameters():
if name in self.masks:
param.data *= self.masks[name]
5.2 점진적 프루닝 (Gradual Magnitude Pruning)
처음부터 많이 프루닝하면 성능이 급락합니다. 학습 중에 서서히 희소성을 높이는 방법이 효과적입니다.
class GradualMagnitudePruner:
"""
학습 중 점진적으로 희소성을 높이는 프루닝.
Zhu & Gupta (2017): To Prune, or Not to Prune
"""
def __init__(self, model, initial_sparsity=0.0, final_sparsity=0.8,
begin_step=0, end_step=1000, frequency=100):
self.model = model
self.initial_sparsity = initial_sparsity
self.final_sparsity = final_sparsity
self.begin_step = begin_step
self.end_step = end_step
self.frequency = frequency
self.masks = {}
self._init_masks()
def _init_masks(self):
for name, param in self.model.named_parameters():
if 'weight' in name and param.dim() > 1:
self.masks[name] = torch.ones_like(param.data)
def compute_sparsity(self, step):
"""현재 스텝에서의 목표 희소성 계산 (cubic schedule)."""
if step < self.begin_step:
return self.initial_sparsity
if step > self.end_step:
return self.final_sparsity
# Cubic decay schedule
pct_done = (step - self.begin_step) / (self.end_step - self.begin_step)
sparsity = (
self.final_sparsity
+ (self.initial_sparsity - self.final_sparsity)
* (1 - pct_done) ** 3
)
return sparsity
def step(self, global_step):
"""학습 스텝에서 프루닝 업데이트."""
if (global_step % self.frequency != 0 or
global_step < self.begin_step or
global_step > self.end_step):
return
target_sparsity = self.compute_sparsity(global_step)
self._update_masks(target_sparsity)
def _update_masks(self, sparsity):
"""현재 희소성 목표에 맞게 마스크 업데이트."""
for name, param in self.model.named_parameters():
if name not in self.masks:
continue
# 현재 활성 가중치만 고려 (이미 pruned된 것 제외)
alive = param.data[self.masks[name].bool()]
if len(alive) == 0:
continue
n_prune = int(sparsity * param.data.numel())
if n_prune == 0:
continue
# 전체 중 하위 n_prune개 찾기
threshold = param.data.abs().view(-1).kthvalue(n_prune).values.item()
self.masks[name] = (param.data.abs() > threshold).float()
param.data *= self.masks[name]
6. 가중치 공유 (Weight Sharing)
6.1 크로스 레이어 파라미터 공유 — ALBERT
ALBERT (Lan et al., 2019)는 BERT의 파라미터를 크게 줄이기 위해 모든 트랜스포머 레이어에서 동일한 파라미터를 반복 사용합니다.
class ALBERTEncoder(nn.Module):
"""
ALBERT 스타일 가중치 공유.
하나의 트랜스포머 레이어를 N번 반복 실행.
"""
def __init__(self, hidden_size=768, num_heads=12,
intermediate_size=3072, num_layers=12):
super().__init__()
# 단 하나의 트랜스포머 레이어 정의
self.shared_layer = nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=num_heads,
dim_feedforward=intermediate_size,
batch_first=True,
norm_first=True,
)
self.num_layers = num_layers
self.norm = nn.LayerNorm(hidden_size)
def forward(self, x, src_key_padding_mask=None):
# 동일한 레이어를 num_layers번 반복
for _ in range(self.num_layers):
x = self.shared_layer(x, src_key_padding_mask=src_key_padding_mask)
return self.norm(x)
def count_parameters(self):
# 실제 파라미터 수 (공유되므로 한 번만 카운트)
return sum(p.numel() for p in self.parameters())
# 파라미터 비교
bert_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(768, 12, 3072, batch_first=True, norm_first=True),
num_layers=12
)
albert_encoder = ALBERTEncoder(num_layers=12)
bert_params = sum(p.numel() for p in bert_encoder.parameters())
albert_params = albert_encoder.count_parameters()
print(f"BERT encoder: {bert_params:,}") # ~85M
print(f"ALBERT encoder: {albert_params:,}") # ~7M
print(f"압축률: {bert_params/albert_params:.1f}x") # ~12x
6.2 팩토라이제이션 (ALBERT 임베딩)
ALBERT는 임베딩 레이어도 분해합니다: vocab_size x hidden_size → vocab_size x embedding_size + embedding_size x hidden_size.
class FactorizedEmbedding(nn.Module):
"""
ALBERT의 임베딩 팩토라이제이션.
vocab_size x H → (vocab_size x E) + (E x H)
E << H로 파라미터 대폭 감소.
"""
def __init__(self, vocab_size, embedding_size=128, hidden_size=768):
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, embedding_size)
# E → H 선형 변환
self.embedding_projection = nn.Linear(embedding_size, hidden_size, bias=False)
def forward(self, input_ids):
embed = self.word_embeddings(input_ids) # (B, seq, E)
return self.embedding_projection(embed) # (B, seq, H)
# 파라미터 비교 (vocab_size=30000)
standard_embed = nn.Embedding(30000, 768) # 23.0M params
factorized_embed = FactorizedEmbedding(30000, 128, 768) # 3.97M params
s_params = sum(p.numel() for p in standard_embed.parameters())
f_params = sum(p.numel() for p in factorized_embed.parameters())
print(f"Standard: {s_params:,} → Factorized: {f_params:,}")
print(f"절감: {(s_params - f_params)/s_params:.1%}")
7. 신경망 구조 탐색 (NAS)
7.1 Manual Design vs AutoML
전통적으로 CNN 구조는 연구자가 수작업으로 설계했습니다 (VGG, ResNet). NAS는 이 과정을 자동화합니다.
NAS의 세 가지 핵심 요소:
- 검색 공간 (Search Space): 어떤 연산/구조를 탐색할지
- 검색 전략 (Search Strategy): 강화학습, 진화 알고리즘, 기울기 기반 등
- 성능 추정 (Performance Estimation): 후보 구조의 성능을 빠르게 평가
7.2 DARTS (Differentiable Architecture Search)
Liu et al. (2019). 구조 탐색을 연속적인 최적화 문제로 변환합니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
# 후보 연산들
PRIMITIVES = [
'none', # 연결 없음
'skip_connect', # 항등 함수
'sep_conv_3x3', # 3x3 Separable Conv
'sep_conv_5x5', # 5x5 Separable Conv
'dil_conv_3x3', # 3x3 Dilated Conv
'dil_conv_5x5', # 5x5 Dilated Conv
'avg_pool_3x3', # 3x3 Average Pool
'max_pool_3x3', # 3x3 Max Pool
]
class MixedOperation(nn.Module):
"""
DARTS: 여러 연산의 가중 합으로 엣지 표현.
alpha (아키텍처 파라미터)가 각 연산의 가중치를 결정.
"""
def __init__(self, C, stride, primitives):
super().__init__()
self.ops = nn.ModuleList([
self._build_op(primitive, C, stride)
for primitive in primitives
])
# 아키텍처 파라미터 (학습 가능)
self.arch_params = nn.Parameter(
torch.randn(len(primitives)) * 1e-3
)
def _build_op(self, primitive, C, stride):
if primitive == 'none':
return nn.Sequential() # Zero operation
elif primitive == 'skip_connect':
return nn.Identity() if stride == 1 else nn.Sequential(
nn.AvgPool2d(stride, stride, padding=0),
nn.Conv2d(C, C, 1, bias=False),
nn.BatchNorm2d(C)
)
elif primitive == 'sep_conv_3x3':
return nn.Sequential(
nn.Conv2d(C, C, 3, stride=stride, padding=1, groups=C, bias=False),
nn.Conv2d(C, C, 1, bias=False),
nn.BatchNorm2d(C),
nn.ReLU(inplace=True),
)
elif primitive == 'avg_pool_3x3':
return nn.AvgPool2d(3, stride=stride, padding=1)
elif primitive == 'max_pool_3x3':
return nn.MaxPool2d(3, stride=stride, padding=1)
else:
return nn.Identity()
def forward(self, x):
# Softmax로 가중치 계산 후 각 연산의 가중 합
weights = F.softmax(self.arch_params, dim=0)
results = []
for w, op in zip(weights, self.ops):
try:
result = op(x)
if result.shape == x.shape or len(results) == 0:
results.append(w * result)
else:
results.append(w * result)
except Exception:
pass
if results:
# 모든 연산 결과의 가중 합
output = results[0]
for r in results[1:]:
if r.shape == output.shape:
output = output + r
return output
return x
class DARTSCell(nn.Module):
"""DARTS 셀: 여러 Mixed Operation으로 구성된 DAG."""
def __init__(self, num_nodes, C, stride=1):
super().__init__()
self.num_nodes = num_nodes
self.ops = nn.ModuleList()
# 각 노드 쌍에 대해 Mixed Operation 생성
for i in range(num_nodes):
for j in range(i):
self.ops.append(
MixedOperation(C, stride, PRIMITIVES)
)
def forward(self, *inputs):
states = list(inputs)
op_idx = 0
for i in range(self.num_nodes):
# 모든 이전 상태로부터 합산
node_input = sum(
self.ops[op_idx + j](states[j])
for j in range(len(states))
)
op_idx += len(states)
states.append(node_input)
return states[-1]
def get_arch_params(self):
"""아키텍처 파라미터 반환."""
return [op.arch_params for op in self.ops]
def darts_discrete(cell):
"""
연속적인 아키텍처를 이산 구조로 변환.
각 엣지에서 가장 높은 가중치의 연산 선택.
"""
for op in cell.ops:
weights = F.softmax(op.arch_params, dim=0)
best_op_idx = weights.argmax().item()
print(f" Selected: {PRIMITIVES[best_op_idx]} "
f"(weight: {weights[best_op_idx]:.3f})")
7.3 EfficientNet의 NAS 프로세스
EfficientNet-B0는 MnasNet과 유사한 신경망 구조 탐색으로 찾은 기저 구조입니다.
# EfficientNet의 MBConv 블록 (NAS로 찾은 핵심 구조)
class MBConvBlock(nn.Module):
"""
Mobile Inverted Bottleneck Convolution.
EfficientNet의 기본 빌딩 블록.
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, expand_ratio=6, se_ratio=0.25):
super().__init__()
self.use_residual = (stride == 1 and in_channels == out_channels)
hidden_dim = int(in_channels * expand_ratio)
layers = []
# Expansion phase (1x1 conv)
if expand_ratio != 1:
layers.extend([
nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
])
# Depthwise conv
layers.extend([
nn.Conv2d(hidden_dim, hidden_dim, kernel_size,
stride=stride,
padding=kernel_size // 2,
groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
])
# Squeeze-and-Excitation (NAS가 발견한 중요 모듈)
se_channels = max(1, int(in_channels * se_ratio))
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(hidden_dim, se_channels, 1),
nn.SiLU(),
nn.Conv2d(se_channels, hidden_dim, 1),
nn.Sigmoid(),
)
# Output phase
layers.extend([
nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
out = self.conv[:-2](x) if len(self.conv) > 2 else self.conv(x)
# SE 모듈 적용
out_se = self.se(out)
out = out * out_se
out = self.conv[-2:](out)
if self.use_residual:
return x + out
return out
7.4 Once-for-All 네트워크
Cai et al. (2020). 하나의 초대형 네트워크를 학습한 후, 다양한 리소스 제약에 맞게 서브네트워크를 추출합니다.
class OFALayer(nn.Module):
"""
Once-for-All: 다양한 커널 크기를 지원하는 탄력적 레이어.
학습 시: 랜덤으로 커널 크기 선택 → 서브네트워크 학습
배포 시: 특정 커널 크기만 사용
"""
def __init__(self, channels, max_kernel=7):
super().__init__()
# 가장 큰 커널 크기의 conv 하나만 학습
self.max_conv = nn.Conv2d(
channels, channels,
kernel_size=max_kernel,
padding=max_kernel // 2,
groups=channels, bias=False
)
self.bn = nn.BatchNorm2d(channels)
self.act = nn.ReLU(inplace=True)
self.active_kernel = max_kernel
def set_active_kernel(self, kernel_size):
"""현재 활성 커널 크기 설정."""
assert kernel_size <= self.max_conv.kernel_size[0]
self.active_kernel = kernel_size
def forward(self, x):
if self.active_kernel == self.max_conv.kernel_size[0]:
weight = self.max_conv.weight
else:
# 중앙에서 active_kernel 크기만큼 잘라내기
center = self.max_conv.kernel_size[0] // 2
half = self.active_kernel // 2
weight = self.max_conv.weight[
:, :,
center - half: center + half + 1,
center - half: center + half + 1
]
padding = self.active_kernel // 2
out = F.conv2d(x, weight, padding=padding, groups=x.size(1))
return self.act(self.bn(out))
# Progressive shrinking 학습 (OFA의 핵심)
def progressive_shrinking_train(model, dataloader, kernels=[7, 5, 3]):
"""
1단계: 최대 커널(7) 학습
2단계: 7, 5 커널 랜덤 샘플링 학습
3단계: 7, 5, 3 커널 랜덤 샘플링 학습
"""
for stage, active_kernels in enumerate(
[kernels[:1], kernels[:2], kernels]
):
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
print(f"Stage {stage+1}: kernels = {active_kernels}")
for epoch in range(5):
for images, labels in dataloader:
# 랜덤으로 커널 크기 선택
k = active_kernels[torch.randint(len(active_kernels), (1,)).item()]
# 모든 레이어의 활성 커널 설정
for module in model.modules():
if isinstance(module, OFALayer):
module.set_active_kernel(k)
outputs = model(images)
loss = F.cross_entropy(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
8. 통합 모델 압축 파이프라인
실제 배포 시에는 여러 기법을 조합합니다:
class ModelCompressionPipeline:
"""
지식 증류 → 프루닝 → 양자화의 통합 파이프라인.
"""
def __init__(self, teacher, student_arch, num_classes):
self.teacher = teacher
self.num_classes = num_classes
# Student 모델 초기화
self.student = student_arch
def step1_distillation(self, train_loader, val_loader,
epochs=30, device='cuda'):
"""1단계: 지식 증류로 기본 학습."""
print("=== Step 1: Knowledge Distillation ===")
self.student = train_with_distillation(
self.teacher, self.student, train_loader,
num_epochs=epochs, temperature=4.0, alpha=0.5,
device=device
)
def step2_pruning(self, train_loader, sparsity=0.5,
finetune_epochs=10, device='cuda'):
"""2단계: 점진적 프루닝 + 파인튜닝."""
print(f"=== Step 2: Pruning (target sparsity: {sparsity:.0%}) ===")
pruner = GradualMagnitudePruner(
self.student,
initial_sparsity=0.0,
final_sparsity=sparsity,
begin_step=0,
end_step=len(train_loader) * finetune_epochs,
frequency=100
)
optimizer = torch.optim.Adam(
self.student.parameters(), lr=1e-4
)
self.student.to(device)
step = 0
for epoch in range(finetune_epochs):
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
outputs = self.student(images)
loss = F.cross_entropy(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pruner.step(step)
pruner.apply_masks() # 프루닝 마스크 유지
step += 1
def step3_quantization(self, calibration_loader, device='cpu'):
"""3단계: Post-Training Quantization."""
print("=== Step 3: Quantization ===")
self.student.eval().to(device)
# Static quantization 준비
self.student.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
torch.ao.quantization.prepare(self.student, inplace=True)
# 보정 데이터로 scale/zero_point 결정
with torch.no_grad():
for images, _ in calibration_loader:
self.student(images.to(device))
# 양자화 적용
torch.ao.quantization.convert(self.student, inplace=True)
print("Quantization complete (INT8)")
return self.student
def compare_models(self, val_loader, device='cuda'):
"""Teacher, Student, Pruned, Quantized 모델 성능 비교."""
def eval_model(model, loader, dev):
model.eval().to(dev)
correct, total = 0, 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(dev), labels.to(dev)
preds = model(images).argmax(1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return correct / total
teacher_acc = eval_model(self.teacher, val_loader, device)
student_acc = eval_model(self.student, val_loader, 'cpu')
t_params = sum(p.numel() for p in self.teacher.parameters())
s_params = sum(p.numel() for p in self.student.parameters())
print(f"\n{'='*50}")
print(f"Teacher | Params: {t_params:>10,} | Acc: {teacher_acc:.4f}")
print(f"Student | Params: {s_params:>10,} | Acc: {student_acc:.4f}")
print(f"압축률: {t_params/s_params:.1f}x | 성능 유지율: {student_acc/teacher_acc:.1%}")
결론
지식 증류와 모델 압축은 AI 모델을 현실 세계에서 사용 가능하게 만드는 핵심 기술입니다.
주요 요약:
- 지식 증류: Teacher의 소프트 타겟(확률 분포)을 통해 클래스 간 관계 정보 전달
- Feature 증류: 중간 레이어의 표현을 Student에 전달
- Relation 증류: 샘플 간 관계 구조 보존
- LLM 증류: DistilBERT, TinyLLM 등 실용적 대형 모델 압축
- 구조적 프루닝: 필터/헤드/레이어 단위 제거로 실제 속도 향상
- 비구조적 프루닝: 점진적 희소화로 정확도 손실 최소화
- 가중치 공유: ALBERT처럼 동일 파라미터를 반복 사용
- NAS: DARTS, OFA 등 자동화된 구조 탐색
실제 배포 시에는 증류 → 프루닝 → 양자화를 순차적으로 적용하는 파이프라인이 가장 효과적입니다.
참고 문헌
- Hinton et al. (2015). Distilling the Knowledge in a Neural Network. https://arxiv.org/abs/1503.02531
- Romero et al. (2015). FitNets: Hints for Thin Deep Nets.
- Zagoruyko and Komodakis (2017). Paying More Attention to Attention.
- Park et al. (2019). Relational Knowledge Distillation.
- Sanh et al. (2019). DistilBERT. https://arxiv.org/abs/1910.01108
- Lan et al. (2019). ALBERT: A Lite BERT.
- Liu et al. (2019). DARTS: Differentiable Architecture Search.
- Cai et al. (2020). Once-for-All: Train One Network and Specialize it for Efficient Deployment.
- Tan and Le (2019). EfficientNet. https://arxiv.org/abs/1905.11946
- PyTorch Pruning Documentation: https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.html