Building a Production-Grade KV Cache for LLM Inference: From Naive to PagedAttention

If you have ever deployed a large language model behind an API, you already know that the bottleneck is rarely the model weights. It is the key-value (KV) cache — the hidden memory beast that grows linearly with context length and batch size. In this tutorial, we will build a KV cache from scratch, identify why the naive implementation fails in production, and then implement a simplified version of PagedAttention, the memory management breakthrough that powers vLLM and modern inference engines.

1. What Is the KV Cache?

During autoregressive decoding, an LLM recomputes attention over all previous tokens at each step. Without caching, every new token costs O(n²) attention computation. The KV cache stores the key and value projections for all prior tokens so that attention at step t only needs to compute projections for the single new token.

For a model with L layers, H attention heads, hidden dimension D, and a batch of sequences with total length T:

KV cache size per layer = 2 × H × D × T × sizeof(float16)
Total KV cache = L × (above) × batch_size

For Llama-3-70B (L=80, H=8 per group, D=128, GQA), a single sequence of 32K tokens needs roughly ~10 GB of KV cache. Multiply that by concurrent requests and you see the problem.

2. The Naive Implementation

Here is the straightforward approach — pre-allocate a contiguous tensor for each sequence:

import torch

class NaiveKVCache:
    """Simple per-sequence KV cache with contiguous allocation."""

    def __init__(self, num_layers: int, num_heads: int,
                 head_dim: int, max_seq_len: int,
                 dtype: torch.dtype = torch.float16):
        self.cache_k = torch.zeros(
            (num_layers, max_seq_len, num_heads, head_dim),
            dtype=dtype, device="cuda"
        )
        self.cache_v = torch.zeros(
            (num_layers, max_seq_len, num_heads, head_dim),
            dtype=dtype, device="cuda"
        )
        self.current_len = 0

    def update(self, k: torch.Tensor, v: torch.Tensor,
               layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        seq_len = k.shape[0]
        start = self.current_len
        end = start + seq_len

        self.cache_k[layer_idx, start:end] = k
        self.cache_v[layer_idx, start:end] = v
        self.current_len = end

        # Return full context for attention
        full_k = self.cache_k[layer_idx, :end]
        full_v = self.cache_v[layer_idx, :end]
        return full_k, full_v

    def clear(self):
        self.current_len = 0

This works for a single sequence. But it has three fatal flaws in production:

  1. Memory fragmentation: Pre-allocating max_seq_len × batch_size wastes GPU memory when sequences are shorter than max length. With 32K max context and average sequence of 2K tokens, you waste 94% of allocated memory.
  2. No sharing: When multiple requests share a common prompt prefix (system prompt, few-shot examples), each request duplicates the KV cache for that prefix.
  3. OOM under load: Contiguous allocation means a single large request can fragment the remaining memory, preventing smaller requests from being scheduled.

3. Measuring the Problem

Before optimizing, let us write a profiler to quantify waste:

def measure_kv_waste(num_seqs: int, max_len: int,
                     actual_lengths: list[int],
                     num_layers: int = 80,
                     num_heads: int = 8,
                     head_dim: int = 128) -> dict:
    """Compare contiguous vs actual KV memory usage."""
    contiguous_bytes = (
        num_seqs * max_len * num_layers *
        2 * num_heads * head_dim * 2  # float16
    )
    actual_bytes = sum(
        length * num_layers * 2 * num_heads * head_dim * 2
        for length in actual_lengths
    )
    waste_pct = (1 - actual_bytes / contiguous_bytes) * 100
    return {
        "allocated_mb": contiguous_bytes / 1e6,
        "used_mb": actual_bytes / 1e6,
        "waste_pct": round(waste_pct, 1)
    }

# Example: 32 concurrent requests, 32K max, avg 2K tokens
import numpy as np
lengths = np.random.poisson(2000, 32).clip(256, 32000)
result = measure_kv_waste(32, 32768, lengths.tolist())
print(f"Allocated: {result[chr(39)+chr(39).join(['allocated_mb'])}:.0f} MB")
print(f"Used: {result[chr(39)+chr(39).join(['used_mb'])}:.0f} MB")
print(f"Waste: {result[chr(39)+chr(39).join(['waste_pct'])]}%")
# Typical output: Allocated: 67109 MB, Used: 4194 MB, Waste: 93.7%

This confirms the fragmentation problem. Let us fix it.

4. Paged KV Cache: The Core Idea

PagedAttention borrows from OS virtual memory. Instead of one contiguous block per sequence, the KV cache is split into blocks (like pages). Each block holds KV states for a fixed number of tokens. A block table maps logical token positions to physical blocks.

Key benefits:

  • No internal fragmentation — allocate blocks on demand
  • Prefix sharing — multiple sequences reference the same blocks for shared prompts
  • Efficient scheduling — swap blocks to CPU when GPU is full

5. Implementing a Simplified Paged KV Cache

Here is a working implementation of the core data structures:

import torch
from dataclasses import dataclass, field
from typing import Optional

@dataclass
class KVBlock:
    """A single block in the paged KV cache."""
    block_id: int
    num_tokens: int = 0

@dataclass
class SequenceState:
    """Tracking for one sequence/request."""
    seq_id: str
    block_table: list[int] = field(default_factory=list)
    current_len: int = 0

class PagedKVCacher:
    """Paged KV cache with block allocation and prefix sharing."""

    def __init__(self, num_layers: int, num_heads: int,
                 head_dim: int, block_size: int,
                 max_num_blocks: int,
                 dtype: torch.dtype = torch.float16,
                 device: str = "cuda"):
        self.block_size = block_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim

        # Block pool: [num_blocks, block_size, num_heads, head_dim]
        self.block_k = torch.zeros(
            (max_num_blocks, num_layers, block_size,
             num_heads, head_dim),
            dtype=dtype, device=device
        )
        self.block_v = torch.zeros(
            (max_num_blocks, num_layers, block_size,
             num_heads, head_dim),
            dtype=dtype, device=device
        )

        # Allocator: track free blocks
        self.free_blocks: list[int] = list(range(max_num_blocks))
        self.sequences: dict[str, SequenceState] = {}
        self._next_id = 0

    def allocate_block(self) -> Optional[int]:
        """Allocate a free block. Returns None if OOM."""
        if not self.free_blocks:
            return None
        return self.free_blocks.pop(0)

    def free_block(self, block_id: int):
        """Return a block to the free pool."""
        self.free_blocks.append(block_id)

    def append_tokens(self, seq_id: str,
                      k: torch.Tensor, v: torch.Tensor):
        """Append new token KV states, allocating blocks as needed."""
        if seq_id not in self.sequences:
            self.sequences[seq_id] = SequenceState(seq_id=seq_id)
            # Allocate first block
            bid = self.allocate_block()
            if bid is None:
                raise RuntimeError(
                    f"OOM: cannot allocate first block for {seq_id}"
                )
            self.sequences[seq_id].block_table.append(bid)

        seq = self.sequences[seq_id]
        num_new = k.shape[0]  # tokens to append

        for i in range(num_new):
            # Check if current block is full
            offset_in_block = seq.current_len % self.block_size
            if offset_in_block == 0 and i > 0:
                # Need a new block
                bid = self.allocate_block()
                if bid is None:
                    raise RuntimeError(
                        f"OOM during append to {seq_id}"
                    )
                seq.block_table.append(bid)

            block_idx = seq.block_table[-1]
            token_offset = seq.current_len % self.block_size

            self.block_k[block_idx, :, token_offset] = k[i]
            self.block_v[block_idx, :, token_offset] = v[i]
            seq.current_len += 1

    def get_context(self, seq_id: str,
                    layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        """Gather full KV context for a sequence from its blocks."""
        seq = self.sequences[seq_id]
        total = seq.current_len
        k_parts = []
        v_parts = []

        for i, block_id in enumerate(seq.block_table):
            # Last block may be partially filled
            if i == len(seq.block_table) - 1:
                remaining = total % self.block_size
                if remaining == 0:
                    remaining = self.block_size
                k_parts.append(
                    self.block_k[block_id, layer_idx, :remaining]
                )
                v_parts.append(
                    self.block_v[block_id, layer_idx, :remaining]
                )
            else:
                k_parts.append(
                    self.block_k[block_id, layer_idx]
                )
                v_parts.append(
                    self.block_v[block_id, layer_idx]
                )

        return torch.cat(k_parts, dim=0), torch.cat(v_parts, dim=0)

    def release(self, seq_id: str):
        """Free all blocks for a completed sequence."""
        if seq_id in self.sequences:
            for bid in self.sequences[seq_id].block_table:
                self.free_block(bid)
            del self.sequences[seq_id]

6. Adding Prefix Sharing

The real power of paging emerges when sequences share prefixes. A common pattern: every request starts with the same system prompt. Instead of duplicating those KV blocks, we let multiple sequences reference the same physical blocks.

class SharedPrefixKVCacher(PagedKVCacher):
    """Extends PagedKVCacher with prefix block sharing."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._prefix_blocks: list[int] = []
        self._prefix_len: int = 0

    def set_shared_prefix(self, k: torch.Tensor,
                          v: torch.Tensor):
        """Store KV for a shared prefix (e.g., system prompt)."""
        # Free old prefix blocks if re-setting
        self._prefix_blocks = []

        num_tokens = k.shape[0]
        num_blocks = (num_tokens + self.block_size - 1) // self.block_size

        for i in range(num_blocks):
            bid = self.allocate_block()
            if bid is None:
                raise RuntimeError("OOM while allocating prefix blocks")
            self._prefix_blocks.append(bid)

            start = i * self.block_size
            end = min(start + self.block_size, num_tokens)
            actual_len = end - start

            self.block_k[bid, :, :actual_len] = k[start:end]
            self.block_v[bid, :, :actual_len] = v[start:end]

        self._prefix_len = num_tokens

    def fork_sequence(self, seq_id: str) -> SequenceState:
        """Create a new sequence that shares the prefix blocks."""
        if not self._prefix_blocks:
            raise ValueError("No shared prefix set. Call set_shared_prefix first.")

        seq = SequenceState(
            seq_id=seq_id,
            block_table=self._prefix_blocks.copy(),  # shallow copy of block IDs
            current_len=self._prefix_len
        )
        self.sequences[seq_id] = seq
        return seq

Now 100 concurrent requests with a 500-token system prompt share the same ~1-2 blocks instead of duplicating them 100 times. Memory savings scale linearly with the number of concurrent requests.

7. Block Swapping (CPU Offloading)

When GPU memory is exhausted, a production system swaps the least-recently-used blocks to CPU RAM:

class SwappablePagedKVCacher(SharedPrefixKVCacher):
    """Adds CPU offloading for KV blocks."""

    def __init__(self, *args, cpu_reserve_blocks: int = 0, **kwargs):
        super().__init__(*args, **kwargs)
        self.cpu_k: Optional[torch.Tensor] = None
        self.cpu_v: Optional[torch.Tensor] = None
        self._swapped: set[int] = set()

        if cpu_reserve_blocks > 0:
            self.cpu_k = torch.zeros(
                (cpu_reserve_blocks, self.num_layers,
                 self.block_size, self.num_heads, self.head_dim),
                dtype=self.block_k.dtype, device="cpu",
                pin_memory=True
            )
            self.cpu_v = torch.zeros_like(self.cpu_k)

    def swap_to_cpu(self, block_id: int):
        if block_id in self._swapped:
            return
        self.cpu_k[block_id] = self.block_k[block_id].cpu()
        self.cpu_v[block_id] = self.block_v[block_id].cpu()
        self._swapped.add(block_id)

    def swap_from_cpu(self, block_id: int):
        if block_id not in self._swapped:
            return
        self.block_k[block_id].copy_(
            self.cpu_k[block_id], non_blocking=True
        )
        self.block_v[block_id].copy_(
            self.cpu_v[block_id], non_blocking=True
        )
        self._swapped.remove(block_id)

8. Putting It Together: A Mini Inference Loop

Here is how you would integrate the paged cache into a decoding loop:

def generate_with_paged_cache(model, tokenizer, prompt: str,
                               max_new_tokens: int, cacher: PagedKVCacher,
                               seq_id: str = "req_0"):
    """Simplified autoregressive generation with paged KV cache."""
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    # Prefill: process prompt, build initial KV cache
    with torch.no_grad():
        outputs = model(inputs.input_ids, use_cache=True)

    # Store prefill KV states in paged cache
    for layer_idx in range(model.config.num_hidden_layers):
        past_k = outputs.past_key_values[layer_idx][0]
        past_v = outputs.past_key_values[layer_idx][1]
        cacher.append_tokens(seq_id, past_k, past_v)

    generated = inputs.input_ids.clone()

    for step in range(max_new_tokens):
        # Get last token
        last_token = generated[:, -1:]

        # Retrieve cached context
        with torch.no_grad():
            # In a real implementation, you pass the block table
            # to a custom attention kernel that reads from paged cache
            outputs = model(last_token, use_cache=True,
                          past_key_values=build_kv_from_cache(
                              cacher, seq_id))

        next_token = outputs.logits[:, -1].argmax(dim=-1, keepdim=True)
        generated = torch.cat([generated, next_token], dim=-1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    cacher.release(seq_id)
    return tokenizer.decode(generated[0], skip_special_tokens=True)

9. Key Takeaways

TechniqueMemory EfficiencyComplexity
Naive contiguousPoor (90%+ waste)Low
Paged blocksGood (near-zero waste)Medium
Prefix sharingExcellentMedium
Block swappingMaximum (GPU + CPU)High

Production inference engines like vLLM, SGLang, and TensorRT-LLM all use variations of paged KV caching. Understanding these internals helps you:

  • Choose the right batch sizes and context limits for your hardware
  • Debug OOM errors during serving
  • Optimize prompt design to maximize prefix sharing
  • Contribute to or customize open-source inference engines

The next step? Dive into the vLLM source code and study how they implement the PagedAttention CUDA kernel — that is where the real performance magic happens.