Mastering KV Cache Compression and Speculative Decoding: Accelerate Your LLM Inference by 3x

If you have ever served a large language model in production, you know the bottleneck: autoregressive token generation. Every output token requires a full forward pass, and the Key-Value (KV) cache grows linearly with sequence length — consuming memory and slowing down decoding. In this advanced tutorial, you will learn two production-grade techniques that together can deliver 2–3x throughput gains without sacrificing output quality.

Understanding the KV Cache Problem

During inference, a transformer model stores the Key and Value tensors for every token in the context window. For a model with n_layers, n_heads, and hidden_dim, the KV cache size per token is:

cache_bytes = 2 * n_layers * n_heads * (hidden_dim / n_heads) * 2  # FP16

For LLaMA-2 70B with a 4096-token context, that's roughly 80 GB of KV cache alone — often exceeding the model weights themselves.

Technique 1: Multi-Query Attention with Grouped-Query KV Cache

The simplest architectural win: instead of storing separate K/V heads for each query head, share K/V projections across groups of query heads. This is the idea behind MQA (Multi-Query Attention) and GQA (Grouped-Query Attention).

import torch
from transformers import AutoModelForCausalLM, AutoConfig

# Load model with GQA configuration
config = AutoConfig.from_pretrained("meta-llama/Llama-2-70b-hf")
config.num_key_value_heads = 8  # Original: 64, Grouped: 8 (8 Q heads per KV head)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf",
    config=config,
    torch_dtype=torch.float16,
    device_map="auto",
)

# KV cache reduction: 64 → 8 = 8x memory savings on KV cache
print(f"KV cache memory ratio: {config.num_key_value_heads / config.num_attention_heads:.2%}")

Trade-off: GQA slightly reduces output quality compared to MHA, but models trained with GQA from scratch (like LLaMA-3) close that gap entirely. For already-deployed MHA models, you can apply KV cache quantization instead.

Technique 2: KV Cache Quantization (INT8 / INT4)

Quantize the KV cache from FP16 to INT8 or even INT4. This requires per-channel dynamic quantization to preserve attention quality:

class QuantizedKVCache:
    """Per-token INT8 KV cache with dynamic scaling."""
    
    def __init__(self, max_seq_len: int, num_layers: int, 
                 num_kv_heads: int, head_dim: int):
        self.cache_k = torch.zeros(
            (num_layers, max_seq_len, num_kv_heads, head_dim),
            dtype=torch.int8, device="cuda"
        )
        self.cache_v = torch.zeros(
            (num_layers, max_seq_len, num_kv_heads, head_dim),
            dtype=torch.int8, device="cuda"
        )
        self.scales_k = torch.zeros(
            (num_layers, max_seq_len, num_kv_heads, 1),
            dtype=torch.float16, device="cuda"
        )
        self.scales_v = torch.zeros(
            (num_layers, max_seq_len, num_kv_heads, 1),
            dtype=torch.float16, device="cuda"
        )
        self.pos = 0  # Current cache position
    
    def update(self, layer: int, k: torch.Tensor, v: torch.Tensor):
        """Store INT8 quantized K/V at current position."""
        # Per-head dynamic quantization
        k_scale = k.abs().amax(dim=-1, keepdim=True) / 127.0
        v_scale = v.abs().amax(dim=-1, keepdim=True) / 127.0
        
        k_int8 = (k / k_scale).round().clamp(-128, 127).to(torch.int8)
        v_int8 = (v / v_scale).round().clamp(-128, 127).to(torch.int8)
        
        self.cache_k[layer, self.pos] = k_int8
        self.cache_v[layer, self.pos] = v_int8
        self.scales_k[layer, self.pos] = k_scale
        self.scales_v[layer, self.pos] = v_scale
        self.pos += 1
    
    def get(self, layer: int, seq_len: int) -> tuple:
        """Retrieve and dequantize KV cache up to seq_len."""
        k = self.cache_k[layer, :seq_len].to(torch.float16)
        v = self.cache_v[layer, :seq_len].to(torch.float16)
        k = k * self.scales_k[layer, :seq_len]
        v = v * self.scales_v[layer, :seq_len]
        return k, v

Memory savings: FP16 → INT8 gives 2x reduction. INT4 gives 4x with slightly more quality loss. For most generative tasks, INT8 is nearly lossless.

Technique 3: Speculative Decoding (Draft-Verify)

Speculative decoding is the most exciting inference acceleration technique in 2025–2026. The idea: use a small draft model to propose multiple tokens, then use the large target model to verify them in a single forward pass.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load target (large) and draft (small) models
target_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf", torch_dtype=torch.float16, device_map="auto"
)
draft_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")

def speculative_decode(input_ids, max_new_tokens=100, gamma=5, temperature=0.7):
    """
    Speculative decoding with draft-verify paradigm.
    
    Args:
        input_ids: Input token IDs
        max_new_tokens: Maximum tokens to generate
        gamma: Number of draft tokens per verification step
        temperature: Sampling temperature
    """
    generated = input_ids.clone()
    target_kv_cache = None
    draft_kv_cache = None
    
    tokens_generated = 0
    
    while tokens_generated < max_new_tokens:
        # Step 1: Draft model generates gamma tokens autoregressively
        draft_tokens = []
        draft_logits = []
        draft_input = generated.clone()
        
        for _ in range(gamma):
            with torch.no_grad():
                outputs = draft_model(
                    draft_input,
                    past_key_values=draft_kv_cache,
                    use_cache=True
                )
            draft_kv_cache = outputs.past_key_values
            next_logits = outputs.logits[0, -1, :] / temperature
            next_token = torch.multinomial(
                torch.softmax(next_logits, dim=-1), num_samples=1
            )
            draft_tokens.append(next_token.item())
            draft_logits.append(next_logits)
            draft_input = next_token.unsqueeze(0)
        
        # Step 2: Target model verifies ALL gamma tokens in ONE forward pass
        draft_seq = torch.tensor(
            draft_tokens, device=generated.device
        ).unsqueeze(0)
        verify_input = torch.cat([generated, draft_seq], dim=1)
        
        with torch.no_grad():
            target_outputs = target_model(
                verify_input,
                past_key_values=target_kv_cache,
                use_cache=True
            )
        target_kv_cache = target_outputs.past_key_values
        
        # Step 3: Accept/reject each draft token
        target_logits = target_outputs.logits[0, -gamma-1:-1, :]  # Logits at draft positions
        
        accepted = 0
        for i in range(gamma):
            draft_token = draft_tokens[i]
            target_probs = torch.softmax(target_logits[i] / temperature, dim=-1)
            draft_prob = target_probs[draft_token].item()
            
            # Acceptance criterion
            if torch.rand(1).item() < min(1.0, draft_prob):
                accepted += 1
            else:
                # Resample from normalized distribution
                target_probs[draft_token] = 0
                target_probs /= target_probs.sum()
                draft_tokens[i] = torch.multinomial(target_probs, num_samples=1).item()
                break
        
        # Append accepted tokens (plus the resampled one if rejection happened)
        tokens_to_add = draft_tokens[:accepted + (1 if accepted < gamma else 0)]
        generated = torch.cat([
            generated,
            torch.tensor(tokens_to_add, device=generated.device).unsqueeze(0)
        ], dim=1)
        
        tokens_generated += len(tokens_to_add)
        
        if tokens_generated >= max_new_tokens:
            break
    
    return generated[:, input_ids.shape[1]:]

# Usage
input_text = "Explain how attention mechanisms work in transformers."
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
output_ids = speculative_decode(inputs.input_ids, max_new_tokens=200, gamma=5)
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))

How Speculative Decoding Achieves Speedup

The key metric is the acceptance rate — how many draft tokens the target model agrees with. With a well-matched draft model:

  • 70B → 7B pairing: ~60-70% acceptance rate at gamma=5
  • Effective speedup: 1 + gamma * acceptance_rate forward passes saved
  • Real-world: 2.0–2.8x tokens/sec improvement

Technique 4: PagedAttention (vLLM-Style Memory Management)

Traditional KV cache allocates contiguous memory, wasting space due to internal and external fragmentation. PagedAttention treats KV cache like OS virtual memory — using non-contiguous blocks managed through a page table.

class PagedKVCache:
    """Simplified PagedAttention implementation."""
    
    def __init__(self, block_size: int = 16, num_blocks: int = 1024,
                 num_layers: int = 80, num_kv_heads: int = 8, head_dim: int = 128):
        self.block_size = block_size
        self.num_blocks = num_blocks
        
        # Allocate non-contiguous blocks
        self.blocks_k = torch.zeros(
            (num_blocks, num_layers, block_size, num_kv_heads, head_dim),
            dtype=torch.float16, device="cuda"
        )
        self.blocks_v = torch.zeros(
            (num_blocks, num_layers, block_size, num_kv_heads, head_dim),
            dtype=torch.float16, device="cuda"
        )
        
        # Block table: maps logical position to physical block
        self.free_blocks = list(range(num_blocks))
        self.block_tables = {}  # request_id -> [block_id, ...]
    
    def allocate_blocks(self, request_id: str, num_tokens: int) -> list:
        """Allocate blocks for a new request."""
        num_needed = (num_tokens + self.block_size - 1) // self.block_size
        assert len(self.free_blocks) >= num_needed, "OOM: no free blocks"
        
        allocated = self.free_blocks[:num_needed]
        self.free_blocks = self.free_blocks[num_needed:]
        self.block_tables[request_id] = allocated
        return allocated
    
    def append_tokens(self, request_id: str, k: torch.Tensor, v: torch.Tensor):
        """Append tokens, allocating new blocks as needed."""
        if request_id not in self.block_tables:
            self.allocate_blocks(request_id, k.shape[1])
        
        block_ids = self.block_tables[request_id]
        # Write to appropriate blocks (simplified)
        # In production: use CUDA kernels for block-table indexed attention
        ...

Why it matters: vLLM with PagedAttention achieves 2–4x higher throughput than naive HuggingFace pipelines by eliminating memory waste and enabling higher batching.

Putting It All Together: Production Stack

For a real production deployment, combine these techniques:

# Optimal stack for LLaMA-3 70B serving:
# 1. Use vLLM (PagedAttention + continuous batching)
# 2. INT8 KV cache quantization
# 3. Speculative decoding with LLaMA-3 8B as draft model
# 4. Tensor parallelism across 4x A100 80GB

from vllm import LLM, SamplingParams

llm = LLM(
    model="meta-llama/Meta-Llama-3-70B-Instruct",
    draft_model="meta-llama/Meta-Llama-3-8B-Instruct",  # Speculative decoding
    tensor_parallel_size=4,
    quantization="awq",  # Weight quantization
    enable_prefix_caching=True,  # KV cache sharing for shared prompts
    max_model_len=8192,
    gpu_memory_utilization=0.95,
)

sampling_params = SamplingParams(
    temperature=0.7,
    max_tokens=512,
    speculative_model="meta-llama/Meta-Llama-3-8B-Instruct",
    num_speculative_tokens=5,  # gamma = 5
)

outputs = llm.generate(
    ["Explain quantum computing", "Write a Python quicksort"],
    sampling_params=sampling_params,
)

Benchmarks: What to Expect

ConfigurationTokens/secMemory (GB)Relative Speed
Baseline (HF, FP16, no opt)121601.0x
+ INT8 KV cache121201.0x (memory only)
+ Speculative decoding (γ=5)281252.3x
+ PagedAttention (vLLM)451103.7x
All combined521104.3x

Key Takeaways

  1. Start with vLLM — PagedAttention + continuous batching alone gives 2–3x improvement over naive serving.
  2. Speculative decoding is the biggest free lunch — with a good draft model, 2–3x token throughput with zero quality loss.
  3. KV cache quantization unlocks longer contexts — INT8 is nearly lossless and halves KV memory.
  4. Profile before optimizing — use nvidia-smi and torch.profiler to identify whether you're compute-bound or memory-bound.

The era of "just run it through HuggingFace pipeline" is over. If you are serving LLMs at scale, these techniques are not optional — they are the foundation of efficient inference.