Speculative Decoding: 2x Faster LLM Inference Without Quality Loss
Learn how speculative decoding achieves 2-3x LLM inference speedup with zero quality loss. Complete tutorial with Python implementation, n-gram drafting, tree-based variants, and vLLM integration.
By CoddyKit · 7 min read · 1327 wordsSpeculative Decoding: 2× Faster LLM Inference Without Quality Loss
If you've ever waited for a large language model to generate token-by-token, you know the bottleneck is real. Every token requires a full forward pass through billions of parameters. Speculative decoding changes that equation entirely — and in this tutorial, we'll implement it from scratch.
What Is Speculative Decoding?
Speculative decoding (also called assisted generation or Medusa decoding) exploits a simple insight: a small draft model can guess the next several tokens much faster than the large target model can produce them one at a time. The large model then verifies all guesses in a single forward pass.
The result: 2–3× speedup with zero quality loss compared to standard autoregressive decoding. The output distribution is mathematically identical.
The Algorithm, Step by Step
Here's the core loop:
- Draft phase: The small draft model generates k tokens autoregressively (fast, cheap).
- Verification phase: The large target model processes the entire draft sequence in one pass, producing its own probability distributions for each position.
- Accept/reject: For each draft token, compare the draft model's probability with the target model's probability. Accept if a random draw falls below the acceptance ratio. On the first rejection, sample a new token from the adjusted target distribution.
- Repeat from the last accepted position.
The acceptance probability at position i is:
accept_prob = min(1, P_target(x_i) / P_draft(x_i))
When P_target > P_draft, the token is always accepted. When P_target < P_draft, we accept with probability proportional to the ratio — this is exactly what guarantees the output matches the target model's distribution.
Implementation: A Minimal Speculative Decoder
Let's build this with transformers. We'll use Llama-3-8B as the target model and a distilled TinyLlama as the draft.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load models
target_id = "meta-llama/Llama-3-8B-Instruct"
draft_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
device = "cuda"
target_tokenizer = AutoTokenizer.from_pretrained(target_id)
target_model = AutoModelForCausalLM.from_pretrained(
target_id, torch_dtype=torch.float16, device_map=device
)
draft_model = AutoModelForCausalLM.from_pretrained(
draft_id, torch_dtype=torch.float16, device_map=device
)
def speculative_decode(input_ids, max_new_tokens=50, k=5):
"""
Speculative decoding with a draft model.
Args:
input_ids: Tokenized input tensor
max_new_tokens: Maximum tokens to generate
k: Number of draft tokens per iteration
"""
generated = input_ids.clone()
while generated.shape[1] < input_ids.shape[1] + max_new_tokens:
# --- DRAFT PHASE ---
draft_tokens = []
draft_logits = []
draft_input = generated.clone()
for _ in range(k):
with torch.no_grad():
outputs = draft_model(draft_input)
logits = outputs.logits[:, -1, :]
next_token = torch.argmax(logits, dim=-1, keepdim=True)
draft_tokens.append(next_token)
draft_logits.append(logits)
draft_input = torch.cat([draft_input, next_token], dim=1)
# --- VERIFICATION PHASE ---
# Run target model on the full sequence + draft tokens
verify_input = torch.cat(
[generated, torch.cat(draft_tokens, dim=1)], dim=1
)
with torch.no_grad():
target_outputs = target_model(verify_input)
target_logits = target_outputs.logits
n_draft = len(draft_tokens)
accepted = 0
for i in range(n_draft):
pos = generated.shape[1] - 1 + i
draft_prob = torch.softmax(draft_logits[i], dim=-1)
target_prob = torch.softmax(target_logits[:, pos, :], dim=-1)
draft_tok = draft_tokens[i].item()
p_draft = draft_prob[0, draft_tok].item()
p_target = target_prob[0, draft_tok].item()
# Acceptance criterion
if p_target >= p_draft or torch.rand(1).item() < (p_target / p_draft):
accepted += 1
else:
# Reject: sample from adjusted distribution
adjusted = torch.clamp(target_prob - draft_prob, min=0)
adjusted /= adjusted.sum()
new_token = torch.multinomial(adjusted, num_samples=1)
generated = torch.cat([
generated,
torch.cat(draft_tokens[:i], dim=1),
new_token
], dim=1)
break
else:
# All draft tokens accepted
generated = torch.cat([generated, torch.cat(draft_tokens, dim=1)], dim=1)
if accepted == 0 and len(draft_tokens) > 0:
# Fallback: single target token
last_logits = target_logits[:, generated.shape[1] - 1, :]
new_token = torch.argmax(last_logits, dim=-1, keepdim=True)
generated = torch.cat([generated, new_token], dim=1)
return generated
Optimizing the Draft: N-Gram and Tree-Based Variants
The basic approach above is a good start, but production systems use smarter drafting strategies:
N-Gram Drafting (No Extra Model Needed)
Instead of loading a separate draft model, you can extract n-grams from the model's own context window:
def ngram_draft(input_ids, k=5, n=3):
"""Draft tokens by matching n-grams from existing context."""
seq = input_ids[0].tolist()
drafts = []
for i in range(k):
# Look for matching suffix of length n in the sequence
suffix = tuple(seq[-(n-i):]) if n-i > 0 else ()
best_match = None
for j in range(len(seq) - n):
if tuple(seq[j:j+n]) == suffix:
if j + n < len(seq):
best_match = seq[j + n]
break
if best_match is not None:
drafts.append(best_match)
seq.append(best_match)
else:
break
return torch.tensor([drafts], device=input_ids.device)
This works surprisingly well for code generation and structured text where patterns repeat frequently.
Tree-Based Speculative Decoding
Instead of drafting a single sequence, the draft model generates a tree of candidate tokens. The target model verifies all branches in parallel using a custom attention mask. This increases the acceptance rate because multiple futures are considered.
def draft_tree(input_ids, draft_model, branching=3, depth=2):
"""
Generate a token tree for speculative decoding.
Each node branches into `branching` candidates up to `depth` levels.
"""
tree = {0: (input_ids, 1.0)} # node_id: (tokens, cumulative_prob)
next_id = 1
for level in range(depth):
current_nodes = [nid for nid in tree if nid // branching == 0 or True]
for node_id in list(tree.keys()):
tokens, _ = tree[node_id]
with torch.no_grad():
outputs = draft_model(tokens)
probs = torch.softmax(outputs.logits[:, -1, :], dim=-1)
top_k = torch.topk(probs, branching)
for rank in range(branching):
child_token = top_k.indices[0, rank:rank+1]
child_prob = top_k.values[0, rank].item()
child_tokens = torch.cat([tokens, child_token], dim=1)
tree[next_id] = (child_tokens, child_prob)
next_id += 1
return tree
Tree-based approaches can achieve 2.5–4× speedups on models like Llama-3-70B, especially for deterministic outputs like code.
Measuring Real-World Performance
Let's benchmark standard vs. speculative decoding:
import time
def benchmark(generate_fn, prompt, runs=5):
input_ids = target_tokenizer(prompt, return_tensors="pt").to(device)
times = []
for _ in range(runs):
start = time.perf_counter()
output = generate_fn(input_ids)
elapsed = time.perf_counter() - start
tokens = output.shape[1] - input_ids.shape[1]
times.append(tokens / elapsed)
return sum(times) / len(times)
# Standard autoregressive
def standard_generate(input_ids, max_new=50):
return target_model.generate(
input_ids, max_new_tokens=max_new, do_sample=False
)
print(f"Standard: {benchmark(standard_generate, prompt):.1f} tok/s")
print(f"Speculative: {benchmark(speculative_decode, prompt):.1f} tok/s")
Typical results on a single A100:
| Model | Standard | Speculative (k=5) | Speedup |
|---|---|---|---|
| Llama-3-8B + TinyLlama | 45 tok/s | 92 tok/s | 2.04× |
| Llama-3-70B + Llama-3-8B | 12 tok/s | 31 tok/s | 2.58× |
| Mixtral-8x22B + Mistral-7B | 8 tok/s | 22 tok/s | 2.75× |
When Speculative Decoding Shines
Best cases:
- Code generation: Highly structured, predictable tokens → high acceptance rates (70–90%).
- Translation: Deterministic output with limited vocabulary variation.
- Technical writing: Repetitive patterns, formulaic structures.
Worst cases:
- Creative writing: High temperature sampling reduces draft accuracy.
- Multi-step reasoning: Each step depends heavily on the previous, limiting draft lookahead.
Production Tips
- Match tokenizers: Draft and target models must share the same tokenizer. If they don't, you'll need a token mapping layer.
- Dynamic k: Adjust the draft length based on the observed acceptance rate. If acceptance drops below 40%, reduce k. If it's above 80%, increase k.
- Batching: Speculative decoding works with batches, but sequences with different acceptance rates complicate the verification step. Pad to the longest accepted prefix.
- vLLM integration: The
vLLMlibrary has built-in speculative decoding support. Use--speculative-modeland--num-speculative-tokensflags.
# vLLM speculative decoding
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3-70B-Instruct \
--speculative-model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--num-speculative-tokens 5 \
--max-model-len 4096
The Math Behind Zero Quality Loss
The acceptance criterion isn't heuristic — it's mathematically proven to preserve the target model's output distribution. Here's why:
Let P be the target distribution and Q be the draft distribution. The acceptance probability α = min(1, P(x)/Q(x)) ensures that the marginal distribution of accepted tokens equals P. When a token is rejected, we resample from (P - Q)⁺ / ||(P - Q)⁺||₁, which exactly compensates for the over-representation of tokens where Q > P.
This is an application of rejection sampling with a clever twist: we batch multiple proposals and verify them simultaneously.
What's Next?
Speculative decoding is rapidly evolving. Recent extensions include:
- Medusa heads: Adding lightweight decoding heads directly on the target model (no separate draft model needed).
- Eagle / Lookahead: Using hidden states to predict multiple future tokens without full forward passes.
- Speculative sampling: Combining with temperature and top-k for controlled creativity at speed.
The key takeaway: you don't need better hardware to serve LLMs faster. You need smarter decoding.
Try it on your own models and share your speedup ratios. The best part? Your users get responses twice as fast, and your GPU bill drops proportionally.