English · Español
03 — Decode cost model: why we need the KV cache¶
🇪🇸 La cuenta de coste del decode sin caché es brutal: cada token cuesta más que el anterior porque el modelo recomputa la atención sobre todo el prefijo. Este capítulo cuantifica el desperdicio y motiva la KV cache de Fase 22.
The bare-bones cost¶
Generating \(T\) tokens given a prompt of length \(L\):
| Step | Forward sees length | Cost of forward |
|---|---|---|
| Generate token 0 | \(L\) | \(O(L \cdot d^2)\) |
| Generate token 1 | \(L + 1\) | \(O((L+1) \cdot d^2)\) |
| Generate token 2 | \(L + 2\) | \(O((L+2) \cdot d^2)\) |
| ... | ... | ... |
| Generate token \(T-1\) | \(L + T - 1\) | \(O((L+T-1) \cdot d^2)\) |
Total: \(\sum_{t=0}^{T-1} O((L+t) \cdot d^2) = O\big(T \cdot L \cdot d^2 + T^2 \cdot d^2 / 2\big)\).
For our Mini-GPT (\(d = 64, L = 8, T = 20\)):
- Linear term: \(20 \cdot 8 \cdot 64^2 = 655{,}360\) ops.
- Quadratic term: \(20^2 \cdot 64^2 / 2 = 819{,}200\) ops.
The two are roughly comparable here. At larger scale (e.g., \(L = 1024, T = 200, d = 4096\)), the quadratic term dominates by orders of magnitude.
Where the redundant work is¶
Look at the forward pass for the Mini-GPT (Phase 17). At each transformer block, attention computes:
When generating token \(t+1\):
- We pass the full sequence of length \(L + t + 1\) as input.
- We compute \(Q, K, V\) for all \(L + t + 1\) positions.
- But: the \(Q, K, V\) for positions \(0, 1, \ldots, L + t - 1\) are the same as on the previous step. Only position \(L + t\) is new.
So on step \(t+1\) we recompute \(K, V\) for \(L + t\) positions we already computed on step \(t\). This is the waste we want to eliminate.
The KV cache idea (forward reference)¶
Phase 22 will implement this: store \(K^{(l)}, V^{(l)}\) for every layer \(l\) and every position seen so far. On step \(t+1\), only compute the \(K, V\) for the new position and concatenate to the cache.
# After Phase 22:
def decode_step(model, last_token, cache: KVCache) -> tuple[int, KVCache]:
"""Process one new token, updating the cache."""
logits, new_cache = model.forward_one(last_token, cache)
next_token = sample(logits[-1])
return next_token, new_cache
Cost per step drops from \(O((L + t) \cdot d^2)\) to \(O(d^2)\) — constant in \(t\) (ignoring the attention's linear-in-cached-length scan, which is far cheaper than the matmul). Total cost drops to \(O((L + T) \cdot d^2)\).
For our example (\(L = 8, T = 20, d = 64\)): \(28 \cdot 4096 = 114{,}688\) ops. That's ~13× cheaper than the bare-bones \(1{,}474{,}560\) above.
Why Phase 21 doesn't use the cache¶
Pedagogical reasons:
- Feel the cost. Lab 03's decode benchmark on \(T = 50\) tokens will be visibly slow. You should see the curve \(t \cdot\)cost climb.
- Cleanliness. Without the cache, the decode loop is the same shape as the training forward. No new state machinery. Phase 22 introduces the state and tests it against the cache-less reference.
- Correctness first. The cache is an optimisation; it must produce bit-identical output to the cache-less decode (modulo floating-point reordering). Phase 22's first test will be: cache vs no-cache, do the outputs agree?
Memory cost (preview)¶
The KV cache stores \(K^{(l)}, V^{(l)}\) for every layer \(l\) and every position. Size:
For Mini-GPT (\(n_\text{layers} = 2, d = 64, L + T = 28\), float32 = 4 bytes): \(2 \cdot 2 \cdot 28 \cdot 64 \cdot 4 = 28{,}672\) bytes = 28 KB. Tiny.
For GPT-3 (\(n_\text{layers} = 96, d = 12{,}288, L + T = 2048\), fp16): \(2 \cdot 96 \cdot 2048 \cdot 12{,}288 \cdot 2 = 9.7\) GB. Enormous.
This is why Phase 27 covers KV-cache optimisations: paged attention, multi-query attention, grouped-query attention. The cache is the dominant memory cost of LLM inference at scale.
A note on continuous batching (forward reference)¶
If you have many requests, you can batch their decode steps: at step \(t\), run a forward pass that processes (request A's token \(t_A\) + request B's token \(t_B\) + ...) together. The catch: requests have different prompt lengths and different generation lengths. Phase 33 ("Inference Serving: From FastAPI to Continuous Batching") covers this in detail.
Phase 21 does not batch. Single-request decode only. The infrastructure for batching is layered on top in Phase 33.
What this file does NOT cover¶
- KV cache implementation. Phase 22.
- Paged attention, MQA, GQA. Phase 27.
- Continuous batching. Phase 33.
Next: ../lab/00-greedy.md