Split View: Sparse Mixture of Experts(MoE) 아키텍처 심층 분석: 설계 원리부터 DeepSeek-V3·Qwen3까지
Sparse Mixture of Experts(MoE) 아키텍처 심층 분석: 설계 원리부터 DeepSeek-V3·Qwen3까지
- 들어가며
- MoE 기본 구조와 수학적 원리
- 라우팅 전략: Top-k, Expert Choice, Hash Routing
- 로드 밸런싱과 Auxiliary Loss
- Switch Transformer에서 DeepSeek-V3까지 진화
- 학습 안정성과 트러블슈팅
- 추론 최적화: Expert Parallelism, Offloading
- 운영 체크리스트
- 실패 사례와 복구
- 참고자료

들어가며
대규모 언어 모델(LLM)의 파라미터 수는 기하급수적으로 증가하고 있지만, 모든 파라미터를 매 토큰마다 활성화하는 Dense 모델은 연산 비용의 벽에 부딪혔다. GPT-4 수준의 Dense 모델을 학습하려면 수만 장의 GPU가 수개월간 가동되어야 하며, 추론 비용 역시 파라미터 수에 비례하여 증가한다. 이 근본적 비효율을 해결하기 위해 조건부 연산(Conditional Computation) 패러다임이 부상했고, 그 중심에 Sparse Mixture of Experts(MoE) 아키텍처가 있다.
MoE의 핵심 아이디어는 단순하다. 수백 개의 전문가(Expert) 네트워크를 두되, 각 입력 토큰에 대해 소수의 전문가만 활성화하여 연산량을 극적으로 줄이는 것이다. 전체 파라미터 수가 모델 용량(capacity)을 결정하고, 활성 파라미터 수가 실제 연산 비용을 결정하므로, 높은 품질과 낮은 비용을 동시에 달성할 수 있다. Mixtral 8x7B는 총 47B 파라미터에서 토큰당 13B만 활성화하며, DeepSeek-V3는 671B 파라미터 중 37B만 활성화한다.
이 글에서는 MoE의 수학적 기초부터 라우팅 전략, 로드 밸런싱, Switch Transformer에서 DeepSeek-V3와 Qwen3-235B-A22B까지의 아키텍처 진화, 그리고 학습과 추론의 실전 최적화 전략을 심층적으로 다룬다. 각 주제를 PyTorch 코드와 함께 설명하며, 운영 환경에서 발생하는 실패 사례와 복구 절차까지 포함한다.
MoE 기본 구조와 수학적 원리
Sparse Activation의 수학적 정의
MoE 레이어는 N개의 전문가 네트워크 E_1, E_2, ..., E_N과 게이팅 네트워크 G로 구성된다. 입력 토큰 x에 대한 MoE 레이어의 출력 y는 다음과 같이 정의된다.
y = sum_{i=1}^{N} G(x)_i * E_i(x)
여기서 G(x)는 게이팅 함수로, 입력 x에 대해 N차원 벡터를 출력한다. Dense MoE에서는 모든 G(x)_i가 0이 아닌 값을 가지지만, Sparse MoE에서는 Top-K 전문가만 선택하고 나머지의 게이팅 값을 0으로 만든다.
G(x)_i = softmax(W_g * x + noise)_i (if i in TopK)
G(x)_i = 0 (otherwise)
노이즈 항은 학습 시 탐색(exploration)을 촉진하여 전문가 활용의 다양성을 높이는 역할을 한다. 이 sparsity 덕분에 전체 파라미터 수는 N에 비례하여 증가하지만, 실제 연산량(FLOPs)은 K에 비례하여 Dense 모델 대비 N/K 배 효율적이다.
전문가 네트워크의 구조
각 전문가는 일반적으로 Transformer의 Feed-Forward Network(FFN)을 대체한다. Self-Attention은 모든 토큰이 공유하고, FFN 부분만 전문가별로 분리하는 것이 표준적인 설계다. 이는 Self-Attention이 토큰 간 관계를 포착하는 전역적 역할을 하고, FFN이 개별 토큰의 표현을 변환하는 지역적 역할을 하기 때문이다.
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
@dataclass
class MoEConfig:
d_model: int = 1024
d_ff: int = 4096
num_experts: int = 8
top_k: int = 2
dropout: float = 0.1
aux_loss_weight: float = 0.01
class SwiGLUExpert(nn.Module):
"""SwiGLU 활성화를 사용하는 전문가 FFN.
LLaMA, Mistral 등 최신 모델에서 표준으로 채택된 구조."""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.w_gate = nn.Linear(d_model, d_ff, bias=False)
self.w_up = nn.Linear(d_model, d_ff, bias=False)
self.w_down = nn.Linear(d_ff, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = F.silu(self.w_gate(x))
up = self.w_up(x)
return self.w_down(self.dropout(gate * up))
class TopKGating(nn.Module):
"""Top-K 게이팅 네트워크.
Noisy Top-K Gating (Shazeer et al., 2017) 구현."""
def __init__(self, d_model: int, num_experts: int, top_k: int = 2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.gate = nn.Linear(d_model, num_experts, bias=False)
self.noise_linear = nn.Linear(d_model, num_experts, bias=False)
def forward(self, x: torch.Tensor):
# x: (batch * seq_len, d_model)
logits = self.gate(x)
if self.training:
noise = F.softplus(self.noise_linear(x))
logits = logits + noise * torch.randn_like(logits)
probs = F.softmax(logits, dim=-1)
top_k_probs, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
# 재정규화: 선택된 전문가 가중치 합이 1이 되도록
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
return top_k_probs, top_k_indices, probs
위 코드에서 SwiGLUExpert는 LLaMA, Mistral, Qwen 등 최신 모델이 채택한 SwiGLU 활성화 함수를 사용하는 전문가다. 기존 ReLU나 GELU 대비 학습 효율이 높다는 것이 경험적으로 확인되었다. TopKGating은 Shazeer et al.(2017)이 제안한 Noisy Top-K Gating을 구현한 것으로, 학습 시 게이팅 로짓에 학습 가능한 노이즈를 추가하여 전문가 탐색을 촉진한다.
Dense vs Sparse MoE 정량 비교
| 항목 | Dense 70B | Sparse MoE 8x7B (Top-2) | Sparse MoE 256x3B (Top-2) |
|---|---|---|---|
| 총 파라미터 | 70B | 47B | 768B |
| 활성 파라미터 | 70B | 13B | 6B |
| FLOPs/token | 140 TFLOPs | 26 TFLOPs | 12 TFLOPs |
| GPU 메모리 (FP16) | 140 GB | 94 GB | 1.5 TB |
| 학습 비용 비율 | 1.0x | 0.35x (FLOPs 기준) | 0.17x (FLOPs 기준) |
| 추론 속도 | 기준 | 2-3x 빠름 | 전문가 로딩 병목 |
주목할 점은 MoE 모델의 메모리 요구량이다. 활성 파라미터는 적지만 전체 파라미터를 메모리에 올려야 하므로, 전문가 수가 매우 많은 경우 메모리가 Dense 모델보다 오히려 더 클 수 있다. 이것이 Expert Parallelism과 Offloading 전략이 필요한 근본적 이유다.
라우팅 전략: Top-k, Expert Choice, Hash Routing
라우팅(Routing)은 MoE의 핵심이자 가장 어려운 설계 문제다. 어떤 전문가를 활성화할지 결정하는 전략에 따라 모델의 품질, 학습 안정성, 추론 효율이 크게 달라진다.
Top-K Routing
가장 전통적인 라우팅 방식으로, 게이팅 네트워크가 각 전문가에 대한 점수를 계산하고 상위 K개를 선택한다. Shazeer et al.(2017)이 Top-2를 제안했고, Switch Transformer(Fedus et al., 2022)가 Top-1으로 단순화하여 통신 비용을 절반으로 줄였다.
Top-1의 장점: 각 토큰이 정확히 하나의 전문가만 사용하므로 분산 환경에서 All-to-All 통신량이 최소화된다. 구현도 단순하다.
Top-1의 단점: 단일 전문가에 의존하므로 표현력이 제한되고, 게이팅 결정이 이산적(discrete)이라 학습 초기에 전문가 붕괴(expert collapse) 위험이 높다.
Top-2의 절충: Mixtral 8x7B와 다수의 최신 모델이 Top-2를 채택한다. 두 전문가의 출력을 가중합하므로 표현력이 풍부하고, 하나의 전문가가 불안정해도 다른 전문가가 보완한다.
Expert Choice Routing
Zhou et al.(2022)이 제안한 Expert Choice 라우팅은 관점을 전환한다. 토큰이 전문가를 선택하는 것이 아니라, 전문가가 처리할 토큰을 선택한다. 각 전문가가 자신에게 가장 적합한 토큰 K개를 선택하므로, 로드 밸런싱이 구조적으로 보장된다.
class ExpertChoiceGating(nn.Module):
"""Expert Choice Routing 구현.
각 전문가가 처리할 토큰을 직접 선택하여
로드 밸런싱을 구조적으로 보장한다."""
def __init__(
self,
d_model: int,
num_experts: int,
capacity_factor: float = 1.0,
):
super().__init__()
self.num_experts = num_experts
self.capacity_factor = capacity_factor
self.gate = nn.Linear(d_model, num_experts, bias=False)
def forward(self, x: torch.Tensor):
# x: (num_tokens, d_model)
num_tokens = x.shape[0]
expert_capacity = int(
num_tokens * self.capacity_factor / self.num_experts
)
# 게이팅 점수: (num_tokens, num_experts)
gate_logits = self.gate(x)
# 전문가 관점에서 점수 계산: (num_experts, num_tokens)
gate_scores = F.softmax(gate_logits.T, dim=-1)
# 각 전문가가 상위 capacity개 토큰 선택
top_k_scores, top_k_indices = torch.topk(
gate_scores, expert_capacity, dim=-1
) # (num_experts, capacity)
# 디스패치 마스크 생성
dispatch_mask = torch.zeros(
self.num_experts, num_tokens,
device=x.device, dtype=x.dtype,
)
dispatch_mask.scatter_(1, top_k_indices, top_k_scores)
return dispatch_mask, top_k_indices, top_k_scores
Expert Choice의 핵심 장점은 보조 손실(auxiliary loss) 없이도 완벽한 로드 밸런싱을 달성한다는 것이다. 각 전문가가 동일한 수의 토큰을 처리하도록 강제되므로 전문가 붕괴 문제가 원천적으로 제거된다. 단, 하나의 토큰이 여러 전문가에게 선택되거나 아무 전문가에게도 선택되지 않을 수 있다는 비대칭성이 존재한다.
Hash Routing
Roller et al.(2021)이 제안한 Hash Routing은 학습 가능한 게이팅을 완전히 제거하고, 해시 함수로 토큰을 전문가에 배정한다. 게이팅 네트워크의 파라미터와 연산이 사라지므로 추론 시 오버헤드가 최소화된다. 그러나 고정된 배정 규칙이므로 입력 의미를 반영하지 못하고, 실전에서는 학습 가능한 라우팅 대비 품질이 낮아 주류로 채택되지는 않았다.
라우팅 전략 비교
| 전략 | 로드 밸런싱 | 표현력 | 통신 비용 | 구현 복잡도 | 대표 모델 |
|---|---|---|---|---|---|
| Top-1 | Aux Loss 필요 | 낮음 | 최소 | 낮음 | Switch Transformer |
| Top-2 | Aux Loss 필요 | 중간 | 중간 | 중간 | Mixtral 8x7B |
| Top-K (K=6,8) | Aux Loss 필요 | 높음 | 높음 | 중간 | DeepSeek-V3 (Top-8/256) |
| Expert Choice | 구조적 보장 | 높음 | 중간 | 높음 | 연구용 모델 |
| Hash Routing | 완벽 | 낮음 | 최소 | 최소 | 연구용 모델 |
로드 밸런싱과 Auxiliary Loss
MoE 학습에서 가장 심각한 문제는 전문가 붕괴(Expert Collapse)다. 게이팅 네트워크가 소수의 전문가에만 토큰을 집중적으로 보내고, 나머지 전문가는 학습 기회를 잃어 사실상 죽은 파라미터가 되는 현상이다. 이를 방지하기 위해 다양한 로드 밸런싱 기법이 개발되었다.
Auxiliary Loss (보조 손실)
Switch Transformer가 제안한 표준적인 접근법이다. 전문가별 토큰 배분이 균등해지도록 유도하는 손실 항을 메인 언어 모델링 손실에 추가한다.
def compute_load_balancing_loss(
gate_probs: torch.Tensor,
top_k_indices: torch.Tensor,
num_experts: int,
top_k: int,
) -> torch.Tensor:
"""Switch Transformer 스타일의 로드 밸런싱 보조 손실 계산.
Args:
gate_probs: 게이팅 확률 (num_tokens, num_experts)
top_k_indices: 선택된 전문가 인덱스 (num_tokens, top_k)
num_experts: 전문가 수
top_k: 선택되는 전문가 수
Returns:
보조 손실 스칼라 값
"""
num_tokens = gate_probs.shape[0]
# f_i: 전문가 i에 배정된 토큰의 비율
expert_mask = F.one_hot(top_k_indices, num_experts).float()
# (num_tokens, top_k, num_experts) -> (num_tokens, num_experts)
expert_mask = expert_mask.sum(dim=1)
tokens_per_expert = expert_mask.sum(dim=0) # (num_experts,)
f = tokens_per_expert / (num_tokens * top_k)
# P_i: 전문가 i에 대한 평균 게이팅 확률
P = gate_probs.mean(dim=0) # (num_experts,)
# 보조 손실: N * sum(f_i * P_i)
# 균등 배분일 때 최소값을 가짐
aux_loss = num_experts * (f * P).sum()
return aux_loss
class MoELayerWithAuxLoss(nn.Module):
"""보조 손실이 포함된 완전한 MoE 레이어 구현."""
def __init__(self, config: MoEConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
SwiGLUExpert(config.d_model, config.d_ff, config.dropout)
for _ in range(config.num_experts)
])
self.gating = TopKGating(
config.d_model, config.num_experts, config.top_k
)
self.aux_loss_weight = config.aux_loss_weight
def forward(self, x: torch.Tensor):
batch_size, seq_len, d_model = x.shape
x_flat = x.view(-1, d_model)
top_k_probs, top_k_indices, gate_probs = self.gating(x_flat)
# 보조 손실 계산
aux_loss = self.aux_loss_weight * compute_load_balancing_loss(
gate_probs, top_k_indices,
self.config.num_experts, self.config.top_k,
)
# 전문가 출력 계산
output = torch.zeros_like(x_flat)
for k in range(self.config.top_k):
expert_indices = top_k_indices[:, k] # (num_tokens,)
expert_weights = top_k_probs[:, k] # (num_tokens,)
for i in range(self.config.num_experts):
mask = (expert_indices == i)
if mask.any():
expert_input = x_flat[mask]
expert_output = self.experts[i](expert_input)
output[mask] += expert_weights[mask].unsqueeze(-1) * expert_output
output = output.view(batch_size, seq_len, d_model)
return output, aux_loss
Auxiliary Loss의 가중치 alpha는 매우 민감한 하이퍼파라미터다. Switch Transformer는 alpha=0.01을 권장했으나, 모델 규모와 전문가 수에 따라 조정이 필요하다. alpha가 너무 크면 언어 모델링 품질이 저하되고, 너무 작으면 로드 밸런싱 효과가 미미하다. ST-MoE(Zoph et al., 2022)는 라우터 z-loss를 추가하여 게이팅 로짓의 크기 자체를 제약하는 방법을 제안했다.
DeepSeek의 Auxiliary-Loss-Free 전략
DeepSeek-V3(2024)는 보조 손실 없이 로드 밸런싱을 달성하는 혁신적 방법을 제안했다. 각 전문가에 학습 불가능한 바이어스 항을 추가하고, 학습 중 토큰이 과도하게 집중되는 전문가의 바이어스를 낮추고 활용이 부족한 전문가의 바이어스를 높이는 동적 조정 메커니즘을 사용한다. 이 접근법은 보조 손실이 메인 학습 목적 함수에 간섭하는 문제를 완전히 제거하며, 학습 안정성과 최종 모델 품질 모두에서 이점을 보였다.
Switch Transformer에서 DeepSeek-V3까지 진화
Switch Transformer (Fedus et al., 2022)
Switch Transformer는 MoE 라우팅을 Top-1으로 단순화한 핵심 논문이다. 기존 Top-2 라우팅의 통신 비용을 절반으로 줄이면서, 적절한 capacity factor와 auxiliary loss로 학습 안정성을 확보했다. 1.6T 파라미터 모델을 T5-XXL 대비 4배 빠르게 학습하며 동등한 품질을 달성했다.
핵심 설계 결정:
- Top-1 라우팅: 통신 비용 최소화
- Capacity Factor: 전문가 버퍼 크기를 동적 조절하여 토큰 드롭 방지
- Selective Precision: 게이팅은 FP32, 전문가 연산은 BF16으로 혼합하여 안정성과 효율성 동시 달성
GShard (Lepikhin et al., 2021)
Google이 제안한 GShard는 600B 파라미터 MoE 모델의 분산 학습 파이프라인을 확립했다. Top-2 라우팅과 Group-level 밸런싱을 사용했으며, SPMD(Single Program Multiple Data) 프로그래밍 모델로 수천 개의 TPU에서 효율적으로 학습할 수 있는 프레임워크를 제시했다.
Mixtral 8x7B (Jiang et al., 2024)
Mistral AI의 Mixtral은 오픈소스 MoE 모델의 실용성을 증명한 이정표다. 8개 전문가 중 Top-2를 선택하는 구조로, 총 47B 파라미터에서 활성 13B를 사용한다. LLaMA-2 70B와 동등하거나 우수한 벤치마크 성능을 보이면서, 추론 FLOPs는 70B의 1/3 수준이다.
DeepSeek-V3 (DeepSeek, 2024)
DeepSeek-V3는 MoE 아키텍처의 여러 설계 영역에서 혁신을 이뤘다.
- Fine-grained Expert Segmentation: 256개의 소규모 전문가를 두고 Top-8을 선택한다. 전문가 수를 늘리고 크기를 줄이면 전문화(specialization)가 심화되어 모델 품질이 향상된다.
- Shared Expert: 1개의 공유 전문가가 모든 토큰을 처리하여 공통 지식을 담당하고, 라우팅된 전문가는 특화 지식을 담당한다.
- Auxiliary-Loss-Free Load Balancing: 앞서 설명한 바이어스 기반 동적 밸런싱으로 보조 손실의 부작용을 제거했다.
- Multi-Token Prediction (MTP): 한 번에 여러 토큰을 예측하는 학습 목표를 추가하여 데이터 효율을 높였다.
- FP8 학습: 671B 파라미터 모델을 2048대의 H800 GPU에서 FP8 정밀도로 학습하여 비용을 극적으로 절감했다.
Qwen3-235B-A22B (Alibaba, 2025)
Qwen3-235B-A22B는 총 235B 파라미터 중 22B만 활성화하는 MoE 모델로, 128개의 전문가 중 Top-8을 선택한다. 기존 Qwen2.5 시리즈의 Dense 아키텍처를 MoE로 전환하면서 GPT-4o 수준의 성능을 10분의 1 수준의 추론 비용으로 달성했다.
MoE 모델 비교
| 모델 | 총 파라미터 | 활성 파라미터 | 전문가 수 | Top-K | 공유 전문가 | 라우팅 전략 |
|---|---|---|---|---|---|---|
| Switch Transformer | 1.6T | 약 100B | 128 | 1 | 없음 | Learned Top-1 |
| GShard | 600B | 약 20B | 2048 | 2 | 없음 | Learned Top-2 |
| Mixtral 8x7B | 47B | 13B | 8 | 2 | 없음 | Learned Top-2 |
| DeepSeek-V3 | 671B | 37B | 256+1 | 8 | 1개 | Bias-adjusted |
| Qwen3-235B-A22B | 235B | 22B | 128 | 8 | 있음 | Learned Top-K |
| DBRX | 132B | 36B | 16 | 4 | 없음 | Learned Top-4 |
학습 안정성과 트러블슈팅
전문가 붕괴(Expert Collapse) 진단
전문가 붕괴는 MoE 학습에서 가장 흔하고 치명적인 문제다. 게이팅 네트워크가 특정 전문가들에만 토큰을 집중적으로 보내면, 나머지 전문가의 그래디언트가 0에 수렴하여 학습이 멈추고, 이 불균형이 자기 강화(self-reinforcing) 루프를 형성하여 악화된다.
import logging
from collections import defaultdict
logger = logging.getLogger(__name__)
class ExpertUtilizationMonitor:
"""전문가 활용 모니터링 및 붕괴 탐지 도구.
학습 중 각 전문가의 활용률을 추적하고,
붕괴 징후를 조기에 감지한다.
"""
def __init__(
self,
num_experts: int,
collapse_threshold: float = 0.01,
window_size: int = 100,
):
self.num_experts = num_experts
self.collapse_threshold = collapse_threshold
self.window_size = window_size
self.history: list[dict[int, float]] = []
def record(self, expert_counts: dict[int, int], total_tokens: int):
"""배치별 전문가 활용 기록."""
utilization = {
i: expert_counts.get(i, 0) / max(total_tokens, 1)
for i in range(self.num_experts)
}
self.history.append(utilization)
if len(self.history) > self.window_size:
self.history = self.history[-self.window_size:]
def detect_collapse(self) -> list[int]:
"""전문가 붕괴 탐지. 활용률이 threshold 이하인 전문가 반환."""
if len(self.history) < self.window_size // 2:
return []
collapsed = []
for expert_id in range(self.num_experts):
recent_util = [
h[expert_id] for h in self.history[-self.window_size:]
]
avg_util = sum(recent_util) / len(recent_util)
if avg_util < self.collapse_threshold:
collapsed.append(expert_id)
if collapsed:
logger.warning(
f"Expert collapse detected! "
f"Experts {collapsed} have utilization below "
f"{self.collapse_threshold:.2%}. "
f"Consider increasing aux_loss_weight or "
f"reinitializing collapsed experts."
)
return collapsed
def get_load_imbalance_ratio(self) -> float:
"""로드 불균형 비율 계산.
1.0이면 완벽한 균형, 값이 클수록 불균형."""
if not self.history:
return 0.0
latest = self.history[-1]
utils = list(latest.values())
max_util = max(utils) if utils else 0
min_util = min(utils) if utils else 0
avg_util = sum(utils) / len(utils) if utils else 0
if avg_util == 0:
return float("inf")
return max_util / avg_util
학습 불안정 원인과 대응
| 증상 | 원인 | 대응 방법 |
|---|---|---|
| Loss spike (손실 급등) | 게이팅 로짓 폭발 | Router z-loss 추가, 게이팅 FP32 유지 |
| 전문가 붕괴 | Aux loss 부족, LR 과다 | Aux loss 가중치 증가, LR warm-up 연장 |
| 전문가 간 중복 | 초기화 유사성 | 전문가 직교 초기화, 다양성 정규화 |
| 토큰 드롭 | Capacity factor 부족 | Capacity factor 1.25-1.5로 증가 |
| 게이팅 진동 | 학습률 과대 | 게이팅 LR을 메인 LR의 0.1배로 분리 |
안정적 학습을 위한 하이퍼파라미터 가이드
학습 안정성을 위한 핵심 원칙은 다음과 같다. 첫째, 게이팅 네트워크의 연산은 반드시 FP32로 수행한다. BF16이나 FP16에서는 softmax의 수치 불안정으로 라우팅이 진동한다. 둘째, 학습률 워밍업을 Dense 모델보다 2-3배 길게 가져간다. 게이팅이 안정되기 전에 높은 학습률을 적용하면 전문가 붕괴가 발생한다. 셋째, 배치 크기를 가능한 한 크게 설정한다. 작은 배치에서는 게이팅의 토큰 분배가 노이즈에 민감하여 불안정하다.
추론 최적화: Expert Parallelism, Offloading
MoE 모델의 추론은 Dense 모델과 근본적으로 다른 도전을 제기한다. 활성 파라미터는 적지만 전체 파라미터를 접근 가능한 상태로 유지해야 하므로, 메모리 관리와 전문가 배치 전략이 핵심이다.
Expert Parallelism
Expert Parallelism(EP)은 전문가를 여러 GPU에 분산 배치하는 전략이다. N개의 전문가를 P개의 GPU에 분배하면, 각 GPU는 N/P개의 전문가만 저장한다. 토큰이 특정 전문가에 라우팅되면 All-to-All 통신으로 해당 GPU에 토큰을 전송하고, 연산 결과를 다시 원래 GPU로 반환한다.
import torch
import torch.distributed as dist
from typing import Optional
class ExpertParallelRouter:
"""Expert Parallelism을 위한 토큰 디스패치/수집 구현.
각 GPU가 전문가 일부를 담당하고,
All-to-All 통신으로 토큰을 라우팅한다.
"""
def __init__(
self,
num_experts: int,
ep_group: Optional[dist.ProcessGroup] = None,
):
self.num_experts = num_experts
self.ep_group = ep_group
self.ep_size = dist.get_world_size(ep_group) if ep_group else 1
self.ep_rank = dist.get_rank(ep_group) if ep_group else 0
self.experts_per_rank = num_experts // self.ep_size
def dispatch(
self,
tokens: torch.Tensor,
expert_indices: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""토큰을 담당 GPU로 디스패치.
Args:
tokens: (num_tokens, d_model) 입력 토큰
expert_indices: (num_tokens,) 각 토큰의 대상 전문가 인덱스
Returns:
dispatched_tokens: 이 GPU가 처리할 토큰
recv_counts: 각 GPU에서 받은 토큰 수
"""
# 각 GPU로 보낼 토큰 수 계산
send_counts = torch.zeros(
self.ep_size, dtype=torch.long, device=tokens.device
)
for rank in range(self.ep_size):
start_expert = rank * self.experts_per_rank
end_expert = start_expert + self.experts_per_rank
mask = (expert_indices >= start_expert) & (
expert_indices < end_expert
)
send_counts[rank] = mask.sum()
# All-to-All로 수신 카운트 교환
recv_counts = torch.zeros_like(send_counts)
dist.all_to_all_single(
recv_counts, send_counts, group=self.ep_group
)
# 토큰 정렬 및 All-to-All 전송
sorted_indices = torch.argsort(expert_indices)
sorted_tokens = tokens[sorted_indices]
send_splits = send_counts.tolist()
recv_splits = recv_counts.tolist()
dispatched_tokens = torch.zeros(
int(recv_counts.sum()), tokens.shape[1],
dtype=tokens.dtype, device=tokens.device,
)
dist.all_to_all_single(
dispatched_tokens, sorted_tokens,
output_split_sizes=recv_splits,
input_split_sizes=send_splits,
group=self.ep_group,
)
return dispatched_tokens, recv_counts
Expert Offloading
GPU 메모리가 부족할 때, 비활성 전문가를 CPU 메모리나 NVMe SSD에 저장하고 필요할 때만 GPU로 로드하는 전략이다. DeepSpeed-MoE와 Mixtral의 추론 최적화에서 핵심적으로 활용된다.
Offloading의 핵심은 프리페칭(prefetching)이다. 현재 레이어의 전문가 연산과 동시에 다음 레이어에서 활성화될 전문가를 비동기적으로 GPU에 로드하면, 전문가 교체 지연을 숨길 수 있다. PCIe 4.0 x16 기준으로 약 32 GB/s의 대역폭이 가능하므로, 전문가 하나(수백 MB)를 수 밀리초 내에 전송할 수 있다.
추론 최적화 전략 비교
| 전략 | GPU 메모리 | 추론 지연 | 처리량 | 적합 시나리오 |
|---|---|---|---|---|
| Full Model on GPU | 최대 | 최소 | 최대 | 고사양 멀티GPU 서버 |
| Expert Parallelism | 분산 | 통신 오버헤드 | 높음 | 멀티GPU 클러스터 |
| CPU Offloading | 최소 | 로딩 지연 | 중간 | 제한된 GPU 환경 |
| NVMe Offloading | 최소 | 높은 로딩 지연 | 낮음 | 단일 GPU 환경 |
| Speculative Expert Prefetch | 중간 | 중간 | 높음 | 배치 추론 서버 |
운영 체크리스트
MoE 모델을 프로덕션에 배포할 때 반드시 확인해야 할 항목들을 정리한다.
학습 단계
- 게이팅 정밀도 확인: 게이팅 네트워크의 forward/backward가 FP32로 수행되는지 검증한다. BF16 게이팅은 학습 초기에는 정상적으로 보이지만 수만 스텝 이후 불안정을 유발할 수 있다.
- 로드 밸런싱 메트릭 대시보드 구축: 전문가별 토큰 할당량, 최대/최소 활용 비율, 보조 손실 값을 실시간으로 모니터링한다.
- 체크포인트 전략: 전문가 병렬화 환경에서는 체크포인트가 GPU별로 분리 저장될 수 있다. 전체 모델을 하나로 병합(consolidate)하는 스크립트를 미리 준비한다.
- Capacity Factor 튜닝: 토큰 드롭률이 1% 이상이면 capacity factor를 높인다. 드롭된 토큰은 residual connection을 통해서만 전달되므로 품질이 저하된다.
- 전문가 붕괴 알림 설정: 특정 전문가의 활용률이 평균의 10% 이하로 떨어지면 알림을 발생시키고, 필요시 해당 전문가를 재초기화한다.
추론/배포 단계
- 메모리 프로파일링: 전체 파라미터의 GPU 메모리 적재 가능 여부를 확인하고, 불가능하면 EP 또는 Offloading 전략을 선택한다.
- 배치 크기 최적화: MoE 추론에서 배치 크기는 전문가 활용 효율에 직접적으로 영향을 미친다. 작은 배치에서는 일부 전문가만 활성화되어 GPU 활용률이 저하된다.
- KV Cache 관리: MoE 모델도 Attention 레이어는 Dense와 동일하므로 KV Cache 관리가 필요하다. PagedAttention(vLLM)과 결합하면 효율적이다.
- 라우팅 일관성 테스트: 동일 입력에 대해 동일한 전문가가 선택되는지 확인한다. 특히 Tensor Parallelism과 Expert Parallelism을 혼합할 때 수치 오차로 라우팅이 달라질 수 있다.
- Fallback 전략: 특정 전문가가 로딩에 실패하면 차순위 전문가로 대체하는 fallback 로직을 구현한다.
- A/B 테스트 파이프라인: Dense 모델 대비 MoE 모델의 품질 동등성을 서빙 환경에서 검증한다.
실패 사례와 복구
사례 1: 전문가 붕괴로 인한 품질 저하
증상: 학습 3만 스텝 이후 갑자기 벤치마크 점수가 하락한다. Loss 자체는 정상적으로 감소하지만, 생성 품질이 떨어진다.
원인 분석: 모니터링 결과, 8개 전문가 중 2개가 전체 토큰의 60% 이상을 처리하고, 3개 전문가는 활용률 2% 미만이었다. 보조 손실 가중치(alpha=0.001)가 너무 낮아 밸런싱 효과가 부족했다.
복구 절차:
- 전문가 붕괴 직전의 체크포인트(2만 스텝)로 롤백
- 보조 손실 가중치를 0.001에서 0.01로 10배 증가
- 붕괴된 전문가의 파라미터를 활성 전문가의 파라미터로 재초기화
- 게이팅 네트워크의 학습률을 메인 학습률의 0.1배로 분리 설정
- 재학습 후 전문가 활용률이 균등(평균 12.5% 기준 8-17% 범위)해질 때까지 모니터링
사례 2: All-to-All 통신 병목
증상: Expert Parallelism으로 64개 GPU에서 학습할 때, GPU 활용률이 40%로 급감한다. 프로파일러에서 All-to-All 통신이 전체 학습 시간의 45%를 차지하는 것으로 확인된다.
원인 분석: 네트워크 토폴로지 분석 결과, 전문가 배치가 네트워크 구조를 고려하지 않아 노드 간 통신이 과도하게 발생했다. 같은 노드 내 GPU 간 통신(NVLink, 900 GB/s)과 노드 간 통신(InfiniBand, 400 Gb/s)의 대역폭 차이가 20배 이상이었다.
복구 절차:
- Hierarchical All-to-All로 전환: 노드 내 통신과 노드 간 통신을 2단계로 분리
- 전문가 배치를 토폴로지 인지(topology-aware)로 재배치: 자주 함께 활성화되는 전문가를 같은 노드에 배치
- 통신-연산 중첩(overlap): 전문가 연산과 다음 배치의 토큰 디스패치를 파이프라인으로 중첩
사례 3: 추론 시 전문가 로딩 지연
증상: CPU Offloading으로 Mixtral 8x7B를 단일 GPU(24GB)에서 서빙할 때, 첫 토큰 지연(TTFT)이 5초를 초과한다.
원인 분석: 각 레이어에서 2개의 전문가를 CPU에서 GPU로 로드할 때마다 100-200ms가 소요되고, 32개 레이어를 순차적으로 처리하므로 누적 지연이 3.2-6.4초에 달한다.
복구 절차:
- 전문가 프리페칭 구현: 현재 레이어 처리 중 다음 레이어의 게이팅 점수를 사전 계산하고, 필요한 전문가를 비동기 로드
- 핫 전문가 캐싱: 활성화 빈도가 높은 상위 2-3개 전문가를 GPU에 상주시킴
- 전문가 가중치 양자화: INT4 양자화로 전문가 크기를 75% 축소하여 전송 시간 단축
- PCIe 대역폭이 병목인 경우, pinned memory 사용으로 CPU-GPU 전송을 최적화
사례 4: 학습 중 Loss Spike 발생
증상: 대규모 MoE 모델(100B 이상) 학습 시 수천 스텝마다 반복적으로 loss가 급등한다. 각 spike 이후 회복은 되지만 학습 시간이 낭비된다.
원인 분석: 게이팅 네트워크의 softmax 입력 로짓이 간헐적으로 매우 큰 값을 가져 수치 불안정을 유발한다. 특히 BF16 학습 시 게이팅 로짓의 범위가 FP32 대비 좁아 overflow가 발생하기 쉽다.
복구 절차:
- Router z-loss 추가로 게이팅 로짓의 크기를 직접 제약한다.
def router_z_loss(gate_logits: torch.Tensor) -> torch.Tensor:
"""ST-MoE 스타일의 Router z-loss.
게이팅 로짓의 크기를 제약하여 수치 안정성을 높인다.
Args:
gate_logits: (num_tokens, num_experts) 게이팅 로짓
Returns:
z_loss 스칼라
"""
log_z = torch.logsumexp(gate_logits, dim=-1) # (num_tokens,)
z_loss = (log_z ** 2).mean()
return z_loss
- 게이팅 연산을 FP32로 강제하여 수치 안정성을 확보한다.
- 그래디언트 클리핑을 게이팅 네트워크에 별도로 적용(max_norm=1.0)한다.
- 학습률 워밍업 기간을 전체 학습의 5-10%까지 확장한다.
참고자료
-
Fedus, W., Zoph, B., & Shazeer, N. (2022). Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. JMLR, 23(120), 1-39. https://arxiv.org/abs/2101.03961
-
DeepSeek-AI. (2024). DeepSeek-V3 Technical Report. https://arxiv.org/abs/2401.06066
-
Cai, W. et al. (2024). A Survey on Mixture of Experts. https://arxiv.org/abs/2407.10671
-
FriendliAI. (2024). MoE Models Comparison: Architectures and Performance. https://friendli.ai/blog/moe-models-comparison
-
Zilliz. (2024). What is Mixture of Experts? A Complete Guide. https://zilliz.com/learn/what-is-mixture-of-experts
-
Wikipedia. Mixture of Experts. https://en.wikipedia.org/wiki/Mixture_of_experts
-
Shazeer, N. et al. (2017). Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer. ICLR 2017. https://arxiv.org/abs/1701.06538
-
Jiang, A. Q. et al. (2024). Mixtral of Experts. https://arxiv.org/abs/2401.04088
Deep Dive into Sparse Mixture of Experts (MoE) Architecture: From Design Principles to DeepSeek-V3 and Qwen3
- Introduction
- MoE Basic Structure and Mathematical Principles
- Routing Strategies: Top-k, Expert Choice, Hash Routing
- Load Balancing and Auxiliary Loss
- Evolution from Switch Transformer to DeepSeek-V3
- Training Stability and Troubleshooting
- Inference Optimization: Expert Parallelism, Offloading
- Operations Checklist
- Failure Cases and Recovery
- References
- Quiz

Introduction
The number of parameters in large language models (LLMs) is growing exponentially, but Dense models that activate all parameters for every token have hit a wall of computational cost. Training a GPT-4 level Dense model requires tens of thousands of GPUs running for months, and inference costs also increase proportionally to the number of parameters. To address this fundamental inefficiency, the conditional computation paradigm has emerged, with Sparse Mixture of Experts (MoE) architecture at its center.
The core idea of MoE is simple. Place hundreds of expert networks, but for each input token, activate only a small number of experts to dramatically reduce computation. Since the total number of parameters determines model capacity, and the number of active parameters determines actual computational cost, high quality and low cost can be achieved simultaneously. Mixtral 8x7B activates only 13B out of a total 47B parameters per token, and DeepSeek-V3 activates only 37B out of 671B parameters.
This article provides an in-depth treatment of the mathematical foundations of MoE, routing strategies, load balancing, architectural evolution from Switch Transformer to DeepSeek-V3 and Qwen3-235B-A22B, and practical optimization strategies for training and inference. Each topic is explained with PyTorch code, including failure cases and recovery procedures encountered in production environments.
MoE Basic Structure and Mathematical Principles
Mathematical Definition of Sparse Activation
An MoE layer consists of N expert networks E_1, E_2, ..., E_N and a gating network G. The output y of an MoE layer for input token x is defined as follows:
y = sum_{i=1}^{N} G(x)_i * E_i(x)
Here, G(x) is the gating function, which outputs an N-dimensional vector for input x. In Dense MoE, all G(x)_i have non-zero values, but in Sparse MoE, only Top-K experts are selected and the gating values of the rest are set to 0.
G(x)_i = softmax(W_g * x + noise)_i (if i in TopK)
G(x)_i = 0 (otherwise)
The noise term promotes exploration during training to increase the diversity of expert utilization. Thanks to this sparsity, while total parameters grow proportionally to N, actual computation (FLOPs) is proportional to K, making it N/K times more efficient than Dense models.
Expert Network Structure
Each expert typically replaces the Feed-Forward Network (FFN) in a Transformer. The standard design is for Self-Attention to be shared across all tokens, with only the FFN portion separated by expert. This is because Self-Attention plays a global role in capturing inter-token relationships, while FFN plays a local role in transforming individual token representations.
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
@dataclass
class MoEConfig:
d_model: int = 1024
d_ff: int = 4096
num_experts: int = 8
top_k: int = 2
dropout: float = 0.1
aux_loss_weight: float = 0.01
class SwiGLUExpert(nn.Module):
"""Expert FFN using SwiGLU activation.
The standard structure adopted by modern models like LLaMA, Mistral, etc."""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.w_gate = nn.Linear(d_model, d_ff, bias=False)
self.w_up = nn.Linear(d_model, d_ff, bias=False)
self.w_down = nn.Linear(d_ff, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = F.silu(self.w_gate(x))
up = self.w_up(x)
return self.w_down(self.dropout(gate * up))
class TopKGating(nn.Module):
"""Top-K Gating Network.
Noisy Top-K Gating (Shazeer et al., 2017) implementation."""
def __init__(self, d_model: int, num_experts: int, top_k: int = 2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.gate = nn.Linear(d_model, num_experts, bias=False)
self.noise_linear = nn.Linear(d_model, num_experts, bias=False)
def forward(self, x: torch.Tensor):
# x: (batch * seq_len, d_model)
logits = self.gate(x)
if self.training:
noise = F.softplus(self.noise_linear(x))
logits = logits + noise * torch.randn_like(logits)
probs = F.softmax(logits, dim=-1)
top_k_probs, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
# Renormalization: ensure selected expert weights sum to 1
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
return top_k_probs, top_k_indices, probs
In the code above, SwiGLUExpert is an expert using the SwiGLU activation function adopted as the standard by modern models like LLaMA, Mistral, and Qwen. It has been empirically confirmed to have higher training efficiency compared to conventional ReLU or GELU. TopKGating implements the Noisy Top-K Gating proposed by Shazeer et al. (2017), which adds learnable noise to gating logits during training to promote expert exploration.
Dense vs Sparse MoE Quantitative Comparison
| Item | Dense 70B | Sparse MoE 8x7B (Top-2) | Sparse MoE 256x3B (Top-2) |
|---|---|---|---|
| Total Parameters | 70B | 47B | 768B |
| Active Parameters | 70B | 13B | 6B |
| FLOPs/token | 140 TFLOPs | 26 TFLOPs | 12 TFLOPs |
| GPU Memory (FP16) | 140 GB | 94 GB | 1.5 TB |
| Training Cost Ratio | 1.0x | 0.35x (FLOPs basis) | 0.17x (FLOPs basis) |
| Inference Speed | Baseline | 2-3x faster | Expert loading bottleneck |
A noteworthy point is the memory requirement of MoE models. Although active parameters are few, all parameters must be kept in memory, so when the number of experts is very large, memory can actually be larger than Dense models. This is the fundamental reason why Expert Parallelism and Offloading strategies are necessary.
Routing Strategies: Top-k, Expert Choice, Hash Routing
Routing is both the core and the most challenging design problem of MoE. The strategy for deciding which experts to activate significantly affects model quality, training stability, and inference efficiency.
Top-K Routing
The most traditional routing method, where the gating network computes scores for each expert and selects the top K. Shazeer et al. (2017) proposed Top-2, and Switch Transformer (Fedus et al., 2022) simplified it to Top-1, cutting communication costs in half.
Advantages of Top-1: Each token uses exactly one expert, minimizing All-to-All communication in distributed environments. Implementation is also straightforward.
Disadvantages of Top-1: Reliance on a single expert limits expressiveness, and the discrete gating decision increases the risk of expert collapse during early training.
Top-2 Compromise: Mixtral 8x7B and many modern models adopt Top-2. Since the outputs of two experts are combined through weighted sum, expressiveness is richer, and if one expert becomes unstable, the other compensates.
Expert Choice Routing
Expert Choice routing, proposed by Zhou et al. (2022), reverses the perspective. Instead of tokens choosing experts, experts choose which tokens to process. Since each expert selects the K tokens most suitable for itself, load balancing is structurally guaranteed.
class ExpertChoiceGating(nn.Module):
"""Expert Choice Routing implementation.
Each expert directly selects which tokens to process,
structurally guaranteeing load balancing."""
def __init__(
self,
d_model: int,
num_experts: int,
capacity_factor: float = 1.0,
):
super().__init__()
self.num_experts = num_experts
self.capacity_factor = capacity_factor
self.gate = nn.Linear(d_model, num_experts, bias=False)
def forward(self, x: torch.Tensor):
# x: (num_tokens, d_model)
num_tokens = x.shape[0]
expert_capacity = int(
num_tokens * self.capacity_factor / self.num_experts
)
# Gating scores: (num_tokens, num_experts)
gate_logits = self.gate(x)
# Scores from expert perspective: (num_experts, num_tokens)
gate_scores = F.softmax(gate_logits.T, dim=-1)
# Each expert selects top capacity tokens
top_k_scores, top_k_indices = torch.topk(
gate_scores, expert_capacity, dim=-1
) # (num_experts, capacity)
# Create dispatch mask
dispatch_mask = torch.zeros(
self.num_experts, num_tokens,
device=x.device, dtype=x.dtype,
)
dispatch_mask.scatter_(1, top_k_indices, top_k_scores)
return dispatch_mask, top_k_indices, top_k_scores
The key advantage of Expert Choice is that it achieves perfect load balancing without auxiliary loss. Since each expert is forced to process the same number of tokens, the expert collapse problem is fundamentally eliminated. However, there is an asymmetry where a single token may be selected by multiple experts or by no expert at all.
Hash Routing
Hash Routing, proposed by Roller et al. (2021), completely eliminates learnable gating and assigns tokens to experts using hash functions. Since the gating network's parameters and computation are eliminated, inference overhead is minimized. However, because the assignment rule is fixed, it cannot reflect input semantics, and in practice, quality is lower compared to learnable routing, so it has not been adopted as mainstream.
Routing Strategy Comparison
| Strategy | Load Balancing | Expressiveness | Comm. Cost | Implementation Complexity | Representative Model |
|---|---|---|---|---|---|
| Top-1 | Aux Loss Required | Low | Minimum | Low | Switch Transformer |
| Top-2 | Aux Loss Required | Medium | Medium | Medium | Mixtral 8x7B |
| Top-K (K=6,8) | Aux Loss Required | High | High | Medium | DeepSeek-V3 (Top-8/256) |
| Expert Choice | Structurally Guaranteed | High | Medium | High | Research models |
| Hash Routing | Perfect | Low | Minimum | Minimum | Research models |
Load Balancing and Auxiliary Loss
The most serious problem in MoE training is expert collapse. This is a phenomenon where the gating network concentrates tokens on only a few experts, while the remaining experts lose training opportunities and effectively become dead parameters. Various load balancing techniques have been developed to prevent this.
Auxiliary Loss
This is the standard approach proposed by Switch Transformer. A loss term that encourages uniform token distribution across experts is added to the main language modeling loss.
def compute_load_balancing_loss(
gate_probs: torch.Tensor,
top_k_indices: torch.Tensor,
num_experts: int,
top_k: int,
) -> torch.Tensor:
"""Switch Transformer-style load balancing auxiliary loss computation.
Args:
gate_probs: Gating probabilities (num_tokens, num_experts)
top_k_indices: Selected expert indices (num_tokens, top_k)
num_experts: Number of experts
top_k: Number of selected experts
Returns:
Auxiliary loss scalar value
"""
num_tokens = gate_probs.shape[0]
# f_i: fraction of tokens assigned to expert i
expert_mask = F.one_hot(top_k_indices, num_experts).float()
# (num_tokens, top_k, num_experts) -> (num_tokens, num_experts)
expert_mask = expert_mask.sum(dim=1)
tokens_per_expert = expert_mask.sum(dim=0) # (num_experts,)
f = tokens_per_expert / (num_tokens * top_k)
# P_i: average gating probability for expert i
P = gate_probs.mean(dim=0) # (num_experts,)
# Auxiliary loss: N * sum(f_i * P_i)
# Achieves minimum value at uniform distribution
aux_loss = num_experts * (f * P).sum()
return aux_loss
class MoELayerWithAuxLoss(nn.Module):
"""Complete MoE layer implementation with auxiliary loss."""
def __init__(self, config: MoEConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
SwiGLUExpert(config.d_model, config.d_ff, config.dropout)
for _ in range(config.num_experts)
])
self.gating = TopKGating(
config.d_model, config.num_experts, config.top_k
)
self.aux_loss_weight = config.aux_loss_weight
def forward(self, x: torch.Tensor):
batch_size, seq_len, d_model = x.shape
x_flat = x.view(-1, d_model)
top_k_probs, top_k_indices, gate_probs = self.gating(x_flat)
# Compute auxiliary loss
aux_loss = self.aux_loss_weight * compute_load_balancing_loss(
gate_probs, top_k_indices,
self.config.num_experts, self.config.top_k,
)
# Compute expert outputs
output = torch.zeros_like(x_flat)
for k in range(self.config.top_k):
expert_indices = top_k_indices[:, k] # (num_tokens,)
expert_weights = top_k_probs[:, k] # (num_tokens,)
for i in range(self.config.num_experts):
mask = (expert_indices == i)
if mask.any():
expert_input = x_flat[mask]
expert_output = self.experts[i](expert_input)
output[mask] += expert_weights[mask].unsqueeze(-1) * expert_output
output = output.view(batch_size, seq_len, d_model)
return output, aux_loss
The auxiliary loss weight alpha is a very sensitive hyperparameter. Switch Transformer recommended alpha=0.01, but adjustment is needed depending on model scale and number of experts. If alpha is too large, language modeling quality degrades; if too small, the load balancing effect is negligible. ST-MoE (Zoph et al., 2022) proposed adding router z-loss to constrain the magnitude of gating logits themselves.
DeepSeek's Auxiliary-Loss-Free Strategy
DeepSeek-V3 (2024) proposed an innovative method to achieve load balancing without auxiliary loss. It adds a non-trainable bias term to each expert and uses a dynamic adjustment mechanism that lowers the bias of experts where tokens are excessively concentrated and raises the bias of underutilized experts during training. This approach completely eliminates the problem of auxiliary loss interfering with the main training objective, showing benefits in both training stability and final model quality.
Evolution from Switch Transformer to DeepSeek-V3
Switch Transformer (Fedus et al., 2022)
Switch Transformer is the key paper that simplified MoE routing to Top-1. It reduced the communication cost of existing Top-2 routing by half while ensuring training stability through appropriate capacity factor and auxiliary loss. It trained a 1.6T parameter model 4x faster than T5-XXL while achieving equivalent quality.
Key design decisions:
- Top-1 Routing: Minimized communication cost
- Capacity Factor: Dynamically adjusted expert buffer size to prevent token dropping
- Selective Precision: Mixed gating in FP32 and expert computation in BF16 for simultaneous stability and efficiency
GShard (Lepikhin et al., 2021)
Google's GShard established the distributed training pipeline for 600B parameter MoE models. It used Top-2 routing with Group-level balancing and presented a framework for efficient training across thousands of TPUs using the SPMD (Single Program Multiple Data) programming model.
Mixtral 8x7B (Jiang et al., 2024)
Mistral AI's Mixtral is a milestone that proved the practicality of open-source MoE models. With a structure selecting Top-2 from 8 experts, it uses an active 13B out of a total 47B parameters. It showed benchmark performance equal to or better than LLaMA-2 70B, while inference FLOPs are at 1/3 the level of 70B.
DeepSeek-V3 (DeepSeek, 2024)
DeepSeek-V3 achieved innovation across multiple design areas of MoE architecture.
- Fine-grained Expert Segmentation: Uses 256 small-scale experts with Top-8 selection. Increasing the number of experts while reducing their size deepens specialization, improving model quality.
- Shared Expert: One shared expert processes all tokens to handle common knowledge, while routed experts handle specialized knowledge.
- Auxiliary-Loss-Free Load Balancing: The bias-based dynamic balancing described above eliminates side effects of auxiliary loss.
- Multi-Token Prediction (MTP): Added a training objective that predicts multiple tokens at once to improve data efficiency.
- FP8 Training: Trained the 671B parameter model on 2048 H800 GPUs with FP8 precision, dramatically reducing costs.
Qwen3-235B-A22B (Alibaba, 2025)
Qwen3-235B-A22B is an MoE model that activates only 22B out of a total 235B parameters, selecting Top-8 from 128 experts. By converting the existing Qwen2.5 series Dense architecture to MoE, it achieved GPT-4o level performance at approximately one-tenth the inference cost.
MoE Model Comparison
| Model | Total Params | Active Params | Num Experts | Top-K | Shared Expert | Routing Strategy |
|---|---|---|---|---|---|---|
| Switch Transformer | 1.6T | ~100B | 128 | 1 | None | Learned Top-1 |
| GShard | 600B | ~20B | 2048 | 2 | None | Learned Top-2 |
| Mixtral 8x7B | 47B | 13B | 8 | 2 | None | Learned Top-2 |
| DeepSeek-V3 | 671B | 37B | 256+1 | 8 | 1 | Bias-adjusted |
| Qwen3-235B-A22B | 235B | 22B | 128 | 8 | Yes | Learned Top-K |
| DBRX | 132B | 36B | 16 | 4 | None | Learned Top-4 |
Training Stability and Troubleshooting
Diagnosing Expert Collapse
Expert collapse is the most common and critical problem in MoE training. When the gating network concentrates tokens on specific experts, the gradients of remaining experts converge to 0, stopping their learning, and this imbalance forms a self-reinforcing loop that worsens over time.
import logging
from collections import defaultdict
logger = logging.getLogger(__name__)
class ExpertUtilizationMonitor:
"""Expert utilization monitoring and collapse detection tool.
Tracks each expert's utilization during training
and detects early signs of collapse.
"""
def __init__(
self,
num_experts: int,
collapse_threshold: float = 0.01,
window_size: int = 100,
):
self.num_experts = num_experts
self.collapse_threshold = collapse_threshold
self.window_size = window_size
self.history: list[dict[int, float]] = []
def record(self, expert_counts: dict[int, int], total_tokens: int):
"""Record expert utilization per batch."""
utilization = {
i: expert_counts.get(i, 0) / max(total_tokens, 1)
for i in range(self.num_experts)
}
self.history.append(utilization)
if len(self.history) > self.window_size:
self.history = self.history[-self.window_size:]
def detect_collapse(self) -> list[int]:
"""Detect expert collapse. Returns experts with utilization below threshold."""
if len(self.history) < self.window_size // 2:
return []
collapsed = []
for expert_id in range(self.num_experts):
recent_util = [
h[expert_id] for h in self.history[-self.window_size:]
]
avg_util = sum(recent_util) / len(recent_util)
if avg_util < self.collapse_threshold:
collapsed.append(expert_id)
if collapsed:
logger.warning(
f"Expert collapse detected! "
f"Experts {collapsed} have utilization below "
f"{self.collapse_threshold:.2%}. "
f"Consider increasing aux_loss_weight or "
f"reinitializing collapsed experts."
)
return collapsed
def get_load_imbalance_ratio(self) -> float:
"""Calculate load imbalance ratio.
1.0 means perfect balance; higher values mean more imbalance."""
if not self.history:
return 0.0
latest = self.history[-1]
utils = list(latest.values())
max_util = max(utils) if utils else 0
min_util = min(utils) if utils else 0
avg_util = sum(utils) / len(utils) if utils else 0
if avg_util == 0:
return float("inf")
return max_util / avg_util
Training Instability Causes and Countermeasures
| Symptom | Cause | Countermeasure |
|---|---|---|
| Loss spike | Gating logit explosion | Add Router z-loss, keep gating in FP32 |
| Expert collapse | Insufficient aux loss, high LR | Increase aux loss weight, extend LR warm-up |
| Expert overlap | Initialization similarity | Orthogonal expert initialization, diversity regularization |
| Token dropping | Insufficient capacity factor | Increase capacity factor to 1.25-1.5 |
| Gating oscillation | Excessive learning rate | Separate gating LR to 0.1x of main LR |
Hyperparameter Guide for Stable Training
The key principles for training stability are as follows. First, gating network computations must be performed in FP32. In BF16 or FP16, numerical instability of softmax causes routing oscillation. Second, extend learning rate warmup 2-3x longer than for Dense models. Applying a high learning rate before gating stabilizes causes expert collapse. Third, set batch size as large as possible. With small batches, gating token distribution is sensitive to noise and becomes unstable.
Inference Optimization: Expert Parallelism, Offloading
MoE model inference poses fundamentally different challenges from Dense models. Although active parameters are few, all parameters must be kept accessible, so memory management and expert placement strategies are critical.
Expert Parallelism
Expert Parallelism (EP) is a strategy for distributing experts across multiple GPUs. When distributing N experts across P GPUs, each GPU stores only N/P experts. When a token is routed to a specific expert, All-to-All communication sends the token to the corresponding GPU, and the computation result is returned to the original GPU.
import torch
import torch.distributed as dist
from typing import Optional
class ExpertParallelRouter:
"""Token dispatch/gather implementation for Expert Parallelism.
Each GPU handles a subset of experts,
routing tokens via All-to-All communication.
"""
def __init__(
self,
num_experts: int,
ep_group: Optional[dist.ProcessGroup] = None,
):
self.num_experts = num_experts
self.ep_group = ep_group
self.ep_size = dist.get_world_size(ep_group) if ep_group else 1
self.ep_rank = dist.get_rank(ep_group) if ep_group else 0
self.experts_per_rank = num_experts // self.ep_size
def dispatch(
self,
tokens: torch.Tensor,
expert_indices: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Dispatch tokens to responsible GPUs.
Args:
tokens: (num_tokens, d_model) input tokens
expert_indices: (num_tokens,) target expert index for each token
Returns:
dispatched_tokens: tokens this GPU should process
recv_counts: number of tokens received from each GPU
"""
# Calculate number of tokens to send to each GPU
send_counts = torch.zeros(
self.ep_size, dtype=torch.long, device=tokens.device
)
for rank in range(self.ep_size):
start_expert = rank * self.experts_per_rank
end_expert = start_expert + self.experts_per_rank
mask = (expert_indices >= start_expert) & (
expert_indices < end_expert
)
send_counts[rank] = mask.sum()
# Exchange receive counts via All-to-All
recv_counts = torch.zeros_like(send_counts)
dist.all_to_all_single(
recv_counts, send_counts, group=self.ep_group
)
# Sort tokens and All-to-All transfer
sorted_indices = torch.argsort(expert_indices)
sorted_tokens = tokens[sorted_indices]
send_splits = send_counts.tolist()
recv_splits = recv_counts.tolist()
dispatched_tokens = torch.zeros(
int(recv_counts.sum()), tokens.shape[1],
dtype=tokens.dtype, device=tokens.device,
)
dist.all_to_all_single(
dispatched_tokens, sorted_tokens,
output_split_sizes=recv_splits,
input_split_sizes=send_splits,
group=self.ep_group,
)
return dispatched_tokens, recv_counts
Expert Offloading
When GPU memory is insufficient, this strategy stores inactive experts in CPU memory or NVMe SSD and loads them to GPU only when needed. It is critically utilized in DeepSpeed-MoE and Mixtral's inference optimization.
The key to offloading is prefetching. By asynchronously loading experts that will be activated in the next layer onto the GPU while the current layer's expert computation is in progress, expert swap latency can be hidden. With PCIe 4.0 x16, approximately 32 GB/s bandwidth is available, allowing transfer of a single expert (several hundred MB) within a few milliseconds.
Inference Optimization Strategy Comparison
| Strategy | GPU Memory | Inference Latency | Throughput | Suitable Scenario |
|---|---|---|---|---|
| Full Model on GPU | Maximum | Minimum | Maximum | High-end multi-GPU server |
| Expert Parallelism | Distributed | Communication overhead | High | Multi-GPU cluster |
| CPU Offloading | Minimum | Loading latency | Medium | Limited GPU environment |
| NVMe Offloading | Minimum | High loading latency | Low | Single GPU environment |
| Speculative Expert Prefetch | Medium | Medium | High | Batch inference server |
Operations Checklist
The following is a list of items that must be checked when deploying MoE models to production.
Training Phase
- Verify gating precision: Confirm that the gating network's forward/backward operations are performed in FP32. BF16 gating may appear normal initially but can cause instability after tens of thousands of steps.
- Build load balancing metrics dashboard: Monitor per-expert token allocation, max/min utilization ratios, and auxiliary loss values in real-time.
- Checkpoint strategy: In expert parallel environments, checkpoints may be saved separately per GPU. Prepare a script to consolidate the full model in advance.
- Capacity factor tuning: If the token drop rate exceeds 1%, increase the capacity factor. Dropped tokens are only passed through residual connections, degrading quality.
- Set expert collapse alerts: Trigger alerts when a specific expert's utilization drops below 10% of the average, and reinitialize the affected expert if necessary.
Inference/Deployment Phase
- Memory profiling: Verify whether all parameters can fit in GPU memory; if not, choose EP or Offloading strategy.
- Batch size optimization: In MoE inference, batch size directly affects expert utilization efficiency. With small batches, only some experts are activated, leading to poor GPU utilization.
- KV Cache management: MoE models also require KV Cache management since Attention layers are identical to Dense models. Combining with PagedAttention (vLLM) is efficient.
- Routing consistency testing: Verify that the same experts are selected for the same input. Especially when mixing Tensor Parallelism and Expert Parallelism, numerical errors can cause different routing decisions.
- Fallback strategy: Implement fallback logic to substitute the next-ranked expert when a specific expert fails to load.
- A/B testing pipeline: Verify quality equivalence of MoE models versus Dense models in the serving environment.
Failure Cases and Recovery
Case 1: Quality Degradation from Expert Collapse
Symptom: After 30,000 training steps, benchmark scores suddenly decline. The loss itself decreases normally, but generation quality deteriorates.
Root Cause Analysis: Monitoring revealed that 2 out of 8 experts were processing over 60% of all tokens, while 3 experts had utilization under 2%. The auxiliary loss weight (alpha=0.001) was too low for effective balancing.
Recovery Procedure:
- Roll back to checkpoint just before expert collapse (20,000 steps)
- Increase auxiliary loss weight 10x from 0.001 to 0.01
- Reinitialize collapsed expert parameters with parameters from active experts
- Set gating network learning rate to 0.1x of main learning rate
- Monitor expert utilization until it becomes balanced (8-17% range based on 12.5% average) after retraining
Case 2: All-to-All Communication Bottleneck
Symptom: When training with Expert Parallelism across 64 GPUs, GPU utilization plummets to 40%. The profiler shows All-to-All communication accounting for 45% of total training time.
Root Cause Analysis: Network topology analysis revealed that expert placement did not consider network structure, causing excessive inter-node communication. The bandwidth difference between intra-node GPU communication (NVLink, 900 GB/s) and inter-node communication (InfiniBand, 400 Gb/s) was over 20x.
Recovery Procedure:
- Switch to Hierarchical All-to-All: separate intra-node and inter-node communication into two stages
- Rearrange expert placement to be topology-aware: place frequently co-activated experts on the same node
- Communication-computation overlap: pipeline expert computation with token dispatch for the next batch
Case 3: Expert Loading Latency During Inference
Symptom: When serving Mixtral 8x7B with CPU Offloading on a single GPU (24GB), time to first token (TTFT) exceeds 5 seconds.
Root Cause Analysis: Loading 2 experts from CPU to GPU at each layer takes 100-200ms, and processing 32 layers sequentially results in cumulative latency of 3.2-6.4 seconds.
Recovery Procedure:
- Implement expert prefetching: pre-compute gating scores for the next layer during current layer processing and asynchronously load required experts
- Hot expert caching: keep the top 2-3 most frequently activated experts resident on GPU
- Expert weight quantization: reduce expert size by 75% with INT4 quantization to shorten transfer time
- When PCIe bandwidth is the bottleneck, optimize CPU-GPU transfer using pinned memory
Case 4: Loss Spike During Training
Symptom: During large-scale MoE model (over 100B) training, loss spikes repeatedly every few thousand steps. Recovery occurs after each spike, but training time is wasted.
Root Cause Analysis: The softmax input logits of the gating network intermittently take very large values, causing numerical instability. Especially during BF16 training, the range of gating logits is narrower than FP32, making overflow more likely.
Recovery Procedure:
- Add Router z-loss to directly constrain the magnitude of gating logits.
def router_z_loss(gate_logits: torch.Tensor) -> torch.Tensor:
"""ST-MoE style Router z-loss.
Constrains the magnitude of gating logits to improve numerical stability.
Args:
gate_logits: (num_tokens, num_experts) gating logits
Returns:
z_loss scalar
"""
log_z = torch.logsumexp(gate_logits, dim=-1) # (num_tokens,)
z_loss = (log_z ** 2).mean()
return z_loss
- Force gating computations to FP32 to ensure numerical stability.
- Apply gradient clipping separately to the gating network (max_norm=1.0).
- Extend learning rate warmup period to 5-10% of total training.
References
-
Fedus, W., Zoph, B., & Shazeer, N. (2022). Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. JMLR, 23(120), 1-39. https://arxiv.org/abs/2101.03961
-
DeepSeek-AI. (2024). DeepSeek-V3 Technical Report. https://arxiv.org/abs/2401.06066
-
Cai, W. et al. (2024). A Survey on Mixture of Experts. https://arxiv.org/abs/2407.10671
-
FriendliAI. (2024). MoE Models Comparison: Architectures and Performance. https://friendli.ai/blog/moe-models-comparison
-
Zilliz. (2024). What is Mixture of Experts? A Complete Guide. https://zilliz.com/learn/what-is-mixture-of-experts
-
Wikipedia. Mixture of Experts. https://en.wikipedia.org/wiki/Mixture_of_experts
-
Shazeer, N. et al. (2017). Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer. ICLR 2017. https://arxiv.org/abs/1701.06538
-
Jiang, A. Q. et al. (2024). Mixtral of Experts. https://arxiv.org/abs/2401.04088
Quiz
Q1: What is the main topic covered in "Deep Dive into Sparse Mixture of Experts (MoE)
Architecture: From Design Principles to DeepSeek-V3 and Qwen3"?
Analyzing the mathematical principles, routing strategies, and load balancing techniques of Sparse MoE architecture, covering the design choices and practical training/inference optimization of modern MoE models from Switch Transformer to DeepSeek-V3 and Qwen3-235B.
Q2: What is MoE Basic Structure and Mathematical Principles?
Mathematical Definition of Sparse Activation An MoE layer consists of N expert networks E_1, E_2,
..., E_N and a gating network G.
Q3: Explain the core concept of Routing Strategies: Top-k, Expert Choice, Hash Routing.
Routing is both the core and the most challenging design problem of MoE. The strategy for deciding which experts to activate significantly affects model quality, training stability, and inference efficiency.
Q4: What are the key aspects of Load Balancing and Auxiliary Loss?
The most serious problem in MoE training is expert collapse. This is a phenomenon where the gating
network concentrates tokens on only a few experts, while the remaining experts lose training
opportunities and effectively become dead parameters.
Q5: How does Evolution from Switch Transformer to DeepSeek-V3 work?
Switch Transformer (Fedus et al., 2022) Switch Transformer is the key paper that simplified MoE
routing to Top-1. It reduced the communication cost of existing Top-2 routing by half while
ensuring training stability through appropriate capacity factor and auxiliary loss.