Mastering KV Cache Compression and Speculative Decoding: Accelerate Your LLM Inference by 3x
If you have ever served a large language model in production, you know the bottleneck: autoregressive token generation. Every output token requires a full forward pass, and the Key-Value (KV) cache grows linearly with sequence length — consuming memory and slowing down decoding. In this advanced tutorial, you will learn two production-grade techniques that together can deliver 2–3x throughput gains without sacrificing output quality.
Understanding the KV Cache Problem
During inference, a transformer model stores the Key and Value tensors for every token in the context window. For a model with n_layers, n_heads, and hidden_dim, the KV cache size per token is:
cache_bytes = 2 * n_layers * n_heads * (hidden_dim / n_heads) * 2 # FP16
For LLaMA-2 70B with a 4096-token context, that's roughly 80 GB of KV cache alone — often exceeding the model weights themselves.
Technique 1: Multi-Query Attention with Grouped-Query KV Cache
The simplest architectural win: instead of storing separate K/V heads for each query head, share K/V projections across groups of query heads. This is the idea behind MQA (Multi-Query Attention) and GQA (Grouped-Query Attention).
import torch
from transformers import AutoModelForCausalLM, AutoConfig
# Load model with GQA configuration
config = AutoConfig.from_pretrained("meta-llama/Llama-2-70b-hf")
config.num_key_value_heads = 8 # Original: 64, Grouped: 8 (8 Q heads per KV head)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
config=config,
torch_dtype=torch.float16,
device_map="auto",
)
# KV cache reduction: 64 → 8 = 8x memory savings on KV cache
print(f"KV cache memory ratio: {config.num_key_value_heads / config.num_attention_heads:.2%}")
Trade-off: GQA slightly reduces output quality compared to MHA, but models trained with GQA from scratch (like LLaMA-3) close that gap entirely. For already-deployed MHA models, you can apply KV cache quantization instead.
Technique 2: KV Cache Quantization (INT8 / INT4)
Quantize the KV cache from FP16 to INT8 or even INT4. This requires per-channel dynamic quantization to preserve attention quality:
class QuantizedKVCache:
"""Per-token INT8 KV cache with dynamic scaling."""
def __init__(self, max_seq_len: int, num_layers: int,
num_kv_heads: int, head_dim: int):
self.cache_k = torch.zeros(
(num_layers, max_seq_len, num_kv_heads, head_dim),
dtype=torch.int8, device="cuda"
)
self.cache_v = torch.zeros(
(num_layers, max_seq_len, num_kv_heads, head_dim),
dtype=torch.int8, device="cuda"
)
self.scales_k = torch.zeros(
(num_layers, max_seq_len, num_kv_heads, 1),
dtype=torch.float16, device="cuda"
)
self.scales_v = torch.zeros(
(num_layers, max_seq_len, num_kv_heads, 1),
dtype=torch.float16, device="cuda"
)
self.pos = 0 # Current cache position
def update(self, layer: int, k: torch.Tensor, v: torch.Tensor):
"""Store INT8 quantized K/V at current position."""
# Per-head dynamic quantization
k_scale = k.abs().amax(dim=-1, keepdim=True) / 127.0
v_scale = v.abs().amax(dim=-1, keepdim=True) / 127.0
k_int8 = (k / k_scale).round().clamp(-128, 127).to(torch.int8)
v_int8 = (v / v_scale).round().clamp(-128, 127).to(torch.int8)
self.cache_k[layer, self.pos] = k_int8
self.cache_v[layer, self.pos] = v_int8
self.scales_k[layer, self.pos] = k_scale
self.scales_v[layer, self.pos] = v_scale
self.pos += 1
def get(self, layer: int, seq_len: int) -> tuple:
"""Retrieve and dequantize KV cache up to seq_len."""
k = self.cache_k[layer, :seq_len].to(torch.float16)
v = self.cache_v[layer, :seq_len].to(torch.float16)
k = k * self.scales_k[layer, :seq_len]
v = v * self.scales_v[layer, :seq_len]
return k, v
Memory savings: FP16 → INT8 gives 2x reduction. INT4 gives 4x with slightly more quality loss. For most generative tasks, INT8 is nearly lossless.
Technique 3: Speculative Decoding (Draft-Verify)
Speculative decoding is the most exciting inference acceleration technique in 2025–2026. The idea: use a small draft model to propose multiple tokens, then use the large target model to verify them in a single forward pass.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load target (large) and draft (small) models
target_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf", torch_dtype=torch.float16, device_map="auto"
)
draft_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")
def speculative_decode(input_ids, max_new_tokens=100, gamma=5, temperature=0.7):
"""
Speculative decoding with draft-verify paradigm.
Args:
input_ids: Input token IDs
max_new_tokens: Maximum tokens to generate
gamma: Number of draft tokens per verification step
temperature: Sampling temperature
"""
generated = input_ids.clone()
target_kv_cache = None
draft_kv_cache = None
tokens_generated = 0
while tokens_generated < max_new_tokens:
# Step 1: Draft model generates gamma tokens autoregressively
draft_tokens = []
draft_logits = []
draft_input = generated.clone()
for _ in range(gamma):
with torch.no_grad():
outputs = draft_model(
draft_input,
past_key_values=draft_kv_cache,
use_cache=True
)
draft_kv_cache = outputs.past_key_values
next_logits = outputs.logits[0, -1, :] / temperature
next_token = torch.multinomial(
torch.softmax(next_logits, dim=-1), num_samples=1
)
draft_tokens.append(next_token.item())
draft_logits.append(next_logits)
draft_input = next_token.unsqueeze(0)
# Step 2: Target model verifies ALL gamma tokens in ONE forward pass
draft_seq = torch.tensor(
draft_tokens, device=generated.device
).unsqueeze(0)
verify_input = torch.cat([generated, draft_seq], dim=1)
with torch.no_grad():
target_outputs = target_model(
verify_input,
past_key_values=target_kv_cache,
use_cache=True
)
target_kv_cache = target_outputs.past_key_values
# Step 3: Accept/reject each draft token
target_logits = target_outputs.logits[0, -gamma-1:-1, :] # Logits at draft positions
accepted = 0
for i in range(gamma):
draft_token = draft_tokens[i]
target_probs = torch.softmax(target_logits[i] / temperature, dim=-1)
draft_prob = target_probs[draft_token].item()
# Acceptance criterion
if torch.rand(1).item() < min(1.0, draft_prob):
accepted += 1
else:
# Resample from normalized distribution
target_probs[draft_token] = 0
target_probs /= target_probs.sum()
draft_tokens[i] = torch.multinomial(target_probs, num_samples=1).item()
break
# Append accepted tokens (plus the resampled one if rejection happened)
tokens_to_add = draft_tokens[:accepted + (1 if accepted < gamma else 0)]
generated = torch.cat([
generated,
torch.tensor(tokens_to_add, device=generated.device).unsqueeze(0)
], dim=1)
tokens_generated += len(tokens_to_add)
if tokens_generated >= max_new_tokens:
break
return generated[:, input_ids.shape[1]:]
# Usage
input_text = "Explain how attention mechanisms work in transformers."
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
output_ids = speculative_decode(inputs.input_ids, max_new_tokens=200, gamma=5)
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
How Speculative Decoding Achieves Speedup
The key metric is the acceptance rate — how many draft tokens the target model agrees with. With a well-matched draft model:
- 70B → 7B pairing: ~60-70% acceptance rate at
gamma=5 - Effective speedup:
1 + gamma * acceptance_rateforward passes saved - Real-world: 2.0–2.8x tokens/sec improvement
Technique 4: PagedAttention (vLLM-Style Memory Management)
Traditional KV cache allocates contiguous memory, wasting space due to internal and external fragmentation. PagedAttention treats KV cache like OS virtual memory — using non-contiguous blocks managed through a page table.
class PagedKVCache:
"""Simplified PagedAttention implementation."""
def __init__(self, block_size: int = 16, num_blocks: int = 1024,
num_layers: int = 80, num_kv_heads: int = 8, head_dim: int = 128):
self.block_size = block_size
self.num_blocks = num_blocks
# Allocate non-contiguous blocks
self.blocks_k = torch.zeros(
(num_blocks, num_layers, block_size, num_kv_heads, head_dim),
dtype=torch.float16, device="cuda"
)
self.blocks_v = torch.zeros(
(num_blocks, num_layers, block_size, num_kv_heads, head_dim),
dtype=torch.float16, device="cuda"
)
# Block table: maps logical position to physical block
self.free_blocks = list(range(num_blocks))
self.block_tables = {} # request_id -> [block_id, ...]
def allocate_blocks(self, request_id: str, num_tokens: int) -> list:
"""Allocate blocks for a new request."""
num_needed = (num_tokens + self.block_size - 1) // self.block_size
assert len(self.free_blocks) >= num_needed, "OOM: no free blocks"
allocated = self.free_blocks[:num_needed]
self.free_blocks = self.free_blocks[num_needed:]
self.block_tables[request_id] = allocated
return allocated
def append_tokens(self, request_id: str, k: torch.Tensor, v: torch.Tensor):
"""Append tokens, allocating new blocks as needed."""
if request_id not in self.block_tables:
self.allocate_blocks(request_id, k.shape[1])
block_ids = self.block_tables[request_id]
# Write to appropriate blocks (simplified)
# In production: use CUDA kernels for block-table indexed attention
...
Why it matters: vLLM with PagedAttention achieves 2–4x higher throughput than naive HuggingFace pipelines by eliminating memory waste and enabling higher batching.
Putting It All Together: Production Stack
For a real production deployment, combine these techniques:
# Optimal stack for LLaMA-3 70B serving:
# 1. Use vLLM (PagedAttention + continuous batching)
# 2. INT8 KV cache quantization
# 3. Speculative decoding with LLaMA-3 8B as draft model
# 4. Tensor parallelism across 4x A100 80GB
from vllm import LLM, SamplingParams
llm = LLM(
model="meta-llama/Meta-Llama-3-70B-Instruct",
draft_model="meta-llama/Meta-Llama-3-8B-Instruct", # Speculative decoding
tensor_parallel_size=4,
quantization="awq", # Weight quantization
enable_prefix_caching=True, # KV cache sharing for shared prompts
max_model_len=8192,
gpu_memory_utilization=0.95,
)
sampling_params = SamplingParams(
temperature=0.7,
max_tokens=512,
speculative_model="meta-llama/Meta-Llama-3-8B-Instruct",
num_speculative_tokens=5, # gamma = 5
)
outputs = llm.generate(
["Explain quantum computing", "Write a Python quicksort"],
sampling_params=sampling_params,
)
Benchmarks: What to Expect
| Configuration | Tokens/sec | Memory (GB) | Relative Speed |
|---|---|---|---|
| Baseline (HF, FP16, no opt) | 12 | 160 | 1.0x |
| + INT8 KV cache | 12 | 120 | 1.0x (memory only) |
| + Speculative decoding (γ=5) | 28 | 125 | 2.3x |
| + PagedAttention (vLLM) | 45 | 110 | 3.7x |
| All combined | 52 | 110 | 4.3x |
Key Takeaways
- Start with vLLM — PagedAttention + continuous batching alone gives 2–3x improvement over naive serving.
- Speculative decoding is the biggest free lunch — with a good draft model, 2–3x token throughput with zero quality loss.
- KV cache quantization unlocks longer contexts — INT8 is nearly lossless and halves KV memory.
- Profile before optimizing — use
nvidia-smiandtorch.profilerto identify whether you're compute-bound or memory-bound.
The era of "just run it through HuggingFace pipeline" is over. If you are serving LLMs at scale, these techniques are not optional — they are the foundation of efficient inference.