Skip to content
Published on

Mixture of Experts(MoE) 아키텍처 논문 심층 분석: GShard에서 DeepSeek-MoE까지

Authors
  • Name
    Twitter
Mixture of Experts Architecture

들어가며

대규모 언어 모델(LLM)의 성능 향상은 전통적으로 파라미터 수 증가와 직결되어 왔다. 그러나 Dense 모델에서 파라미터를 두 배로 늘리면 학습과 추론의 연산량(FLOPs)도 비례하여 증가한다. Mixture of Experts(MoE) 아키텍처는 이 한계를 돌파하기 위한 핵심 전략이다. 모델의 총 파라미터 수는 대폭 늘리되, 각 입력 토큰에 대해 전체 파라미터 중 소수의 전문가(Expert)만 활성화하여 연산 비용을 일정 수준으로 유지한다.

MoE의 원래 아이디어는 1991년 Jacobs 등의 논문으로 거슬러 올라가지만, 2017년 Shazeer 등의 "Outrageously Large Neural Networks"에서 Sparsely-Gated MoE가 제안되면서 현대적 형태를 갖추기 시작했다. 이후 Google의 GShard(2020)가 600B 파라미터 규모의 분산 MoE 학습을 실현했고, Switch Transformer(2021)가 단일 전문가 라우팅으로 효율성을 극대화했다. Mistral AI의 Mixtral 8x7B(2024)는 오픈소스 MoE 모델의 새 기준을 세웠으며, DeepSeek의 DeepSeek-MoE(2024)는 세밀한 전문가 분할(Fine-grained Expert Segmentation)과 공유 전문가 격리(Shared Expert Isolation)로 전문가 특화를 극대화했다.

이 글에서는 각 논문의 핵심 기여를 분석하고, 라우팅 전략, 로드 밸런싱, 학습 안정성 기법을 코드와 함께 비교한다. 또한 MoE 학습 시 실제로 마주치는 불안정성 문제와 디버깅 전략을 다룬다.

MoE 아키텍처 기초: Sparse Gating과 Expert 선택

게이팅 네트워크의 원리

MoE 레이어의 핵심은 게이팅 네트워크(Gating Network) 이다. 입력 토큰 x가 주어지면, 게이팅 네트워크는 N개의 전문가 각각에 대한 확률 분포를 출력한다.

G(x)=softmax(Wgx+ϵ)G(x) = \text{softmax}(W_g \cdot x + \epsilon)

여기서 W_g는 학습 가능한 게이팅 가중치 행렬이며, 노이즈 항은 학습 초기에 전문가 탐색을 돕는 역할을 한다. MoE 레이어의 최종 출력은 선택된 Top-k 전문가의 가중합으로 계산된다.

y=iTop-kG(x)iEi(x)y = \sum_{i \in \text{Top-k}} G(x)_i \cdot E_i(x)

E_i(x)는 i번째 전문가 네트워크(통상 FFN)의 출력이다.

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicMoEGating(nn.Module):
    """Sparsely-Gated MoE의 기본 게이팅 네트워크"""
    def __init__(self, d_model: int, num_experts: int, top_k: int = 2,
                 noise_std: float = 1.0):
        super().__init__()
        self.top_k = top_k
        self.noise_std = noise_std
        self.gate = nn.Linear(d_model, num_experts, bias=False)
        self.noise_gate = nn.Linear(d_model, num_experts, bias=False)

    def forward(self, x: torch.Tensor):
        # x: (batch_size, seq_len, d_model)
        logits = self.gate(x)

        # 학습 시 노이즈 추가 (탐색 촉진)
        if self.training:
            noise = F.softplus(self.noise_gate(x))
            logits = logits + torch.randn_like(logits) * noise * self.noise_std

        # Top-k 전문가 선택
        top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
        top_k_gates = F.softmax(top_k_logits, dim=-1)

        return top_k_gates, top_k_indices

Sparse Activation과 연산 효율성

Dense 모델에서는 모든 파라미터가 매 토큰마다 활성화되지만, Sparse MoE에서는 Top-k 전문가만 활성화된다. N개의 전문가 중 k개만 연산에 참여하므로, 모델 용량은 N배로 늘리면서 연산량은 k배 수준으로 유지할 수 있다.

예를 들어, Mixtral 8x7B는 8개 전문가 중 2개를 활성화한다. 총 파라미터는 약 47B이지만 활성 파라미터는 약 13B로, Dense 13B 모델과 유사한 추론 비용으로 훨씬 높은 성능을 달성한다.

Capacity Factor와 토큰 드롭

각 전문가가 처리할 수 있는 토큰 수의 상한을 Capacity Factor(CF) 로 제어한다. 이상적인 균등 분배 시 각 전문가는 T/N 개의 토큰을 처리하며, CF를 곱해서 실제 용량을 결정한다.

Expert Capacity=CF×TN\text{Expert Capacity} = \text{CF} \times \frac{T}{N}

CF가 1.0이면 완벽한 균등 분배를 가정한다. 실제로는 1.25에서 1.5 사이의 값을 사용한다. CF가 너무 낮으면 토큰이 드롭(overflow)되어 정보 손실이 발생하고, 너무 높으면 메모리가 낭비된다.

class MoELayer(nn.Module):
    """기본 Sparse MoE 레이어 (Capacity Factor 포함)"""
    def __init__(self, d_model: int, d_ff: int, num_experts: int,
                 top_k: int = 2, capacity_factor: float = 1.25):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.capacity_factor = capacity_factor
        self.gate = BasicMoEGating(d_model, num_experts, top_k)

        # 각 전문가는 독립적인 FFN
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.GELU(),
                nn.Linear(d_ff, d_model)
            ) for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor):
        B, T, D = x.shape
        gates, indices = self.gate(x)
        expert_capacity = int(self.capacity_factor * T / self.num_experts)

        output = torch.zeros_like(x)
        tokens_dropped = 0

        for i in range(self.num_experts):
            # i번째 전문가에 할당된 토큰 마스크
            expert_mask = (indices == i).any(dim=-1)  # (B, T)
            # 용량 초과 토큰 드롭
            token_count = expert_mask.sum(dim=-1)
            for b in range(B):
                if token_count[b] > expert_capacity:
                    overflow = token_count[b] - expert_capacity
                    tokens_dropped += overflow.item()

            if not expert_mask.any():
                continue

            expert_input = x[expert_mask]
            expert_output = self.experts[i](expert_input)
            gate_values = gates[indices == i]
            output[expert_mask] += gate_values.unsqueeze(-1) * expert_output

        return output, tokens_dropped

GShard: 대규모 분산 MoE 학습

논문 개요

GShard(Lepikhin et al., 2020)는 Google이 발표한 논문으로, MoE를 활용하여 600B 파라미터 규모의 다국어 기계번역 모델을 2048개의 TPU v3에서 4일 만에 학습시킨 연구다. 핵심 기여는 세 가지다.

  1. Top-2 게이팅과 랜덤 라우팅: 첫 번째 전문가는 Top-1으로 결정론적으로 선택하고, 두 번째 전문가는 게이트 확률에 비례하여 확률적으로 선택한다
  2. 로컬 그룹 디스패칭: 토큰을 G개의 로컬 그룹으로 나누어 병렬로 독립 처리하여 통신 비용을 절감한다
  3. 자동 샤딩(Auto-sharding): XLA 컴파일러를 활용한 자동 파티셔닝으로 분산 학습의 복잡성을 추상화한다

GShard의 랜덤 라우팅

GShard의 라우팅 전략은 독특하다. Top-1 전문가는 가장 높은 게이트 값을 가진 전문가로 결정론적으로 선택되지만, 두 번째 전문가는 나머지 전문가들의 게이트 확률에 비례하여 확률적으로 샘플링된다. 이 방식은 활용도가 낮은 전문가에게도 토큰이 배정될 기회를 부여하여, 전문가 간 부하 분산과 탐색(exploration)을 촉진한다.

class GShardRouter(nn.Module):
    """GShard Top-2 라우터 (랜덤 라우팅 포함)"""
    def __init__(self, d_model: int, num_experts: int,
                 capacity_factor: float = 2.0):
        super().__init__()
        self.num_experts = num_experts
        self.capacity_factor = capacity_factor
        self.gate = nn.Linear(d_model, num_experts, bias=False)

    def forward(self, x: torch.Tensor):
        # x: (batch_size, seq_len, d_model)
        B, T, D = x.shape
        logits = self.gate(x)
        probs = F.softmax(logits, dim=-1)

        # Top-1: 결정론적 선택
        top1_probs, top1_indices = probs.max(dim=-1)

        # Top-2: 확률적 선택 (나머지 전문가에서 샘플링)
        mask = torch.zeros_like(probs)
        mask.scatter_(-1, top1_indices.unsqueeze(-1), 1.0)
        remaining_probs = probs * (1 - mask)
        remaining_probs = remaining_probs / (remaining_probs.sum(dim=-1, keepdim=True) + 1e-8)

        # 확률에 비례하여 두 번째 전문가 샘플링
        top2_indices = torch.multinomial(
            remaining_probs.view(-1, self.num_experts), 1
        ).view(B, T)
        top2_probs = probs.gather(-1, top2_indices.unsqueeze(-1)).squeeze(-1)

        # 게이트 값 재정규화
        total = top1_probs + top2_probs
        top1_gates = top1_probs / total
        top2_gates = top2_probs / total

        return (top1_gates, top1_indices, top2_gates, top2_indices)

로컬 그룹 디스패칭

GShard는 전체 배치의 토큰을 G개의 로컬 그룹으로 균등 분할한다. 각 그룹은 S = N/G 개의 토큰을 포함하며, 각 그룹 내에서 독립적으로 전문가 디스패칭이 이루어진다. 이를 통해 게이팅 연산을 O(G)배 빠르게 수행할 수 있고, 각 전문가의 로컬 용량은 C = 2N/(G x E)로 설정된다.

Switch Transformer: 단일 Expert 라우팅의 효율성

논문 개요

Switch Transformer(Fedus et al., 2021)는 Google Brain에서 발표한 논문으로, MoE 라우팅의 상식을 뒤집었다. 기존에는 Top-2 이상의 전문가를 활성화해야 학습이 안정적이라고 여겨졌지만, Switch Transformer는 Top-1 단일 전문가 라우팅으로도 우수한 성능을 달성할 수 있음을 입증했다.

Top-1 라우팅의 세 가지 장점

  1. 라우터 연산 감소: 하나의 전문가만 선택하므로 게이팅 연산이 단순화된다
  2. 배치 효율 향상: 각 토큰이 하나의 전문가에만 배정되므로 동일 Capacity Factor에서 배치 크기를 두 배로 늘릴 수 있다
  3. 통신 비용 절감: 분산 학습 시 토큰을 하나의 전문가 디바이스로만 전송하면 된다

Switch Transformer의 라우터 구현

class SwitchRouter(nn.Module):
    """Switch Transformer 라우터 (Top-1 단일 전문가)"""
    def __init__(self, d_model: int, num_experts: int,
                 capacity_factor: float = 1.25):
        super().__init__()
        self.num_experts = num_experts
        self.capacity_factor = capacity_factor
        self.gate = nn.Linear(d_model, num_experts, bias=False)

    def forward(self, x: torch.Tensor):
        B, T, D = x.shape
        logits = self.gate(x)
        probs = F.softmax(logits, dim=-1)

        # Top-1 선택 (단일 전문가)
        gate_values, expert_indices = probs.max(dim=-1)  # (B, T)

        # Capacity 기반 토큰 드롭
        capacity = int(self.capacity_factor * T / self.num_experts)

        # 각 전문가별 할당 토큰 수 추적
        expert_counts = torch.zeros(B, self.num_experts, device=x.device)
        drop_mask = torch.zeros(B, T, dtype=torch.bool, device=x.device)

        for t in range(T):
            for b in range(B):
                eid = expert_indices[b, t].item()
                if expert_counts[b, eid] >= capacity:
                    drop_mask[b, t] = True  # 용량 초과 토큰 드롭
                else:
                    expert_counts[b, eid] += 1

        # 드롭된 토큰의 게이트 값을 0으로
        gate_values = gate_values * (~drop_mask).float()

        return gate_values, expert_indices, drop_mask

Simplified Load Balancing Loss

Switch Transformer는 로드 밸런싱을 위한 간소화된 보조 손실(auxiliary loss)을 제안한다.

Lbalance=αNi=1NfiPiL_{\text{balance}} = \alpha \cdot N \sum_{i=1}^{N} f_i \cdot P_i

여기서 N은 전문가 수, f_i는 i번째 전문가에 배정된 토큰 비율, P_i는 i번째 전문가에 대한 평균 라우터 확률, 알파는 하이퍼파라미터(논문에서는 0.01 권장)다. 이 손실은 f_iP_i 모두 1/N이 되도록, 즉 전문가들이 균등하게 토큰을 받도록 유도한다.

def switch_load_balancing_loss(router_probs: torch.Tensor,
                                expert_indices: torch.Tensor,
                                num_experts: int,
                                alpha: float = 0.01) -> torch.Tensor:
    """Switch Transformer의 Simplified Load Balancing Loss

    Args:
        router_probs: 라우터 확률 (B, T, num_experts)
        expert_indices: 선택된 전문가 인덱스 (B, T)
        num_experts: 전문가 수
        alpha: 밸런싱 계수
    """
    # f_i: 각 전문가에 배정된 토큰 비율
    one_hot = F.one_hot(expert_indices, num_experts).float()
    f = one_hot.mean(dim=[0, 1])  # (num_experts,)

    # P_i: 각 전문가에 대한 평균 라우터 확률
    P = router_probs.mean(dim=[0, 1])  # (num_experts,)

    # 밸런싱 손실: alpha * N * sum(f_i * P_i)
    loss = alpha * num_experts * (f * P).sum()
    return loss

bfloat16 혼합 정밀도 학습

Switch Transformer의 또 다른 중요한 기여는 대규모 Sparse 모델을 bfloat16 혼합 정밀도로 학습하는 기법이다. 라우터의 softmax 연산은 수치적으로 불안정할 수 있으므로, 라우터 부분만 float32로 유지하고 나머지는 bfloat16으로 처리한다. 이를 통해 메모리를 절약하면서도 학습 안정성을 확보했다.

Mixtral 8x7B: 오픈소스 MoE의 혁신

논문 개요

Mixtral 8x7B(Jiang et al., 2024)는 Mistral AI가 발표한 Sparse MoE 언어 모델이다. Mistral 7B와 동일한 아키텍처를 기반으로 하되, 각 레이어의 FFN(Feed-Forward Network)을 8개의 전문가로 교체하고 Top-2 라우팅을 적용한다.

아키텍처 세부 사항

  • 총 파라미터: 약 46.7B (8개 전문가 x 7B FFN + 공유 Attention)
  • 활성 파라미터: 약 12.9B (각 토큰당 2개 전문가 활성화)
  • 컨텍스트 길이: 32K 토큰
  • 어텐션: Grouped Query Attention(GQA) + Sliding Window Attention(SWA)

Mixtral의 Top-2 라우팅

GShard와 달리 Mixtral의 Top-2 라우팅은 결정론적이다. 각 토큰에 대해 가장 높은 게이트 값을 가진 2개의 전문가를 선택하고, 이 두 전문가의 출력을 정규화된 게이트 값으로 가중합한다.

y=iTop-2egijTop-2egjEi(x)y = \sum_{i \in \text{Top-2}} \frac{e^{g_i}}{\sum_{j \in \text{Top-2}} e^{g_j}} \cdot E_i(x)

전문가 특화 패턴 분석

Mixtral 논문에서 흥미로운 발견은 전문가 특화 패턴이다. The Pile 검증 세트의 여러 도메인에서 각 전문가의 활성화 빈도를 분석한 결과, 전문가들은 주제(topic)보다는 구문적 패턴(syntactic pattern) 에 따라 특화되는 경향을 보였다. 특정 전문가가 코드, 수학, 자연어 중 하나에만 특화되는 것이 아니라, 토큰 시퀀스의 구조적 특성에 기반하여 반복적으로 선택되는 양상이다.

벤치마크 성능

Mixtral 8x7B는 13B 활성 파라미터만으로 대부분의 벤치마크에서 LLaMA 2 70B와 동등하거나 우수한 성능을 보였다. 특히 수학과 코드 생성 벤치마크에서 두각을 나타냈으며, GPT-3.5 수준의 성능을 보이는 영역도 있었다.

DeepSeek-MoE: Fine-grained Expert Segmentation

논문 개요

DeepSeek-MoE(Dai et al., 2024)는 MoE 아키텍처에서 전문가 특화를 극대화하기 위한 두 가지 핵심 전략을 제안했다.

  1. Fine-grained Expert Segmentation: 전문가를 더 작은 단위로 세분화하여 유연한 조합을 가능하게 한다
  2. Shared Expert Isolation: 일부 전문가를 공유 전문가로 격리하여 공통 지식을 집중시키고, 라우팅 전문가 간 중복을 줄인다

Fine-grained Expert Segmentation의 원리

기존 MoE가 N개의 전문가 중 K개를 활성화한다면, DeepSeek-MoE는 동일한 파라미터 예산을 m배 더 많은 전문가(mN개)로 분할하고, m배 더 많은 전문가(mK개)를 활성화한다. 각 전문가의 FFN 중간 차원(intermediate hidden dimension)을 m으로 나누는 방식이다.

총 파라미터 수와 연산량은 동일하게 유지되지만, 더 세밀한 전문가 조합이 가능해지므로 지식이 더 정밀하게 분해되어 학습된다. 각 전문가는 더 높은 수준의 특화를 달성한다.

Shared Expert Isolation

DeepSeek-MoE는 일부 전문가를 공유 전문가(Shared Expert) 로 지정하여 모든 토큰에 대해 항상 활성화한다. 공유 전문가는 문맥에 관계없이 공통적으로 필요한 지식(일반 언어 구조, 기본 문법 등)을 학습하고, 나머지 라우팅 전문가들은 각자의 전문 영역에 집중할 수 있다. 이를 통해 라우팅 전문가 간의 지식 중복이 줄어든다.

class DeepSeekMoELayer(nn.Module):
    """DeepSeek-MoE 레이어 (Fine-grained Segmentation + Shared Expert)"""
    def __init__(self, d_model: int, d_ff: int, num_routed_experts: int,
                 num_shared_experts: int, top_k: int, segmentation_factor: int = 4):
        super().__init__()
        self.num_routed = num_routed_experts
        self.num_shared = num_shared_experts
        self.top_k = top_k

        # Fine-grained: FFN 중간 차원을 segmentation_factor로 나눔
        fine_d_ff = d_ff // segmentation_factor

        # 공유 전문가 (항상 활성화)
        self.shared_experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.SiLU(),
                nn.Linear(d_ff, d_model)
            ) for _ in range(num_shared_experts)
        ])

        # 라우팅 전문가 (세분화된 크기)
        self.routed_experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, fine_d_ff),
                nn.SiLU(),
                nn.Linear(fine_d_ff, d_model)
            ) for _ in range(num_routed_experts)
        ])

        # 게이팅 네트워크 (라우팅 전문가용)
        self.gate = nn.Linear(d_model, num_routed_experts, bias=False)

    def forward(self, x: torch.Tensor):
        B, T, D = x.shape

        # 1. 공유 전문가 출력 (모든 토큰에 적용)
        shared_output = sum(expert(x) for expert in self.shared_experts)

        # 2. 라우팅 전문가 선택 및 출력
        logits = self.gate(x)
        probs = F.softmax(logits, dim=-1)
        top_k_probs, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
        top_k_gates = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        routed_output = torch.zeros_like(x)
        for k in range(self.top_k):
            expert_idx = top_k_indices[:, :, k]  # (B, T)
            gate_val = top_k_gates[:, :, k]       # (B, T)
            for i in range(self.num_routed):
                mask = (expert_idx == i)
                if mask.any():
                    inp = x[mask]
                    out = self.routed_experts[i](inp)
                    routed_output[mask] += gate_val[mask].unsqueeze(-1) * out

        return shared_output + routed_output

성능 비교

DeepSeek-MoE 2B는 GShard 2.9B와 동등한 성능을 1.5배 적은 전문가 파라미터로 달성했다. DeepSeek-MoE 16B는 LLaMA 2 7B와 동등한 성능을 약 40%의 연산량으로 달성하여, Fine-grained Segmentation의 효과를 실증적으로 검증했다.

MoE 모델 비교 분석표

Dense vs MoE 비교

특성Dense 모델Sparse MoE 모델
활성 파라미터전체 파라미터Top-k 전문가 파라미터만
연산량(FLOPs)파라미터에 비례활성 파라미터에 비례
메모리 사용량파라미터 크기에 비례전체 파라미터 메모리 필요 (추론 시)
학습 안정성상대적으로 안정로드 밸런싱 필요, 불안정 가능
분산 학습데이터/모델 병렬Expert 병렬 추가 필요
추론 지연시간예측 가능라우팅 오버헤드 존재
스케일링 효율선형 증가서브리니어 연산 증가로 효율적

MoE 변형 모델 비교

특성GShard (2020)Switch Transformer (2021)Mixtral 8x7B (2024)DeepSeek-MoE (2024)
라우팅 방식Top-2 (랜덤 2nd)Top-1Top-2 (결정론적)Top-K (세분화)
전문가 수2048 (분산)128+8 per layer64 routed + 2 shared
활성 전문가 수2126 routed + 2 shared
총 파라미터600B1.6T (최대)46.7B16.4B / 145B
활성 파라미터--12.9B2.8B / 22.2B
전문가 세분화표준표준표준Fine-grained
공유 전문가없음없음없음있음
로드 밸런싱Auxiliary loss + CFSimplified loss암묵적Auxiliary loss
학습 정밀도float32bfloat16 혼합bfloat16bfloat16
주요 기여대규모 분산 MoETop-1 라우팅 효율성오픈소스 MoE LLM전문가 특화 극대화

라우팅 전략과 로드 밸런싱

라우팅 전략 계보

MoE 라우팅 전략은 크게 세 가지 패러다임으로 분류할 수 있다.

1. Token-Choice Routing: 토큰이 전문가를 선택하는 방식 (GShard, Switch, Mixtral)

가장 일반적인 방식으로, 각 토큰의 히든 상태를 게이팅 네트워크에 입력하여 Top-k 전문가를 선택한다. 직관적이고 구현이 간단하지만, 로드 불균형이 발생할 수 있다.

2. Expert-Choice Routing: 전문가가 토큰을 선택하는 방식 (Expert Choice, 2022)

각 전문가가 자신이 처리할 토큰을 선택한다. 전문가 용량이 고정되므로 완벽한 로드 밸런싱이 보장되지만, 특정 토큰이 아무 전문가에게도 선택되지 않거나 여러 전문가에게 중복 선택될 수 있다.

3. Shared + Routed Hybrid: 공유 전문가와 라우팅 전문가를 병용하는 방식 (DeepSeek-MoE)

공통 지식은 공유 전문가가 처리하고, 전문 지식은 라우팅 전문가가 담당한다. 라우팅 전문가 간 중복이 줄어들어 전문가 특화가 향상된다.

로드 밸런싱의 중요성

MoE 학습에서 가장 흔한 문제 중 하나는 전문가 붕괴(Expert Collapse) 다. 학습 초기에 특정 전문가가 다른 전문가보다 약간이라도 우수한 성능을 보이면, 라우터가 해당 전문가에 더 많은 토큰을 배정하게 되고, 이는 다시 해당 전문가의 학습을 강화하여 positive feedback loop가 형성된다. 결국 소수의 전문가만 사용되고 나머지는 학습되지 않는 상태에 빠진다.

def compute_expert_utilization(expert_indices: torch.Tensor,
                                num_experts: int) -> dict:
    """전문가 활용도 모니터링 유틸리티

    Args:
        expert_indices: 선택된 전문가 인덱스 (B, T) 또는 (B, T, K)
        num_experts: 총 전문가 수

    Returns:
        활용도 메트릭 딕셔너리
    """
    flat = expert_indices.flatten()
    counts = torch.bincount(flat, minlength=num_experts).float()
    total = flat.numel()

    # 각 전문가의 토큰 비율
    fractions = counts / total

    # 활용도 메트릭
    max_fraction = fractions.max().item()
    min_fraction = fractions.min().item()
    ideal_fraction = 1.0 / num_experts

    # 불균형 비율 (1에 가까울수록 균등)
    balance_ratio = min_fraction / (max_fraction + 1e-8)

    # 비활성 전문가 수
    dead_experts = (counts == 0).sum().item()

    return {
        "balance_ratio": balance_ratio,
        "max_load": max_fraction,
        "min_load": min_fraction,
        "ideal_load": ideal_fraction,
        "dead_experts": dead_experts,
        "load_std": fractions.std().item(),
    }

Auxiliary Loss의 효과와 한계

Auxiliary loss(보조 손실)는 전문가 붕괴를 방지하는 가장 널리 사용되는 기법이다. 그러나 알파 하이퍼파라미터 설정이 까다롭다.

  • 알파가 너무 작으면: 밸런싱 효과 부족, 전문가 붕괴 위험
  • 알파가 너무 크면: 라우팅 품질 저하, 모든 전문가에 토큰을 균등 배분하려고 하여 전문가 특화가 약화됨
  • Switch Transformer 권장값: 알파 = 0.01
  • 실무 범위: 0.001 ~ 0.1 사이에서 탐색

MoE 학습 시 주의사항과 불안정성 해결

학습 불안정성의 원인

MoE 모델은 Dense 모델에 비해 학습이 불안정한 경향이 있다. 주요 원인은 다음과 같다.

1. 라우터의 이산적 결정(Discrete Decision)

Top-k 선택은 본질적으로 미분 불가능한 이산적 연산이다. softmax를 통해 근사적으로 미분 가능하게 만들지만, 게이트 값의 작은 변동이 전문가 선택을 급격하게 바꿀 수 있다. 이는 학습 불안정성의 근본 원인 중 하나다.

2. 전문가 간 비대칭 학습

인기 있는 전문가는 더 많은 그래디언트를 받아 빠르게 학습되고, 비인기 전문가는 그래디언트가 부족하여 정체된다. 이 비대칭은 시간이 갈수록 심화된다.

3. 라우터 softmax의 수치적 불안정성

대규모 로짓 값에서 softmax 연산은 수치적으로 불안정할 수 있다. 특히 float16이나 bfloat16 정밀도에서는 이 문제가 더 심각해진다.

안정화 기법

Router z-loss: Switch Transformer에서 제안된 z-loss는 라우터 로짓의 크기를 억제하여 수치적 안정성을 높인다.

Lz=1BTb,t(logi=1Nexb,t,i)2L_z = \frac{1}{B \cdot T} \sum_{b,t} \left( \log \sum_{i=1}^{N} e^{x_{b,t,i}} \right)^2
def router_z_loss(router_logits: torch.Tensor) -> torch.Tensor:
    """Router z-loss: 라우터 로짓의 크기를 억제하여 안정성 향상

    Args:
        router_logits: 라우터의 원시 로짓 (B, T, num_experts)
    """
    # log-sum-exp의 제곱
    log_z = torch.logsumexp(router_logits, dim=-1)  # (B, T)
    z_loss = (log_z ** 2).mean()
    return z_loss


class StableMoETrainer:
    """MoE 학습 안정화를 위한 트레이너 예시"""
    def __init__(self, model, optimizer, alpha_balance=0.01,
                 alpha_z=0.001):
        self.model = model
        self.optimizer = optimizer
        self.alpha_balance = alpha_balance
        self.alpha_z = alpha_z

    def train_step(self, input_ids, labels):
        self.optimizer.zero_grad()

        # Forward pass
        logits, router_logits, expert_indices = self.model(input_ids)

        # Task loss
        task_loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)), labels.view(-1)
        )

        # Load balancing loss
        router_probs = F.softmax(router_logits, dim=-1)
        balance_loss = switch_load_balancing_loss(
            router_probs, expert_indices, self.model.num_experts
        )

        # Router z-loss
        z_loss = router_z_loss(router_logits)

        # Total loss
        total_loss = task_loss + self.alpha_balance * balance_loss + self.alpha_z * z_loss

        total_loss.backward()

        # 그래디언트 클리핑 (MoE 학습에서 특히 중요)
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

        self.optimizer.step()

        return {
            "task_loss": task_loss.item(),
            "balance_loss": balance_loss.item(),
            "z_loss": z_loss.item(),
            "total_loss": total_loss.item(),
        }

혼합 정밀도 학습 전략

Switch Transformer에서 밝힌 중요한 교훈은, MoE 모델에서 라우터는 반드시 고정밀도(float32)로 유지해야 한다는 것이다. 라우터의 softmax 연산이 bfloat16에서는 수치적 오류가 누적되어 학습이 발산할 수 있다.

권장 전략은 다음과 같다.

  • 라우터 게이팅: float32 유지
  • 전문가 FFN: bfloat16 가능
  • Attention 레이어: bfloat16 가능
  • 손실 계산: float32 유지

실패 사례와 디버깅

사례 1: 전문가 붕괴 (Expert Collapse)

증상: 학습 진행 중 특정 1~2개의 전문가만 토큰의 90% 이상을 처리하고, 나머지 전문가의 게이트 확률이 0에 수렴한다.

원인: Auxiliary loss의 알파 값이 너무 작거나, 학습률이 너무 높아서 라우터가 초기 편향에서 벗어나지 못한다.

해결 방법:

  • 알파 값을 0.01에서 0.1로 증가시켜 본다
  • 라우터 가중치를 Xavier 초기화 대신 더 작은 표준편차(0.01)로 초기화한다
  • Jitter noise를 추가하여 탐색을 강제한다
  • 전문가 활용도를 매 스텝 모니터링하여 조기 감지한다

사례 2: 토큰 드롭으로 인한 성능 저하

증상: Capacity Factor가 낮은 상태에서 10% 이상의 토큰이 드롭되어 perplexity가 개선되지 않는다.

원인: CF가 너무 낮거나, 특정 전문가에 토큰이 지나치게 집중된다.

해결 방법:

  • CF를 1.0에서 1.5로 증가 (메모리 허용 범위 내에서)
  • Auxiliary loss 강화로 토큰 분포 균등화
  • Token dropping 전략 변경: "position" 기반에서 "probability" 기반으로 전환하여 저확률 토큰만 드롭

사례 3: 분산 학습 시 All-to-All 통신 병목

증상: 전문가 수를 늘렸는데도 학습 처리량(throughput)이 오히려 감소한다.

원인: Expert Parallelism에서 All-to-All 통신이 연산 시간을 초과한다.

해결 방법:

  • Expert Parallelism 그룹 크기를 줄이고 Data Parallelism을 늘린다
  • 전문가당 토큰 배치 크기를 키워 연산 대비 통신 비율을 개선한다
  • GShard의 로컬 그룹 디스패칭을 적용하여 통신 범위를 제한한다

사례 4: 추론 시 메모리 폭증

증상: MoE 모델 추론 시 GPU 메모리가 Dense 모델 대비 N배 이상 필요하다.

원인: Sparse MoE는 학습과 추론 시 모든 전문가 파라미터를 메모리에 올려야 한다. 활성 파라미터는 적지만, 전체 파라미터가 메모리에 상주한다.

해결 방법:

  • Expert Offloading: 비활성 전문가를 CPU/NVMe로 이동하고 필요 시 로드
  • Expert Quantization: 비활성 전문가를 INT8/INT4로 양자화
  • Expert Pruning: 활용도가 낮은 전문가를 제거하고 라우터를 재학습

디버깅 체크리스트

MoE 학습 시 반드시 모니터링해야 할 메트릭들은 다음과 같다.

  1. 전문가별 토큰 할당 비율: 균등해야 함 (이상적으로 1/N)
  2. 비활성 전문가 수: 0이어야 함
  3. 토큰 드롭률: 5% 미만 유지 권장
  4. 라우터 엔트로피: 너무 낮으면 전문가 붕괴, 너무 높으면 랜덤 라우팅
  5. Auxiliary loss 추이: 단조감소해야 함
  6. 전문가별 그래디언트 노름: 전문가 간 편차가 크면 불균형 학습

마치며

MoE 아키텍처는 LLM 스케일링의 핵심 전략으로 확고히 자리잡았다. GShard는 대규모 분산 MoE 학습의 가능성을 증명했고, Switch Transformer는 단일 전문가 라우팅의 효율성을 입증했다. Mixtral 8x7B는 오픈소스 커뮤니티에 실용적인 MoE 모델을 제공했으며, DeepSeek-MoE는 Fine-grained Expert Segmentation과 Shared Expert Isolation으로 전문가 특화를 극대화했다.

각 논문의 핵심 교훈을 정리하면 다음과 같다.

  • GShard: 랜덤 라우팅과 로컬 그룹 디스패칭으로 수천 대 디바이스에서의 MoE 학습이 실현 가능하다
  • Switch Transformer: Top-1 라우팅이 Top-2보다 단순하면서도 효율적일 수 있다. 라우터의 수치적 안정성이 핵심이다
  • Mixtral 8x7B: 적은 수의 전문가(8개)로도 실용적인 MoE LLM을 구축할 수 있다
  • DeepSeek-MoE: 전문가를 세분화하고 공유 전문가를 격리하면, 동일 연산 예산으로 더 높은 성능을 달성할 수 있다

MoE 연구는 계속 진화하고 있다. DeepSeek-V2에서는 Multi-head Latent Attention(MLA)과 결합한 MoE 아키텍처가 등장했고, DeepSeek-V3는 671B 파라미터에 37B만 활성화하는 구조로 효율성의 새 지평을 열었다. 앞으로도 라우팅 전략의 개선, 학습 안정성 향상, 추론 최적화 등의 방향에서 발전이 지속될 것으로 기대된다.

참고자료