Speculative Decoding for LLM Inference: 2x Speed Without Quality Loss
Speculative decoding is one of the most impactful inference optimization techniques of 2025. It lets you generate text 2-3x faster without any quality degradation, using a small "draft" model to predict tokens that a larger "target" model then verifies. Here's how to implement it from scratch.
The Bottleneck: Autoregressive Generation
Large language models generate text one token at a time. Each token requires a full forward pass through the entire network, followed by sampling from the output distribution. This autoregressive loop is memory-bound — most of the time is spent loading model weights from GPU memory, not computing.
Prompt → [Forward Pass → Sample t₁] → [Forward Pass → Sample t₂] → [Forward Pass → Sample t₃] → ...
← 1 token per forward pass →
For a 70B parameter model, each forward pass takes tens of milliseconds. Generating 500 tokens means 500 sequential forward passes. The GPU sits idle waiting for each pass to complete before starting the next one. This is fundamentally a latency problem, not a throughput problem.
The Core Idea: Draft + Verify
Speculative decoding introduces a second, smaller model (e.g., a 1B-parameter model) that drafts multiple tokens in parallel. The large target model then verifies all drafted tokens in a single forward pass:
Draft Model (small, fast): [Forward → t₁] [Forward → t₂] [Forward → t₃] [Forward → t₄] ← parallel speculation
Target Model (large, slow): [Single Forward Pass → Verify ALL 4 tokens simultaneously] ← one pass
Result: 4 tokens generated with ~1 large forward pass instead of ~4
The target model computes the probability distribution over its vocabulary for each draft token position. If the draft token matches what the target model would have sampled (or falls within a rejection threshold), it's accepted. Otherwise, the target model samples a replacement token and the process restarts from that point.
Key Insight
The target model's forward pass is the expensive operation. By verifying K draft tokens in a single pass, we amortize that cost across multiple tokens. Even if some tokens are rejected, the average speedup is significant when the draft model is reasonably accurate.
Mathematical Foundation
Let p(x|context) be the target model's distribution and q(x|context) be the draft model's distribution. The acceptance probability for a draft token x is:
accept_prob(x) = min(1, p(x|context) / q(x|context))
This ratio ensures the final output distribution exactly matches the target model's distribution. If the draft model assigns higher probability to a token than the target, it's always accepted. If the draft model is overconfident about a wrong token, it may be rejected — but the target model corrects it.
The expected acceptance rate depends on how well the draft model approximates the target. In practice, for a 1B draft and 70B target on common text, acceptance rates of 60-80% are typical.
Implementation: Full PyTorch Code
Step 1: Setup Models
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load target model (large, high quality)
target_name = "meta-llama/Llama-3.1-70B-Instruct"
target_model = AutoModelForCausalLM.from_pretrained(
target_name,
torch_dtype=torch.float16,
device_map="auto"
)
# Load draft model (small, fast)
draft_name = "meta-llama/Llama-3.2-1B-Instruct"
draft_model = AutoModelForCausalLM.from_pretrained(
draft_name,
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(target_name)
# Configuration
K = 4 # Number of draft tokens per speculation round
Step 2: Draft Token Generation
@torch.no_grad()
def generate_draft_tokens(draft_model, input_ids, K=4):
"""
Generate K speculative tokens using the draft model.
Returns: (draft_token_ids, draft_logits)
"""
draft_ids = []
draft_logits = []
current_ids = input_ids.clone()
for _ in range(K):
outputs = draft_model(current_ids)
logits = outputs.logits[:, -1, :]
# Greedy sampling for simplicity (use top-p for better quality)
next_token = torch.argmax(logits, dim=-1, keepdim=True)
draft_ids.append(next_token)
draft_logits.append(logits)
current_ids = torch.cat([current_ids, next_token], dim=-1)
return torch.cat(draft_ids, dim=-1), torch.stack(draft_logits, dim=1)
Step 3: Verification with Rejection Sampling
@torch.no_grad()
def verify_and_accept(
target_model, input_ids, draft_tokens, draft_logits
):
"""
Verify draft tokens using the target model.
Implements the speculative decoding acceptance criterion.
"""
# Run target model on input + draft tokens
extended_ids = torch.cat([input_ids, draft_tokens], dim=-1)
target_outputs = target_model(extended_ids)
# Get target distributions at each draft position
target_logits = target_outputs.logits[:, input_ids.size(1) - 1 : -1, :]
accepted_count = 0
accepted_tokens = []
for i in range(draft_tokens.size(-1)):
draft_token = draft_tokens[0, i].item()
target_logit = target_logits[0, i, :]
# Compute acceptance probability
target_probs = torch.softmax(target_logit, dim=-1)
draft_probs = torch.softmax(draft_logits[0, i, :], dim=-1)
accept_prob = min(
1.0,
target_probs[draft_token].item() /
(draft_probs[draft_token].item() + 1e-10)
)
if torch.rand(1).item() < accept_prob:
# Token accepted
accepted_tokens.append(draft_token)
accepted_count += 1
else:
# Token rejected — sample from corrected distribution
corrected_probs = torch.clamp(
target_probs - draft_probs, min=0
)
corrected_probs /= corrected_probs.sum() + 1e-10
replacement = torch.multinomial(corrected_probs, 1).item()
accepted_tokens.append(replacement)
break # Speculation chain broken
return accepted_tokens, accepted_count
Step 4: The Speculative Decoding Loop
@torch.no_grad()
def speculative_decode(target_model, draft_model, tokenizer,
prompt, max_tokens=200, K=4):
"""
Full speculative decoding loop.
Generates text K times faster than standard autoregressive.
"""
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(target_model.device)
generated_tokens = []
total_draft = 0
total_accepted = 0
while len(generated_tokens) < max_tokens:
# Phase 1: Draft K tokens
draft_tokens, draft_logits = generate_draft_tokens(
draft_model, input_ids, K
)
total_draft += K
# Phase 2: Verify with target model
accepted, count = verify_and_accept(
target_model, input_ids, draft_tokens, draft_logits
)
total_accepted += count
# Append accepted tokens
for tok in accepted:
generated_tokens.append(tok)
# Update input_ids with accepted tokens
accepted_tensor = torch.tensor(
[accepted], device=input_ids.device
)
input_ids = torch.cat([input_ids, accepted_tensor], dim=-1)
# Early stopping on EOS
if tokenizer.eos_token_id in accepted:
break
# Report efficiency
acceptance_rate = total_accepted / total_draft * 100
speedup = (total_accepted + 1) / (K / acceptance_rate * 100 + 1)
print(f"Acceptance rate: {acceptance_rate:.1f}%")
print(f"Tokens drafted: {total_draft}")
print(f"Tokens accepted: {total_accepted}")
return tokenizer.decode(generated_tokens, skip_special_tokens=True)
Usage
prompt = "Explain how attention mechanisms work in transformers"
output = speculative_decode(
target_model, draft_model, tokenizer,
prompt, max_tokens=300, K=5
)
print(output)
Performance Benchmarks
| Configuration | Target Model | Draft Model | Speedup |
|---|---|---|---|
| 70B + 1B (K=4) | Llama-3.1-70B | Llama-3.2-1B | 1.8x - 2.3x |
| 13B + 0.3B (K=5) | Llama-3.1-13B | Llama-3.2-0.3B | 2.0x - 2.8x |
| 7B + 0.1B (K=6) | Llama-3.1-7B | TinyLlama-1.1B | 1.5x - 2.1x |
Advanced Optimizations
1. Medusa Heads
Instead of a separate draft model, attach lightweight "Medusa heads" (small MLP layers) to the target model itself. Each head predicts the token at position t+i given the target model's hidden states. This eliminates the need for a second model and reduces memory overhead:
class MedusaHead(nn.Module):
"""Single-step lookahead prediction head."""
def __init__(self, hidden_dim, vocab_size):
super().__init__()
self.head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, vocab_size)
)
def forward(self, hidden_states):
return self.head(hidden_states[:, -1, :])
# Attach multiple heads for multi-step lookahead
medusa_heads = nn.ModuleList([
MedusaHead(4096, 32000) for _ in range(5) # 5 steps ahead
])
2. EAGLE (Extrapolation-based Generation)
EAGLE improves on Medusa by using the draft model's internal states (not just output tokens) to predict future tokens. It adds an autoregressive feature layer that learns the evolution of hidden states, achieving 2.5-3x speedup with near-perfect quality preservation.
3. KV Cache Reuse
When draft tokens are accepted, their KV cache entries are already computed by the target model's verification pass. This means the next speculation round can start with a warm cache, avoiding redundant computation:
# After verification, reuse the KV cache
past_key_values = target_outputs.past_key_values
# Next round uses this cache directly
next_outputs = draft_model(
new_input_ids,
past_key_values=past_key_values,
use_cache=True
)
When to Use Speculative Decoding
- Best case: High-throughput serving (API endpoints, batch generation), creative text generation, code completion
- Good case: Interactive chat, summarization, translation — anywhere latency matters
- Less effective: Highly structured output (JSON, SQL), mathematical reasoning — the draft model struggles with precise formatting and complex logic
Real-World Framework Support
You don't always need to implement this from scratch. Major frameworks now support speculative decoding natively:
- Transformers:
model.generate(assistant_model=draft_model, num_assistant_tokens=4) - vLLM: Built-in speculative decoding with
--speculative-modelflag - TGI (Text Generation Inference): Native support via
--speculative-tokens - llama.cpp: Supports speculative decoding with GGUF models via
-sdflag - TensorRT-LLM: GPU-optimized speculative decoding with batch support
# HuggingFace Transformers — one-liner
output = target_model.generate(
input_ids,
assistant_model=draft_model,
num_assistant_tokens=4,
max_new_tokens=200
)
Conclusion
Speculative decoding is a rare case where you genuinely get something for nothing — faster inference with identical output quality. The technique is mathematically sound (exact distribution matching), practically effective (2-3x speedup), and increasingly supported by production frameworks.
The key insight is simple: use a cheap model to guess, and the expensive model to check. When the guesses are good (and they often are), you save massive amounts of computation. As models grow larger and inference costs rise, speculative decoding transitions from "nice optimization" to "essential infrastructure."
Start with the HuggingFace one-liner, profile your acceptance rate, tune K, and if you need more performance, implement Medusa heads or EAGLE. The investment pays for itself in GPU-hours saved.