Speculative Decoding: 2x–3x Faster LLM Inference Without Quality Loss
If you’ve ever waited for a large language model to generate a response token by token, you know the frustration. Autoregressive decoding is inherently sequential — each token depends on all previous ones. But what if you could predict several tokens ahead, verify them in parallel, and dramatically speed up generation? That’s exactly what speculative decoding does, and in this tutorial, we’ll build it from scratch.
What Is Speculative Decoding?
Speculative decoding (also called assisted decoding or medusa decoding) uses a small, fast draft model to propose multiple tokens ahead. A larger target model then verifies those proposals in a single forward pass. The key insight: verifying k tokens costs roughly the same as generating one token with the target model.
The algorithm works in rounds:
- Draft phase: The draft model generates γ (gamma) tokens autoregressively — very fast because it’s small.
- Verification phase: The target model runs a single forward pass on the input + all γ draft tokens, producing its own probability distribution for each position.
- Acceptance check: Each draft token is accepted or rejected using a probabilistic criterion that guarantees the output distribution matches the target model exactly.
- Repeat: Continue from the last accepted position.
The beauty? The final output distribution is mathematically identical to what the target model would produce alone. No quality loss — just speed.
The Acceptance Criterion
For each position i, let q(x) be the draft model’s probability and p(x) be the target model’s probability for the next token. We accept the draft token x with probability:
accept_prob = min(1, p(x) / q(x))
If rejected, we sample a new token from the normalized residual distribution max(0, p(x) − q(x)). This rejection sampling step is what guarantees exact distribution matching.
Implementation in Python with Hugging Face
Let’s build a working implementation. We’ll use a small model (TinyLlama) as the drafter and a larger model (Llama-3.1-8B) as the target.
Step 1: Setup
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load models
draft_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
target_name = "meta-llama/Llama-3.1-8B-Instruct"
draft_tok = AutoTokenizer.from_pretrained(draft_name)
draft_tok.pad_token = draft_tok.eos_token
target_tok = AutoTokenizer.from_pretrained(target_name)
target_tok.pad_token = target_tok.eos_token
draft_model = AutoModelForCausalLM.from_pretrained(
draft_name, torch_dtype=torch.float16, device_map=device
)
target_model = AutoModelForCausalLM.from_pretrained(
target_name, torch_dtype=torch.float16, device_map=device
)
draft_model.eval()
target_model.eval()
Step 2: Core Speculative Decoding Loop
def speculative_decode(input_ids, draft_model, target_model,
draft_tok, target_tok, gamma=5,
max_new_tokens=100, temperature=0.7):
"""
Speculative decoding with exact distribution matching.
Args:
input_ids: Initial token IDs (tensor)
gamma: Number of draft tokens to generate per round
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
Returns:
Generated token IDs
"""
current_ids = input_ids.clone()
tokens_generated = 0
while tokens_generated < max_new_tokens:
# --- DRAFT PHASE ---
draft_tokens = []
draft_probs = []
draft_input = current_ids.clone()
for _ in range(gamma):
with torch.no_grad():
out = draft_model(draft_input)
logits = out.logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
token = torch.multinomial(probs, num_samples=1)
draft_tokens.append(token.item())
draft_probs.append(probs.squeeze())
draft_input = torch.cat([draft_input, token], dim=1)
if token.item() == draft_tok.eos_token_id:
break
k = len(draft_tokens)
if k == 0:
break
# --- VERIFICATION PHASE ---
with torch.no_grad():
# Run target model on input + all draft tokens
verify_input = torch.cat([
current_ids,
torch.tensor([draft_tokens], device=device)
], dim=1)
target_out = target_model(verify_input)
target_logits = target_out.logits
accepted = 0
for i in range(k):
# Position in the full sequence
pos = current_ids.size(1) - 1 + i
q = draft_probs[i] # draft probability distribution
p = F.softmax(target_logits[0, pos, :] / temperature, dim=-1)
draft_token = draft_tokens[i]
# Acceptance probability
accept_prob = min(1.0, p[draft_token].item() / q[draft_token].item())
if torch.rand(1).item() < accept_prob:
accepted += 1
current_ids = torch.cat([
current_ids,
torch.tensor([[draft_token]], device=device)
], dim=1)
else:
# Rejection sampling from residual distribution
residual = torch.clamp(p - q, min=0)
residual_sum = residual.sum()
if residual_sum > 0:
residual = residual / residual_sum
new_token = torch.multinomial(residual, num_samples=1)
else:
new_token = torch.multinomial(p, num_samples=1)
current_ids = torch.cat([
current_ids,
new_token.unsqueeze(0)
], dim=1)
break # Move to next round
tokens_generated += accepted + (1 if accepted < k else 0)
if tokens_generated >= max_new_tokens:
break
return current_ids[:, input_ids.size(1):]
Step 3: Running the Generator
prompt = "Explain how transformers use attention mechanisms"
input_ids = target_tok(prompt, return_tensors="pt").input_ids.to(device)
# Standard autoregressive (baseline)
import time
start = time.time()
std_output = target_model.generate(
input_ids, max_new_tokens=50, do_sample=True, temperature=0.7
)
std_time = time.time() - std_time
print(f"Standard: {std_time:.2f}s")
# Speculative decoding
start = time.time()
spec_output = speculative_decode(
input_ids, draft_model, target_model,
draft_tok, target_tok, gamma=5,
max_new_tokens=50, temperature=0.7
)
spec_time = time.time() - start
print(f"Speculative: {spec_time:.2f}s")
speedup = std_time / spec_time
print(f"Speedup: {speedup:.2f}x")
print(target_tok.decode(spec_output[0], skip_special_tokens=True))
Optimizations That Matter
1. KV-Cache for the Target Model
The naive implementation re-runs the target model on all previous tokens each round. In production, maintain a KV-cache to avoid redundant computation:
# Use past_key_values to cache attention states
target_out = target_model(
verify_input,
past_key_values=past_kv,
use_cache=True
)
logits = target_out.logits
past_kv = target_out.past_key_values
2. Draft Model Selection Strategies
The choice of draft model dramatically affects speedup. Practical options:
- Smaller model from same family: Llama-3.2-1B → Llama-3.1-8B (best accuracy match)
- Earlier layers of the same model: Use the first N layers as a drafter (zero extra memory)
- n-gram matching: Extract frequent n-grams from context as a free drafter (no model needed)
- Medusa heads: Train lightweight decoding heads on top of the target model
3. Adaptive Gamma
Instead of a fixed γ, adapt dynamically based on the recent acceptance rate:
if acceptance_rate > 0.8:
gamma = min(gamma + 1, max_gamma)
elif acceptance_rate < 0.4:
gamma = max(gamma - 1, 1)
Expected Performance
With a well-matched draft/target pair on GPU:
| Scenario | Typical Speedup | Acceptance Rate |
|---|---|---|
| Code generation | 2.0x – 2.8x | 70–85% |
| Creative writing | 1.5x – 2.2x | 50–70% |
| Math/STEM | 1.3x – 1.8x | 40–60% |
| n-gram only (no draft model) | 1.2x – 1.5x | 30–50% |
Code generation sees the highest speedups because code is highly predictable — the draft model gets most tokens right.
Production-Ready Alternatives
If you don’t want to implement this yourself, several frameworks support speculative decoding out of the box:
# Hugging Face transformers (v4.38+)
output = model.generate(
input_ids,
assistant_model=draft_model, # triggers speculative decoding
max_new_tokens=100,
assistant_tokenizer=draft_tok,
)
# vLLM (high-throughput serving)
# Set in server config:
# --speculative-model [draft_path] --num-speculative-tokens 5
# TensorRT-LLM
# Use SpeculativeDecodingPlugin in the build config
When NOT to Use Speculative Decoding
Speculative decoding isn’t always the right choice:
- Small models (<3B): The overhead outweighs the benefit. Speculative decoding shines when the target model is large enough that a forward pass is the bottleneck.
- Batch size > 1: Acceptance rates vary across sequences, making batched speculative decoding complex. vLLM handles this with per-sequence speculation.
- Memory-constrained: You need to load two models. If you can’t fit both on your GPU, consider Medusa heads or n-gram drafting instead.
Key Takeaways
- Speculative decoding gives exact target model distribution — zero quality tradeoff.
- Speedup depends on draft model quality and task predictability. Code = fast, creative = moderate.
- KV-caching and adaptive gamma are essential for production performance.
- Frameworks like HF Transformers, vLLM, and TensorRT-LLM support it natively — no custom code needed in many cases.
Speculative decoding is one of those rare techniques where you genuinely get something for nothing: faster inference, same output distribution. If you’re serving LLMs in production and not using it, you’re leaving 2–3x performance on the table.