Back to Articles

MoBA: How Moonshot AI Serves 1M-Token Contexts in Production with Learned Sparse Attention

[ View on GitHub ]

MoBA: How Moonshot AI Serves 1M-Token Contexts in Production with Learned Sparse Attention

Hook

Moonshot AI's Kimi assistant handles million-token contexts in production not by throwing more hardware at the problem, but by teaching their models which parts of context to ignore—without adding a single trainable parameter.

Context

The quadratic complexity of transformer attention has haunted large language model development since the architecture's inception. As context windows expanded from 2K to 8K to 100K+ tokens, the memory and compute requirements became untenable. A 1M token context with full attention requires processing 1 trillion key-value pairs per layer—an engineering challenge that has spawned dozens of solutions.

The existing approaches fall into predictable camps: structured sparse attention patterns (Longformer, BigBird) that predefine which tokens attend to which, exact attention optimizations (Flash Attention) that make the quadratic operation faster but not fundamentally cheaper, and linear attention variants that change the mechanism entirely with unpredictable impacts on model quality. What's been missing is a solution that learns which parts of context matter for each query without requiring hand-crafted attention patterns or architectural overhauls. MoBA, deployed in Moonshot AI's production Kimi service, represents a different approach: applying the mixture-of-experts gating philosophy to context blocks themselves.

Technical Insight

MoBA's core insight is deceptively simple: divide the key-value cache into fixed-size blocks, then use a lightweight gating mechanism to select only the top-k most relevant blocks for each query token. The brilliance lies in the gating function being parameter-less—it computes relevance scores directly from the query and key representations without introducing new weights that need training.

The architecture works by first chunking the context into blocks of size B (typically 256 or 512 tokens). For each query token, the system computes a block-level attention score by aggregating the attention logits within each block. The top-k blocks with highest scores proceed to full attention computation, while the rest are discarded. This creates a sparse attention pattern that adapts per query rather than following rigid structural rules.

Here's what the block selection logic looks like in practice:

# Simplified MoBA block selection
def select_kv_blocks(query, key_blocks, value_blocks, top_k):
    # query: [batch, heads, 1, dim]
    # key_blocks: [batch, heads, num_blocks, block_size, dim]
    
    # Compute attention scores for each block
    scores = torch.einsum('bhqd,bhbnd->bhbn', 
                         query, key_blocks)
    
    # Aggregate block scores (max pooling over tokens in block)
    block_scores = scores.max(dim=-1).values  # [batch, heads, num_blocks]
    
    # Select top-k blocks per head
    selected_indices = torch.topk(block_scores, k=top_k, dim=-1).indices
    
    # Gather selected KV blocks
    selected_keys = torch.gather(key_blocks, 2, 
                                selected_indices.unsqueeze(-1).unsqueeze(-1))
    selected_values = torch.gather(value_blocks, 2,
                                  selected_indices.unsqueeze(-1).unsqueeze(-1))
    
    return selected_keys, selected_values

The repository provides two implementations that illuminate the design trade-offs. The naive version uses attention masking to zero out non-selected blocks, which integrates cleanly with existing attention kernels but forces computation on the entire context. The production kernel, by contrast, physically removes unselected blocks before attention computation, achieving the 40x speedup on 32K sequences by operating only on the selected subset.

What makes MoBA production-ready is its seamless mode switching. During training, you can toggle between full attention (for stable gradient signal) and sparse attention (for efficiency), or even gradually transition as training progresses. This flexibility addresses a critical concern with sparse attention methods: ensuring the model learns robust representations despite seeing only partial context during forward passes.

The continued training requirement deserves emphasis—you cannot simply apply MoBA to an existing pretrained model and expect acceleration. The model needs to learn which block selection patterns preserve task performance. Moonshot AI's experiments show this requires meaningful continued training compute, but the investment pays off: their 1M token needle-in-haystack results demonstrate that models can learn to attend to relevant context blocks with high precision.

The integration with Flash Attention is particularly elegant. Rather than reimplementing attention from scratch, MoBA operates as a preprocessing step that filters blocks before handing off to Flash Attention kernels for the actual attention computation. This composability means you get the memory efficiency of Flash Attention combined with the compute reduction of sparse attention—a multiplicative benefit.

Gotcha

The continued training requirement is not a minor implementation detail—it's a fundamental constraint that affects deployment strategy. If you have a pretrained model you want to use with MoBA, budget for substantial continued training compute. Moonshot AI's paper doesn't specify exact numbers, but training models to effectively utilize learned sparse attention patterns requires enough gradient steps for the gating scores to meaningfully influence block selection. For organizations without significant compute resources or the ability to train models from scratch, this makes MoBA a non-starter despite its impressive results.

The hard dependency on flash-attn==2.6.3 creates integration friction. Flash Attention development moves quickly, and pinning to a specific version means missing optimizations and bug fixes in newer releases. More problematically, the kernel implementation has architectural assumptions baked in—block sizes, head dimensions, sequence length constraints—that may not align with your model architecture. The repository's kernel code is educational but adapting it to non-standard configurations requires CUDA expertise. Teams should expect to invest engineering time debugging kernel launches and memory layouts when moving beyond the reference configurations.

Verdict

Use MoBA if you're building or fine-tuning models for extreme context lengths (100K+ tokens) and have the training infrastructure to perform meaningful continued training. The production deployment at Moonshot AI proves it works at scale, and the 40x speedup on long sequences translates directly to cost savings in serving. The parameter-less gating is genuinely novel—you get adaptive sparsity without the optimizer complexity of adding MoE parameters. This is particularly compelling if you're already using Flash Attention, since MoBA layers on top rather than replacing it. Skip MoBA if you need a drop-in solution for existing pretrained models without retraining, or if your context lengths stay under 32K where full attention with Flash Attention 2/3 remains practical and simpler to deploy. Also skip if you lack CUDA development capacity to debug kernel issues, since the dependency pinning and custom kernels will create maintenance burden. For moderate context lengths, the engineering complexity outweighs the performance gains.

// ADD TO YOUR README
[![Featured on Starlog](https://starlog.is/api/badge/llm-engineering/moonshotai-moba.svg)](https://starlog.is/api/badge-click/llm-engineering/moonshotai-moba)