Medusa: How Multiple Prediction Heads Eliminate the Draft Model Problem in LLM Inference
Hook
What if you could make your LLM generate 2-3 tokens per forward pass instead of one, without doubling your memory footprint or managing a second model? That’s exactly what Medusa does by rethinking how speculative decoding should work.
Context
Large language models have a fundamental bottleneck: they generate tokens one at a time. Each token requires a full forward pass through billions of parameters, making inference painfully slow even on powerful GPUs. The standard solution—speculative decoding—tries to predict multiple tokens ahead using a smaller “draft” model, then verifies those predictions with the main model. It works, but introduces operational complexity: you need to find or train a compatible draft model, manage two models in memory, and carefully tune the interaction between them. The draft model needs to be similar enough to the main model to make good predictions, but small enough to run quickly. Get this balance wrong, and you waste resources on poor speculations.
Medusa takes a fundamentally different approach. Instead of using a separate draft model, it augments your existing LLM with lightweight “Medusa heads”—small neural networks attached to the model’s hidden states that predict multiple future tokens in parallel. The base model stays frozen, and these heads are trained to speculate about tokens 1, 2, 3, or more positions ahead simultaneously. During inference, Medusa combines predictions from all heads using a tree-based attention mechanism that explores multiple candidate futures at once, then verifies them in a single forward pass. This eliminates the draft model entirely while achieving comparable or better speedups.
Technical Insight
At its core, Medusa’s architecture is elegantly simple. It adds N additional prediction heads (typically 3-5) on top of your LLM’s final hidden layer. Each head is a small network—usually just a few residual blocks followed by a language modeling head—that shares the same vocabulary as the base model. The key innovation is that each head specializes in predicting tokens at a specific future offset. Head 1 predicts the next token, head 2 predicts the token after that, and so on.
During training (Medusa-1 approach), the base LLM weights remain completely frozen. Only the new heads are trained using standard next-token prediction, but with offset targets. If the model is currently processing “The cat sat on the”, head 1 learns to predict “mat”, head 2 learns to predict the token after “mat”, and so forth. This training is remarkably parameter-efficient—the heads are tiny compared to the base model, often adding less than 5% additional parameters. You can train Medusa heads on a single GPU in hours, not the days or weeks required for base model training.
# Simplified Medusa head architecture
class MedusaHead(nn.Module):
def __init__(self, hidden_size, vocab_size, num_layers=1):
super().__init__()
# Residual blocks for processing hidden states
self.blocks = nn.ModuleList([
ResidualBlock(hidden_size)
for _ in range(num_layers)
])
# Final projection to vocabulary
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(self, hidden_states):
x = hidden_states
for block in self.blocks:
x = block(x)
return self.lm_head(x)
# Multiple heads predicting different future positions
medusa_heads = nn.ModuleList([
MedusaHead(hidden_size=4096, vocab_size=32000)
for _ in range(4) # Predict 4 tokens ahead
])
# During inference
hidden_state = base_model(input_ids)
candidates = [head(hidden_state) for head in medusa_heads]
The real magic happens during inference through tree-based attention. Instead of simply taking the top prediction from each head independently (which would compound errors), Medusa constructs a tree of candidate sequences. It takes the top-k predictions from each head and builds a tree where each path represents a possible future sequence. For example, if you use top-2 predictions from 3 heads, you might explore up to 8 different candidate futures simultaneously.
This tree is then processed using a modified attention mechanism. Medusa creates a special attention mask that allows each candidate in the tree to attend to its ancestors but not to other branches. This means multiple speculative paths are evaluated in a single forward pass through the base model. The base model computes logits for every position in the tree simultaneously, leveraging the parallel nature of transformer attention.
# Tree-based candidate generation (simplified)
def generate_candidates(medusa_logits, top_k=2):
"""Build tree of candidate sequences from Medusa heads"""
tree = []
# Start with top-k from first head
head_0_top = torch.topk(medusa_logits[0], top_k)
for token_1 in head_0_top.indices:
# For each choice, get top-k from second head
head_1_top = torch.topk(medusa_logits[1], top_k)
for token_2 in head_1_top.indices:
# Continue building tree...
tree.append([token_1, token_2])
return tree
# Verify all candidates in one pass
candidates = generate_candidates(medusa_predictions)
verification_logits = base_model(
candidates,
attention_mask=build_tree_mask(candidates)
)
accepted = find_longest_valid_prefix(verification_logits, candidates)
The verification step uses a greedy acceptance mechanism. Medusa walks through each position in the tree and checks if the base model’s predicted token matches the speculated token. It accepts tokens position-by-position until it hits a mismatch, then continues generation from there. This maintains the exact same output distribution as standard autoregressive generation—Medusa is mathematically equivalent to running the base model normally, just faster.
Medusa-2 introduces an important variant: instead of freezing the base model, it allows fine-tuning the entire model while adding Medusa heads. This requires careful training recipes to avoid catastrophic forgetting—the model needs to maintain its original capabilities while learning to help the Medusa heads make better predictions. The technique involves jointly optimizing the base model and heads with a combined loss, using techniques like LoRA to make fine-tuning efficient. For practitioners with access to training data, Medusa-2 can achieve even better speedups because the base model learns to produce hidden states that are easier for Medusa heads to predict from.
The framework also includes self-distillation capabilities, allowing you to add Medusa heads to any fine-tuned LLM. You simply run inference with your fine-tuned model on any text corpus (doesn’t need to be the original training data), collect the hidden states and generated tokens, then train Medusa heads to predict future tokens from those hidden states. This makes Medusa applicable even when you don’t have access to the original training pipeline.
Gotcha
Medusa’s most significant limitation is its optimization for batch size 1. The tree-based attention mechanism that makes it fast for single requests becomes inefficient when you’re processing multiple requests simultaneously. The tree structure is different for each request, making it difficult to pack them efficiently into a single batch. If you’re running a high-throughput API serving hundreds of concurrent users, traditional continuous batching approaches in frameworks like vLLM will likely give you better overall throughput, even though Medusa wins on per-request latency.
The training requirement is another real constraint. Unlike plug-and-play inference optimizations like quantization or better kernels, Medusa requires you to train heads for each model you want to accelerate. While this training is relatively cheap (hours on a single GPU), it’s still overhead. If you’re experimenting with dozens of different models or need to support arbitrary user-uploaded models, the per-model training tax adds up. You also need access to appropriate training data—while self-distillation helps, you still need to run inference on a reasonably large corpus to collect training examples.
There’s also a subtle quality consideration: Medusa works best with greedy or low-temperature sampling. When you use high-temperature sampling or diverse beam search, the prediction accuracy of Medusa heads degrades because there are more valid futures to predict. The speedup is still there, but it’s less dramatic than with greedy decoding. This isn’t a dealbreaker for most applications, but if you’re building something that requires highly creative or diverse outputs, you might see smaller gains than the advertised 2-3x speedup.
Verdict
Use if: You’re deploying LLMs for single-user scenarios like personal assistants, local coding copilots, or interactive research tools where per-request latency matters more than aggregate throughput. You have the GPU resources to train lightweight heads (even a consumer GPU works), and you’re working with a relatively stable set of models where one-time training overhead is acceptable. Medusa shines when you want speculative decoding benefits without the operational complexity of managing draft models, especially if you’re using sampling-based generation where traditional speculative decoding struggles. Skip if: You’re building a high-throughput serving system that needs to efficiently batch many concurrent requests—traditional continuous batching will serve you better. You need truly plug-and-play acceleration without any training, or you’re rapidly iterating across many different models where per-model training becomes a bottleneck. Also skip if you’re memory-constrained to the point where even 5% additional parameters matter, or if you’re in an environment where modifying model architectures (even just adding heads) creates deployment friction with your MLOps pipeline.