Skip to content

필사 모드: 지식 증류(Knowledge Distillation) 완전 가이드: 모델 경량화와 압축 기법

한국어
0%
정확도 0%
💡 왼쪽 원문을 읽으면서 오른쪽에 따라 써보세요. Tab 키로 힌트를 받을 수 있습니다.
원문 렌더가 준비되기 전까지 텍스트 가이드로 표시합니다.

소개

딥러닝 모델이 강력해질수록 크기도 커집니다. GPT-4, Llama 3 405B 같은 대형 모델은 수백 GB의 메모리를 필요로 하며, 모바일 기기나 엣지 디바이스에서는 실행이 불가능합니다. **지식 증류(Knowledge Distillation)**와 **모델 압축** 기법은 큰 모델의 성능을 최대한 유지하면서 모델을 작고 빠르게 만드는 핵심 기술입니다.

이 가이드에서 다루는 내용:

- 지식 증류의 이론적 배경과 완전한 PyTorch 구현

- 다양한 증류 방법 (Response-based, Feature-based, Relation-based)

- LLM 증류 사례 (DistilBERT, TinyLLM)

- 구조적/비구조적 프루닝

- 가중치 공유와 NAS

1. 지식 증류 기초

1.1 Teacher-Student 프레임워크

지식 증류는 2015년 Hinton, Vinyals, Dean이 제안한 기법으로, **큰 Teacher 모델의 "지식"을 작은 Student 모델로 전달**하는 것이 핵심입니다.

Teacher Model (크고 정확)

│ 소프트 타겟 (확률 분포) 전달

Student Model (작고 빠름)

단순히 정답 레이블(하드 타겟)로 학습하는 것과 달리, Teacher의 **소프트 타겟(softmax 출력 확률 분포)**을 사용합니다.

예를 들어, 고양이 이미지 분류:

- 하드 타겟: [0, 0, 1, 0, 0] (정답만 1)

- Teacher 소프트 타겟: [0.01, 0.05, **0.85**, 0.07, 0.02]

소프트 타겟에는 "이 이미지는 고양이지만 호랑이와 약간 비슷하다"는 정보가 담겨 있습니다. 이런 **클래스 간 유사성 정보**가 Student 학습에 큰 도움이 됩니다.

1.2 온도(Temperature) 파라미터

소프트맥스 출력이 너무 confident하면 [0.001, 0.002, 0.997, ...]처럼 거의 하드 타겟과 같아져서 정보가 적습니다. 온도 파라미터 T로 분포를 부드럽게 만듭니다:

softmax_T(z_i) = exp(z_i / T) / sum_j(exp(z_j / T))

- T가 1일 때: 일반 소프트맥스

- T가 1보다 클 때: 분포가 더 균일해짐 (더 많은 정보 전달)

- T가 1보다 작을 때: 분포가 더 날카로워짐

def temperature_softmax(logits, temperature=1.0):

"""온도로 조정된 소프트맥스."""

return F.softmax(logits / temperature, dim=-1)

예시

logits = torch.tensor([2.0, 1.0, 0.1, 0.5])

print("T=1:", temperature_softmax(logits, T=1).numpy().round(3))

[0.596, 0.219, 0.090, 0.096]

print("T=4:", temperature_softmax(logits, T=4).numpy().round(3))

[0.345, 0.262, 0.195, 0.216] — 더 균일

1.3 Hinton의 KD 손실 함수

KD 손실은 두 항의 가중 합입니다:

**KL Divergence 항** (소프트 타겟 매칭):

L_KD = T^2 * KLDiv(softmax_T(student_logits), softmax_T(teacher_logits))

T^2 스케일링: 그래디언트 크기를 T에 무관하게 유지하기 위함.

**Cross-Entropy 항** (하드 타겟 학습):

L_CE = CrossEntropy(student_logits, true_labels)

**총 손실**:

L = alpha * L_CE + (1 - alpha) * L_KD

1.4 완전한 PyTorch 구현

from torch.utils.data import DataLoader

class KnowledgeDistillationLoss(nn.Module):

"""

Hinton et al. (2015)의 지식 증류 손실.

L = alpha * CE(student, labels) + (1-alpha) * T^2 * KLDiv(student_soft, teacher_soft)

"""

def __init__(self, temperature=4.0, alpha=0.5):

super().__init__()

self.T = temperature

self.alpha = alpha

self.ce_loss = nn.CrossEntropyLoss()

self.kl_loss = nn.KLDivLoss(reduction='batchmean')

def forward(self, student_logits, teacher_logits, labels):

하드 타겟 손실

loss_ce = self.ce_loss(student_logits, labels)

소프트 타겟 손실 (KL Divergence)

student_soft = F.log_softmax(student_logits / self.T, dim=-1)

teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)

loss_kd = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)

return self.alpha * loss_ce + (1 - self.alpha) * loss_kd

def train_with_distillation(

teacher, student, train_loader,

num_epochs=10, temperature=4.0, alpha=0.5,

device='cuda'

):

"""Teacher-Student 증류 학습 루프."""

teacher = teacher.to(device).eval() # Teacher는 frozen

student = student.to(device)

Teacher 파라미터 동결

for param in teacher.parameters():

param.requires_grad = False

criterion = KnowledgeDistillationLoss(temperature=temperature, alpha=alpha)

optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

for epoch in range(num_epochs):

student.train()

total_loss = 0.0

correct = 0

total = 0

for images, labels in train_loader:

images, labels = images.to(device), labels.to(device)

Teacher 추론 (그래디언트 불필요)

with torch.no_grad():

teacher_logits = teacher(images)

Student 추론

student_logits = student(images)

KD 손실 계산

loss = criterion(student_logits, teacher_logits, labels)

optimizer.zero_grad()

loss.backward()

optimizer.step()

total_loss += loss.item() * images.size(0)

correct += (student_logits.argmax(1) == labels).sum().item()

total += images.size(0)

scheduler.step()

print(f"Epoch {epoch+1}/{num_epochs} | "

f"Loss: {total_loss/total:.4f} | "

f"Acc: {correct/total:.4f}")

return student

실제 사용 예시

Teacher: ResNet50 (25M params), Student: ResNet18 (11M params)

teacher = models.resnet50(weights='DEFAULT')

teacher.fc = nn.Linear(2048, 10)

student = models.resnet18(weights=None)

student.fc = nn.Linear(512, 10)

파라미터 수 비교

t_params = sum(p.numel() for p in teacher.parameters())

s_params = sum(p.numel() for p in student.parameters())

print(f"Teacher: {t_params:,} params") # ~25.6M

print(f"Student: {s_params:,} params") # ~11.2M

print(f"압축률: {t_params/s_params:.1f}x")

2. 다양한 증류 방법

2.1 Response-based 증류 (로짓 매칭)

가장 기본적인 방법으로, 앞서 설명한 Hinton KD가 대표적입니다. Student가 Teacher의 **최종 출력(로짓)**을 모방합니다.

class ResponseBasedDistillation(nn.Module):

"""최종 출력만 사용하는 Response-based 증류."""

def __init__(self, temperature=4.0):

super().__init__()

self.T = temperature

def forward(self, student_logits, teacher_logits):

student_soft = F.log_softmax(student_logits / self.T, dim=-1)

teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)

return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.T ** 2)

2.2 Feature-based 증류 (중간 레이어)

FitNets (Romero et al., 2015)에서 제안. Teacher의 중간 레이어 **feature map**을 Student가 따라하도록 학습합니다.

Teacher와 Student의 채널 수가 다를 수 있으므로 **Regressor** 네트워크로 차원을 맞춥니다.

class FeatureDistillationLoss(nn.Module):

"""중간 레이어 feature를 매칭하는 증류."""

def __init__(self, teacher_channels, student_channels):

super().__init__()

Student feature를 Teacher feature 공간으로 projection

self.regressor = nn.Sequential(

nn.Conv2d(student_channels, teacher_channels, kernel_size=1, bias=False),

nn.BatchNorm2d(teacher_channels),

nn.ReLU(inplace=True),

)

def forward(self, student_feat, teacher_feat):

Student feature를 Teacher 차원으로 변환

projected = self.regressor(student_feat)

MSE 손실로 feature 맞추기

return F.mse_loss(projected, teacher_feat.detach())

class HookBasedDistillation:

"""

forward hook으로 중간 레이어 feature 추출.

"""

def __init__(self, model, layer_names):

self.features = {}

self.hooks = []

for name, layer in model.named_modules():

if name in layer_names:

hook = layer.register_forward_hook(

self._make_hook(name)

)

self.hooks.append(hook)

def _make_hook(self, name):

def hook(module, input, output):

self.features[name] = output

return hook

def remove(self):

for hook in self.hooks:

hook.remove()

사용 예시

teacher = models.resnet50(weights='DEFAULT')

student = models.resnet18(weights=None)

Teacher: layer3 출력 (1024채널), Student: layer3 출력 (256채널)

teacher_hook = HookBasedDistillation(teacher, ['layer3'])

student_hook = HookBasedDistillation(student, ['layer3'])

feat_distill = FeatureDistillationLoss(

teacher_channels=1024,

student_channels=256

)

학습 시

x = torch.randn(4, 3, 224, 224)

teacher_out = teacher(x)

student_out = student(x)

teacher_feat = teacher_hook.features['layer3']

student_feat = student_hook.features['layer3']

loss_feat = feat_distill(student_feat, teacher_feat)

2.3 Relation-based 증류 (상관관계 학습)

RKD (Park et al., 2019). 샘플들 **사이의 관계**를 Student가 모방하도록 합니다. 절대적인 값이 아니라 상대적인 구조를 학습합니다.

class RelationalKnowledgeDistillation(nn.Module):

"""

Relational KD: 샘플 쌍의 거리 관계를 보존.

Teacher의 embedding 공간 구조를 Student가 학습.

"""

def __init__(self, distance_weight=25.0, angle_weight=50.0):

super().__init__()

self.dist_w = distance_weight

self.angle_w = angle_weight

def pdist(self, e, squared=False, eps=1e-12):

"""Pairwise distance computation."""

e_sq = (e ** 2).sum(dim=1)

prod = e @ e.t()

res = (e_sq.unsqueeze(1) + e_sq.unsqueeze(0) - 2 * prod).clamp(min=eps)

if not squared:

res = res.sqrt()

return res

def distance_loss(self, teacher_emb, student_emb):

"""거리 관계 보존 손실."""

t_d = self.pdist(teacher_emb)

정규화

t_d = t_d / (t_d.mean() + 1e-12)

s_d = self.pdist(student_emb)

s_d = s_d / (s_d.mean() + 1e-12)

return F.smooth_l1_loss(s_d, t_d.detach())

def angle_loss(self, teacher_emb, student_emb):

"""각도 관계 보존 손실."""

e_i - e_j 벡터 간 각도 관계

td = teacher_emb.unsqueeze(0) - teacher_emb.unsqueeze(1) # (N, N, D)

sd = student_emb.unsqueeze(0) - student_emb.unsqueeze(1)

cosine similarity

td_norm = F.normalize(td.view(-1, td.size(-1)), dim=-1)

sd_norm = F.normalize(sd.view(-1, sd.size(-1)), dim=-1)

t_angle = (td_norm * td_norm.flip(0)).sum(dim=-1)

s_angle = (sd_norm * sd_norm.flip(0)).sum(dim=-1)

return F.smooth_l1_loss(s_angle, t_angle.detach())

def forward(self, teacher_emb, student_emb):

loss = self.dist_w * self.distance_loss(teacher_emb, student_emb)

loss += self.angle_w * self.angle_loss(teacher_emb, student_emb)

return loss

2.4 Attention Transfer

Zagoruyko & Komodakis (2017). Attention map (채널별 activation 합)을 Teacher에서 Student로 전달합니다.

class AttentionTransfer(nn.Module):

"""Attention Transfer: activation map의 공간적 패턴 전달."""

def attention_map(self, feat):

"""

Feature map에서 attention map 계산.

각 공간 위치에서 채널 방향으로 제곱합의 제곱근.

"""

return F.normalize(feat.pow(2).mean(1).view(feat.size(0), -1))

def forward(self, student_feats, teacher_feats):

"""

student_feats, teacher_feats: list of feature maps

각 쌍에 대해 AT 손실 계산.

"""

loss = 0.0

for s_feat, t_feat in zip(student_feats, teacher_feats):

s_attn = self.attention_map(s_feat)

t_attn = self.attention_map(t_feat)

loss += (s_attn - t_attn.detach()).pow(2).mean()

return loss

3. LLM 증류

3.1 DistilBERT

DistilBERT (Sanh et al., 2019)는 BERT-base (110M 파라미터)를 66M으로 압축한 모델입니다.

주요 기법:

- **레이어 수 절반**: 12개 → 6개

- **Token-type embedding 제거**

- **Pooler 제거**

- **Triple 손실**: MLM + Distillation + Cosine embedding

from transformers import (

DistilBertModel, BertModel,

DistilBertTokenizer

)

class DistilBERTLoss(nn.Module):

"""

DistilBERT 3중 손실:

1. MLM CE 손실 (언어 모델링)

2. Soft-target KD 손실 (Teacher logit 매칭)

3. Cosine embedding 손실 (hidden state 유사도)

"""

def __init__(self, temperature=2.0, alpha=0.5, beta=0.1):

super().__init__()

self.T = temperature

self.alpha = alpha # MLM 가중치

self.beta = beta # Cosine 손실 가중치

def forward(self, student_logits, teacher_logits,

student_hidden, teacher_hidden, mlm_labels):

1. MLM 손실

loss_mlm = F.cross_entropy(

student_logits.view(-1, student_logits.size(-1)),

mlm_labels.view(-1),

ignore_index=-100

)

2. Soft KD 손실

s_soft = F.log_softmax(student_logits / self.T, dim=-1)

t_soft = F.softmax(teacher_logits / self.T, dim=-1)

loss_kd = F.kl_div(s_soft, t_soft, reduction='batchmean') * (self.T ** 2)

3. Cosine embedding 손실 (hidden state 유사도)

loss_cos = 1 - F.cosine_similarity(student_hidden, teacher_hidden, dim=-1).mean()

return (self.alpha * loss_mlm

+ (1 - self.alpha) * loss_kd

+ self.beta * loss_cos)

추론 예시

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

model = DistilBertModel.from_pretrained('distilbert-base-uncased')

text = "Knowledge distillation is a model compression technique."

inputs = tokenizer(text, return_tensors='pt')

with torch.no_grad():

outputs = model(**inputs)

print(outputs.last_hidden_state.shape)

torch.Size([1, 11, 768])

DistilBERT 성능:

- BERT-base 대비 40% 크기 감소

- 60% 빠른 추론 속도

- GLUE 벤치마크 97% 성능 유지

3.2 TinyLLaMA 스타일 증류

LLM 증류에서는 Teacher의 다음 토큰 예측 분포를 Student가 학습합니다.

class LLMDistillationTrainer:

"""

LLM 증류 학습기.

Teacher의 로짓 분포를 Student에 전달.

"""

def __init__(self, teacher, student, temperature=2.0, alpha=0.5):

self.teacher = teacher

self.student = student

self.T = temperature

self.alpha = alpha

def compute_loss(self, input_ids, attention_mask, labels):

Teacher 추론 (no_grad)

with torch.no_grad():

teacher_out = self.teacher(

input_ids=input_ids,

attention_mask=attention_mask,

)

teacher_logits = teacher_out.logits # (B, seq_len, vocab_size)

Student 추론

student_out = self.student(

input_ids=input_ids,

attention_mask=attention_mask,

)

student_logits = student_out.logits

1. CE 손실 (하드 타겟)

shift_logits = student_logits[..., :-1, :].contiguous()

shift_labels = labels[..., 1:].contiguous()

loss_ce = F.cross_entropy(

shift_logits.view(-1, shift_logits.size(-1)),

shift_labels.view(-1),

ignore_index=-100

)

2. KD 손실 (소프트 타겟)

유효한 토큰에 대해서만 계산

mask = (shift_labels != -100).float()

s_log_soft = F.log_softmax(

shift_logits / self.T, dim=-1

)

t_soft = F.softmax(

teacher_logits[..., :-1, :].contiguous() / self.T, dim=-1

)

토큰별 KL divergence

kl_per_token = F.kl_div(

s_log_soft, t_soft,

reduction='none'

).sum(-1) # (B, seq_len)

loss_kd = (kl_per_token * mask).sum() / mask.sum()

loss_kd = loss_kd * (self.T ** 2)

return self.alpha * loss_ce + (1 - self.alpha) * loss_kd

3.3 Distil Whisper

OpenAI Whisper 음성 인식 모델의 증류판입니다.

from transformers import (

WhisperProcessor, WhisperForConditionalGeneration

)

Distil-Whisper: Whisper-large-v2의 증류

large-v2: 1550M params → distil-large-v2: 756M params (2x faster, same WER)

processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v2")

model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v2")

6개 디코더 레이어만 유지 (원본 32개 → 2개)

인코더는 동일하게 유지

print(f"Encoder layers: {len(model.model.encoder.layers)}")

print(f"Decoder layers: {len(model.model.decoder.layers)}")

4. 구조적 프루닝 (Structured Pruning)

프루닝은 모델에서 중요하지 않은 파라미터나 구조를 제거하는 기법입니다.

**구조적 프루닝**: 필터, 헤드, 레이어 전체 단위로 제거 → **실제 속도 향상**

**비구조적 프루닝**: 개별 가중치를 0으로 만들기 → 희소 행렬 (특수 하드웨어 필요)

4.1 필터 프루닝

CNN에서 L1/L2 norm이 작은 필터를 제거합니다.

def get_filter_importance(conv_layer, norm='l1'):

"""각 필터의 중요도(norm) 계산."""

weight = conv_layer.weight.data # (out_ch, in_ch, kH, kW)

if norm == 'l1':

return weight.abs().sum(dim=(1, 2, 3))

elif norm == 'l2':

return weight.pow(2).sum(dim=(1, 2, 3)).sqrt()

elif norm == 'gm':

Geometric Median 기반 (더 robust)

return weight.view(weight.size(0), -1).norm(p=2, dim=1)

def prune_conv_layer(conv, keep_ratio=0.5):

"""

필터 프루닝: 중요도 낮은 필터 제거.

Returns: 새로운 Conv2d 레이어 (필터 수 감소)

"""

importance = get_filter_importance(conv)

n_keep = int(conv.out_channels * keep_ratio)

중요도 상위 n_keep개 필터 선택

_, indices = importance.topk(n_keep)

indices, _ = indices.sort()

새 conv 레이어 생성

new_conv = nn.Conv2d(

conv.in_channels,

n_keep,

kernel_size=conv.kernel_size,

stride=conv.stride,

padding=conv.padding,

bias=conv.bias is not None

)

new_conv.weight.data = conv.weight.data[indices]

if conv.bias is not None:

new_conv.bias.data = conv.bias.data[indices]

return new_conv, indices

class StructuredPruner:

"""ResNet 스타일 모델의 구조적 프루닝."""

def __init__(self, model, prune_ratio=0.5):

self.model = model

self.prune_ratio = prune_ratio

def compute_layer_importance(self):

"""각 레이어 필터의 중요도 계산."""

importance_dict = {}

for name, module in self.model.named_modules():

if isinstance(module, nn.Conv2d):

importance = get_filter_importance(module)

importance_dict[name] = importance

return importance_dict

def global_threshold_prune(self, sparsity=0.5):

"""

전체 필터에 대해 전역 임계값으로 프루닝.

가장 중요도 낮은 sparsity 비율의 필터 제거.

"""

all_importances = []

for name, module in self.model.named_modules():

if isinstance(module, nn.Conv2d):

imp = get_filter_importance(module)

all_importances.append(imp)

모든 필터 중요도를 하나로 합치기

all_imp = torch.cat(all_importances)

threshold = all_imp.kthvalue(int(len(all_imp) * sparsity)).values.item()

임계값보다 낮은 필터 제거

masks = {}

for name, module in self.model.named_modules():

if isinstance(module, nn.Conv2d):

imp = get_filter_importance(module)

masks[name] = imp >= threshold

return masks

4.2 헤드 프루닝 (Attention Head Pruning)

트랜스포머에서 중요하지 않은 Attention Head를 제거합니다.

class AttentionHeadPruner:

"""

Transformer의 Multi-Head Attention 헤드 프루닝.

Michel et al. (2019): Are Sixteen Heads Really Better than One?

"""

def compute_head_importance(self, model, dataloader, device='cuda'):

"""각 레이어, 각 헤드의 중요도 계산."""

head_importance = {}

model.eval()

for batch in dataloader:

inputs = {k: v.to(device) for k, v in batch.items()}

Attention 가중치 저장용 hook

attention_weights = {}

hooks = []

for name, module in model.named_modules():

if 'attention' in name.lower() and hasattr(module, 'num_heads'):

def hook_fn(m, inp, out, layer_name=name):

if hasattr(out, 'attentions') and out.attentions is not None:

attention_weights[layer_name] = out.attentions

hooks.append(module.register_forward_hook(hook_fn))

with torch.no_grad():

outputs = model(**inputs, output_attentions=True)

Head importance: attention 가중치의 분산 합

if hasattr(outputs, 'attentions') and outputs.attentions:

for layer_idx, attn in enumerate(outputs.attentions):

attn: (batch, heads, seq, seq)

head_imp = attn.mean(0).var(-1).mean(-1) # (heads,)

key = f"layer_{layer_idx}"

if key not in head_importance:

head_importance[key] = head_imp

else:

head_importance[key] += head_imp

for hook in hooks:

hook.remove()

return head_importance

def prune_heads(self, model, heads_to_prune):

"""

지정된 head를 제거.

heads_to_prune: {layer_idx: [head_indices]}

"""

model.prune_heads(heads_to_prune) # HuggingFace 모델 내장 기능

return model

HuggingFace transformers 활용

from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

layer 0의 head 0, 3, 5 제거

heads_to_prune = {0: [0, 3, 5], 6: [1, 2]}

model.prune_heads(heads_to_prune)

total_params = sum(p.numel() for p in model.parameters())

print(f"Pruned BERT params: {total_params:,}")

4.3 레이어 프루닝

성능에 기여가 적은 트랜스포머 레이어 전체를 제거합니다.

class LayerPruner:

"""트랜스포머 레이어 프루닝."""

def compute_layer_importance_by_gradient(

self, model, dataloader, device='cuda'

):

"""

각 레이어에 대한 gradient 기반 중요도 계산.

중요도 = |gradient * weight|의 평균

"""

model.train()

layer_importance = {}

for batch in dataloader:

inputs = {k: v.to(device) for k, v in batch.items()}

outputs = model(**inputs)

loss = outputs.loss

loss.backward()

for i, layer in enumerate(model.encoder.layer):

grad_sum = 0.0

param_count = 0

for param in layer.parameters():

if param.grad is not None:

grad_sum += (param.grad * param.data).abs().sum().item()

param_count += param.numel()

key = f"layer_{i}"

imp = grad_sum / max(param_count, 1)

if key not in layer_importance:

layer_importance[key] = 0.0

layer_importance[key] += imp

model.zero_grad()

return layer_importance

def drop_layers(self, model, layers_to_drop):

"""지정된 레이어를 제거하고 남은 레이어로 재구성."""

new_model = copy.deepcopy(model)

remaining = [

layer for i, layer in enumerate(new_model.encoder.layer)

if i not in layers_to_drop

]

new_model.encoder.layer = nn.ModuleList(remaining)

new_model.config.num_hidden_layers = len(remaining)

return new_model

4.4 PyTorch torch.nn.utils.prune

model = models.resnet18(weights='DEFAULT')

특정 레이어에 비구조적 L1 프루닝 적용

conv = model.layer1[0].conv1

prune.l1_unstructured(conv, name='weight', amount=0.3)

weight의 30%를 0으로 만듦 (mask 사용)

확인

print(f"Sparsity: {(conv.weight == 0).float().mean():.2f}")

구조적 프루닝 (필터 단위)

prune.ln_structured(conv, name='weight', amount=0.3, n=2, dim=0)

영구 적용 (mask 제거, 실제 weight 0화)

prune.remove(conv, 'weight')

여러 레이어에 반복 적용

parameters_to_prune = []

for module in model.modules():

if isinstance(module, nn.Conv2d):

parameters_to_prune.append((module, 'weight'))

prune.global_unstructured(

parameters_to_prune,

pruning_method=prune.L1Unstructured,

amount=0.4, # 전체 파라미터의 40% 프루닝

)

전체 희소성 확인

total_zero = sum(

(m.weight == 0).sum().item()

for m in model.modules() if isinstance(m, nn.Conv2d)

)

total_params = sum(

m.weight.numel()

for m in model.modules() if isinstance(m, nn.Conv2d)

)

print(f"Global sparsity: {total_zero/total_params:.2%}")

5. 비구조적 프루닝 (Unstructured Pruning)

5.1 크기 기반 프루닝

절대값이 작은 가중치를 0으로 만드는 가장 단순한 방법입니다.

class MagnitudePruner:

"""크기 기반 비구조적 프루닝."""

def __init__(self, model, sparsity=0.5):

self.model = model

self.sparsity = sparsity

self.masks = {}

def compute_global_threshold(self):

"""전체 가중치에 대한 전역 임계값 계산."""

all_weights = []

for name, param in self.model.named_parameters():

if 'weight' in name and param.dim() > 1:

all_weights.append(param.data.abs().view(-1))

all_weights = torch.cat(all_weights)

threshold = all_weights.kthvalue(

int(len(all_weights) * self.sparsity)

).values.item()

return threshold

def apply_pruning(self):

"""전역 임계값으로 프루닝 마스크 생성."""

threshold = self.compute_global_threshold()

for name, param in self.model.named_parameters():

if 'weight' in name and param.dim() > 1:

mask = (param.data.abs() >= threshold).float()

self.masks[name] = mask

param.data *= mask # 임계값 이하 가중치를 0으로

total_zeros = sum((m == 0).sum().item() for m in self.masks.values())

total_params = sum(m.numel() for m in self.masks.values())

print(f"실제 희소성: {total_zeros/total_params:.2%}")

def apply_masks(self):

"""역전파 후 마스크 재적용 (기울기로 0이 복원되지 않도록)."""

for name, param in self.model.named_parameters():

if name in self.masks:

param.data *= self.masks[name]

5.2 점진적 프루닝 (Gradual Magnitude Pruning)

처음부터 많이 프루닝하면 성능이 급락합니다. 학습 중에 서서히 희소성을 높이는 방법이 효과적입니다.

class GradualMagnitudePruner:

"""

학습 중 점진적으로 희소성을 높이는 프루닝.

Zhu & Gupta (2017): To Prune, or Not to Prune

"""

def __init__(self, model, initial_sparsity=0.0, final_sparsity=0.8,

begin_step=0, end_step=1000, frequency=100):

self.model = model

self.initial_sparsity = initial_sparsity

self.final_sparsity = final_sparsity

self.begin_step = begin_step

self.end_step = end_step

self.frequency = frequency

self.masks = {}

self._init_masks()

def _init_masks(self):

for name, param in self.model.named_parameters():

if 'weight' in name and param.dim() > 1:

self.masks[name] = torch.ones_like(param.data)

def compute_sparsity(self, step):

"""현재 스텝에서의 목표 희소성 계산 (cubic schedule)."""

if step < self.begin_step:

return self.initial_sparsity

if step > self.end_step:

return self.final_sparsity

Cubic decay schedule

pct_done = (step - self.begin_step) / (self.end_step - self.begin_step)

sparsity = (

self.final_sparsity

+ (self.initial_sparsity - self.final_sparsity)

* (1 - pct_done) ** 3

)

return sparsity

def step(self, global_step):

"""학습 스텝에서 프루닝 업데이트."""

if (global_step % self.frequency != 0 or

global_step < self.begin_step or

global_step > self.end_step):

return

target_sparsity = self.compute_sparsity(global_step)

self._update_masks(target_sparsity)

def _update_masks(self, sparsity):

"""현재 희소성 목표에 맞게 마스크 업데이트."""

for name, param in self.model.named_parameters():

if name not in self.masks:

continue

현재 활성 가중치만 고려 (이미 pruned된 것 제외)

alive = param.data[self.masks[name].bool()]

if len(alive) == 0:

continue

n_prune = int(sparsity * param.data.numel())

if n_prune == 0:

continue

전체 중 하위 n_prune개 찾기

threshold = param.data.abs().view(-1).kthvalue(n_prune).values.item()

self.masks[name] = (param.data.abs() > threshold).float()

param.data *= self.masks[name]

6. 가중치 공유 (Weight Sharing)

6.1 크로스 레이어 파라미터 공유 — ALBERT

ALBERT (Lan et al., 2019)는 BERT의 파라미터를 크게 줄이기 위해 모든 트랜스포머 레이어에서 **동일한 파라미터를 반복 사용**합니다.

class ALBERTEncoder(nn.Module):

"""

ALBERT 스타일 가중치 공유.

하나의 트랜스포머 레이어를 N번 반복 실행.

"""

def __init__(self, hidden_size=768, num_heads=12,

intermediate_size=3072, num_layers=12):

super().__init__()

단 하나의 트랜스포머 레이어 정의

self.shared_layer = nn.TransformerEncoderLayer(

d_model=hidden_size,

nhead=num_heads,

dim_feedforward=intermediate_size,

batch_first=True,

norm_first=True,

)

self.num_layers = num_layers

self.norm = nn.LayerNorm(hidden_size)

def forward(self, x, src_key_padding_mask=None):

동일한 레이어를 num_layers번 반복

for _ in range(self.num_layers):

x = self.shared_layer(x, src_key_padding_mask=src_key_padding_mask)

return self.norm(x)

def count_parameters(self):

실제 파라미터 수 (공유되므로 한 번만 카운트)

return sum(p.numel() for p in self.parameters())

파라미터 비교

bert_encoder = nn.TransformerEncoder(

nn.TransformerEncoderLayer(768, 12, 3072, batch_first=True, norm_first=True),

num_layers=12

)

albert_encoder = ALBERTEncoder(num_layers=12)

bert_params = sum(p.numel() for p in bert_encoder.parameters())

albert_params = albert_encoder.count_parameters()

print(f"BERT encoder: {bert_params:,}") # ~85M

print(f"ALBERT encoder: {albert_params:,}") # ~7M

print(f"압축률: {bert_params/albert_params:.1f}x") # ~12x

6.2 팩토라이제이션 (ALBERT 임베딩)

ALBERT는 임베딩 레이어도 분해합니다: vocab_size x hidden_size → vocab_size x embedding_size + embedding_size x hidden_size.

class FactorizedEmbedding(nn.Module):

"""

ALBERT의 임베딩 팩토라이제이션.

vocab_size x H → (vocab_size x E) + (E x H)

E << H로 파라미터 대폭 감소.

"""

def __init__(self, vocab_size, embedding_size=128, hidden_size=768):

super().__init__()

self.word_embeddings = nn.Embedding(vocab_size, embedding_size)

E → H 선형 변환

self.embedding_projection = nn.Linear(embedding_size, hidden_size, bias=False)

def forward(self, input_ids):

embed = self.word_embeddings(input_ids) # (B, seq, E)

return self.embedding_projection(embed) # (B, seq, H)

파라미터 비교 (vocab_size=30000)

standard_embed = nn.Embedding(30000, 768) # 23.0M params

factorized_embed = FactorizedEmbedding(30000, 128, 768) # 3.97M params

s_params = sum(p.numel() for p in standard_embed.parameters())

f_params = sum(p.numel() for p in factorized_embed.parameters())

print(f"Standard: {s_params:,} → Factorized: {f_params:,}")

print(f"절감: {(s_params - f_params)/s_params:.1%}")

7. 신경망 구조 탐색 (NAS)

7.1 Manual Design vs AutoML

전통적으로 CNN 구조는 연구자가 수작업으로 설계했습니다 (VGG, ResNet). NAS는 이 과정을 자동화합니다.

NAS의 세 가지 핵심 요소:

1. **검색 공간 (Search Space)**: 어떤 연산/구조를 탐색할지

2. **검색 전략 (Search Strategy)**: 강화학습, 진화 알고리즘, 기울기 기반 등

3. **성능 추정 (Performance Estimation)**: 후보 구조의 성능을 빠르게 평가

7.2 DARTS (Differentiable Architecture Search)

Liu et al. (2019). 구조 탐색을 연속적인 최적화 문제로 변환합니다.

후보 연산들

PRIMITIVES = [

'none', # 연결 없음

'skip_connect', # 항등 함수

'sep_conv_3x3', # 3x3 Separable Conv

'sep_conv_5x5', # 5x5 Separable Conv

'dil_conv_3x3', # 3x3 Dilated Conv

'dil_conv_5x5', # 5x5 Dilated Conv

'avg_pool_3x3', # 3x3 Average Pool

'max_pool_3x3', # 3x3 Max Pool

]

class MixedOperation(nn.Module):

"""

DARTS: 여러 연산의 가중 합으로 엣지 표현.

alpha (아키텍처 파라미터)가 각 연산의 가중치를 결정.

"""

def __init__(self, C, stride, primitives):

super().__init__()

self.ops = nn.ModuleList([

self._build_op(primitive, C, stride)

for primitive in primitives

])

아키텍처 파라미터 (학습 가능)

self.arch_params = nn.Parameter(

torch.randn(len(primitives)) * 1e-3

)

def _build_op(self, primitive, C, stride):

if primitive == 'none':

return nn.Sequential() # Zero operation

elif primitive == 'skip_connect':

return nn.Identity() if stride == 1 else nn.Sequential(

nn.AvgPool2d(stride, stride, padding=0),

nn.Conv2d(C, C, 1, bias=False),

nn.BatchNorm2d(C)

)

elif primitive == 'sep_conv_3x3':

return nn.Sequential(

nn.Conv2d(C, C, 3, stride=stride, padding=1, groups=C, bias=False),

nn.Conv2d(C, C, 1, bias=False),

nn.BatchNorm2d(C),

nn.ReLU(inplace=True),

)

elif primitive == 'avg_pool_3x3':

return nn.AvgPool2d(3, stride=stride, padding=1)

elif primitive == 'max_pool_3x3':

return nn.MaxPool2d(3, stride=stride, padding=1)

else:

return nn.Identity()

def forward(self, x):

Softmax로 가중치 계산 후 각 연산의 가중 합

weights = F.softmax(self.arch_params, dim=0)

results = []

for w, op in zip(weights, self.ops):

try:

result = op(x)

if result.shape == x.shape or len(results) == 0:

results.append(w * result)

else:

results.append(w * result)

except Exception:

pass

if results:

모든 연산 결과의 가중 합

output = results[0]

for r in results[1:]:

if r.shape == output.shape:

output = output + r

return output

return x

class DARTSCell(nn.Module):

"""DARTS 셀: 여러 Mixed Operation으로 구성된 DAG."""

def __init__(self, num_nodes, C, stride=1):

super().__init__()

self.num_nodes = num_nodes

self.ops = nn.ModuleList()

각 노드 쌍에 대해 Mixed Operation 생성

for i in range(num_nodes):

for j in range(i):

self.ops.append(

MixedOperation(C, stride, PRIMITIVES)

)

def forward(self, *inputs):

states = list(inputs)

op_idx = 0

for i in range(self.num_nodes):

모든 이전 상태로부터 합산

node_input = sum(

self.ops[op_idx + j](states[j])

for j in range(len(states))

)

op_idx += len(states)

states.append(node_input)

return states[-1]

def get_arch_params(self):

"""아키텍처 파라미터 반환."""

return [op.arch_params for op in self.ops]

def darts_discrete(cell):

"""

연속적인 아키텍처를 이산 구조로 변환.

각 엣지에서 가장 높은 가중치의 연산 선택.

"""

for op in cell.ops:

weights = F.softmax(op.arch_params, dim=0)

best_op_idx = weights.argmax().item()

print(f" Selected: {PRIMITIVES[best_op_idx]} "

f"(weight: {weights[best_op_idx]:.3f})")

7.3 EfficientNet의 NAS 프로세스

EfficientNet-B0는 MnasNet과 유사한 신경망 구조 탐색으로 찾은 기저 구조입니다.

EfficientNet의 MBConv 블록 (NAS로 찾은 핵심 구조)

class MBConvBlock(nn.Module):

"""

Mobile Inverted Bottleneck Convolution.

EfficientNet의 기본 빌딩 블록.

"""

def __init__(self, in_channels, out_channels, kernel_size,

stride=1, expand_ratio=6, se_ratio=0.25):

super().__init__()

self.use_residual = (stride == 1 and in_channels == out_channels)

hidden_dim = int(in_channels * expand_ratio)

layers = []

Expansion phase (1x1 conv)

if expand_ratio != 1:

layers.extend([

nn.Conv2d(in_channels, hidden_dim, 1, bias=False),

nn.BatchNorm2d(hidden_dim),

nn.SiLU(),

])

Depthwise conv

layers.extend([

nn.Conv2d(hidden_dim, hidden_dim, kernel_size,

stride=stride,

padding=kernel_size // 2,

groups=hidden_dim, bias=False),

nn.BatchNorm2d(hidden_dim),

nn.SiLU(),

])

Squeeze-and-Excitation (NAS가 발견한 중요 모듈)

se_channels = max(1, int(in_channels * se_ratio))

self.se = nn.Sequential(

nn.AdaptiveAvgPool2d(1),

nn.Conv2d(hidden_dim, se_channels, 1),

nn.SiLU(),

nn.Conv2d(se_channels, hidden_dim, 1),

nn.Sigmoid(),

)

Output phase

layers.extend([

nn.Conv2d(hidden_dim, out_channels, 1, bias=False),

nn.BatchNorm2d(out_channels),

])

self.conv = nn.Sequential(*layers)

def forward(self, x):

out = self.conv[:-2](x) if len(self.conv) > 2 else self.conv(x)

SE 모듈 적용

out_se = self.se(out)

out = out * out_se

out = self.conv[-2:](out)

if self.use_residual:

return x + out

return out

7.4 Once-for-All 네트워크

Cai et al. (2020). 하나의 초대형 네트워크를 학습한 후, 다양한 리소스 제약에 맞게 서브네트워크를 추출합니다.

class OFALayer(nn.Module):

"""

Once-for-All: 다양한 커널 크기를 지원하는 탄력적 레이어.

학습 시: 랜덤으로 커널 크기 선택 → 서브네트워크 학습

배포 시: 특정 커널 크기만 사용

"""

def __init__(self, channels, max_kernel=7):

super().__init__()

가장 큰 커널 크기의 conv 하나만 학습

self.max_conv = nn.Conv2d(

channels, channels,

kernel_size=max_kernel,

padding=max_kernel // 2,

groups=channels, bias=False

)

self.bn = nn.BatchNorm2d(channels)

self.act = nn.ReLU(inplace=True)

self.active_kernel = max_kernel

def set_active_kernel(self, kernel_size):

"""현재 활성 커널 크기 설정."""

assert kernel_size <= self.max_conv.kernel_size[0]

self.active_kernel = kernel_size

def forward(self, x):

if self.active_kernel == self.max_conv.kernel_size[0]:

weight = self.max_conv.weight

else:

중앙에서 active_kernel 크기만큼 잘라내기

center = self.max_conv.kernel_size[0] // 2

half = self.active_kernel // 2

weight = self.max_conv.weight[

:, :,

center - half: center + half + 1,

center - half: center + half + 1

]

padding = self.active_kernel // 2

out = F.conv2d(x, weight, padding=padding, groups=x.size(1))

return self.act(self.bn(out))

Progressive shrinking 학습 (OFA의 핵심)

def progressive_shrinking_train(model, dataloader, kernels=[7, 5, 3]):

"""

1단계: 최대 커널(7) 학습

2단계: 7, 5 커널 랜덤 샘플링 학습

3단계: 7, 5, 3 커널 랜덤 샘플링 학습

"""

for stage, active_kernels in enumerate(

[kernels[:1], kernels[:2], kernels]

):

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

print(f"Stage {stage+1}: kernels = {active_kernels}")

for epoch in range(5):

for images, labels in dataloader:

랜덤으로 커널 크기 선택

k = active_kernels[torch.randint(len(active_kernels), (1,)).item()]

모든 레이어의 활성 커널 설정

for module in model.modules():

if isinstance(module, OFALayer):

module.set_active_kernel(k)

outputs = model(images)

loss = F.cross_entropy(outputs, labels)

optimizer.zero_grad()

loss.backward()

optimizer.step()

8. 통합 모델 압축 파이프라인

실제 배포 시에는 여러 기법을 조합합니다:

class ModelCompressionPipeline:

"""

지식 증류 → 프루닝 → 양자화의 통합 파이프라인.

"""

def __init__(self, teacher, student_arch, num_classes):

self.teacher = teacher

self.num_classes = num_classes

Student 모델 초기화

self.student = student_arch

def step1_distillation(self, train_loader, val_loader,

epochs=30, device='cuda'):

"""1단계: 지식 증류로 기본 학습."""

print("=== Step 1: Knowledge Distillation ===")

self.student = train_with_distillation(

self.teacher, self.student, train_loader,

num_epochs=epochs, temperature=4.0, alpha=0.5,

device=device

)

def step2_pruning(self, train_loader, sparsity=0.5,

finetune_epochs=10, device='cuda'):

"""2단계: 점진적 프루닝 + 파인튜닝."""

print(f"=== Step 2: Pruning (target sparsity: {sparsity:.0%}) ===")

pruner = GradualMagnitudePruner(

self.student,

initial_sparsity=0.0,

final_sparsity=sparsity,

begin_step=0,

end_step=len(train_loader) * finetune_epochs,

frequency=100

)

optimizer = torch.optim.Adam(

self.student.parameters(), lr=1e-4

)

self.student.to(device)

step = 0

for epoch in range(finetune_epochs):

for images, labels in train_loader:

images, labels = images.to(device), labels.to(device)

outputs = self.student(images)

loss = F.cross_entropy(outputs, labels)

optimizer.zero_grad()

loss.backward()

optimizer.step()

pruner.step(step)

pruner.apply_masks() # 프루닝 마스크 유지

step += 1

def step3_quantization(self, calibration_loader, device='cpu'):

"""3단계: Post-Training Quantization."""

print("=== Step 3: Quantization ===")

self.student.eval().to(device)

Static quantization 준비

self.student.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')

torch.ao.quantization.prepare(self.student, inplace=True)

보정 데이터로 scale/zero_point 결정

with torch.no_grad():

for images, _ in calibration_loader:

self.student(images.to(device))

양자화 적용

torch.ao.quantization.convert(self.student, inplace=True)

print("Quantization complete (INT8)")

return self.student

def compare_models(self, val_loader, device='cuda'):

"""Teacher, Student, Pruned, Quantized 모델 성능 비교."""

def eval_model(model, loader, dev):

model.eval().to(dev)

correct, total = 0, 0

with torch.no_grad():

for images, labels in loader:

images, labels = images.to(dev), labels.to(dev)

preds = model(images).argmax(1)

correct += (preds == labels).sum().item()

total += labels.size(0)

return correct / total

teacher_acc = eval_model(self.teacher, val_loader, device)

student_acc = eval_model(self.student, val_loader, 'cpu')

t_params = sum(p.numel() for p in self.teacher.parameters())

s_params = sum(p.numel() for p in self.student.parameters())

print(f"\n{'='*50}")

print(f"Teacher | Params: {t_params:>10,} | Acc: {teacher_acc:.4f}")

print(f"Student | Params: {s_params:>10,} | Acc: {student_acc:.4f}")

print(f"압축률: {t_params/s_params:.1f}x | 성능 유지율: {student_acc/teacher_acc:.1%}")

결론

지식 증류와 모델 압축은 AI 모델을 현실 세계에서 사용 가능하게 만드는 핵심 기술입니다.

주요 요약:

1. **지식 증류**: Teacher의 소프트 타겟(확률 분포)을 통해 클래스 간 관계 정보 전달

2. **Feature 증류**: 중간 레이어의 표현을 Student에 전달

3. **Relation 증류**: 샘플 간 관계 구조 보존

4. **LLM 증류**: DistilBERT, TinyLLM 등 실용적 대형 모델 압축

5. **구조적 프루닝**: 필터/헤드/레이어 단위 제거로 실제 속도 향상

6. **비구조적 프루닝**: 점진적 희소화로 정확도 손실 최소화

7. **가중치 공유**: ALBERT처럼 동일 파라미터를 반복 사용

8. **NAS**: DARTS, OFA 등 자동화된 구조 탐색

실제 배포 시에는 증류 → 프루닝 → 양자화를 순차적으로 적용하는 파이프라인이 가장 효과적입니다.

참고 문헌

- Hinton et al. (2015). Distilling the Knowledge in a Neural Network. https://arxiv.org/abs/1503.02531

- Romero et al. (2015). FitNets: Hints for Thin Deep Nets.

- Zagoruyko and Komodakis (2017). Paying More Attention to Attention.

- Park et al. (2019). Relational Knowledge Distillation.

- Sanh et al. (2019). DistilBERT. https://arxiv.org/abs/1910.01108

- Lan et al. (2019). ALBERT: A Lite BERT.

- Liu et al. (2019). DARTS: Differentiable Architecture Search.

- Cai et al. (2020). Once-for-All: Train One Network and Specialize it for Efficient Deployment.

- Tan and Le (2019). EfficientNet. https://arxiv.org/abs/1905.11946

- PyTorch Pruning Documentation: https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.html

현재 단락 (1/943)

딥러닝 모델이 강력해질수록 크기도 커집니다. GPT-4, Llama 3 405B 같은 대형 모델은 수백 GB의 메모리를 필요로 하며, 모바일 기기나 엣지 디바이스에서는 실행이 불가...

작성 글자: 0원문 글자: 30,618작성 단락: 0/943