Skip to content

English · Español

01 — Prefill vs Decode: Two Phases of One Forward

🇪🇸 Inferencia autorregresiva = una pasada paralela inicial (prefill) sobre el prompt entero, seguida de muchas pasadas seriales (decode), una por token nuevo. La asimetría entre las dos es lo que organiza todo el resto de la infraestructura de servir LLMs.

This page formalizes what "prefill" and "decode" actually are at the operator level — same model, different compute shape. Once the shape is named, the costs follow algebraically.


Setup

Fix a causal transformer with \(L\) layers, \(H\) attention heads per layer, head dim \(d_h\), model dim \(d = H \cdot d_h\), hidden FFN dim \(d_\text{ffn} \approx 4d\), vocab size \(V\), batch size \(B\).

For the grammar MiniGPT (§A13 corpus): \(L = 4\), \(H = 4\), \(d_h = 16\), \(d = 64\), \(d_\text{ffn} = 256\), \(V \approx 600\), \(B = 1\). Trivial numbers — every formula on this page can be evaluated on the back of a napkin.

A user submits a prompt of length \(P\). The system must produce \(D\) new tokens. Total sequence at the end: \(S = P + D\). For the running example "Yesterday I worked and he": \(P = 2\) (prefill "Yesterday I"), \(D = 3\) ("worked", "and", "he"), \(S = 5\).

Prefill phase

Input: \(P\) tokens (the prompt). Output: the hidden state for the last prompt token (used to sample the first new token) and a populated KV cache of length \(P\) at every layer.

Per layer, prefill computes:

  1. \(X \in \mathbb{R}^{P \times d}\) — embedded inputs.
  2. \(Q = X W_Q\), \(K = X W_K\), \(V = X W_V\). Each is \(P \times d\). Linear projections: \(P \cdot d^2\) FLOPs each, three times. Total: \(3 P d^2\) FLOPs.
  3. Attention: \(A = \text{softmax}(QK^\top / \sqrt{d_h} + M) V\), where \(M\) is the causal mask. The \(QK^\top\) is \(P \times P\) per head; \(P \cdot d_h \cdot P\) FLOPs per head; \(H\) heads \(\to\) \(P^2 \cdot d\) FLOPs. The \(\cdot V\) is another \(P^2 \cdot d\). Total attention: \(2 P^2 d\) FLOPs.
  4. Output projection \(A W_O\): \(P \cdot d^2\) FLOPs.
  5. FFN: two matmuls \((P \times d) \cdot (d \times d_\text{ffn}) \cdot (d_\text{ffn} \times d)\). With \(d_\text{ffn} = 4d\): \(8 P d^2\) FLOPs.

Per-layer total: \(3 P d^2 + 2 P^2 d + P d^2 + 8 P d^2 = 12 P d^2 + 2 P^2 d\).

Across \(L\) layers: \(\boxed{F_\text{prefill} \approx 12 L P d^2 + 2 L P^2 d}\).

Two regimes: - Short prompt (\(P \ll d\)): dominated by \(12 L P d^2\) — linear in \(P\), quadratic in \(d\). Matmul-bound. - Long prompt (\(P \gtrsim d\)): the \(2 L P^2 d\) attention term dominates. Quadratic in \(P\).

For Llama-2-7B (\(L=32\), \(d=4096\)): the two terms cross at \(P \approx 6d = 24576\). Below that, FFNs are most of the cost; above, attention is. For a 4096-token prompt, FFN ~85%, attention ~15%. For our grammar MiniGPT (\(L=4\), \(d=64\), typical \(P=2\)): \(12 \cdot 4 \cdot 2 \cdot 64^2 = 393\) K FLOPs — i.e. nothing. The prefill of "Yesterday I" takes microseconds. The cache machinery being correct still matters: the same code runs at 175 B parameters.

Memory traffic. Prefill reads weights (\(\sim 12 L d^2\) bytes) once, reads \(X\) once. The activations \(Q, K, V, A\) are \(P \cdot d\) each and live in cache or spill to DRAM. Arithmetic intensity = \(F_\text{prefill}\) / bytes_moved \(\approx P\) — grows with \(P\). Prefill is compute-bound for any non-trivial prompt.

Decode phase

Input at step \(t\): one new token, plus the populated cache from steps \(1..t-1\). Current cache length: \(S = P + t - 1\). Output: the hidden state of the new token, and the cache extended to length \(S+1\).

Per layer, per decode step, we compute:

  1. \(x \in \mathbb{R}^{1 \times d}\) — the embedded new token.
  2. \(q = x W_Q\), \(k_\text{new} = x W_K\), \(v_\text{new} = x W_V\). Each is \(1 \times d\). Cost: \(3 d^2\) FLOPs per layer.
  3. Append \(k_\text{new}\) and \(v_\text{new}\) to the cache. The cache becomes \(K \in \mathbb{R}^{(S+1) \times d}\), \(V \in \mathbb{R}^{(S+1) \times d}\).
  4. Attention: \(a = \text{softmax}(q K^\top / \sqrt{d_h}) V\). Here \(q\) is \(1 \times d_h\) per head; \(K\) is \((S+1) \times d_h\) per head. The \(qK^\top\) is \(1 \times (S+1)\) per head; \(d_h \cdot (S+1)\) FLOPs per head; \(H\) heads \(\to (S+1) \cdot d\) FLOPs. The \(\cdot V\) adds another \((S+1) \cdot d\). No mask needed — the query is length 1, cache contains only past tokens (and the just-appended current one, which is fine: it's the diagonal of the causal mask). Total attention: \(2(S+1) d\) FLOPs.
  5. Output projection: \(d^2\) FLOPs.
  6. FFN: \(8 d^2\) FLOPs.

Per-layer per-step: \(12 d^2 + 2(S+1) d\).

Across \(L\) layers, per decode step: \(\boxed{F_\text{decode-step} \approx 12 L d^2 + 2 L S d}\).

Summing over \(D\) decode steps, with \(S\) growing from \(P\) to \(P+D-1\):

\[F_\text{decode-total} = \sum_{t=0}^{D-1} \left[ 12 L d^2 + 2 L (P+t) d \right] = 12 L D d^2 + 2 L d \left[ DP + \frac{D(D-1)}{2} \right]\]

For long generation (\(D \gg P\), \(D \gg d\)): the \(L D^2 d\) term dominates. Quadratic in tokens generated, even with the cache. The cache turned cubic (\(\Theta(D^3)\) without cache) into quadratic (\(\Theta(D^2)\) with cache).

The no-cache disaster (sanity check)

Without the cache, each decode step re-runs prefill on \(S\) tokens. Step \(t\) costs \(F_\text{prefill}(S=P+t)\)\(12 L (P+t) d^2 + 2 L (P+t)^2 d\). Summing:

\[F_\text{no-cache} \approx \sum_{t=0}^{D-1} \left[ 12 L (P+t) d^2 + 2 L (P+t)^2 d \right] = O(L D P d^2) + O(L D^3 d)\]

That's \(\Theta(D^3)\) in the long-generation regime — cubic in generated tokens. The cache buys you a factor of \(D\). For \(D = 1000\), that's 1000×. This is not a minor optimization; this is the difference between "usable inference" and "no inference".

Memory traffic asymmetry

This is the most important table in Phase 22. Memorize it.

Quantity Prefill Decode (per step)
FLOPs \(12 L P d^2 + 2 L P^2 d\) \(12 L d^2 + 2 L S d\)
Bytes moved (weights) \(\sim 12 L d^2\) (once) \(\sim 12 L d^2\) (every step!)
Bytes moved (cache) none (cache empty at start) \(\sim 2 L S d \cdot s\) per step
Arithmetic intensity \(\sim P\) (grows with prompt) \(\sim O(1)\) — actually ~0.5 FLOPs/byte
Bottleneck Compute (FLOPS) Memory bandwidth (BW)

Look at the weights row. Decode re-reads the full weight matrix every single step, even though it only does \(O(d^2)\) FLOPs against it (not \(O(P d^2)\)). That makes decode's arithmetic intensity for the FFN layers fixed at ~1 (1 FLOP per byte loaded) — independent of model size or context length.

This is why a 70 B-parameter model can generate at maybe 10 tokens/sec on an A100, while the GPU's nominal 312 TFLOPS sit ~99% idle. The whole memory hierarchy is the rate limiter. (The grammar MiniGPT decoding "Yesterday I worked" is also memory-bound, but at a scale where you cannot see the bottleneck — caches all fit in L1. The DoD's cost-curve experiment forces you to extrapolate to a regime where it would matter.)

It's also why batching helps decode dramatically. If 16 users are decoding concurrently against the same model, the weight read is amortized 16-fold — the GPU reads the FFN matrix once per step and applies it to 16 queries. Arithmetic intensity goes from 1 to ~16. This is the entire business model of LLM-serving infrastructure (Phase 28).

What the dichotomy tells you about every optimization you'll meet

Use this as a decoder ring:

  • Flash-Attention — restructures the prefill attention to avoid materializing the \(P \times P\) matrix. Wins for prefill on long \(P\). Doesn't help decode (no \(P \times P\) to avoid).
  • Flash-Decoding — restructures the decode attention to better parallelize the row-against-cache step. Wins for decode at long \(S\). Doesn't help prefill.
  • PagedAttention — restructures the cache layout to handle variable-length batched serving. Affects memory layout in decode. Wins when serving many users with different sequence lengths.
  • Continuous batching — decouples per-user prefill from a shared decode batch. Reorganizes the schedule, not the math.
  • GQA / MQA — reduces \(H\) (or shares K, V across heads). Linearly shrinks the cache. Linearly improves decode (less to read), barely touches prefill (compute-bound there anyway).
  • Quantization (int8 cache) — halves \(s\). Linearly shrinks cache memory and decode memory traffic. Touches prefill only insofar as it lowers weight bytes.
  • Speculative decoding — does multiple decode steps "for free" by validating cheap-model guesses in parallel. Reorganizes decode to look more like prefill (parallel rather than serial).

Every one of these is named after a symptom. The symptoms are entries in the table above.

Pseudo-pseudocode for Phase 22

def generate(prompt, max_new_tokens):
    # Prefill: parallel pass over the prompt.
    cache = KVCache.allocate(...)
    h = embed(prompt)
    for layer in layers:
        q, k, v = layer.qkv(h)
        cache.append(layer_idx, k, v)        # cache filled to length P
        h = layer.attention(q, k, v, causal_mask=True)
        h = layer.ffn(h)
    next_token = sample(h[-1])               # h[-1] is the last prompt token

    # Decode: one new token at a time.
    for step in range(max_new_tokens):
        h = embed(next_token).unsqueeze(0)   # shape: (1, d)
        for layer in layers:
            q, k_new, v_new = layer.qkv(h)
            cache.append(layer_idx, k_new, v_new)   # cache grows by 1
            K, V = cache.read(layer_idx)            # the full cached K, V
            h = layer.attention(q, K, V, causal_mask=False)  # q is len 1; no mask needed
            h = layer.ffn(h)
        next_token = sample(h[0])
        yield next_token

This is essentially what Borja will implement in lab/01-implement-cache.md. The two for layer in layers loops are the same code — they differ only in the shape of \(Q\). The cache layer is what makes that shape-change work without quadratic-per-step cost.

What this page does NOT cover

  • Bytes-of-cache derivation. Sketched in the table; full derivation in theory/02-memory-cost.md.
  • Arithmetic-intensity argument for decode. Asserted as ~0.5 FLOPs/byte; derived in theory/03-decode-as-memory-bound.md.
  • GPU specifics. \(F\) and bytes formulas are hardware-independent. Phase 23 maps them onto SMs, warps, HBM.
  • Backward pass / training. Cache is decode-only. Training uses full prefill-shape attention every step.

Next: theory/02-memory-cost.md — derive the bytes formula and apply it to grammar MiniGPT, Llama-2-7B, and GPT-3.