Skip to content

English · Español

03 — Decode cost model: why we need the KV cache

🇪🇸 La cuenta de coste del decode sin caché es brutal: cada token cuesta más que el anterior porque el modelo recomputa la atención sobre todo el prefijo. Este capítulo cuantifica el desperdicio y motiva la KV cache de Fase 22.

The bare-bones cost

Generating \(T\) tokens given a prompt of length \(L\):

Step Forward sees length Cost of forward
Generate token 0 \(L\) \(O(L \cdot d^2)\)
Generate token 1 \(L + 1\) \(O((L+1) \cdot d^2)\)
Generate token 2 \(L + 2\) \(O((L+2) \cdot d^2)\)
... ... ...
Generate token \(T-1\) \(L + T - 1\) \(O((L+T-1) \cdot d^2)\)

Total: \(\sum_{t=0}^{T-1} O((L+t) \cdot d^2) = O\big(T \cdot L \cdot d^2 + T^2 \cdot d^2 / 2\big)\).

For our Mini-GPT (\(d = 64, L = 8, T = 20\)):

  • Linear term: \(20 \cdot 8 \cdot 64^2 = 655{,}360\) ops.
  • Quadratic term: \(20^2 \cdot 64^2 / 2 = 819{,}200\) ops.

The two are roughly comparable here. At larger scale (e.g., \(L = 1024, T = 200, d = 4096\)), the quadratic term dominates by orders of magnitude.

Where the redundant work is

Look at the forward pass for the Mini-GPT (Phase 17). At each transformer block, attention computes:

\[Q, K, V = X W_Q, X W_K, X W_V \quad \text{shape } (T, d)$$ $$A = \text{softmax}(QK^\top / \sqrt{d_h}) \quad \text{shape } (T, T)$$ $$\text{output} = A V \quad \text{shape } (T, d)\]

When generating token \(t+1\):

  • We pass the full sequence of length \(L + t + 1\) as input.
  • We compute \(Q, K, V\) for all \(L + t + 1\) positions.
  • But: the \(Q, K, V\) for positions \(0, 1, \ldots, L + t - 1\) are the same as on the previous step. Only position \(L + t\) is new.

So on step \(t+1\) we recompute \(K, V\) for \(L + t\) positions we already computed on step \(t\). This is the waste we want to eliminate.

The KV cache idea (forward reference)

Phase 22 will implement this: store \(K^{(l)}, V^{(l)}\) for every layer \(l\) and every position seen so far. On step \(t+1\), only compute the \(K, V\) for the new position and concatenate to the cache.

# After Phase 22:
def decode_step(model, last_token, cache: KVCache) -> tuple[int, KVCache]:
    """Process one new token, updating the cache."""
    logits, new_cache = model.forward_one(last_token, cache)
    next_token = sample(logits[-1])
    return next_token, new_cache

Cost per step drops from \(O((L + t) \cdot d^2)\) to \(O(d^2)\)constant in \(t\) (ignoring the attention's linear-in-cached-length scan, which is far cheaper than the matmul). Total cost drops to \(O((L + T) \cdot d^2)\).

For our example (\(L = 8, T = 20, d = 64\)): \(28 \cdot 4096 = 114{,}688\) ops. That's ~13× cheaper than the bare-bones \(1{,}474{,}560\) above.

Why Phase 21 doesn't use the cache

Pedagogical reasons:

  1. Feel the cost. Lab 03's decode benchmark on \(T = 50\) tokens will be visibly slow. You should see the curve \(t \cdot\)cost climb.
  2. Cleanliness. Without the cache, the decode loop is the same shape as the training forward. No new state machinery. Phase 22 introduces the state and tests it against the cache-less reference.
  3. Correctness first. The cache is an optimisation; it must produce bit-identical output to the cache-less decode (modulo floating-point reordering). Phase 22's first test will be: cache vs no-cache, do the outputs agree?

Memory cost (preview)

The KV cache stores \(K^{(l)}, V^{(l)}\) for every layer \(l\) and every position. Size:

\[\text{KV cache size} = 2 \cdot n_\text{layers} \cdot (L + T) \cdot d \cdot \text{bytes-per-float}\]

For Mini-GPT (\(n_\text{layers} = 2, d = 64, L + T = 28\), float32 = 4 bytes): \(2 \cdot 2 \cdot 28 \cdot 64 \cdot 4 = 28{,}672\) bytes = 28 KB. Tiny.

For GPT-3 (\(n_\text{layers} = 96, d = 12{,}288, L + T = 2048\), fp16): \(2 \cdot 96 \cdot 2048 \cdot 12{,}288 \cdot 2 = 9.7\) GB. Enormous.

This is why Phase 27 covers KV-cache optimisations: paged attention, multi-query attention, grouped-query attention. The cache is the dominant memory cost of LLM inference at scale.

A note on continuous batching (forward reference)

If you have many requests, you can batch their decode steps: at step \(t\), run a forward pass that processes (request A's token \(t_A\) + request B's token \(t_B\) + ...) together. The catch: requests have different prompt lengths and different generation lengths. Phase 33 ("Inference Serving: From FastAPI to Continuous Batching") covers this in detail.

Phase 21 does not batch. Single-request decode only. The infrastructure for batching is layered on top in Phase 33.

What this file does NOT cover

  • KV cache implementation. Phase 22.
  • Paged attention, MQA, GQA. Phase 27.
  • Continuous batching. Phase 33.

Next: ../lab/00-greedy.md