Skip to content
Published on

Mixture of Experts (MoE) Architecture: A Complete Analysis

Authors
  • Name
    Twitter
Mixture of Experts Architecture

1. What is MoE?

Mixture of Experts (MoE) is an architecture that improves computational efficiency by activating only a subset of the model's total parameters. Unlike Dense models that use all parameters for every input, MoE selects and activates only the optimal experts based on the input.

Dense vs Sparse Models

  • Dense Model: All parameters are activated for every input (e.g., LLaMA, GPT-4)
  • Sparse MoE: Only a fraction of parameters are activated (e.g., Mixtral, DeepSeek-V3)

The key advantage is that the model has a large number of parameters but low computational cost. Mixtral 8x7B has 46.7B total parameters, but only about 12.9B are activated during inference.

2. Core Components of MoE Architecture

Expert Network

Each Expert is an independent FFN (Feed-Forward Network):

import torch
import torch.nn as nn

class Expert(nn.Module):
    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)  # SwiGLU gate

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU activation
        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))

Router (Gating Network)

The Router determines which Expert each token is sent to:

class TopKRouter(nn.Module):
    def __init__(self, d_model: int, num_experts: int, top_k: int = 2):
        super().__init__()
        self.gate = nn.Linear(d_model, num_experts, bias=False)
        self.top_k = top_k

    def forward(self, x: torch.Tensor):
        # x shape: (batch, seq_len, d_model)
        logits = self.gate(x)  # (batch, seq_len, num_experts)
        top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1)
        top_k_weights = torch.softmax(top_k_logits, dim=-1)
        return top_k_weights, top_k_indices

Full MoE Layer Implementation

class MoELayer(nn.Module):
    def __init__(self, d_model: int, d_ff: int,
                 num_experts: int = 8, top_k: int = 2):
        super().__init__()
        self.experts = nn.ModuleList([
            Expert(d_model, d_ff) for _ in range(num_experts)
        ])
        self.router = TopKRouter(d_model, num_experts, top_k)
        self.num_experts = num_experts

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, d_model = x.shape
        weights, indices = self.router(x)

        # Reshape for expert processing
        flat_x = x.view(-1, d_model)
        flat_weights = weights.view(-1, weights.shape[-1])
        flat_indices = indices.view(-1, indices.shape[-1])

        output = torch.zeros_like(flat_x)
        for i, expert in enumerate(self.experts):
            # Find tokens routed to this expert
            mask = (flat_indices == i).any(dim=-1)
            if mask.any():
                expert_input = flat_x[mask]
                expert_output = expert(expert_input)
                # Weight by router probability
                idx = (flat_indices[mask] == i).float()
                w = (flat_weights[mask] * idx).sum(dim=-1, keepdim=True)
                output[mask] += w * expert_output

        return output.view(batch_size, seq_len, d_model)

3. Major MoE Model Analysis

Mixtral 8x7B (Mistral AI)

  • 8 Experts, Top-2 routing
  • 46.7B total parameters, 12.9B active
  • Attention layers are shared; only FFN is split into Experts

DeepSeek-V3 MoE

DeepSeek-V3 employs a more sophisticated MoE design:

class DeepSeekMoE(nn.Module):
    """DeepSeek-V3 style: Shared Expert + Routed Expert"""
    def __init__(self, d_model, d_ff, num_shared=1,
                 num_routed=256, top_k=8):
        super().__init__()
        # Shared Expert that all tokens pass through
        self.shared_experts = nn.ModuleList([
            Expert(d_model, d_ff) for _ in range(num_shared)
        ])
        # Routed Expert selected per token
        self.routed_experts = nn.ModuleList([
            Expert(d_model, d_ff // 4) for _ in range(num_routed)
        ])
        self.router = TopKRouter(d_model, num_routed, top_k)

    def forward(self, x):
        # Shared Expert output
        shared_out = sum(e(x) for e in self.shared_experts)
        # Routed Expert output
        weights, indices = self.router(x)
        routed_out = self._route_tokens(x, weights, indices)
        return shared_out + routed_out
  • 1 Shared Expert + 256 Routed Experts (Top-8 selection)
  • 671B total parameters, 37B active
  • Introduced Auxiliary-loss-free load balancing

Model Comparison

ModelTotal ParamsActive ParamsExpertsTop-K
Mixtral 8x7B46.7B12.9B82
Mixtral 8x22B141B39B82
DeepSeek-V3671B37B256+18+1
Qwen2.5-MoE14.3B2.7B60+44+4

4. Routing Strategies

Token Choice vs Expert Choice

# Token Choice: Each token selects its Experts
def token_choice_routing(logits, top_k=2):
    top_k_vals, top_k_idx = logits.topk(top_k, dim=-1)
    weights = torch.softmax(top_k_vals, dim=-1)
    return weights, top_k_idx

# Expert Choice: Each Expert selects its tokens
def expert_choice_routing(logits, capacity_factor=1.25):
    num_tokens = logits.shape[0]
    num_experts = logits.shape[1]
    capacity = int(num_tokens * capacity_factor / num_experts)

    expert_scores = logits.T  # (num_experts, num_tokens)
    top_k_vals, top_k_idx = expert_scores.topk(capacity, dim=-1)
    return top_k_vals, top_k_idx

5. Load Balancing

Load imbalance across Experts is a core challenge in MoE:

def load_balancing_loss(router_logits, top_k_indices, num_experts):
    """Auxiliary load balancing loss (Switch Transformer style)"""
    # Token ratio per Expert
    mask = torch.zeros_like(router_logits)
    mask.scatter_(-1, top_k_indices, 1.0)
    tokens_per_expert = mask.float().mean(dim=0)  # (num_experts,)

    # Average routing probability per Expert
    router_probs = torch.softmax(router_logits, dim=-1)
    router_prob_per_expert = router_probs.mean(dim=0)

    # Dot product of the two distributions = measure of imbalance
    loss = num_experts * (tokens_per_expert * router_prob_per_expert).sum()
    return loss

DeepSeek-V3's Auxiliary-loss-free approach dynamically adjusts per-Expert bias terms, lowering the bias for overloaded Experts and raising it for underutilized ones, achieving balanced load without an additional loss term.

6. Inference Optimization

# Expert Parallelism: Distribute Experts across multiple GPUs
# GPU 0: Expert 0-3, GPU 1: Expert 4-7
class ExpertParallel(nn.Module):
    def __init__(self, experts_per_gpu, rank, world_size):
        super().__init__()
        self.local_experts = nn.ModuleList([
            Expert(d_model, d_ff)
            for _ in range(experts_per_gpu)
        ])
        self.rank = rank
        self.world_size = world_size

    def forward(self, x, indices):
        # Redistribute tokens via All-to-All communication
        dispatched = all_to_all(x, indices, self.world_size)
        # Process with local Experts
        output = self._process_local(dispatched)
        # Recombine results
        return all_to_all(output, indices, self.world_size)

7. Quiz

Q1: Approximately how many parameters does a single token use during inference in Mixtral 8x7B?

Approximately 12.9B parameters. Since only the Top-2 out of 8 Experts are activated, only the FFN parameters of 2 Experts plus the shared Attention parameters are used. This is roughly 28% of the total 46.7B.

Q2: What is the core idea behind DeepSeek-V3's Auxiliary-loss-free load balancing?

Traditional MoE adds an auxiliary loss for load balancing, which can degrade model performance. DeepSeek-V3 introduces dynamic bias terms for each Expert, lowering the bias for Experts receiving too many tokens and raising it for underutilized ones, naturally achieving balance. This enables stable load balancing without an additional loss.

Q3: What are the differences and trade-offs between Token Choice and Expert Choice routing?
  • Token Choice: Each token selects its Top-K Experts. Simple to implement but can lead to load imbalance where tokens concentrate on certain Experts.
  • Expert Choice: Each Expert selects the tokens it processes. Guarantees perfect load balancing but some tokens may not be selected by any Expert.

In practice, Token Choice combined with a load balancing loss is the most commonly used approach.