The KV cache is the model's short-term memory, and like all short-term memory, it grows without bound until something crashes. The entire field of memory optimization is basically teaching transformers to forget gracefully.
A Paged Attention MechanismThe hidden memory bottleneck. While the model weights occupy a fixed amount of GPU memory, the KV cache grows dynamically with every token generated. For a Llama 3 70B model serving a batch of 32 sequences at 8K context length, the KV cache alone requires over 80 GB of memory, often exceeding the weight memory itself. Optimizing how this cache is stored, shared, and evicted is the single most important lever for increasing throughput in LLM serving systems. This section covers the data structures and algorithms that make high-throughput serving possible.
This section builds on the self-attention mechanism from Module 04 (query, key, value projections and attention computation). GQA architecture details were introduced in Section 7.2; here we focus on the memory savings. Understanding of quantization from Section 8.1 provides context for KV cache quantization discussed later.
1. The KV Cache Explained
You quantized your model to 4-bit and slashed its memory footprint by 4x. But inference is still slow, and GPU memory is still filling up. Why? Because the KV cache, not the model weights, has become your bottleneck.
In autoregressive generation, each new token attends to all previous tokens. The attention mechanism computes queries, keys, and values for each token. Without caching, generating the t-th token would require recomputing the keys and values for all t−1 previous tokens, making generation O(t²) in computation. The KV cache stores the key and value projections from all previous positions so that each new token only needs to compute its own query, key, and value, then attend against the cached keys and values.
For a model with L layers, nkv key/value heads, head dimension dh, sequence length s, and batch size b, the KV cache memory per sequence is:
The factor of 2 accounts for both keys and values. For a batch of b sequences, multiply by b.
# Example 1: Calculate KV cache size for various models
def kv_cache_size_gb(
num_layers: int,
num_kv_heads: int,
head_dim: int,
seq_len: int,
batch_size: int = 1,
dtype_bytes: int = 2, # FP16 = 2 bytes
) -> float:
"""Calculate KV cache memory in GB."""
# 2 for K and V, both stored
total_bytes = (
2 * num_layers * num_kv_heads * seq_len * head_dim * dtype_bytes * batch_size
)
return total_bytes / (1024 ** 3)
# Llama 3.1 8B: 32 layers, 8 KV heads (GQA), head_dim=128
# Llama 3.1 70B: 80 layers, 8 KV heads (GQA), head_dim=128
models = {
"Llama 3.1 8B": {"layers": 32, "kv_heads": 8, "head_dim": 128},
"Llama 3.1 70B": {"layers": 80, "kv_heads": 8, "head_dim": 128},
"Llama 3.1 405B": {"layers": 126, "kv_heads": 8, "head_dim": 128},
}
print(f"{'Model':<20} {'Context':>8} {'Batch':>6} {'KV Cache (GB)':>14}")
print("-" * 52)
for name, cfg in models.items():
for seq_len in [4096, 8192, 32768, 131072]:
for batch in [1, 16, 64]:
mem = kv_cache_size_gb(
cfg["layers"], cfg["kv_heads"], cfg["head_dim"],
seq_len, batch
)
if batch == 1 or (seq_len == 8192):
print(f"{name:<20} {seq_len:>8} {batch:>6} {mem:>13.2f}")
print()
For Llama 3.1 70B (which requires about 140 GB for weights in FP16), serving a batch of 64 sequences at 8K context requires 160 GB just for the KV cache. This means the KV cache uses more memory than the model weights. This is why KV cache optimization is critical for high-throughput serving.
2. Why Inference Is Memory-Bandwidth-Bound
During the decode phase (generating one token at a time), each token requires reading all model weights and the full KV cache from GPU memory. The actual floating-point operations are minimal: for a single-token decode step, the arithmetic intensity (FLOPs per byte of memory accessed) is approximately 1. Modern GPUs have a ratio of compute to memory bandwidth of 100:1 or more (the H100 has 3958 TFLOPS of FP8 compute but only 3.35 TB/s of HBM bandwidth). This means the GPU spends most of its time waiting for data from memory, not computing. Reducing the amount of memory to read (through quantization and cache optimization) directly translates to higher throughput.
3. PagedAttention
Traditional serving systems pre-allocate a contiguous block of GPU memory for the KV cache of each request, sized for the maximum possible sequence length. This leads to massive internal fragmentation: if the max length is 8192 tokens but most sequences use only 500 tokens, over 90% of the allocated cache memory is wasted.
PagedAttention (Kwon et al., 2023), introduced in vLLM, borrows the concept of virtual memory and paging from operating systems. Instead of contiguous allocation, the KV cache is divided into fixed-size blocks (typically 16 tokens per block). A block table maps each sequence's logical KV positions to physical GPU memory blocks. Blocks are allocated on demand as the sequence grows.
4. MHA, MQA, and GQA
The original transformer uses Multi-Head Attention (MHA), where each attention head has its own separate set of key, value, and query projections. For a model with 32 attention heads, this means 32 sets of K and V tensors must be stored in the cache.
Multi-Query Attention (MQA), introduced by Shazeer (2019), shares a single set of key and value heads across all query heads. This reduces the KV cache by a factor equal to the number of query heads (e.g., 32x for a 32-head model). The tradeoff is a small quality reduction.
Grouped-Query Attention (GQA), introduced in Llama 2 (Ainslie et al., 2023), is the compromise: query heads are divided into groups, and each group shares one set of KV heads. Llama 3 uses 32 query heads with 8 KV head groups, reducing the cache by 4x compared to MHA while preserving nearly all of MHA's quality. (GQA's architectural design and its role in models like Llama 3 were introduced in Section 7.2; here we focus on its memory optimization implications.)
| Attention Type | Q Heads | KV Heads | KV Cache Ratio | Used By |
|---|---|---|---|---|
| MHA | 32 | 32 | 1x (baseline) | GPT-2, GPT-3, OPT |
| GQA (4 groups) | 32 | 8 | 0.25x | Llama 2/3, Mistral, Gemma |
| GQA (2 groups) | 32 | 2 | 0.0625x | DeepSeek V2 (partially) |
| MQA | 32 | 1 | 0.03125x | Falcon, PaLM, StarCoder |
MHA scenario: 32 layers, 32 KV heads, d_head = 128, 4096 tokens, FP16. Per-token KV size = 32 layers x 32 heads x 128 dims x 2 (K+V) x 2 bytes = 524,288 bytes. Total at 4K tokens = 524,288 x 4096 = 2.15 GB. For a batch of 32: 68.7 GB, nearly matching the model weights.
GQA scenario (actual Llama 3): 32 layers, 8 KV heads, d_head = 128. Per-token KV size = 32 x 8 x 128 x 2 x 2 = 131,072 bytes. Total at 4K = 0.54 GB. Batch of 32: 17.2 GB, a 4x reduction. This is why GQA enables much higher batch sizes on the same hardware.
5. Prefix Caching and RadixAttention
In many serving scenarios, multiple requests share a common system prompt or prefix. For example, a chatbot may prepend the same 2000-token system prompt to every user message. Without prefix caching, the KV cache for those 2000 tokens is recomputed for every request.
Prefix caching stores the KV cache for common prefixes and reuses it across requests. SGLang implements this via RadixAttention, which organizes cached prefixes in a radix tree (trie) data structure. Each node in the tree corresponds to a segment of tokens, and the associated KV cache blocks are stored on the GPU. When a new request arrives, the serving system traverses the tree to find the longest matching prefix, reuses its cache, and only computes the KV for the new tokens.
# Example 2: Demonstrating prefix caching benefit with vLLM
from vllm import LLM, SamplingParams
# Initialize vLLM with prefix caching enabled
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
enable_prefix_caching=True,
gpu_memory_utilization=0.9,
)
# Long system prompt shared across requests
system_prompt = """You are a helpful AI assistant specializing in Python
programming. You provide clear, well-documented code examples with
explanations. Always include error handling and type hints in your
responses. Follow PEP 8 style guidelines."""
# Multiple requests sharing the same prefix
requests = [
f"{system_prompt}\n\nUser: Write a function to merge two sorted lists.",
f"{system_prompt}\n\nUser: Write a decorator that caches function results.",
f"{system_prompt}\n\nUser: Write an async function to fetch multiple URLs.",
]
sampling = SamplingParams(temperature=0.7, max_tokens=200)
import time
# First batch: system prompt cache is populated
start = time.perf_counter()
outputs_1 = llm.generate(requests[:1], sampling)
time_first = time.perf_counter() - start
# Second batch: system prompt cache is HIT
start = time.perf_counter()
outputs_2 = llm.generate(requests[1:], sampling)
time_cached = time.perf_counter() - start
print(f"First request (cache cold): {time_first:.3f}s")
print(f"Next 2 requests (cache hit): {time_cached:.3f}s")
print(f"Speedup from prefix caching: {time_first*2/time_cached:.1f}x")
6. Continuous Batching
Traditional static batching groups multiple requests into a batch and processes them together. The entire batch must wait for the longest sequence to finish before any results are returned. This wastes GPU cycles: if one sequence in the batch finishes after 20 tokens while another needs 500, the GPU sits idle for that short sequence's remaining "slots" in the batch.
Continuous batching (also called iteration-level scheduling) allows new requests to join the batch and completed requests to leave at every iteration (every single token step). When sequence A completes, its slot is immediately filled by a new request from the queue, keeping the GPU fully utilized at all times.
Continuous batching can improve throughput by 2x to 10x compared to static batching, depending on the variance in output sequence lengths. The improvement is largest when output lengths are highly variable (common in real chat workloads). vLLM, TGI, SGLang, and TensorRT-LLM all implement continuous batching.
7. KV Cache Compression Techniques
Beyond architectural choices (GQA) and memory management (PagedAttention), the KV cache can be further compressed:
7.1 KV Cache Quantization
KV cache quantization compresses the cached key and value tensors from FP16 to FP8, INT8, or INT4. This is distinct from weight quantization (covered in Section 8.1): weight quantization compresses the model itself, while KV cache quantization compresses the dynamically growing cache. Since the KV cache can exceed model weight memory at high batch sizes and long contexts, cache quantization is equally important.
vLLM supports FP8 KV cache natively with the flag --kv-cache-dtype fp8. This halves cache memory with minimal quality impact (typically less than 0.5% perplexity increase). INT4 KV cache quantization is more aggressive, reducing memory by 4x, but requires calibration to maintain quality. The combination of GQA (4x reduction) plus FP8 KV cache (2x reduction) yields an 8x total reduction compared to standard MHA with FP16 cache.
7.2 Eviction and Sparsification
- H2O (Heavy-Hitter Oracle): An eviction policy that keeps only the most "important" KV entries (those with the highest cumulative attention scores) plus a small window of recent tokens. Reduces cache size by 80%+ for long contexts.
- Sliding window attention: Limit attention to a fixed window of recent tokens (e.g., 4096). Used by Mistral. Reduces cache to a fixed size regardless of context length.
- StreamingLLM (attention sinks): Keeps the first few "sink" tokens (which accumulate disproportionate attention) plus a sliding window. Enables unbounded context length with fixed cache.
# Example 3: Profiling KV cache memory in vLLM
from vllm import LLM, SamplingParams
# Load model and examine memory allocation
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
gpu_memory_utilization=0.9,
max_model_len=8192,
kv_cache_dtype="fp8", # FP8 KV cache for 2x compression
)
# Check the number of KV cache blocks allocated
cache_config = llm.llm_engine.cache_config
print(f"Block size: {cache_config.block_size} tokens")
print(f"KV cache dtype: {cache_config.cache_dtype}")
# Generate with different context sizes and observe memory
import torch
gpu_mem_allocated = torch.cuda.memory_allocated() / 1e9
gpu_mem_reserved = torch.cuda.memory_reserved() / 1e9
print(f"\nGPU memory allocated: {gpu_mem_allocated:.2f} GB")
print(f"GPU memory reserved: {gpu_mem_reserved:.2f} GB")
# Profile throughput at different batch sizes
sampling = SamplingParams(max_tokens=100, temperature=0.0)
prompts = ["Explain the concept of attention in transformers."] * 32
import time
for batch_size in [1, 4, 16, 32]:
batch = prompts[:batch_size]
start = time.perf_counter()
outputs = llm.generate(batch, sampling)
elapsed = time.perf_counter() - start
total_tokens = sum(len(o.outputs[0].token_ids) for o in outputs)
print(f"Batch {batch_size:>2}: {total_tokens/elapsed:>7.1f} tok/s "
f"({elapsed:.2f}s for {total_tokens} tokens)")
8. Research Frontiers
8.1 Test-Time Training (TTT)
Test-Time Training (Sun et al., 2024) proposes a radical alternative to the KV cache. Instead of storing explicit key-value pairs for all past tokens, TTT layers compress the context into updated model weights. During inference, when processing a long context, a TTT layer performs a mini training step: it updates a small set of internal parameters via gradient descent on a next-token-prediction loss over the recent context. These updated parameters implicitly encode the contextual information that would otherwise require an explicit KV cache.
The result is dramatic: TTT achieves up to 35x speedup over full attention at 2 million token context length, because the "cache" is a fixed-size set of model parameters rather than a linearly-growing tensor. However, the approach blurs the traditional boundary between training and inference, since gradient computation occurs at every forward pass. Unlike fine-tuning, which permanently updates model weights for reuse across many requests, TTT creates temporary weight updates for a single inference request. The model compresses long-context information into these ephemeral weights, then discards them entirely once generation is complete. This makes TTT a form of adaptive inference rather than a training procedure.
8.2 DeepSeek Sparse Attention (DSA)
Introduced in DeepSeek V3.2, DeepSeek Sparse Attention addresses long-context inference through a hierarchical two-stage pipeline. The first stage, called the Lightning indexer, performs a coarse scan over the full context to identify which segments are most relevant to the current query. The second stage applies fine-grained token-level attention only within the selected segments. This two-stage approach reduces inference compute by approximately 70% for long contexts while maintaining quality comparable to full attention.
TTT and DeepSeek Sparse Attention are active research topics that are not yet widely available in standard serving frameworks. They represent the frontier of KV cache optimization and may become standard in future systems. For production deployments today, PagedAttention with GQA and FP8 KV cache quantization provides the best combination of throughput and compatibility.
Check Your Understanding
1. For Llama 3.1 70B with 80 layers, 8 KV heads, and head_dim=128, how much KV cache memory (in GB, FP16) does a single sequence at 32K context require?
Show Answer
2. How does PagedAttention eliminate memory fragmentation?
Show Answer
3. If a model uses GQA with 32 query heads and 8 KV heads, by what factor is the KV cache reduced compared to standard MHA?
Show Answer
4. Why does continuous batching provide the largest throughput improvement when output sequence lengths are highly variable?
Show Answer
Key Takeaways
- The KV cache can exceed model weight memory at large batch sizes and long contexts. For Llama 3.1 70B, a batch of 64 at 8K context needs 160 GB of KV cache alone.
- PagedAttention uses block-based allocation (like OS virtual memory) to eliminate fragmentation and enable cross-request memory sharing.
- GQA reduces the KV cache by sharing KV heads across query head groups. Llama 3 uses 8 KV heads for 32 query heads, a 4x reduction.
- Prefix caching (RadixAttention) avoids recomputing KV for shared system prompts, providing 2x to 5x speedup for prefix-heavy workloads.
- Continuous batching fills freed slots immediately with new requests, achieving 2x to 10x throughput over static batching.
- KV cache quantization to FP8 halves cache memory with minimal quality impact.
- TTT and DeepSeek Sparse Attention represent research frontiers that compress or sparsify the cache for extreme context lengths.