Building a Mixture of Experts (MoE) Model from Scratch in PyTorch
Mixture of Experts (MoE) has become one of the most influential architectural patterns in modern large language models. Models like Mixtral 8x7B, Grok, and Gemini all rely on MoE to achieve massive parameter counts while keeping inference compute bounded. In this tutorial, you will build a sparse MoE layer from scratch in PyTorch, understand the routing mechanics, and see how it scales to a complete transformer block.
What Is Mixture of Experts?
Instead of using a single dense feed-forward network (FFN) in each transformer layer, MoE splits the FFN into multiple parallel "expert" sub-networks. A learned gating network (router) decides which experts process each token. Crucially, only the top-k experts are activated per token — the rest remain dormant. This means you can have billions of parameters in the model while each token only touches a fraction of them.
The mathematical formulation is straightforward. Given a token embedding x and experts E1, ..., En:
y = Σ (G(x)_i · E_i(x)) for i in Top-K(G(x))
Where G(x) is the router output (softmax over expert logits) and Top-K selects the k highest-probability experts.
Step 1: Implement the Expert Network
Each expert is a standard two-layer feed-forward network with a SwiGLU activation — the same structure used in Llama and Mistral:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Expert(nn.Module):
def __init__(self, d_model: int, d_ff: int):
super().__init__()
self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
self.up_proj = nn.Linear(d_model, d_ff, bias=False)
self.down_proj = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU activation
return self.down_proj(
F.silu(self.gate_proj(x)) * self.up_proj(x)
)
SwiGLU outperforms standard ReLU or GeLU in practice, giving better gradient flow. The three-projection pattern (gate, up, down) matches what you will find in Llama's FFN implementation.
Step 2: Implement the Top-K Router with Auxiliary Load Balancing
The router is where MoE gets interesting. A naive softmax router can collapse — all tokens routing to the same expert, defeating the purpose. The solution, introduced in the original GShard paper, is an auxiliary loss that penalizes imbalanced routing:
class MoERouter(nn.Module):
def __init__(self, d_model: int, num_experts: int, top_k: int):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.router = nn.Linear(d_model, num_experts, bias=False)
def forward(self, x: torch.Tensor):
# Shape: (batch, seq_len, num_experts)
router_logits = self.router(x)
# Softmax over experts
probs = F.softmax(router_logits, dim=-1)
# Get top-k experts and their weights
topk_weights, topk_indices = torch.topk(probs, self.top_k, dim=-1)
topk_weights = F.softmax(topk_weights, dim=-1)
# Create sparse one-hot mask for expert selection
# Shape: (batch, seq_len, num_experts)
mask = F.one_hot(topk_indices, num_classes=self.num_experts).sum(dim=-2)
# Auxiliary load-balancing loss
# Importance = mean probability per expert
# Load = fraction of tokens assigned to each expert
importance = probs.mean(dim=(0, 1)) # (num_experts,)
load = mask.float().mean(dim=(0, 1)) # (num_experts,)
aux_loss = self.num_experts * (importance * load).sum()
return topk_indices, topk_weights, aux_loss
The auxiliary loss is added to the total training loss with a small coefficient (typically 0.01). It ensures experts are utilized evenly without explicit hard assignment — the model learns to distribute naturally.
Step 3: Assemble the MoE Layer
Now combine the router and experts. The key insight: we must batch computations per expert to avoid sequential for-loops, which would destroy GPU utilization:
class SparseMoE(nn.Module):
def __init__(self, d_model: int, d_ff: int,
num_experts: int = 8, top_k: int = 2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.experts = nn.ModuleList([
Expert(d_model, d_ff) for _ in range(num_experts)
])
self.router = MoERouter(d_model, num_experts, top_k)
def forward(self, x: torch.Tensor):
# x shape: (batch, seq_len, d_model)
batch, seq_len, d_model = x.shape
# Route tokens to experts
topk_indices, topk_weights, aux_loss = self.router(x)
# Flatten for processing
x_flat = x.view(-1, d_model) # (batch*seq_len, d_model)
output = torch.zeros_like(x_flat)
# Process each expert with its assigned tokens
for expert_id in range(self.num_experts):
# Find tokens routed to this expert
token_mask = (topk_indices == expert_id) # (batch, seq_len, top_k)
# Combine weights across top-k slots
token_weight = (topk_weights * token_mask).sum(dim=-1) # (batch, seq_len)
# Select tokens for this expert
indices = token_mask.any(dim=-1).view(-1) # (batch*seq_len,)
if indices.sum() == 0:
continue
selected = x_flat[indices]
# Compute expert output and scale by routing weights
expert_output = self.experts[expert_id](selected)
weights = token_weight.view(-1)[indices].unsqueeze(-1)
# Accumulate weighted outputs
output[indices] += expert_output * weights
return output.view(batch, seq_len, d_model), aux_loss
This implementation uses token-choice routing — each token independently selects its top-k experts. The alternative, expert-choice routing, has each expert select its top-k tokens. Token-choice is simpler and used in Mixtral; expert-choice (from ST-MoE) gives better load balancing but adds complexity.
Step 4: Integrate into a Transformer Block
Replace the standard FFN in your transformer with the MoE layer:
class MoETransformerBlock(nn.Module):
def __init__(self, d_model: int, d_ff: int, n_heads: int,
num_experts: int = 8, top_k: int = 2):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.moe = SparseMoE(d_model, d_ff, num_experts, top_k)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor):
# Self-attention with residual
attn_out, _ = self.attn(x, x, x)
x = self.norm1(x + attn_out)
# MoE with residual
moe_out, aux_loss = self.moe(x)
x = self.norm2(x + moe_out)
return x, aux_loss
The training loop accumulates auxiliary loss across all MoE layers:
total_loss = cross_entropy_loss + 0.01 * sum(aux_losses)
Step 5: Performance Considerations
Production MoE implementations add several optimizations that our toy version lacks:
- Token dropping / capacity factors: When an expert's token queue exceeds capacity, excess tokens are dropped or rerouted. This prevents OOM during training but requires careful capacity tuning.
- Expert parallelism: Distribute experts across multiple GPUs. Each GPU holds a subset of experts, and all-to-all communication routes tokens. DeepSpeed and Megatron-LM implement this.
- Grouped GEMM: Batch matrix multiplications across experts using CUTLASS kernels for better GPU utilization.
- FP8 training: MoE layers benefit significantly from mixed precision due to the sparse activation pattern.
Practical Tips
- Start with 4 experts, top-2 routing for initial experiments. Scale to 8-16 experts once routing stabilizes.
- Use a warmup phase for the auxiliary loss coefficient — start at 0.0 and ramp to 0.01 over the first 10% of training.
- Monitor expert utilization histograms during training. If one expert dominates, increase the aux loss weight.
- For inference, pre-compute the expert assignment map to minimize routing overhead.
When to Use MoE
MoE is not always the answer. It shines when you have abundant memory but limited compute budget per token — typical for deployment of large models. For smaller models or training on a single GPU, a well-tuned dense model often outperforms a sparse one due to better hardware utilization and simpler optimization dynamics.
The sweet spot: models with 7B+ parameters where you want to push capability without proportional inference cost increases. That is exactly why Mistral chose MoE for Mixtral — 46.7B parameters with only 12.9B active per token.
Next Steps
Once you have the basic MoE working, experiment with:
- Noisy Top-K gating: Add Gumbel noise to router logits for exploration during training.
- Soft routing: Use all experts with smoothed weights instead of hard Top-K.
- Cross-layer expert sharing: Reuse the same expert set across multiple transformer layers to reduce total parameters.
Building MoE from scratch teaches you why modern LLMs are architected the way they are — and gives you the foundation to experiment with the next generation of sparse models.