Speculative Decoding: Speed Up LLM Inference by 2-4× Without Sacrificing Quality
If you have ever run a large language model locally, you know the pain: tokens trickle out one by one, each requiring a full forward pass through billions of parameters. The bottleneck is not compute — it is memory bandwidth. Every single token generation step loads the entire model from VRAM, and modern GPUs are starved for bandwidth long before they hit their FLOPS ceiling.
Speculative decoding (also called assisted generation or draft-and-verify) is a lossless inference optimization that changes the math entirely. Instead of generating one token per forward pass, a lightweight draft model proposes several tokens in parallel, and the large target model verifies them all in a single pass. The result: 2-4× speedup with zero change in output distribution.
This tutorial walks you through the algorithm, the intuition, and a working implementation.
Why Autoregressive Decoding Is Slow
Standard autoregressive generation looks like this:
for _ in range(max_tokens):
logits = model(input_ids) # full forward pass
next_token = sample(logits[:, -1]) # pick one token
input_ids = concat(input_ids, next_token)
Each iteration loads every weight matrix (attention, MLP, norms) from HBM. For a 70B model in FP16, that is ~140 GB of memory traffic per token. Even an A100 with 2 TB/s bandwidth caps out at roughly 14 tokens/sec — and that is before accounting for the KV cache growth.
The key insight: if we could predict the next N tokens, we could verify them with one forward pass instead of N.
The Speculative Decoding Algorithm
The algorithm alternates between two phases:
Phase 1: Draft (Speculate)
A small, fast draft model generates γ candidate tokens autoregressively:
def draft_tokens(draft_model, input_ids, gamma):
"""Generate γ speculative tokens from the draft model."""
draft_ids = []
for _ in range(gamma):
logits = draft_model(input_ids)
next_id = sample(logits[:, -1])
draft_ids.append(next_id)
input_ids = concat(input_ids, next_id)
return draft_ids
The draft model can be a smaller version of the target (e.g., Llama-7B drafting for Llama-70B), a quantized copy, or even a trained speculative head.
Phase 2: Verify (Correct)
The target model processes the entire sequence — original context plus all γ draft tokens — in one forward pass:
def verify_tokens(target_model, input_ids, draft_ids, gamma):
"""Verify draft tokens against target model distribution."""
# Forward pass on extended sequence
target_logits = target_model(concat(input_ids, draft_ids))
accepted_count = 0
for i, draft_id in enumerate(draft_ids):
pos_logits = target_logits[:, -(gamma - i)]
target_prob = softmax(pos_logits)[draft_id]
draft_prob = draft_probs[i][draft_id]
# Acceptance test: min(1, p_target / p_draft)
if random() < min(1.0, target_prob / draft_prob):
accepted_count += 1
else:
# Reject: resample from adjusted distribution
adjusted = max(0, target_probs[i] - draft_probs[i])
adjusted /= adjusted.sum()
replacement = categorical_sample(adjusted)
draft_ids[i] = replacement
break # stop at first rejection
return draft_ids[:accepted_count + 1], accepted_count
The Acceptance Criterion
The critical detail is the Metropolis-Hastings-style acceptance test:
Accept draft token xi with probability min(1, ptarget(xi) / pdraft(xi))
If rejected, replace it with a token sampled from the residual distribution: max(0, ptarget - pdraft), renormalized. This guarantees the output is mathematically identical to sampling from the target model alone.
Practical Implementation with Hugging Face
Hugging Face Transformers supports speculative decoding out of the box:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load target (large) model
target_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3-70B-Instruct",
torch_dtype=torch.float16,
device_map="auto"
)
# Load draft (small) model
draft_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3-8B-Instruct",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8B-Instruct")
inputs = tokenizer("Explain how transformers handle positional encoding", return_tensors="pt")
inputs = inputs.to("cuda")
# Speculative decoding: draft generates 5 tokens, target verifies
output = target_model.generate(
**inputs,
assistant_model=draft_model,
max_new_tokens=200,
num_assistant_tokens=5, # γ = 5
temperature=0.7,
do_sample=True
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
Optimizing the Speedup
Choosing γ (Speculation Depth)
The number of draft tokens is a tradeoff:
- γ too small: not enough parallelism, minimal speedup
- γ too large: acceptance rate drops (draft diverges from target), wasted compute
- γ = 3-5: sweet spot for most model pairs
Draft Model Selection
Best results come from draft models that share architecture and vocabulary with the target:
| Target Model | Recommended Draft | Typical Speedup |
|---|---|---|
| Llama-3-70B | Llama-3-8B | 2.0-2.5× |
| Llama-3-70B | Llama-3-8B (Q4) | 2.5-3.0× |
| Mixtral-8x22B | Mixtral-8x7B | 2.0-2.3× |
| Gemma-2-27B | Gemma-2-9B | 1.8-2.2× |
Self-Speculative Decoding
When you cannot fit a separate draft model, use an early-exit head — attach a lightweight LM head to an intermediate layer of the same model:
class EarlyExitDraft(nn.Module):
"""Use layer N/2 as draft, full model as target."""
def __init__(self, model, exit_layer):
super().__init__()
self.model = model
self.exit_layer = exit_layer
self.draft_head = nn.Linear(
model.config.hidden_size,
model.config.vocab_size
)
def forward(self, input_ids):
hidden = self.model.embed_tokens(input_ids)
for layer in self.model.layers[:self.exit_layer]:
hidden = layer(hidden)
return self.draft_head(hidden[:, -1])
This adds negligible memory overhead and typically achieves 1.5-2× speedup.
Medusa: Training Multiple Speculative Heads
Medusa takes self-speculation further by training multiple independent heads, each predicting a different position ahead:
class MedusaModel(nn.Module):
def __init__(self, base_model, medusa_num_heads=5):
super().__init__()
self.base_model = base_model
self.medusa_heads = nn.ModuleList([
nn.Linear(base_model.config.hidden_size,
base_model.config.vocab_size)
for _ in range(medusa_num_heads)
])
def forward(self, input_ids):
hidden = self.base_model(input_ids).last_hidden_state
# Each head predicts token at position t+i
return [head(hidden[:, -1]) for head in self.medusa_heads]
During inference, head 0 predicts token t+1, head 1 predicts t+2, etc. All heads share the same backbone, so the memory cost is just the small head matrices.
When Speculative Decoding Shines (and When It Does Not)
Best cases:
- Creative writing / open-ended generation: high acceptance rates (70-90%) because text is predictable
- Code completion: structured, repetitive patterns → 80%+ acceptance
- Translation: deterministic outputs → excellent draft-target alignment
Worst cases:
- High-temperature sampling: randomness hurts draft prediction
- Mathematical reasoning: each step depends critically on the last → lower acceptance
- Very small models (<7B): the draft overhead outweighs the benefit
Putting It All Together: A Production Pipeline
Here is a complete production-ready wrapper:
import time
from typing import Optional
class SpeculativeGenerator:
def __init__(self, target_path, draft_path, device="cuda"):
self.target = AutoModelForCausalLM.from_pretrained(
target_path, torch_dtype=torch.float16, device_map=device
)
self.draft = AutoModelForCausalLM.from_pretrained(
draft_path, torch_dtype=torch.float16, device_map=device
)
self.tokenizer = AutoTokenizer.from_pretrained(draft_path)
def generate(self, prompt: str, max_tokens: int = 200,
gamma: int = 5, temperature: float = 0.7) -> dict:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.target.device)
start = time.perf_counter()
output = self.target.generate(
**inputs,
assistant_model=self.draft,
max_new_tokens=max_tokens,
num_assistant_tokens=gamma,
temperature=temperature,
do_sample=True,
)
elapsed = time.perf_counter() - start
text = self.tokenizer.decode(output[0], skip_special_tokens=True)
token_count = len(output[0]) - len(inputs.input_ids[0])
return {
"text": text,
"tokens_generated": token_count,
"tokens_per_second": token_count / elapsed,
"elapsed_seconds": round(elapsed, 2),
"speculation_depth": gamma,
}
# Usage
gen = SpeculativeGenerator(
target_path="meta-llama/Llama-3-70B-Instruct",
draft_path="meta-llama/Llama-3-8B-Instruct"
)
result = gen.generate("Write a Python function that implements")
print(f"Speed: {result['tokens_per_second']:.1f} tok/s")
Conclusion
Speculative decoding is one of the few inference optimizations that gives you something for nothing: faster generation with identical output distribution. The technique is maturing rapidly — vLLM, TGI, and Ollama all support it, and research extensions like Medusa, EAGLE, and Lookahead Decoding push speedups even higher.
Start with a small draft model and γ = 5, measure your acceptance rate, and tune from there. Your users will notice the difference.