English · Español
00 — Why a KV Cache Exists¶
🇪🇸 La caché KV existe porque la generación autorregresiva es estructuralmente repetitiva: en cada paso, la atención del nuevo token mira a exactamente las mismas claves y valores que ya miró el token anterior, más una fila nueva. Recalcularlas todas en cada paso es trabajo redundante que multiplica el coste por O(n).
This is the motivation page. Read it before the derivation pages — without the why, the formulas are dry. With the why, they're inevitable.
The worked example¶
Throughout Phase 22 we anchor every concept to a concrete prompt drawn from the §A13 verb-grammar corpus:
Prompt: "Yesterday I" (2 tokens — past time-adverbial + 1st-person singular pronoun)
Decode → "Yesterday I worked" (one new token)
Decode → "Yesterday I worked and" (another new token)
Decode → "Yesterday I worked and he" (and another)
The grammar MiniGPT (Phase 17, vocab ≈ 600 forms) is trained to assign high probability to past-simple verbs after "Yesterday I", so the first decoded token is plausibly "worked", "played", "talked" — any past-simple regular. The model's preference is irrelevant for this phase; what matters is that on step t+1 the keys and values for "Yesterday" and "I" are the same as on step t. That redundancy is what the cache exists to remove.
The thing autoregressive decoding does¶
A causal LM generates one token at a time. After t steps the model has produced tokens \(x_1, x_2, \ldots, x_t\). To produce \(x_{t+1}\):
- Embed all \(t\) tokens.
- Run them through \(L\) transformer layers. In each layer, attention computes \(\text{softmax}(QK^\top / \sqrt{d_k}) V\) — a \(t \times t\) matrix of attention weights, times \(V\).
- Take the final token's hidden state, project to vocabulary, sample.
Now produce \(x_{t+2}\). The naive way:
- Embed all \(t+1\) tokens.
- Run them through \(L\) layers. Each attention now computes a \((t+1) \times (t+1)\) matrix.
- Take the final token's hidden state, sample.
Notice. The top-left \(t \times t\) block of the \((t+1) \times (t+1)\) attention matrix in step 2 is identical to the entire attention matrix from the previous step. We recomputed it. We will recompute it again next step as the top-left \((t+1) \times (t+1)\) of a \((t+2) \times (t+2)\). And again. And again.
Total redundant attention work across \(n\) generated tokens: \(\sum_{t=1}^{n} t^2 = \Theta(n^3)\). Linear-per-token compute becomes cubic-total compute. That is the disease.
The remedy¶
In any transformer layer, attention reads three matrices:
- \(Q = X W_Q\) — query, one row per current token.
- \(K = X W_K\) — key.
- \(V = X W_V\) — value.
At step \(t+1\), \(K\) and \(V\) for tokens \(1..t\) are byte-for-byte identical to what they were at step \(t\) — because the inputs \(x_1..x_t\) haven't changed, the weights haven't changed, nothing has changed. Only one new row needs computing: the row for \(x_{t+1}\).
The cache stores the rows of \(K\) and \(V\) from prior steps. On step \(t+1\) we:
- Compute the new row of \(K\) and \(V\) for \(x_{t+1}\) only. Cost: \(O(d^2)\).
- Append that new row to the cache.
- Compute attention as \(Q_\text{new} \cdot K_\text{cached}^\top\) — a \(1 \times (t+1)\) row, not a \((t+1) \times (t+1)\) block. Cost: \(O((t+1) d)\).
- Softmax that row, multiply by \(V_\text{cached}\), output.
Total compute for the new token: \(O(t \cdot d)\). Total compute for the whole sequence: \(\sum_{t=1}^{n} t \cdot d = O(n^2 d)\). Cubic becomes quadratic.
Concretely on "Yesterday I worked": step 1 (decode of the third token) computes \(q\) for the position-2 slot only, projects \(x_{\text{"I"}}\) into a single new row of \(K\) and \(V\), and writes that row at cache index 1 (zero-based). The cache rows for "Yesterday" (index 0) were written during prefill and are touched only as reads. No multiplication by \(W_K\) or \(W_V\) is repeated for them. That is the entire optimization.
The price¶
Cached \(K, V\) must live somewhere. After \(S\) generated tokens, the cache holds:
- \(2\): one for K, one for V.
- \(L\): layers (cache per layer; layers don't share K, V).
- \(H \cdot d_h = d\): heads times head-dim equals model dim.
- \(S\): current sequence length in tokens.
- \(B\): batch size.
- \(s\): bytes per element (4 for fp32, 2 for fp16/bf16, 1 for int8).
This formula is derived in detail in 02-memory-cost.md. Two consequences worth absorbing now:
- The cache grows linearly in \(S\). Doubling context doubles cache memory. Long context is linearly expensive in memory.
- The cache is enormous for real models. Llama-2-7B at 4096 context, fp16, batch 1: \(2 \times 32 \times 32 \times 128 \times 4096 \times 1 \times 2 = 2.15 \cdot 10^9\) bytes ≈ 2 GiB. Per sequence. Per GPU. This is why serving a 7B model with 16 concurrent users on a single A100 is hard even though the model weights only take 14 GiB. By contrast, the grammar MiniGPT (Phase 17 default: 4 layers, 4 heads, \(d_h\) = 16, fp32) at our longest plausible prompt (ctx = 32) holds only \(2 \times 4 \times 4 \times 16 \times 32 \times 1 \times 4 \approx 16\) KiB — fits in L1. Cache size doesn't matter at our scale; the formula does, because every system you read about later assumes you can apply it.
The dichotomy: prefill vs decode¶
The cache also restructures inference into two distinct phases:
- Prefill (a.k.a. "context encoding", "prompt processing"). The user submits a prompt of length \(P\). We compute \(K, V\) for all \(P\) tokens in one parallel pass — a single \(P \times P\) attention per layer. Compute: \(O(L P^2 d)\). Memory: \(O(L H d_h P) = O(L d P)\). The prefill is compute-bound for moderate \(P\) — there's plenty of arithmetic per byte loaded.
- Decode (a.k.a. "generation", "incremental decoding"). After prefill, we produce tokens one at a time. Each step is one \(1 \times S\) attention row per layer, reading the entire cache. Compute per step: \(O(L S d)\). Memory traffic per step: O(cache size) = \(O(L d S)\). Arithmetic intensity: \(O(d) / O(d) = O(1)\) — actually \(\sim 0.5\) FLOPs/byte. The decode is memory-bound, deeply so.
These two phases want different hardware, different schedulers, different optimizations. Every serving system you'll read about — vLLM, TensorRT-LLM, SGLang — is at heart a way to keep the GPU busy with prefill work while waiting on memory for decode work. The cache is the artifact that creates the asymmetry.
What the cache is not¶
A few clarifications that prevent confusion later:
- The cache stores \(K\) and \(V\), not \(Q\). \(Q\) is recomputed each step because the new token has its own query row.
- The cache is per-layer. Cached \(K, V\) for layer \(\ell\) are not reusable in layer \(\ell+1\). The cache is a list of \(L\) tensors.
- The cache is per-head. Within a layer, each of the \(H\) attention heads has its own \((S, d_h)\) K and V. (Some architectures share K, V across heads — grouped query attention, multi-query attention. Those are Phase 26 / 27 topics. Phase 22 assumes full multi-head with per-head K, V.)
- The cache is per-sequence in the batch. Batched serving with sequences of different lengths is hard precisely because each sequence's cache is a different size. This is the fragmentation problem PagedAttention solves (preview in
04-toward-paged-attention.md). - The cache is not "memoization" in the Lisp sense. It's a fixed-structure ring of tensors, not a hash map. Append-only during generation; cleared on next prompt.
What this page does NOT cover¶
- The bytes formula derivation. Sketched here; rigorous in
theory/02-memory-cost.md. - Why decode is memory-bound on real hardware. Named here; arithmetic-intensity argument in
theory/03-decode-as-memory-bound.md. - PagedAttention / variable-length batched serving. Named here; problem-statement-only in
theory/04-toward-paged-attention.md, full derivation in Phase 27. - GPU-specific cache layout. Phase 22 runs in DRAM via NumPy; HBM / SRAM / register layout is Phase 23–24.
- Cache quantization (int8 / fp16). Phase 26.
What you should be able to do after this phase¶
- Sketch the prefill/decode dichotomy on a whiteboard, with the right asymptotic costs labeled, using
"Yesterday I worked"as the running example. - Derive the bytes formula from first principles — i.e. count what's stored, don't memorize the formula.
- Predict the cache size for any model from its config (Llama, Mistral, GPT-class) without running anything.
- Explain in one sentence why decode is memory-bound, and what that implies for which optimization (Flash-decoding, paging, GQA, quantization) attacks which symptom.
If any of those four is shaky, the lab will catch it. If they all land, the lab is mostly mechanical.
Next: theory/01-prefill-vs-decode.md — formalize the two phases and the FLOP accounting that justifies the dichotomy.