Speculative Decoding: Accelerate LLM Inference with Draft Models
Speculative decoding accelerates LLM inference by using a small draft model to generate candidate tokens, then verifying them in parallel with the large target model. This tutorial explains the algorithm, provides a complete PyTorch implementation, and covers production deployment strategies for 2-4x speedup with zero quality loss.
By CoddyKit · 5 min read · 1026 wordsSpeculative Decoding: Accelerate LLM Inference with Draft Models
Large language models have become incredibly capable, but their sequential token-by-token generation remains a bottleneck. Every autoregressive step requires a full forward pass through billions of parameters — making inference slow and expensive. Enter speculative decoding, a breakthrough technique that can accelerate LLM inference by 2–4× without sacrificing output quality. In this tutorial, you will learn how it works under the hood and implement a working version from scratch.
The Problem with Autoregressive Generation
Standard LLM decoding generates one token at a time. For a response of length N, the model performs N sequential forward passes. Each pass reads the entire context through all layers to produce a single token distribution. The memory bandwidth and compute cost add up quickly, especially for long outputs.
The fundamental insight of speculative decoding is simple: not every token is equally hard to predict. Many tokens are straightforward — common words, code syntax, punctuation. If we could cheaply guess several tokens ahead and verify them all in parallel, we could amortize the expensive large-model forward pass across multiple tokens.
How Speculative Decoding Works
Speculative decoding pairs a small, fast draft model with the large target model. The algorithm operates in cycles:
- Draft phase: The small model generates K tokens autoregressively (cheap and fast).
- Verification phase: The large model evaluates all K drafted tokens in a single forward pass.
- Acceptance: Each drafted token is accepted with probability min(1, P_target / P_draft), ensuring the output distribution matches the target model exactly.
- Continuation: Accepted tokens are committed; on rejection, the large model samples a corrected token and a new cycle begins.
The key mathematical property: when the acceptance rate is high, you generate multiple tokens per target-model forward pass. The speedup is approximately 1 + K × acceptance_rate.
Implementation from Scratch
Let us build a minimal speculative decoder using PyTorch and the Hugging Face transformers library.
Step 1: Setup Models
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Target model: large, high-quality
target_name = "meta-llama/Llama-2-13b-hf"
target_model = AutoModelForCausalLM.from_pretrained(
target_name, torch_dtype=torch.float16, device_map="auto"
)
target_tokenizer = AutoTokenizer.from_pretrained(target_name)
# Draft model: small, fast
draft_name = "meta-llama/Llama-2-7b-hf"
draft_model = AutoModelForCausalLM.from_pretrained(
draft_name, torch_dtype=torch.float16, device_map="auto"
)
draft_tokenizer = AutoTokenizer.from_pretrained(draft_name)
# Ensure shared tokenizer
tokenizer = target_tokenizer
Step 2: The Draft Function
@torch.no_grad()
def generate_draft_tokens(model, tokenizer, input_ids, K, temperature=1.0):
"""Generate K tokens from the draft model."""
generated = []
current_ids = input_ids.clone()
for _ in range(K):
outputs = model(current_ids)
logits = outputs.logits[:, -1, :] / temperature
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated.append(next_token.item())
current_ids = torch.cat([current_ids, next_token], dim=1)
return generated, current_ids
Step 3: The Verification Function
@torch.no_grad()
def verify_and_accept(target_model, draft_tokens, input_ids, temperature=1.0):
"""
Run target model once over all draft tokens.
Return accepted token count and corrected token on rejection.
"""
extended = torch.cat([
input_ids,
torch.tensor([draft_tokens], device=input_ids.device)
], dim=1)
outputs = target_model(extended)
logits = outputs.logits[:, -(len(draft_tokens)+1):, :] / temperature
target_probs = torch.softmax(logits[:-1], dim=-1)
accepted_count = 0
for i, draft_tok in enumerate(draft_tokens):
p_target = target_probs[0, i, draft_tok].item()
p_draft = 1.0
acceptance_prob = min(1.0, p_target / max(p_draft, 1e-10))
if torch.rand(1).item() < acceptance_prob:
accepted_count += 1
else:
corrected = torch.multinomial(target_probs[0, i], num_samples=1)
return accepted_count, corrected.item()
final_logits = logits[-1]
final_probs = torch.softmax(final_logits, dim=-1)
next_token = torch.multinomial(final_probs, num_samples=1).item()
return accepted_count, next_token
Step 4: The Full Speculative Decoding Loop
@torch.no_grad()
def speculative_decode(target_model, draft_model, tokenizer, prompt,
max_new_tokens=256, K=5, temperature=0.7):
"""Full speculative decoding with draft-verify cycles."""
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(target_model.device)
original_length = input_ids.shape[1]
while input_ids.shape[1] - original_length < max_new_tokens:
draft_tokens, _ = generate_draft_tokens(
draft_model, tokenizer, input_ids, K, temperature
)
accepted, corrected = verify_and_accept(
target_model, draft_tokens, input_ids, temperature
)
accepted_tokens = draft_tokens[:accepted]
input_ids = torch.cat([
input_ids,
torch.tensor([accepted_tokens + [corrected]], device=input_ids.device)
], dim=1)
if input_ids[0, -1].item() == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0, original_length:], skip_special_tokens=True)
Step 5: Run and Measure
import time
prompt = "Explain how attention mechanisms work in transformer models."
# Standard decoding (baseline)
start = time.time()
baseline = target_model.generate(
tokenizer(prompt, return_tensors="pt").input_ids.to(target_model.device),
max_new_tokens=256, temperature=0.7, do_sample=True
)
baseline_time = time.time() - start
# Speculative decoding
start = time.time()
result = speculative_decode(target_model, draft_model, tokenizer, prompt,
max_new_tokens=256, K=5, temperature=0.7)
speculative_time = time.time() - start
print(f"Baseline: {baseline_time:.2f}s")
print(f"Speculative: {speculative_time:.2f}s")
print(f"Speedup: {baseline_time / speculative_time:.2f}x")
Understanding Acceptance Rates
The effectiveness of speculative decoding depends entirely on how well the draft model predicts what the target model would generate. Here are realistic acceptance rates by scenario:
| Scenario | K=3 | K=5 | K=8 |
|---|---|---|---|
| Code completion | 0.82 | 0.68 | 0.51 |
| Technical writing | 0.78 | 0.61 | 0.44 |
| Creative writing | 0.55 | 0.38 | 0.25 |
| Mathematical reasoning | 0.60 | 0.42 | 0.28 |
Code and technical text have the highest acceptance rates because the vocabulary is constrained and predictable. Creative tasks, which require more "surprise," see lower rates but still benefit from the technique.
Production Considerations
Several frameworks now support speculative decoding natively:
- llama.cpp — supports draft models with the
--draftflag, compatible with GGUF quantized models. - vLLM — integrated speculative decoding with batched verification for serving.
- Hugging Face TGI — native speculative decoding support for production deployments.
- Ollama — supports multi-model speculative decoding through its API.
For production deployments, consider these practical tips:
- Match training domains: A draft model trained on code works best for code tasks. Fine-tune a small model on your specific domain for maximum acceptance rates.
- Tune K dynamically: Start with K=5 and adjust based on observed acceptance rates. Higher K means more risk of rejection; lower K means less speedup.
- Use quantized draft models: The draft model can be heavily quantized (4-bit or even 3-bit) since occasional errors only reduce acceptance rate, not output quality.
- Cache reuse: KV cache from accepted tokens can be reused across cycles, reducing redundant computation.
Why This Matters
Speculative decoding is one of the rare ML optimizations that gives you genuine speedup without any quality loss — the output distribution is provably identical to running the target model alone. For production LLM services, this directly translates to lower latency, higher throughput, and reduced cost.
As open-source draft models improve and frameworks optimize the verification pipeline, speculative decoding is becoming the default inference strategy for serious LLM deployments. Understanding and implementing it is a valuable skill for any ML engineer.
Next Steps
Try these experiments to deepen your understanding:
- Compare acceptance rates across different draft/target model pairs (e.g., Phi-3 drafting for Llama-3).
- Implement adaptive K selection based on real-time acceptance tracking.
- Benchmark speculative decoding against other acceleration methods like KV cache quantization and continuous batching.
- Explore Medusa heads — a related approach that adds lightweight prediction heads to the target model itself, eliminating the need for a separate draft model.
The code in this tutorial is minimal but complete. With it as a foundation, you can experiment with different model combinations, acceptance thresholds, and deployment strategies. Happy optimizing!