Skip to content

English · Español

02 — Memory Cost of the KV Cache

🇪🇸 La fórmula bytes = 2 · L · H · d_h · S · B · s no es una receta a memorizar — es una cuenta de "qué guardo y para cuántos". Derivarla cada vez te protege contra los errores de "olvidé el factor 2" o "confundí d con d_h", que son los dos errores más comunes en este tema.

This is the algebraic centerpiece of Phase 22. We derive the cache size formula from first principles, then apply it to real models, then derive what changing each factor (L, H, d_h, S, B, s) buys you.


The formula, derived by counting

The cache stores, per transformer layer, two tensors: \(K\) and \(V\). After \(S\) tokens have been processed, each of these is shaped:

\[K, V \in \mathbb{R}^{B \times H \times S \times d_h}\]
  • \(B\): batch size (number of concurrent sequences sharing this cache instance).
  • \(H\): number of attention heads in this layer.
  • \(S\): number of tokens whose K, V are currently stored.
  • \(d_h\): head dimension (note: \(H \cdot d_h = d\), the model dim).

The number of elements in \(K\) alone is \(B \cdot H \cdot S \cdot d_h\). The number in \(V\) is the same. Per layer, \(K + V\) together hold \(2 B H S d_h\) elements.

We have \(L\) layers, each with its own independent cache. Layers do not share K, V — each layer's K, V are produced by that layer's \(W_K, W_V\) projections from that layer's input. Total elements:

\[\text{elements} = 2 \cdot L \cdot B \cdot H \cdot S \cdot d_h\]

Multiply by bytes per element \(s\):

\[\boxed{\text{bytes}_\text{cache} = 2 \cdot L \cdot H \cdot d_h \cdot S \cdot B \cdot s}\]

Conventionally written with \(L, H, d_h\) first because those are model-architecture constants, and \(S, B\) last because those are runtime knobs. \(s\) depends on dtype:

dtype \(s\)
fp64 8
fp32 4
fp16 / bf16 2
int8 1
int4 (packed) 0.5

Sanity check: an alternative derivation

Some references write the formula with model dim \(d\) instead of \(H \cdot d_h\):

\[\text{bytes}_\text{cache} = 2 \cdot L \cdot d \cdot S \cdot B \cdot s \quad \text{(since } d = H \cdot d_h\text{)}\]

Both are correct. The \(H \cdot d_h\) form is useful when reasoning about Grouped-Query Attention (Phase 27), where K and V are shared across head-groups: GQA changes \(H\) in the cache (to the number of key-value heads, \(H_{KV} < H\)) but does not change it in \(Q\) (still \(H\) heads). So the cache formula becomes \(2 L H_{KV} d_h S B s\), while the attention compute still uses \(H\) heads. The \(d\)-form hides this; the \(H \cdot d_h\) form makes it visible.

Per-token marginal cost

How many bytes does each additional generated token add to the cache?

Differentiate (well — take the difference, since \(S\) is discrete):

\[\Delta\text{bytes} = 2 \cdot L \cdot H \cdot d_h \cdot B \cdot s \quad \text{(constant in } S \text{)}\]

The cache grows by a constant number of bytes per token — not dependent on the current cache size. This is what "linear-in-context memory" means quantitatively.

For Llama-2-7B (fp16, batch 1): $\(\Delta\text{bytes} = 2 \cdot 32 \cdot 32 \cdot 128 \cdot 1 \cdot 2 = 524288 = 512 \text{ KiB per token}\)$

Half a megabyte per token. A 4096-token context = 2 GiB.

A scaling table: cache size for real models

Memorize the shape of this table; it shows up in every serving-system paper.

Model \(L\) \(H\) \(d_h\) \(d\) dtype Per-token 4k ctx 32k ctx 128k ctx
Grammar MiniGPT (Phase 17 default, §A13) 4 4 16 64 fp32 512 B 2 MiB 16 MiB 64 MiB
GPT-2 small 12 12 64 768 fp16 36 KiB 144 MiB 1.1 GiB 4.5 GiB
Llama-2-7B 32 32 128 4096 fp16 512 KiB 2 GiB 16 GiB 64 GiB
Llama-2-13B 40 40 128 5120 fp16 800 KiB 3.1 GiB 25 GiB 100 GiB
Llama-2-70B (MHA, no GQA — counterfactual) 80 64 128 8192 fp16 2.5 MiB 10 GiB 80 GiB 320 GiB
Llama-2-70B (GQA, \(H_{KV}=8\)) 80 8 128 8192 fp16 320 KiB 1.25 GiB 10 GiB 40 GiB
GPT-3 175B 96 96 128 12288 fp16 4.5 MiB 18 GiB 144 GiB 576 GiB

(MiniGPT row: derivable from Phase 17's config — confirm at phase open in case the config changed.)

Things to notice while reading the table:

  1. GQA is not a tweak. Going from 70B-MHA-counterfactual to 70B-GQA shrinks the cache 8×. Without GQA, serving Llama-2-70B at long context on a single H100 is impossible. This is one of three reasons GQA exists; the other two are inference latency (less to read each step) and memory-bandwidth utilization.
  2. Long context is not "free with enough RAM". 128k context on a 70B model is 40 GiB per sequence. A single A100 (80 GB) holds the model weights and one sequence's cache — barely. Concurrent users break this.
  3. Grammar MiniGPT's cache is trivial. That's the point: Phase 22 stays at a scale where Borja can verify every byte by hand. The longest realistic sentence in the §A13 corpus is ~10 tokens ("Tomorrow he is going to study and finish"); cache at that length is well under 1 KiB total. Phase 24 moves to a scale where measurement replaces enumeration.

Where the cache lives in memory

Implementation choices:

  1. One big tensor per layer, pre-allocated to S_max. Shape \((B, H, S_\text{max}, d_h)\). Write into slices [..., :S, :] as \(S\) grows. Memory: constant. Fragmentation: none (one contiguous block per layer). Wasted bytes: \(S_\text{max} - S_\text{current}\) rows per layer. This is what Phase 22 implements.
  2. A list of growing tensors. Each cache.append(k) does K = np.concatenate([K, k_new]). Memory: variable. Fragmentation: heap thrash. Cost per append: O(S) — destroys the linear-per-step decode. Do not do this.
  3. Paged: a list of fixed-size blocks per sequence, indexed by a "block table". Memory: constant per block, variable in #blocks. Fragmentation: only at block boundaries (small). This is PagedAttention; Phase 27.

Phase 22 uses (1). The wasted bytes are the price for keeping the implementation small and the math transparent. With \(S_\text{max} = 64\) on the grammar MiniGPT, the cache is 32 KiB total — fits in L1. We don't care.

Two equations to internalize

Both follow trivially from the boxed formula above; both should be on the tip of your tongue.

1. Cache size doubles when you double context. $\(\text{bytes}(2S) = 2 \cdot \text{bytes}(S)\)$ "Going from 4k to 8k context doubles the cache" is not an empirical observation; it's algebra. Anyone who says it casually as if it were a surprise has not internalized the formula.

2. Cache size at fixed bytes-budget gives a context ceiling. $\(S_\text{max} = \frac{\text{bytes}_\text{budget}}{2 \cdot L \cdot H \cdot d_h \cdot B \cdot s}\)$

For Llama-2-7B fp16 on a 40 GB A100 with 14 GiB taken by weights, leaving 26 GiB for cache, batch 1: $\(S_\text{max} = \frac{26 \cdot 2^{30}}{2 \cdot 32 \cdot 32 \cdot 128 \cdot 1 \cdot 2} = \frac{27.9 \cdot 10^9}{524288} \approx 53200 \text{ tokens}\)$

So a single A100 can serve Llama-2-7B at ~53k context, one user. Bump batch to 16, and \(S_\text{max}\) drops to ~3300 tokens. That's the exact tradeoff serving systems navigate.

Drill problems

Solutions in solutions/02-memory-cost-ref.md (not visible during pre-write). Work these before lab.

  1. Llama-2-7B, fp16, batch=4, S=8192. Cache size in GiB?
  2. Mistral-7B uses GQA with \(H_{KV}=8\), otherwise same config as Llama-2-7B. Same situation (batch=4, S=8192): cache size in GiB?
  3. Quantize the cache to int8. Re-do (1). How does this affect accuracy (qualitative)? (Hint: K and V are post-rotary, post-projection — they're activations, not weights. Quantization noise compounds across layers.)
  4. Sliding-window attention keeps only the last \(W = 1024\) tokens of cache. Llama-2-7B fp16, batch=1, ctx=32k. Cache size?
  5. Reverse the formula. You have a 24 GiB GPU; 10 GiB are model weights and forward activations. Your model is GPT-2 small (config above), fp16, batch=8. What's the maximum context \(S_\text{max}\)?

If those five are mechanical for you, the formula has landed.

What this page does NOT cover

  • Why decode hits the memory ceiling on real hardware. Arithmetic intensity argument is theory/03-decode-as-memory-bound.md.
  • Paged cache memory layout. theory/04-toward-paged-attention.md previews; Phase 27 implements.
  • Int8 / fp16 cache numerics. Bytes formula scales linearly via \(s\); the accuracy impact is Phase 26 (quantization).
  • GPU memory-bandwidth ceilings (HBM vs SRAM vs registers). Phase 23. The size formula is hardware-independent; how fast you can read it isn't.

Next: theory/03-decode-as-memory-bound.md — the arithmetic intensity of the decode attention, and why decoding from cache is a memory bandwidth problem.