Speculative Decoding: 2-3x Faster LLM Inference Without Quality Loss
If you have ever stared at a terminal waiting for an LLM to finish generating a response, you have felt the fundamental bottleneck of modern language models: autoregressive decoding. At every token, the model runs a full forward pass. Generating 500 tokens means 500 sequential GPU calls. No batching can hide that latency.
Speculative decoding breaks this constraint. It lets a small, fast "draft" model propose multiple tokens in parallel, then uses the large "target" model to verify them in a single forward pass. The result: 2-3x throughput improvement with zero quality degradation — the output distribution is mathematically identical to greedy or nucleus sampling from the target model alone.
In this tutorial, we will build a working speculative decoder from scratch using Hugging Face Transformers and PyTorch.
How Speculative Decoding Works
The algorithm has two phases per step:
- Draft Phase: A small model (e.g., Llama-3.2-1B) generates γ tokens autoregressively — fast, cheap, parallelizable.
- Verification Phase: The large model (e.g., Llama-3.1-70B) runs a single forward pass over all γ draft tokens, computing probabilities for each position simultaneously. Tokens that pass a probabilistic acceptance test are kept; rejected tokens trigger a correction step.
The key insight: verifying γ tokens costs roughly the same as generating one token with the large model, because the KV cache is shared and the forward pass is parallel across positions.
The Acceptance Criterion
At position i, let the draft model produce token xᵢ with probability q(xᵢ), and the target model assign probability p(xᵢ). We accept xᵢ with probability:
α = min(1, p(xᵢ) / q(xᵢ))
If accepted, we move to position i+1. If rejected, we sample from the normalized residual distribution: norm(max(0, p − q)), which guarantees the output distribution matches the target model exactly.
Implementation from Scratch
1. Setup and Model Loading
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load draft model (small, fast)
draft_name = "meta-llama/Llama-3.2-1B"
draft_tokenizer = AutoTokenizer.from_pretrained(draft_name)
draft_model = AutoModelForCausalLM.from_pretrained(
draft_name,
torch_dtype=torch.float16,
device_map="auto"
).eval()
# Load target model (large, accurate)
target_name = "meta-llama/Llama-3.1-8B"
target_tokenizer = AutoTokenizer.from_pretrained(target_name)
target_model = AutoModelForCausalLM.from_pretrained(
target_name,
torch_dtype=torch.float16,
device_map="auto"
).eval()
2. Core Speculative Decoding Loop
def speculative_decode(draft_model, target_model, draft_tok, target_tok,
prompt, max_new_tokens=256, gamma=5):
"""
Speculative decoding with a single draft model.
Args:
gamma: Number of draft tokens per verification step.
Returns:
Generated tokens and generation statistics.
"""
# Encode prompt
input_ids = draft_tok(prompt, return_tensors="pt").input_ids.to(device)
seq_len = input_ids.shape[1]
accepted_total = 0
drafted_total = 0
steps = 0
while input_ids.shape[1] < seq_len + max_new_tokens:
steps += 1
# --- DRAFT PHASE ---
draft_ids = input_ids.clone()
for _ in range(gamma):
with torch.no_grad():
out = draft_model(draft_ids)
next_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
draft_ids = torch.cat([draft_ids, next_token], dim=1)
draft_tokens = draft_ids[:, seq_len:] # γ new tokens
# --- VERIFICATION PHASE ---
with torch.no_grad():
# Run target model on the full sequence (draft tokens included)
target_out = target_model(draft_ids)
target_logits = target_out.logits[:, seq_len - 1 : seq_len - 1 + gamma, :]
target_probs = torch.softmax(target_logits, dim=-1)
# Get draft probabilities for each position
draft_probs_list = []
for i in range(gamma):
with torch.no_grad():
draft_out = draft_model(
draft_ids[:, : seq_len + i]
)
d_prob = torch.softmax(draft_out.logits[:, -1, :], dim=-1)
draft_probs_list.append(d_prob)
# Acceptance loop
accepted_count = 0
for i in range(gamma):
draft_token = draft_tokens[0, i].item()
p_val = target_probs[0, i, draft_token].item()
q_val = draft_probs_list[i][0, draft_token].item()
acceptance_prob = min(1.0, p_val / (q_val + 1e-12))
if torch.rand(1).item() < acceptance_prob:
accepted_count += 1
# Append the verified token
input_ids = torch.cat([
input_ids,
draft_tokens[:, i:i+1]
], dim=1)
else:
# Rejection: sample from residual distribution
residual = torch.clamp(target_probs[0, i] - draft_probs_list[i][0], min=0)
residual_sum = residual.sum()
if residual_sum > 0:
residual = residual / residual_sum
new_token = torch.multinomial(residual, 1)
input_ids = torch.cat([input_ids, new_token.unsqueeze(0)], dim=1)
else:
new_token = target_probs[0, i].multinomial(1)
input_ids = torch.cat([input_ids, new_token.unsqueeze(0)], dim=1)
break # Stop this draft round
accepted_total += accepted_count
drafted_total += gamma
seq_len = input_ids.shape[1]
# Early exit if no tokens accepted
if accepted_count == 0:
break
acceptance_rate = accepted_total / max(drafted_total, 1)
print(f"Steps: {steps} | Acceptance rate: {acceptance_rate:.2%} | "
f"Speedup: ~{1 + accepted_total / steps:.1f}x")
return draft_tok.decode(input_ids[0], skip_special_tokens=True)
3. Running It
prompt = "Explain how attention mechanisms work in transformer models, including"
result = speculative_decode(
draft_model, target_model,
draft_tokenizer, target_tokenizer,
prompt, max_new_tokens=200, gamma=4
)
print(result)
Expected Performance
With a 1B draft model and an 8B target model on a single A100 GPU:
| Metric | Baseline | Speculative Decoding |
|---|---|---|
| Tokens/sec | ~45 | ~95-120 |
| Acceptance rate (γ=4) | — | ~75-85% |
| Effective speedup | 1x | 2.0-2.5x |
Acceptance rate depends on how well the draft model matches the target. Fine-tuning the draft on the target's outputs (a process called speculative decoding with a self-drafted model) can push rates above 90%.
Optimization Tips
- Choose γ wisely: Too small → underutilized verification. Too large → diminishing returns as acceptance drops. Empirically, γ ∈ [3, 6] works best for 1B→8B pairs.
- Shared tokenizer: Use the same tokenizer for both models to avoid encoding mismatches.
- Batch KV cache reuse: In the verification phase, the target model only needs to process positions from seq_len − 1 onward. Pre-computed KV caches make this near-free.
- Multi-token draft: Advanced implementations (e.g., Medusa, Eagle) replace the draft model with lightweight heads on the target model itself, achieving 3-4x speedups.
- Speculative sampling: Replace greedy draft generation with temperature-scaled sampling to increase diversity and acceptance rates for creative tasks.
When to Use Speculative Decoding
Great for:
- Real-time chat applications where latency matters
- Code completion (high acceptance rates due to deterministic syntax)
- Long-form generation (blog posts, summaries, reports)
- Any scenario where the target model is GPU-memory-bound
Less effective for:
- Very short responses (< 20 tokens) — overhead outweighs gains
- Highly creative/divergent generation where draft-target alignment is low
- When the draft model is too small relative to the target (> 50x parameter gap)
Production Implementations
Don't roll your own for production. These frameworks support speculative decoding out of the box:
- vLLM —
--speculative-modelflag with n-gram or model-based draft - TGI (Text Generation Inference) — Medusa and draft model support
- Hugging Face Transformers —
AssistedGenerationMixinsince v4.38 - llama.cpp — speculative decoding with small draft models
Summary
Speculative decoding is one of the few inference optimizations that improves throughput without any trade-off in output quality. By letting a cheap draft model propose tokens and a large model verify them in parallel, you get a free 2-3x speedup. The math guarantees identical distributions. The implementation is straightforward. There is no reason not to use it.
Start with Hugging Face's built-in assisted generation, profile your acceptance rates, tune γ, and watch your tokens/sec climb. 🦊