Speculative Decoding: 2x–3x Faster LLM Inference Without Quality Loss

If you’ve ever waited for a large language model to generate a response token by token, you know the frustration. Autoregressive decoding is inherently sequential — each token depends on all previous ones. But what if you could predict several tokens ahead, verify them in parallel, and dramatically speed up generation? That’s exactly what speculative decoding does, and in this tutorial, we’ll build it from scratch.

What Is Speculative Decoding?

Speculative decoding (also called assisted decoding or medusa decoding) uses a small, fast draft model to propose multiple tokens ahead. A larger target model then verifies those proposals in a single forward pass. The key insight: verifying k tokens costs roughly the same as generating one token with the target model.

The algorithm works in rounds:

  1. Draft phase: The draft model generates γ (gamma) tokens autoregressively — very fast because it’s small.
  2. Verification phase: The target model runs a single forward pass on the input + all γ draft tokens, producing its own probability distribution for each position.
  3. Acceptance check: Each draft token is accepted or rejected using a probabilistic criterion that guarantees the output distribution matches the target model exactly.
  4. Repeat: Continue from the last accepted position.

The beauty? The final output distribution is mathematically identical to what the target model would produce alone. No quality loss — just speed.

The Acceptance Criterion

For each position i, let q(x) be the draft model’s probability and p(x) be the target model’s probability for the next token. We accept the draft token x with probability:

accept_prob = min(1, p(x) / q(x))

If rejected, we sample a new token from the normalized residual distribution max(0, p(x) − q(x)). This rejection sampling step is what guarantees exact distribution matching.

Implementation in Python with Hugging Face

Let’s build a working implementation. We’ll use a small model (TinyLlama) as the drafter and a larger model (Llama-3.1-8B) as the target.

Step 1: Setup

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load models
draft_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
target_name = "meta-llama/Llama-3.1-8B-Instruct"

draft_tok = AutoTokenizer.from_pretrained(draft_name)
draft_tok.pad_token = draft_tok.eos_token

target_tok = AutoTokenizer.from_pretrained(target_name)
target_tok.pad_token = target_tok.eos_token

draft_model = AutoModelForCausalLM.from_pretrained(
    draft_name, torch_dtype=torch.float16, device_map=device
)
target_model = AutoModelForCausalLM.from_pretrained(
    target_name, torch_dtype=torch.float16, device_map=device
)

draft_model.eval()
target_model.eval()

Step 2: Core Speculative Decoding Loop

def speculative_decode(input_ids, draft_model, target_model,
                       draft_tok, target_tok, gamma=5,
                       max_new_tokens=100, temperature=0.7):
    """
    Speculative decoding with exact distribution matching.

    Args:
        input_ids: Initial token IDs (tensor)
        gamma: Number of draft tokens to generate per round
        max_new_tokens: Maximum tokens to generate
        temperature: Sampling temperature

    Returns:
        Generated token IDs
    """
    current_ids = input_ids.clone()
    tokens_generated = 0

    while tokens_generated < max_new_tokens:
        # --- DRAFT PHASE ---
        draft_tokens = []
        draft_probs = []
        draft_input = current_ids.clone()

        for _ in range(gamma):
            with torch.no_grad():
                out = draft_model(draft_input)
                logits = out.logits[:, -1, :] / temperature
                probs = F.softmax(logits, dim=-1)
                token = torch.multinomial(probs, num_samples=1)

            draft_tokens.append(token.item())
            draft_probs.append(probs.squeeze())
            draft_input = torch.cat([draft_input, token], dim=1)

            if token.item() == draft_tok.eos_token_id:
                break

        k = len(draft_tokens)
        if k == 0:
            break

        # --- VERIFICATION PHASE ---
        with torch.no_grad():
            # Run target model on input + all draft tokens
            verify_input = torch.cat([
                current_ids,
                torch.tensor([draft_tokens], device=device)
            ], dim=1)

            target_out = target_model(verify_input)
            target_logits = target_out.logits

        accepted = 0
        for i in range(k):
            # Position in the full sequence
            pos = current_ids.size(1) - 1 + i

            q = draft_probs[i]  # draft probability distribution
            p = F.softmax(target_logits[0, pos, :] / temperature, dim=-1)

            draft_token = draft_tokens[i]

            # Acceptance probability
            accept_prob = min(1.0, p[draft_token].item() / q[draft_token].item())

            if torch.rand(1).item() < accept_prob:
                accepted += 1
                current_ids = torch.cat([
                    current_ids,
                    torch.tensor([[draft_token]], device=device)
                ], dim=1)
            else:
                # Rejection sampling from residual distribution
                residual = torch.clamp(p - q, min=0)
                residual_sum = residual.sum()
                if residual_sum > 0:
                    residual = residual / residual_sum
                    new_token = torch.multinomial(residual, num_samples=1)
                else:
                    new_token = torch.multinomial(p, num_samples=1)

                current_ids = torch.cat([
                    current_ids,
                    new_token.unsqueeze(0)
                ], dim=1)
                break  # Move to next round

        tokens_generated += accepted + (1 if accepted < k else 0)

        if tokens_generated >= max_new_tokens:
            break

    return current_ids[:, input_ids.size(1):]

Step 3: Running the Generator

prompt = "Explain how transformers use attention mechanisms"
input_ids = target_tok(prompt, return_tensors="pt").input_ids.to(device)

# Standard autoregressive (baseline)
import time
start = time.time()
std_output = target_model.generate(
    input_ids, max_new_tokens=50, do_sample=True, temperature=0.7
)
std_time = time.time() - std_time
print(f"Standard: {std_time:.2f}s")

# Speculative decoding
start = time.time()
spec_output = speculative_decode(
    input_ids, draft_model, target_model,
    draft_tok, target_tok, gamma=5,
    max_new_tokens=50, temperature=0.7
)
spec_time = time.time() - start
print(f"Speculative: {spec_time:.2f}s")

speedup = std_time / spec_time
print(f"Speedup: {speedup:.2f}x")
print(target_tok.decode(spec_output[0], skip_special_tokens=True))

Optimizations That Matter

1. KV-Cache for the Target Model

The naive implementation re-runs the target model on all previous tokens each round. In production, maintain a KV-cache to avoid redundant computation:

# Use past_key_values to cache attention states
target_out = target_model(
    verify_input,
    past_key_values=past_kv,
    use_cache=True
)
logits = target_out.logits
past_kv = target_out.past_key_values

2. Draft Model Selection Strategies

The choice of draft model dramatically affects speedup. Practical options:

  • Smaller model from same family: Llama-3.2-1B → Llama-3.1-8B (best accuracy match)
  • Earlier layers of the same model: Use the first N layers as a drafter (zero extra memory)
  • n-gram matching: Extract frequent n-grams from context as a free drafter (no model needed)
  • Medusa heads: Train lightweight decoding heads on top of the target model

3. Adaptive Gamma

Instead of a fixed γ, adapt dynamically based on the recent acceptance rate:

if acceptance_rate > 0.8:
    gamma = min(gamma + 1, max_gamma)
elif acceptance_rate < 0.4:
    gamma = max(gamma - 1, 1)

Expected Performance

With a well-matched draft/target pair on GPU:

ScenarioTypical SpeedupAcceptance Rate
Code generation2.0x – 2.8x70–85%
Creative writing1.5x – 2.2x50–70%
Math/STEM1.3x – 1.8x40–60%
n-gram only (no draft model)1.2x – 1.5x30–50%

Code generation sees the highest speedups because code is highly predictable — the draft model gets most tokens right.

Production-Ready Alternatives

If you don’t want to implement this yourself, several frameworks support speculative decoding out of the box:

# Hugging Face transformers (v4.38+)
output = model.generate(
    input_ids,
    assistant_model=draft_model,  # triggers speculative decoding
    max_new_tokens=100,
    assistant_tokenizer=draft_tok,
)

# vLLM (high-throughput serving)
# Set in server config:
# --speculative-model [draft_path] --num-speculative-tokens 5

# TensorRT-LLM
# Use SpeculativeDecodingPlugin in the build config

When NOT to Use Speculative Decoding

Speculative decoding isn’t always the right choice:

  • Small models (<3B): The overhead outweighs the benefit. Speculative decoding shines when the target model is large enough that a forward pass is the bottleneck.
  • Batch size > 1: Acceptance rates vary across sequences, making batched speculative decoding complex. vLLM handles this with per-sequence speculation.
  • Memory-constrained: You need to load two models. If you can’t fit both on your GPU, consider Medusa heads or n-gram drafting instead.

Key Takeaways

  • Speculative decoding gives exact target model distribution — zero quality tradeoff.
  • Speedup depends on draft model quality and task predictability. Code = fast, creative = moderate.
  • KV-caching and adaptive gamma are essential for production performance.
  • Frameworks like HF Transformers, vLLM, and TensorRT-LLM support it natively — no custom code needed in many cases.

Speculative decoding is one of those rare techniques where you genuinely get something for nothing: faster inference, same output distribution. If you’re serving LLMs in production and not using it, you’re leaving 2–3x performance on the table.