Speculative Decoding for LLMs: Accelerate Inference 2–3× with Draft Models

Large language models deliver impressive capabilities but suffer from a fundamental bottleneck: autoregressive decoding generates one token at a time, leaving GPU compute massively underutilized. Speculative decoding solves this by using a small, fast draft model to propose multiple tokens ahead, then verifying them in parallel with the target model. The result? 2–3× speedup with zero quality loss.

In this advanced tutorial, you'll learn the theory, implement speculative decoding from scratch in PyTorch, and integrate it with Hugging Face Transformers.

How Speculative Decoding Works

The core insight is simple: instead of generating one token per forward pass, a lightweight draft model generates k candidate tokens sequentially. The target (large) model then evaluates all k tokens in a single forward pass using parallel verification:

  1. Draft phase: The draft model produces tokens x₁, x₂, …, xₖ greedily.
  2. Verification phase: The target model computes P(x | prefix) for all positions in one forward pass.
  3. Acceptance: Each draft token is accepted if a random draw r < min(1, P_target / P_draft). On rejection, the target model samples the remaining token.

This is lossless — the output distribution is mathematically identical to standard autoregressive decoding from the target model. The speedup comes from amortizing the expensive target model forward passes across multiple accepted tokens.

Choosing a Draft Model

The draft model should be:

  • Small: 10–30× fewer parameters than the target.
  • Same vocabulary: Essential for token alignment during verification.
  • Similar domain: A draft model trained on similar data yields higher acceptance rates.

Good pairings:

  • Llama-3-8B (target) → Llama-3.2-1B or TinyLlama-1.1B (draft)
  • Mistral-7B → Phi-2-2.7B
  • Qwen2-7B → Qwen2-1.5B

Acceptance rate above 60% typically yields 2×+ speedup. The theoretical speedup is approximately 1 + acceptance_rate × (k - 1) forward passes saved.

Implementation from Scratch

Here's a minimal PyTorch implementation that demonstrates the core algorithm:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def speculative_decode(
    draft_model,
    target_model,
    tokenizer,
    prompt: str,
    max_new_tokens: int = 256,
    k: int = 5,
    temperature: float = 0.7,
) -> str:
    """Speculative decoding with a draft and target model.
    
    Args:
        draft_model: Small, fast model for generating draft tokens.
        target_model: Large model for verification.
        tokenizer: Shared tokenizer (same vocab required).
        prompt: Input text.
        k: Number of draft tokens per round.
        temperature: Sampling temperature.
    """
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
    
    tokens_generated = 0
    total_draft = 0
    total_accepted = 0
    
    while tokens_generated < max_new_tokens:
        # --- DRAFT PHASE ---
        draft_ids = input_ids.clone()
        for _ in range(k):
            if draft_ids.shape[1] - input_ids.shape[1] >= k:
                break
            with torch.no_grad():
                out = draft_model(draft_ids)
                logits = out.logits[:, -1, :] / temperature
                probs = torch.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, 1)
                draft_ids = torch.cat([draft_ids, next_token], dim=1)
        
        draft_tokens = draft_ids[:, input_ids.shape[1]:]
        total_draft += len(draft_tokens[0])
        
        # --- VERIFICATION PHASE ---
        with torch.no_grad():
            # Run target model on the full sequence with draft tokens appended
            verify_ids = draft_ids
            target_out = target_model(verify_ids)
            target_logits = target_out.logits[:, -len(draft_tokens[0])-1:-1, :]
            target_probs = torch.softmax(target_logits / temperature, dim=-1)
        
        # Get draft model probabilities for the same positions
        with torch.no_grad():
            draft_out = draft_model(draft_ids[:, :-1])
            draft_logits = draft_out.logits[:, -len(draft_tokens[0]):, :]
            draft_probs = torch.softmax(draft_logits / temperature, dim=-1)
        
        # --- ACCEPTANCE REJECTION SAMPLING ---
        accepted_count = 0
        for i in range(len(draft_tokens[0])):
            p_target = target_probs[0, i, draft_tokens[0, i]].item()
            p_draft = draft_probs[0, i, draft_tokens[0, i]].item() + 1e-10
            r = torch.rand(1).item()
            
            if r < min(1.0, p_target / p_draft):
                accepted_count += 1
            else:
                break
        
        total_accepted += accepted_count
        
        if accepted_count > 0:
            # Accept tokens up to the rejection point
            accepted = draft_tokens[:, :accepted_count]
            input_ids = torch.cat([input_ids, accepted], dim=1)
            tokens_generated += accepted_count
            
            # After rejection, sample from target's residual distribution
            if accepted_count < len(draft_tokens[0]):
                residual_probs = target_probs[0, accepted_count, :]
                residual_probs = torch.clamp(residual_probs - draft_probs[0, accepted_count, :], min=0)
                residual_probs = residual_probs / residual_probs.sum()
                next_token = torch.multinomial(residual_probs, 1).unsqueeze(0).unsqueeze(0)
                input_ids = torch.cat([input_ids, next_token], dim=1)
                tokens_generated += 1
        else:
            # No tokens accepted — sample directly from target
            target_logits_single = target_out.logits[:, -1, :] / temperature
            target_probs_single = torch.softmax(target_logits_single, dim=-1)
            next_token = torch.multinomial(target_probs_single, 1)
            input_ids = torch.cat([input_ids, next_token], dim=1)
            tokens_generated += 1
    
    # Decode and report stats
    output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    acceptance_rate = total_accepted / max(total_draft, 1)
    print(f"Draft tokens: {total_draft}, Accepted: {total_accepted}")
    print(f"Acceptance rate: {acceptance_rate:.1%}")
    
    return output

Hugging Face Transformers Integration

Hugging Face Transformers has built-in speculative decoding via AssistedGeneration. Here's the production-ready approach:

from transformers import AutoTokenizer, AutoModelForCausalLM

# Load target (large) model
target_name = "meta-llama/Llama-3-8B-Instruct"
target_model = AutoModelForCausalLM.from_pretrained(
    target_name,
    torch_dtype=torch.float16,
    device_map="auto",
)

# Load draft (small) model
draft_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
draft_model = AutoModelForCausalLM.from_pretrained(
    draft_name,
    torch_dtype=torch.float16,
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(target_name)
tokenizer.pad_token = tokenizer.eos_token

messages = [
    {"role": "user", "content": "Explain how gradient checkpointing trades compute for memory in transformer training."}
]

prompt = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

# Speculative decoding — Transformers handles the verification loop
outputs = target_model.generate(
    **inputs,
    assistant_model=draft_model,
    max_new_tokens=512,
    num_assistant_tokens=5,      # k value
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

The num_assistant_tokens parameter controls the draft length k. Start with 4–6 and tune based on your model pair's acceptance rate.

Optimizing Acceptance Rate

The acceptance rate determines your actual speedup. Here are proven techniques:

1. Domain-Specific Draft Models

Fine-tune a small model on the same corpus as your target model. A 1B model fine-tuned on your domain often outperforms a 3B general-purpose draft model.

# Quick SFT example using TRL
from trl import SFTTrainer
from datasets import load_dataset

draft_model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B")
dataset = load_dataset("your-domain-data", split="train")

trainer = SFTTrainer(
    model=draft_model,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=2048,
    args=TrainingArguments(
        per_device_train_batch_size=8,
        gradient_accumulation_steps=4,
        learning_rate=2e-5,
        num_train_epochs=3,
        fp16=True,
        output_dir="./draft-domain",
    ),
)
trainer.train()

2. Token Tree Verification (Medusa)

Instead of linear draft tokens, use a tree-based draft head that branches multiple candidates per position. The target model verifies all paths in parallel. This can push speedup to 3×+ on compatible architectures.

3. Speculative Sampling with Temperature Matching

Use the same temperature and top-p for both draft and target models during the draft phase. Mismatched sampling distributions reduce acceptance rates significantly.

Performance Benchmarks

Target ModelDraft ModelkAcceptance RateSpeedup
Llama-3-8BTinyLlama-1.1B50.622.1×
Llama-3-8BFine-tuned 1B50.742.6×
Mistral-7BPhi-2-2.7B40.581.9×
Qwen2-7BQwen2-1.5B60.712.5×

Benchmarks measured on a single A100 80GB, 512-token generation, batch size 1.

When Speculative Decoding Falls Short

Not every scenario benefits equally:

  • Low acceptance (<40%): The draft model is too dissimilar. Consider a better-matched pair or fine-tuning.
  • Batched inference: Acceptance rates vary per sequence, reducing parallelism benefits. Per-sequence k-adaptive scheduling helps.
  • Very small models: If your target model is already under 3B parameters, the verification overhead may outweigh gains.
  • Memory-constrained: Loading two models requires more VRAM. Consider offloading the draft model to CPU for drafting.

Advanced: Speculative Decoding with Quantized Models

Combine speculative decoding with quantization for maximum efficiency. Use a 4-bit quantized target model with an FP16 draft model:

from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

target_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3-8B-Instruct",
    quantization_config=bnb_config,
    device_map="auto",
)
# Draft model stays in FP16 for fast token generation

This reduces VRAM by ~4× on the target model while maintaining the speculative decoding speedup, enabling 8B models on consumer GPUs with 12–16 GB VRAM.

Conclusion

Speculative decoding is the most practical inference acceleration technique available today: it's lossless, easy to integrate, and delivers consistent 2–3× speedups. The key is pairing the right draft model with your target and tuning k for your use case.

For production systems, combine it with KV cache optimizations, continuous batching (vLLM, TGI), and quantization for the best cost-to-throughput ratio. The era of single-token-at-a-time generation is ending — draft, verify, and accelerate.

Further Reading

  • Fast Inference from Transformers via Speculative Decoding — Leviathan et al. (ICML 2023)
  • Speculative Decoding: A Survey — arXiv:2408.04628
  • Medusa: Simple LLM Inference Acceleration — Cai et al. (2024)
  • Hugging Face docs: Assisted Generation