Speculative Decoding: 2-3x Faster LLM Inference Without Quality Loss

If you have ever stared at a terminal waiting for an LLM to finish generating a response, you have felt the fundamental bottleneck of modern language models: autoregressive decoding. At every token, the model runs a full forward pass. Generating 500 tokens means 500 sequential GPU calls. No batching can hide that latency.

Speculative decoding breaks this constraint. It lets a small, fast "draft" model propose multiple tokens in parallel, then uses the large "target" model to verify them in a single forward pass. The result: 2-3x throughput improvement with zero quality degradation — the output distribution is mathematically identical to greedy or nucleus sampling from the target model alone.

In this tutorial, we will build a working speculative decoder from scratch using Hugging Face Transformers and PyTorch.

How Speculative Decoding Works

The algorithm has two phases per step:

  1. Draft Phase: A small model (e.g., Llama-3.2-1B) generates γ tokens autoregressively — fast, cheap, parallelizable.
  2. Verification Phase: The large model (e.g., Llama-3.1-70B) runs a single forward pass over all γ draft tokens, computing probabilities for each position simultaneously. Tokens that pass a probabilistic acceptance test are kept; rejected tokens trigger a correction step.

The key insight: verifying γ tokens costs roughly the same as generating one token with the large model, because the KV cache is shared and the forward pass is parallel across positions.

The Acceptance Criterion

At position i, let the draft model produce token xᵢ with probability q(xᵢ), and the target model assign probability p(xᵢ). We accept xᵢ with probability:

α = min(1, p(xᵢ) / q(xᵢ))

If accepted, we move to position i+1. If rejected, we sample from the normalized residual distribution: norm(max(0, p − q)), which guarantees the output distribution matches the target model exactly.

Implementation from Scratch

1. Setup and Model Loading

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load draft model (small, fast)
draft_name = "meta-llama/Llama-3.2-1B"
draft_tokenizer = AutoTokenizer.from_pretrained(draft_name)
draft_model = AutoModelForCausalLM.from_pretrained(
    draft_name,
    torch_dtype=torch.float16,
    device_map="auto"
).eval()

# Load target model (large, accurate)
target_name = "meta-llama/Llama-3.1-8B"
target_tokenizer = AutoTokenizer.from_pretrained(target_name)
target_model = AutoModelForCausalLM.from_pretrained(
    target_name,
    torch_dtype=torch.float16,
    device_map="auto"
).eval()

2. Core Speculative Decoding Loop

def speculative_decode(draft_model, target_model, draft_tok, target_tok,
                       prompt, max_new_tokens=256, gamma=5):
    """
    Speculative decoding with a single draft model.

    Args:
        gamma: Number of draft tokens per verification step.
    Returns:
        Generated tokens and generation statistics.
    """
    # Encode prompt
    input_ids = draft_tok(prompt, return_tensors="pt").input_ids.to(device)
    seq_len = input_ids.shape[1]

    accepted_total = 0
    drafted_total = 0
    steps = 0

    while input_ids.shape[1] < seq_len + max_new_tokens:
        steps += 1

        # --- DRAFT PHASE ---
        draft_ids = input_ids.clone()
        for _ in range(gamma):
            with torch.no_grad():
                out = draft_model(draft_ids)
                next_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
                draft_ids = torch.cat([draft_ids, next_token], dim=1)

        draft_tokens = draft_ids[:, seq_len:]  # γ new tokens

        # --- VERIFICATION PHASE ---
        with torch.no_grad():
            # Run target model on the full sequence (draft tokens included)
            target_out = target_model(draft_ids)
            target_logits = target_out.logits[:, seq_len - 1 : seq_len - 1 + gamma, :]
            target_probs = torch.softmax(target_logits, dim=-1)

        # Get draft probabilities for each position
        draft_probs_list = []
        for i in range(gamma):
            with torch.no_grad():
                draft_out = draft_model(
                    draft_ids[:, : seq_len + i]
                )
                d_prob = torch.softmax(draft_out.logits[:, -1, :], dim=-1)
                draft_probs_list.append(d_prob)

        # Acceptance loop
        accepted_count = 0
        for i in range(gamma):
            draft_token = draft_tokens[0, i].item()
            p_val = target_probs[0, i, draft_token].item()
            q_val = draft_probs_list[i][0, draft_token].item()

            acceptance_prob = min(1.0, p_val / (q_val + 1e-12))

            if torch.rand(1).item() < acceptance_prob:
                accepted_count += 1
                # Append the verified token
                input_ids = torch.cat([
                    input_ids,
                    draft_tokens[:, i:i+1]
                ], dim=1)
            else:
                # Rejection: sample from residual distribution
                residual = torch.clamp(target_probs[0, i] - draft_probs_list[i][0], min=0)
                residual_sum = residual.sum()
                if residual_sum > 0:
                    residual = residual / residual_sum
                    new_token = torch.multinomial(residual, 1)
                    input_ids = torch.cat([input_ids, new_token.unsqueeze(0)], dim=1)
                else:
                    new_token = target_probs[0, i].multinomial(1)
                    input_ids = torch.cat([input_ids, new_token.unsqueeze(0)], dim=1)
                break  # Stop this draft round

        accepted_total += accepted_count
        drafted_total += gamma
        seq_len = input_ids.shape[1]

        # Early exit if no tokens accepted
        if accepted_count == 0:
            break

    acceptance_rate = accepted_total / max(drafted_total, 1)
    print(f"Steps: {steps} | Acceptance rate: {acceptance_rate:.2%} | "
          f"Speedup: ~{1 + accepted_total / steps:.1f}x")

    return draft_tok.decode(input_ids[0], skip_special_tokens=True)

3. Running It

prompt = "Explain how attention mechanisms work in transformer models, including"

result = speculative_decode(
    draft_model, target_model,
    draft_tokenizer, target_tokenizer,
    prompt, max_new_tokens=200, gamma=4
)

print(result)

Expected Performance

With a 1B draft model and an 8B target model on a single A100 GPU:

MetricBaselineSpeculative Decoding
Tokens/sec~45~95-120
Acceptance rate (γ=4)~75-85%
Effective speedup1x2.0-2.5x

Acceptance rate depends on how well the draft model matches the target. Fine-tuning the draft on the target's outputs (a process called speculative decoding with a self-drafted model) can push rates above 90%.

Optimization Tips

  • Choose γ wisely: Too small → underutilized verification. Too large → diminishing returns as acceptance drops. Empirically, γ ∈ [3, 6] works best for 1B→8B pairs.
  • Shared tokenizer: Use the same tokenizer for both models to avoid encoding mismatches.
  • Batch KV cache reuse: In the verification phase, the target model only needs to process positions from seq_len − 1 onward. Pre-computed KV caches make this near-free.
  • Multi-token draft: Advanced implementations (e.g., Medusa, Eagle) replace the draft model with lightweight heads on the target model itself, achieving 3-4x speedups.
  • Speculative sampling: Replace greedy draft generation with temperature-scaled sampling to increase diversity and acceptance rates for creative tasks.

When to Use Speculative Decoding

Great for:

  • Real-time chat applications where latency matters
  • Code completion (high acceptance rates due to deterministic syntax)
  • Long-form generation (blog posts, summaries, reports)
  • Any scenario where the target model is GPU-memory-bound

Less effective for:

  • Very short responses (< 20 tokens) — overhead outweighs gains
  • Highly creative/divergent generation where draft-target alignment is low
  • When the draft model is too small relative to the target (> 50x parameter gap)

Production Implementations

Don't roll your own for production. These frameworks support speculative decoding out of the box:

  • vLLM--speculative-model flag with n-gram or model-based draft
  • TGI (Text Generation Inference) — Medusa and draft model support
  • Hugging Face TransformersAssistedGenerationMixin since v4.38
  • llama.cpp — speculative decoding with small draft models

Summary

Speculative decoding is one of the few inference optimizations that improves throughput without any trade-off in output quality. By letting a cheap draft model propose tokens and a large model verify them in parallel, you get a free 2-3x speedup. The math guarantees identical distributions. The implementation is straightforward. There is no reason not to use it.

Start with Hugging Face's built-in assisted generation, profile your acceptance rates, tune γ, and watch your tokens/sec climb. 🦊