Skip to content
Published on

知識蒸留完全ガイド: モデル圧縮と軽量化技術

Authors

はじめに

深層学習モデルは高性能になる一方で、巨大化も続けています。GPT-4やLlama 3 405Bのようなモデルは数百GBのメモリを要求するため、モバイルデバイスやエッジハードウェアでの動作が困難です。知識蒸留モデル圧縮技術は、精度をできる限り保ちながらモデルを小さく高速化するための必須ツールです。

このガイドが扱う内容:

  • 完全なPyTorch実装を含む知識蒸留の理論的基礎
  • 複数の蒸留パラダイム: レスポンスベース、フィーチャーベース、関係ベース
  • LLM蒸留のケーススタディ(DistilBERT、TinyLLM、Distil Whisper)
  • 構造化・非構造化プルーニング
  • 重み共有とニューラルアーキテクチャ探索(NAS)

1. 知識蒸留の基礎

1.1 教師-学生フレームワーク

知識蒸留は2015年にHinton、Vinyals、Deanによって提案されました。核心アイデアは大きな教師モデルの「知識」を小さな学生モデルに転移することです。

Teacher Model (大きく、高精度)
       │  ソフトターゲット(確率分布)
Student Model (小さく、高速)

ハードラベル(one-hotベクトル)での学習ではなく、学生は教師のソフトターゲット — 完全なソフトマックス出力分布 — から学習します。

例えば、猫の画像を分類する場合:

  • ハードターゲット: [0, 0, 1, 0, 0](正解クラスのみ1)
  • 教師のソフトターゲット: [0.01, 0.05, 0.85, 0.07, 0.02]

ソフトターゲットは「この画像は猫だが、トラに似た特徴もある」という情報を含んでいます。このクラス間の類似性情報が学生への豊富な教師信号となります。

1.2 温度パラメータ

ソフトマックス出力が非常に確信的な場合([0.001, 0.002, 0.997, ...])、ハードラベルとほとんど情報量が変わりません。温度パラメータTにより分布を滑らかにします:

softmax_T(z_i) = exp(z_i / T) / sum_j(exp(z_j / T))
  • Tが1のとき: 標準ソフトマックス
  • Tが1より大きいとき: より平坦で均一な分布(情報量が多い)
  • Tが1より小さいとき: より尖った分布
import torch
import torch.nn.functional as F

def temperature_softmax(logits, temperature=1.0):
    """温度スケールソフトマックス。"""
    return F.softmax(logits / temperature, dim=-1)

# 例
logits = torch.tensor([2.0, 1.0, 0.1, 0.5])
print("T=1:", temperature_softmax(logits, temperature=1).numpy().round(3))
# [0.596, 0.219, 0.090, 0.096]
print("T=4:", temperature_softmax(logits, temperature=4).numpy().round(3))
# [0.345, 0.262, 0.195, 0.216] — より均一で情報量が多い

1.3 HintonのKD損失関数

KD損失は2つの項の加重和です:

KL発散項(ソフトターゲットのマッチング):

L_KD = T^2 * KLDiv(softmax_T(student_logits), softmax_T(teacher_logits))

T^2のスケーリングにより、温度値に依存せず勾配の大きさを一定に保ちます。

クロスエントロピー項(ハードターゲットの学習):

L_CE = CrossEntropy(student_logits, true_labels)

総損失:

L = alpha * L_CE + (1 - alpha) * L_KD

1.4 完全なPyTorch実装

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader
import torchvision.transforms as transforms


class KnowledgeDistillationLoss(nn.Module):
    """
    Hinton et al. (2015) 知識蒸留損失。
    L = alpha * CE(student, labels) + (1-alpha) * T^2 * KLDiv(student_soft, teacher_soft)
    """
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.T = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, labels):
        # ハードターゲット損失
        loss_ce = self.ce_loss(student_logits, labels)

        # ソフトターゲット損失(KL発散)
        student_soft = F.log_softmax(student_logits / self.T, dim=-1)
        teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
        loss_kd = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)

        return self.alpha * loss_ce + (1 - self.alpha) * loss_kd


def train_with_distillation(
    teacher, student, train_loader,
    num_epochs=10, temperature=4.0, alpha=0.5,
    device='cuda'
):
    """教師-学生蒸留学習ループ。"""
    teacher = teacher.to(device).eval()  # 教師は固定
    student = student.to(device)

    # 教師パラメータを凍結
    for param in teacher.parameters():
        param.requires_grad = False

    criterion = KnowledgeDistillationLoss(temperature=temperature, alpha=alpha)
    optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    for epoch in range(num_epochs):
        student.train()
        total_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # 教師推論(勾配不要)
            with torch.no_grad():
                teacher_logits = teacher(images)

            # 学生推論
            student_logits = student(images)

            # KD損失の計算
            loss = criterion(student_logits, teacher_logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * images.size(0)
            correct += (student_logits.argmax(1) == labels).sum().item()
            total += images.size(0)

        scheduler.step()
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Loss: {total_loss/total:.4f} | "
              f"Acc: {correct/total:.4f}")

    return student


# 実践例
# 教師: ResNet50(25Mパラメータ)、学生: ResNet18(11Mパラメータ)
teacher = models.resnet50(weights='DEFAULT')
teacher.fc = nn.Linear(2048, 10)

student = models.resnet18(weights=None)
student.fc = nn.Linear(512, 10)

# パラメータ数の比較
t_params = sum(p.numel() for p in teacher.parameters())
s_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {t_params:,} params")   # ~25.6M
print(f"Student: {s_params:,} params")   # ~11.2M
print(f"Compression ratio: {t_params/s_params:.1f}x")

2. 蒸留パラダイム

2.1 レスポンスベース蒸留(ロジットマッチング)

最も基本的なアプローチ — 学生が教師の**最終出力(ロジット)**を模倣します。HintonのオリジナルKDがその代表例です。

class ResponseBasedDistillation(nn.Module):
    """最終出力のみを使ったレスポンスベース蒸留。"""
    def __init__(self, temperature=4.0):
        super().__init__()
        self.T = temperature

    def forward(self, student_logits, teacher_logits):
        student_soft = F.log_softmax(student_logits / self.T, dim=-1)
        teacher_soft = F.softmax(teacher_logits / self.T, dim=-1)
        return F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.T ** 2)

2.2 フィーチャーベース蒸留(中間層)

FitNets(Romero et al., 2015)で提案。学生は最終出力だけでなく、教師の中間特徴マップのマッチングも学習します。

教師と学生のチャンネル数が異なる場合があるため、Regressorネットワークが学生の特徴量を教師の次元に射影します。

class FeatureDistillationLoss(nn.Module):
    """中間層特徴マッチングによる蒸留。"""
    def __init__(self, teacher_channels, student_channels):
        super().__init__()
        # 学生特徴量を教師特徴空間に射影
        self.regressor = nn.Sequential(
            nn.Conv2d(student_channels, teacher_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(teacher_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, student_feat, teacher_feat):
        projected = self.regressor(student_feat)
        return F.mse_loss(projected, teacher_feat.detach())


class HookBasedDistillation:
    """
    forwardフックを使って中間層特徴量を抽出する。
    モデルのforward()メソッドを変更せずに特徴量を取得できる。
    """
    def __init__(self, model, layer_names):
        self.features = {}
        self.hooks = []
        for name, layer in model.named_modules():
            if name in layer_names:
                hook = layer.register_forward_hook(
                    self._make_hook(name)
                )
                self.hooks.append(hook)

    def _make_hook(self, name):
        def hook(module, input, output):
            self.features[name] = output
        return hook

    def remove(self):
        for hook in self.hooks:
            hook.remove()


# 使用例
teacher = models.resnet50(weights='DEFAULT')
student = models.resnet18(weights=None)

# 教師layer3: 1024チャンネル、学生layer3: 256チャンネル
teacher_hook = HookBasedDistillation(teacher, ['layer3'])
student_hook = HookBasedDistillation(student, ['layer3'])

feat_distill = FeatureDistillationLoss(
    teacher_channels=1024,
    student_channels=256
)

# 学習中
x = torch.randn(4, 3, 224, 224)
teacher_out = teacher(x)
student_out = student(x)

teacher_feat = teacher_hook.features['layer3']
student_feat = student_hook.features['layer3']

loss_feat = feat_distill(student_feat, teacher_feat)

2.3 関係ベース蒸留

RKD(Park et al., 2019)。絶対値ではなく、教師の埋め込み空間におけるサンプル間の関係構造を模倣するよう学生を学習させます。

class RelationalKnowledgeDistillation(nn.Module):
    """
    関係ベースKD: 埋め込み空間でのペアワイズ距離と角度関係を保持する。
    """
    def __init__(self, distance_weight=25.0, angle_weight=50.0):
        super().__init__()
        self.dist_w = distance_weight
        self.angle_w = angle_weight

    def pdist(self, e, squared=False, eps=1e-12):
        """ペアワイズ距離行列の計算。"""
        e_sq = (e ** 2).sum(dim=1)
        prod = e @ e.t()
        res = (e_sq.unsqueeze(1) + e_sq.unsqueeze(0) - 2 * prod).clamp(min=eps)
        if not squared:
            res = res.sqrt()
        return res

    def distance_loss(self, teacher_emb, student_emb):
        """ペアワイズ距離関係の保持。"""
        t_d = self.pdist(teacher_emb)
        t_d = t_d / (t_d.mean() + 1e-12)  # 正規化
        s_d = self.pdist(student_emb)
        s_d = s_d / (s_d.mean() + 1e-12)
        return F.smooth_l1_loss(s_d, t_d.detach())

    def angle_loss(self, teacher_emb, student_emb):
        """トリプレット間の角度関係の保持。"""
        td = teacher_emb.unsqueeze(0) - teacher_emb.unsqueeze(1)  # (N, N, D)
        sd = student_emb.unsqueeze(0) - student_emb.unsqueeze(1)

        td_norm = F.normalize(td.view(-1, td.size(-1)), dim=-1)
        sd_norm = F.normalize(sd.view(-1, sd.size(-1)), dim=-1)

        t_angle = (td_norm * td_norm.flip(0)).sum(dim=-1)
        s_angle = (sd_norm * sd_norm.flip(0)).sum(dim=-1)

        return F.smooth_l1_loss(s_angle, t_angle.detach())

    def forward(self, teacher_emb, student_emb):
        loss = self.dist_w * self.distance_loss(teacher_emb, student_emb)
        loss += self.angle_w * self.angle_loss(teacher_emb, student_emb)
        return loss

2.4 アテンション転移

Zagoruyko and Komodakis(2017)。教師のアテンションマップ — 空間的活性化パターン — を学生に転移します。

class AttentionTransfer(nn.Module):
    """
    アテンション転移: 特徴マップ活性化から導出した
    空間的アテンションパターンを転移する。
    """

    def attention_map(self, feat):
        """
        特徴マップから空間的アテンションマップを計算。
        チャンネル方向に二乗活性化を合計して正規化する。
        """
        return F.normalize(feat.pow(2).mean(1).view(feat.size(0), -1))

    def forward(self, student_feats, teacher_feats):
        """
        student_feats, teacher_feats: 複数層の特徴マップのリスト。
        全ペアにわたるAT損失の合計を返す。
        """
        loss = 0.0
        for s_feat, t_feat in zip(student_feats, teacher_feats):
            s_attn = self.attention_map(s_feat)
            t_attn = self.attention_map(t_feat)
            loss += (s_attn - t_attn.detach()).pow(2).mean()
        return loss

3. LLM蒸留

3.1 DistilBERT

DistilBERT(Sanh et al., 2019)はBERT-base(1億1千万パラメータ)を6600万パラメータに圧縮します。

主要技術:

  • レイヤー数を半減: 12層 → 6層
  • トークンタイプ埋め込みを削除
  • プーラーを削除
  • トリプル損失: MLM + 蒸留 + コサイン埋め込み損失
from transformers import (
    DistilBertModel,
    DistilBertTokenizer
)
import torch
import torch.nn as nn
import torch.nn.functional as F


class DistilBERTLoss(nn.Module):
    """
    DistilBERTトリプル損失:
    1. MLM CE損失(言語モデリング)
    2. ソフトターゲットKD損失(教師ロジットマッチング)
    3. コサイン埋め込み損失(隠れ状態の類似性)
    """
    def __init__(self, temperature=2.0, alpha=0.5, beta=0.1):
        super().__init__()
        self.T = temperature
        self.alpha = alpha  # MLM重み
        self.beta = beta    # コサイン損失重み

    def forward(self, student_logits, teacher_logits,
                student_hidden, teacher_hidden, mlm_labels):
        # 1. MLM損失
        loss_mlm = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            mlm_labels.view(-1),
            ignore_index=-100
        )

        # 2. ソフトKD損失
        s_soft = F.log_softmax(student_logits / self.T, dim=-1)
        t_soft = F.softmax(teacher_logits / self.T, dim=-1)
        loss_kd = F.kl_div(s_soft, t_soft, reduction='batchmean') * (self.T ** 2)

        # 3. コサイン埋め込み損失(隠れ状態の整合)
        loss_cos = 1 - F.cosine_similarity(student_hidden, teacher_hidden, dim=-1).mean()

        return (self.alpha * loss_mlm
                + (1 - self.alpha) * loss_kd
                + self.beta * loss_cos)


# 推論例
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')

text = "Knowledge distillation is a model compression technique."
inputs = tokenizer(text, return_tensors='pt')

with torch.no_grad():
    outputs = model(**inputs)

print(outputs.last_hidden_state.shape)
# torch.Size([1, 11, 768])

DistilBERTの性能:

  • BERT-baseより40%小型
  • 推論速度60%向上
  • BERTのGLUE性能の97%を維持

3.2 LLM蒸留パターン

大規模言語モデルの場合、学生は教師の次トークン予測分布から学習します。

class LLMDistillationTrainer:
    """
    LLM蒸留トレーナー。
    教師のトークン分布を学生に転移する。
    """
    def __init__(self, teacher, student, temperature=2.0, alpha=0.5):
        self.teacher = teacher
        self.student = student
        self.T = temperature
        self.alpha = alpha

    def compute_loss(self, input_ids, attention_mask, labels):
        # 教師推論(no_grad)
        with torch.no_grad():
            teacher_out = self.teacher(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            teacher_logits = teacher_out.logits  # (B, seq_len, vocab_size)

        # 学生推論
        student_out = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        student_logits = student_out.logits

        # 1. CE損失(ハードターゲット)
        shift_logits = student_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss_ce = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100
        )

        # 2. KD損失(ソフトターゲット)
        # 有効トークンのみで計算
        mask = (shift_labels != -100).float()

        s_log_soft = F.log_softmax(
            shift_logits / self.T, dim=-1
        )
        t_soft = F.softmax(
            teacher_logits[..., :-1, :].contiguous() / self.T, dim=-1
        )

        # トークンごとのKL発散
        kl_per_token = F.kl_div(
            s_log_soft, t_soft,
            reduction='none'
        ).sum(-1)  # (B, seq_len-1)

        loss_kd = (kl_per_token * mask).sum() / mask.sum()
        loss_kd = loss_kd * (self.T ** 2)

        return self.alpha * loss_ce + (1 - self.alpha) * loss_kd

3.3 Distil Whisper

OpenAIのWhisper音声認識モデルには、積極的なLLM圧縮を示す蒸留バリアントがあります。

from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration
)

# Distil-Whisper: Whisper-large-v2を蒸留
# large-v2: 1550Mパラメータ -> distil-large-v2: 756Mパラメータ(2倍高速、同等WER)
processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v2")
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v2")

# デコーダは2層のみ保持(オリジナル: 32層)
# エンコーダはそのまま維持
print(f"Encoder layers: {len(model.model.encoder.layers)}")
print(f"Decoder layers: {len(model.model.decoder.layers)}")

4. 構造化プルーニング

プルーニングはモデルから重要度の低いパラメータや構造を除去します。

構造化プルーニング: フィルター、ヘッド、レイヤー全体を削除 → あらゆるハードウェアで実際の高速化が得られる 非構造化プルーニング: 個々の重みをゼロに設定 → 疎行列(専用ハードウェアが必要)

4.1 フィルタープルーニング

CNNレイヤーからL1/L2ノルムが小さいフィルターを除去します。

import torch
import torch.nn as nn
import numpy as np


def get_filter_importance(conv_layer, norm='l1'):
    """各フィルターの重要度(ノルム)を計算する。"""
    weight = conv_layer.weight.data  # (out_ch, in_ch, kH, kW)
    if norm == 'l1':
        return weight.abs().sum(dim=(1, 2, 3))
    elif norm == 'l2':
        return weight.pow(2).sum(dim=(1, 2, 3)).sqrt()
    elif norm == 'gm':
        return weight.view(weight.size(0), -1).norm(p=2, dim=1)


def prune_conv_layer(conv, keep_ratio=0.5):
    """
    フィルタープルーニング: 重要度が最低のフィルターを削除。
    フィルター数を減らした新しいConv2dを返す。
    """
    importance = get_filter_importance(conv)
    n_keep = int(conv.out_channels * keep_ratio)
    _, indices = importance.topk(n_keep)
    indices, _ = indices.sort()

    new_conv = nn.Conv2d(
        conv.in_channels,
        n_keep,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        bias=conv.bias is not None
    )
    new_conv.weight.data = conv.weight.data[indices]
    if conv.bias is not None:
        new_conv.bias.data = conv.bias.data[indices]

    return new_conv, indices


class StructuredPruner:
    """ResNetスタイルモデルの構造化プルーニング。"""

    def __init__(self, model, prune_ratio=0.5):
        self.model = model
        self.prune_ratio = prune_ratio

    def global_threshold_prune(self, sparsity=0.5):
        """
        グローバルプルーニング: 全レイヤーを通じて
        重要度が低いsparsity割合のフィルターを削除。
        """
        all_importances = []
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                imp = get_filter_importance(module)
                all_importances.append(imp)

        all_imp = torch.cat(all_importances)
        threshold = all_imp.kthvalue(int(len(all_imp) * sparsity)).values.item()

        masks = {}
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                imp = get_filter_importance(module)
                masks[name] = imp >= threshold

        return masks

4.2 アテンションヘッドプルーニング

トランスフォーマーモデルから重要度の低いアテンションヘッドを削除します。

class AttentionHeadPruner:
    """
    マルチヘッドアテンションのヘッドプルーニング。
    Michel et al. (2019): Are Sixteen Heads Really Better than One?
    """

    def compute_head_importance(self, model, dataloader, device='cuda'):
        """レイヤーごと、ヘッドごとの重要度スコアを計算する。"""
        head_importance = {}

        model.eval()
        for batch in dataloader:
            inputs = {k: v.to(device) for k, v in batch.items()}

            with torch.no_grad():
                outputs = model(**inputs, output_attentions=True)

            if hasattr(outputs, 'attentions') and outputs.attentions:
                for layer_idx, attn in enumerate(outputs.attentions):
                    # attn: (batch, heads, seq, seq)
                    # 重要度: アテンション重みの分散(焦点が絞られているほど重要)
                    head_imp = attn.mean(0).var(-1).mean(-1)  # (heads,)
                    key = f"layer_{layer_idx}"
                    if key not in head_importance:
                        head_importance[key] = head_imp
                    else:
                        head_importance[key] += head_imp

        return head_importance

    def prune_heads(self, model, heads_to_prune):
        """
        指定ヘッドを削除。
        heads_to_prune: layer_idxをヘッドインデックスリストにマッピングするdict。
        """
        model.prune_heads(heads_to_prune)  # HuggingFaceの組み込みメソッド
        return model


# HuggingFaceの例
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

# レイヤー0のヘッド0, 3, 5とレイヤー6のヘッド1, 2を削除
heads_to_prune = {0: [0, 3, 5], 6: [1, 2]}
model.prune_heads(heads_to_prune)

total_params = sum(p.numel() for p in model.parameters())
print(f"Pruned BERT params: {total_params:,}")

4.3 レイヤープルーニング

モデル性能への貢献が最も低いトランスフォーマーレイヤー全体を削除します。

class LayerPruner:
    """トランスフォーマーレイヤープルーニング。"""

    def compute_layer_importance_by_gradient(
        self, model, dataloader, device='cuda'
    ):
        """
        勾配ベースのレイヤー重要度。
        重要度 = レイヤー全体にわたる |勾配 * 重み| の平均。
        """
        model.train()
        layer_importance = {}

        for batch in dataloader:
            inputs = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**inputs)
            loss = outputs.loss
            loss.backward()

            for i, layer in enumerate(model.encoder.layer):
                grad_sum = 0.0
                param_count = 0
                for param in layer.parameters():
                    if param.grad is not None:
                        grad_sum += (param.grad * param.data).abs().sum().item()
                        param_count += param.numel()

                key = f"layer_{i}"
                imp = grad_sum / max(param_count, 1)
                if key not in layer_importance:
                    layer_importance[key] = 0.0
                layer_importance[key] += imp

            model.zero_grad()

        return layer_importance

    def drop_layers(self, model, layers_to_drop):
        """指定レイヤーを削除してエンコーダーを再構築する。"""
        import copy
        new_model = copy.deepcopy(model)

        remaining = [
            layer for i, layer in enumerate(new_model.encoder.layer)
            if i not in layers_to_drop
        ]
        new_model.encoder.layer = nn.ModuleList(remaining)
        new_model.config.num_hidden_layers = len(remaining)

        return new_model

4.4 PyTorch torch.nn.utils.prune

import torch.nn.utils.prune as prune

model = models.resnet18(weights='DEFAULT')

# 特定レイヤーに非構造化L1プルーニングを適用
conv = model.layer1[0].conv1
prune.l1_unstructured(conv, name='weight', amount=0.3)
# マスクを使って重みの30%をゼロにする

print(f"Sparsity: {(conv.weight == 0).float().mean():.2f}")

# 構造化プルーニング(フィルターレベル)
prune.ln_structured(conv, name='weight', amount=0.3, n=2, dim=0)

# プルーニングを永続化(マスクを削除し、ゼロを重みに焼き付ける)
prune.remove(conv, 'weight')

# 全Conv2dレイヤーにグローバルに適用
parameters_to_prune = []
for module in model.modules():
    if isinstance(module, nn.Conv2d):
        parameters_to_prune.append((module, 'weight'))

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.4,  # 全conv重みの40%をプルーニング
)

# グローバルスパース度を確認
total_zero = sum(
    (m.weight == 0).sum().item()
    for m in model.modules() if isinstance(m, nn.Conv2d)
)
total_params = sum(
    m.weight.numel()
    for m in model.modules() if isinstance(m, nn.Conv2d)
)
print(f"Global sparsity: {total_zero/total_params:.2%}")

5. 非構造化プルーニング

5.1 マグニチュードベースプルーニング

最もシンプルな方法: 絶対値が最も小さい重みをゼロに設定します。

class MagnitudePruner:
    """マグニチュードベースの非構造化プルーニング。"""

    def __init__(self, model, sparsity=0.5):
        self.model = model
        self.sparsity = sparsity
        self.masks = {}

    def compute_global_threshold(self):
        """全重みテンソルにわたるグローバル閾値を計算する。"""
        all_weights = []
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                all_weights.append(param.data.abs().view(-1))

        all_weights = torch.cat(all_weights)
        threshold = all_weights.kthvalue(
            int(len(all_weights) * self.sparsity)
        ).values.item()
        return threshold

    def apply_pruning(self):
        """グローバル閾値を使ってプルーニングマスクを作成する。"""
        threshold = self.compute_global_threshold()

        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                mask = (param.data.abs() >= threshold).float()
                self.masks[name] = mask
                param.data *= mask

        total_zeros = sum((m == 0).sum().item() for m in self.masks.values())
        total_params = sum(m.numel() for m in self.masks.values())
        print(f"Actual sparsity: {total_zeros/total_params:.2%}")

    def apply_masks(self):
        """勾配ステップ後にマスクを再適用して、ゼロ重みが復活しないようにする。"""
        for name, param in self.model.named_parameters():
            if name in self.masks:
                param.data *= self.masks[name]

5.2 段階的マグニチュードプルーニング

最初から積極的にプルーニングすると精度が急激に低下します。学習中にスパース度を徐々に増加させる段階的スケジュールの方がはるかに効果的です。

class GradualMagnitudePruner:
    """
    学習中にスパース度を徐々に増加させる。
    Zhu and Gupta (2017): To Prune, or Not to Prune.
    """
    def __init__(self, model, initial_sparsity=0.0, final_sparsity=0.8,
                 begin_step=0, end_step=1000, frequency=100):
        self.model = model
        self.initial_sparsity = initial_sparsity
        self.final_sparsity = final_sparsity
        self.begin_step = begin_step
        self.end_step = end_step
        self.frequency = frequency
        self.masks = {}
        self._init_masks()

    def _init_masks(self):
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                self.masks[name] = torch.ones_like(param.data)

    def compute_sparsity(self, step):
        """三次スケジュールを使って現在ステップの目標スパース度を計算する。"""
        if step < self.begin_step:
            return self.initial_sparsity
        if step > self.end_step:
            return self.final_sparsity

        pct_done = (step - self.begin_step) / (self.end_step - self.begin_step)
        sparsity = (
            self.final_sparsity
            + (self.initial_sparsity - self.final_sparsity)
            * (1 - pct_done) ** 3
        )
        return sparsity

    def step(self, global_step):
        """適切な学習ステップでプルーニングマスクを更新する。"""
        if (global_step % self.frequency != 0 or
                global_step < self.begin_step or
                global_step > self.end_step):
            return

        target_sparsity = self.compute_sparsity(global_step)
        self._update_masks(target_sparsity)

    def _update_masks(self, sparsity):
        """目標スパース度に合わせてマスクを更新する。"""
        for name, param in self.model.named_parameters():
            if name not in self.masks:
                continue

            n_prune = int(sparsity * param.data.numel())
            if n_prune == 0:
                continue

            threshold = param.data.abs().view(-1).kthvalue(n_prune).values.item()
            self.masks[name] = (param.data.abs() > threshold).float()
            param.data *= self.masks[name]

6. 重み共有

6.1 クロスレイヤーパラメータ共有 — ALBERT

ALBERT(Lan et al., 2019)は全トランスフォーマーレイヤーで同じパラメータを再利用することで、BERTのパラメータを大幅に削減します。

class ALBERTEncoder(nn.Module):
    """
    ALBERTスタイルの重み共有。
    1つのトランスフォーマーレイヤーをN回実行する。
    """
    def __init__(self, hidden_size=768, num_heads=12,
                 intermediate_size=3072, num_layers=12):
        super().__init__()
        # 1つのトランスフォーマーレイヤーのみ定義
        self.shared_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=num_heads,
            dim_feedforward=intermediate_size,
            batch_first=True,
            norm_first=True,
        )
        self.num_layers = num_layers
        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, x, src_key_padding_mask=None):
        # 同じレイヤーをnum_layers回実行
        for _ in range(self.num_layers):
            x = self.shared_layer(x, src_key_padding_mask=src_key_padding_mask)
        return self.norm(x)


# パラメータ比較
bert_encoder = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(768, 12, 3072, batch_first=True, norm_first=True),
    num_layers=12
)
albert_encoder = ALBERTEncoder(num_layers=12)

bert_params = sum(p.numel() for p in bert_encoder.parameters())
albert_params = sum(p.numel() for p in albert_encoder.parameters())

print(f"BERT encoder: {bert_params:,}")      # ~85M
print(f"ALBERT encoder: {albert_params:,}")  # ~7M
print(f"Compression: {bert_params/albert_params:.1f}x")

6.2 因子化埋め込み(ALBERT)

ALBERTは埋め込み行列も因子化します: vocab_size x hidden_size → (vocab_size x embedding_size) + (embedding_size x hidden_size)(embedding_sizeはhidden_sizeよりはるかに小さい)。

class FactorizedEmbedding(nn.Module):
    """
    ALBERTスタイルの因子化埋め込み。
    vocab_size x H を (vocab_size x E) + (E x H) に因子化。
    E << H。
    """
    def __init__(self, vocab_size, embedding_size=128, hidden_size=768):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, embedding_size)
        self.embedding_projection = nn.Linear(embedding_size, hidden_size, bias=False)

    def forward(self, input_ids):
        embed = self.word_embeddings(input_ids)        # (B, seq, E)
        return self.embedding_projection(embed)         # (B, seq, H)

# パラメータ比較(vocab_size=30000)
standard_embed = nn.Embedding(30000, 768)                  # 23.0Mパラメータ
factorized_embed = FactorizedEmbedding(30000, 128, 768)    # 3.97Mパラメータ

s_params = sum(p.numel() for p in standard_embed.parameters())
f_params = sum(p.numel() for p in factorized_embed.parameters())
print(f"Standard: {s_params:,} -> Factorized: {f_params:,}")
print(f"Savings: {(s_params - f_params)/s_params:.1%}")

7. ニューラルアーキテクチャ探索(NAS)

7.1 手動設計 vs AutoML

従来のCNNは研究者が手動で設計していました(VGG、ResNet)。NASはこのプロセスを自動化します。

NASの3つのコアコンポーネント:

  1. 探索空間: 探索する演算と構造
  2. 探索戦略: 強化学習、進化的アルゴリズム、勾配ベース手法
  3. 性能推定: 候補アーキテクチャの高速評価

7.2 DARTS(微分可能アーキテクチャ探索)

Liu et al.(2019)。アーキテクチャ探索を連続最適化問題に変換します。

import torch
import torch.nn as nn
import torch.nn.functional as F


# 候補演算
PRIMITIVES = [
    'none',           # ゼロ接続
    'skip_connect',   # 恒等写像
    'sep_conv_3x3',   # 3x3分離畳み込み
    'sep_conv_5x5',   # 5x5分離畳み込み
    'dil_conv_3x3',   # 3x3膨張畳み込み
    'dil_conv_5x5',   # 5x5膨張畳み込み
    'avg_pool_3x3',   # 3x3平均プール
    'max_pool_3x3',   # 3x3最大プール
]


class MixedOperation(nn.Module):
    """
    DARTS: 各エッジを演算の加重混合として表現する。
    アーキテクチャパラメータ(alpha)が各演算の重みを制御する。
    """
    def __init__(self, C, stride, primitives):
        super().__init__()
        self.ops = nn.ModuleList([
            self._build_op(primitive, C, stride)
            for primitive in primitives
        ])
        # 学習可能なアーキテクチャパラメータ
        self.arch_params = nn.Parameter(
            torch.randn(len(primitives)) * 1e-3
        )

    def _build_op(self, primitive, C, stride):
        if primitive == 'none':
            return nn.Sequential()
        elif primitive == 'skip_connect':
            return nn.Identity() if stride == 1 else nn.Sequential(
                nn.AvgPool2d(stride, stride, padding=0),
                nn.Conv2d(C, C, 1, bias=False),
                nn.BatchNorm2d(C)
            )
        elif primitive == 'sep_conv_3x3':
            return nn.Sequential(
                nn.Conv2d(C, C, 3, stride=stride, padding=1, groups=C, bias=False),
                nn.Conv2d(C, C, 1, bias=False),
                nn.BatchNorm2d(C),
                nn.ReLU(inplace=True),
            )
        elif primitive == 'avg_pool_3x3':
            return nn.AvgPool2d(3, stride=stride, padding=1)
        elif primitive == 'max_pool_3x3':
            return nn.MaxPool2d(3, stride=stride, padding=1)
        else:
            return nn.Identity()

    def forward(self, x):
        weights = F.softmax(self.arch_params, dim=0)
        results = []
        for w, op in zip(weights, self.ops):
            try:
                result = op(x)
                results.append(w * result)
            except Exception:
                pass

        if results:
            output = results[0]
            for r in results[1:]:
                if r.shape == output.shape:
                    output = output + r
            return output
        return x


def darts_discrete(cell):
    """
    各エッジで最も重みの高い演算を選択して
    連続アーキテクチャを離散アーキテクチャに変換する。
    """
    for op in cell.ops:
        weights = F.softmax(op.arch_params, dim=0)
        best_op_idx = weights.argmax().item()
        print(f"  Selected: {PRIMITIVES[best_op_idx]} "
              f"(weight: {weights[best_op_idx]:.3f})")

7.3 EfficientNetのNASプロセス

EfficientNet-B0はMnasNetスタイルのNASで発見されたベースアーキテクチャです。モバイルレイテンシの制約下での精度を最適化した結果、スクイーズ・アンド・エクシテーションを持つMBConvブロックから構成されるネットワークが得られました。

class MBConvBlock(nn.Module):
    """
    モバイル転置ボトルネック畳み込み。
    NASで発見されたEfficientNetのコアビルディングブロック。
    """
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, expand_ratio=6, se_ratio=0.25):
        super().__init__()
        self.use_residual = (stride == 1 and in_channels == out_channels)
        hidden_dim = int(in_channels * expand_ratio)

        layers = []
        # 拡張フェーズ
        if expand_ratio != 1:
            layers.extend([
                nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
            ])

        # 深度方向畳み込み
        layers.extend([
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size,
                      stride=stride,
                      padding=kernel_size // 2,
                      groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
        ])

        # スクイーズ・アンド・エクシテーション(NASで発見された主要モジュール)
        se_channels = max(1, int(in_channels * se_ratio))
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(hidden_dim, se_channels, 1),
            nn.SiLU(),
            nn.Conv2d(se_channels, hidden_dim, 1),
            nn.Sigmoid(),
        )

        # 出力射影
        layers.extend([
            nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
        ])

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        out = self.layers(x)
        if self.use_residual:
            return x + out
        return out

7.4 Once-for-All ネットワーク

Cai et al.(2020)。1つの大きなネットワークを学習し、再学習なしで異なるリソース制約に合わせたサブネットワークを抽出します。

class OFALayer(nn.Module):
    """
    複数のカーネルサイズをサポートするOnce-for-All弾性レイヤー。
    学習時: カーネルサイズをランダムサンプリング -> サブネットワークを学習
    デプロイ時: 必要なカーネルサイズのみを使用
    """
    def __init__(self, channels, max_kernel=7):
        super().__init__()
        # 最大カーネルサイズの畳み込みのみ保存
        self.max_conv = nn.Conv2d(
            channels, channels,
            kernel_size=max_kernel,
            padding=max_kernel // 2,
            groups=channels, bias=False
        )
        self.bn = nn.BatchNorm2d(channels)
        self.act = nn.ReLU(inplace=True)
        self.active_kernel = max_kernel

    def set_active_kernel(self, kernel_size):
        """このレイヤーのアクティブカーネルサイズを設定する。"""
        assert kernel_size <= self.max_conv.kernel_size[0]
        self.active_kernel = kernel_size

    def forward(self, x):
        if self.active_kernel == self.max_conv.kernel_size[0]:
            weight = self.max_conv.weight
        else:
            # 中心サブカーネルを抽出
            center = self.max_conv.kernel_size[0] // 2
            half = self.active_kernel // 2
            weight = self.max_conv.weight[
                :, :,
                center - half: center + half + 1,
                center - half: center + half + 1
            ]

        padding = self.active_kernel // 2
        out = F.conv2d(x, weight, padding=padding, groups=x.size(1))
        return self.act(self.bn(out))


def progressive_shrinking_train(model, dataloader, kernels=[7, 5, 3]):
    """
    Stage 1: 最大カーネル(7)で学習
    Stage 2: [7, 5]からランダムサンプリング
    Stage 3: [7, 5, 3]からランダムサンプリング
    """
    for stage, active_kernels in enumerate(
        [kernels[:1], kernels[:2], kernels]
    ):
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        print(f"Stage {stage+1}: kernels = {active_kernels}")

        for epoch in range(5):
            for images, labels in dataloader:
                # バッチごとにカーネルサイズをランダムサンプリング
                k = active_kernels[torch.randint(len(active_kernels), (1,)).item()]

                for module in model.modules():
                    if isinstance(module, OFALayer):
                        module.set_active_kernel(k)

                outputs = model(images)
                loss = F.cross_entropy(outputs, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

8. 統合モデル圧縮パイプライン

実際には、複数の技術を順次組み合わせます:

class ModelCompressionPipeline:
    """
    統合パイプライン: 蒸留 -> プルーニング -> 量子化。
    """

    def __init__(self, teacher, student, num_classes):
        self.teacher = teacher
        self.num_classes = num_classes
        self.student = student

    def step1_distillation(self, train_loader,
                           epochs=30, device='cuda'):
        """Step 1: 知識蒸留。"""
        print("=== Step 1: Knowledge Distillation ===")
        self.student = train_with_distillation(
            self.teacher, self.student, train_loader,
            num_epochs=epochs, temperature=4.0, alpha=0.5,
            device=device
        )

    def step2_pruning(self, train_loader, sparsity=0.5,
                      finetune_epochs=10, device='cuda'):
        """Step 2: 段階的プルーニング + ファインチューニング。"""
        print(f"=== Step 2: Pruning (target: {sparsity:.0%}) ===")
        pruner = GradualMagnitudePruner(
            self.student,
            initial_sparsity=0.0,
            final_sparsity=sparsity,
            begin_step=0,
            end_step=len(train_loader) * finetune_epochs,
            frequency=100
        )

        optimizer = torch.optim.Adam(self.student.parameters(), lr=1e-4)
        self.student.to(device)
        step = 0

        for epoch in range(finetune_epochs):
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = self.student(images)
                loss = F.cross_entropy(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                pruner.step(step)
                pruner.apply_masks()
                step += 1

    def step3_quantization(self, calibration_loader, device='cpu'):
        """Step 3: ポストトレーニング静的量子化。"""
        print("=== Step 3: Quantization ===")
        self.student.eval().to(device)

        self.student.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
        torch.ao.quantization.prepare(self.student, inplace=True)

        with torch.no_grad():
            for images, _ in calibration_loader:
                self.student(images.to(device))

        torch.ao.quantization.convert(self.student, inplace=True)
        print("Quantization complete (INT8)")
        return self.student

    def compare_models(self, val_loader, device='cuda'):
        """教師モデルと圧縮学生モデルを比較する。"""
        def eval_model(model, loader, dev):
            model.eval().to(dev)
            correct, total = 0, 0
            with torch.no_grad():
                for images, labels in loader:
                    images, labels = images.to(dev), labels.to(dev)
                    preds = model(images).argmax(1)
                    correct += (preds == labels).sum().item()
                    total += labels.size(0)
            return correct / total

        teacher_acc = eval_model(self.teacher, val_loader, device)
        student_acc = eval_model(self.student, val_loader, 'cpu')

        t_params = sum(p.numel() for p in self.teacher.parameters())
        s_params = sum(p.numel() for p in self.student.parameters())

        print(f"\n{'='*55}")
        print(f"Teacher  | Params: {t_params:>10,} | Acc: {teacher_acc:.4f}")
        print(f"Student  | Params: {s_params:>10,} | Acc: {student_acc:.4f}")
        print(f"Compression: {t_params/s_params:.1f}x | "
              f"Retained accuracy: {student_acc/teacher_acc:.1%}")

まとめ

知識蒸留とモデル圧縮は、研究スケールのモデルと実世界デプロイメントを繋ぐ不可欠な橋渡しです。

重要なポイント:

  1. 知識蒸留: 教師のソフトターゲット(確率分布)によるクラス間関係情報の転移
  2. フィーチャー蒸留: 最終出力だけでなく中間表現の転移
  3. 関係蒸留: サンプル間の構造的関係の保持
  4. LLM蒸留: DistilBERT、TinyLLaMA、Distil Whisperが実用的な大規模モデル圧縮を実証
  5. 構造化プルーニング: フィルター、ヘッド、レイヤーの削除による実際のハードウェア高速化
  6. 非構造化プルーニング: 段階的スパース度スケジュールで精度低下を最小化
  7. 重み共有: ALBERTスタイルのパラメータ再利用で最大12倍の圧縮を達成
  8. NAS: DARTSとOnce-for-Allがターゲットハードウェアに向けたアーキテクチャ設計を自動化

最良の実践的アプローチは順次パイプラインです: 蒸留 -> プルーニング -> 量子化。各ステップが前のステップの上に構築されるため、各圧縮技術がすでに圧縮されたモデルに作用できます。


参考文献

  • Hinton et al. (2015). Distilling the Knowledge in a Neural Network. https://arxiv.org/abs/1503.02531
  • Romero et al. (2015). FitNets: Hints for Thin Deep Nets.
  • Zagoruyko and Komodakis (2017). Paying More Attention to Attention.
  • Park et al. (2019). Relational Knowledge Distillation.
  • Sanh et al. (2019). DistilBERT, a distilled version of BERT. https://arxiv.org/abs/1910.01108
  • Lan et al. (2019). ALBERT: A Lite BERT for Self-supervised Learning of Language Representations.
  • Liu et al. (2019). DARTS: Differentiable Architecture Search.
  • Cai et al. (2020). Once-for-All: Train One Network and Specialize it for Efficient Deployment.
  • Zhu and Gupta (2017). To Prune, or Not to Prune: Exploring the Efficacy of Pruning for Model Compression.
  • Tan and Le (2019). EfficientNet. https://arxiv.org/abs/1905.11946
  • PyTorch Pruning Documentation: https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.html