Skip to content

English · Español

01 — The Online Softmax Recurrence

🇪🇸 La clave matemática para Flash: poder calcular softmax(s) @ V por trozos sin haber visto todo s antes. Mantienes el máximo corriente y la suma corriente, y al añadir un trozo nuevo, reescalas lo anterior por exp(m_viejo - m_nuevo). Una línea de álgebra, todo el resto del fenómeno depende de ella.


The classical softmax

For a vector s ∈ ℝ^N, the numerically stable softmax is:

m = max(s)
p = exp(s - m) / sum(exp(s - m))

The subtraction of m keeps exp arguments ≤ 0, so no overflow. This is the numerically stable formulation; the unstable version (just exp(s) / sum(exp(s))) overflows for any s_i > log(float_max) ≈ 88 (fp32) or > log(half_max) ≈ 11 (fp16).

For attention with Q, K of head dim d, a single pre-softmax value s_i = (Q[i, :] · K[j, :]) / √d can easily exceed 11 for fp16. Stable softmax (i.e., subtracting the max) is mandatory, not optional, in fp16 attention.

The catch: the standard formulation requires the full vector s to compute m = max(s). That precludes streaming computation.

The setup: streaming attention

In Flash Attention, we don't have the full row of s = Q[i, :] @ K^T ∈ ℝ^N at once. We have it tile-by-tile: chunks s_1, s_2, ..., s_{N/B_c} each of length B_c. For each chunk s_k, we want to update a running O[i, :] ∈ ℝ^d (the partial attention output) such that, after all chunks are consumed, O[i, :] = softmax(s) @ V.

The question: can we update O correctly using only the current chunk and a small amount of running state?

Yes. Here's how.

The recurrence

Maintain three pieces of running state for output row i:

  • m ∈ ℝ — running maximum of s seen so far.
  • ℓ ∈ ℝ — running denominator sum(exp(s_seen - m)).
  • O ∈ ℝ^d — running unnormalized output sum_j exp(s_j - m) · V_j.

When chunk s_new ∈ ℝ^{B_c} (with corresponding V_new ∈ ℝ^{B_c × d}) arrives:

  1. New max: $$ m' = \max(m, \max(s_{\text{new}})) $$

  2. Rescale the old state to the new max: $$ \alpha = \exp(m - m') $$ The old and O were computed relative to the old m. To put them on the same footing as the new chunk (which we'll compute relative to m'), multiply both by α: $$ ℓ \leftarrow \alpha \cdot ℓ \qquad O \leftarrow \alpha \cdot O $$

  3. Add the new chunk's contribution: $$ p_{\text{new}} = \exp(s_{\text{new}} - m') \in \mathbb{R}^{B_c} $$ $$ ℓ \leftarrow ℓ + \sum p_{\text{new}} $$ $$ O \leftarrow O + p_{\text{new}} \cdot V_{\text{new}} \quad \text{(matrix-vector multiply over } B_c \text{ terms)} $$

  4. Update m: $$ m \leftarrow m' $$

At the end, divide once:

\[ O_{\text{final}} = O / ℓ \]

That's the entire recurrence. Six lines of pseudo-code; an O(N·d) running computation; mathematically identical to the all-at-once softmax up to fp round-off.

Proof of correctness

Claim: after processing all chunks, O / ℓ = softmax(s) @ V.

Let s = [s_1, ..., s_K] be the concatenated chunks (each of length B_c). Let m_global = max(s). By construction, after all chunks, m = m_global (running max accumulates correctly).

After all chunks: - ℓ = sum_{j=1..N} exp(s_j - m_global). - O = sum_{j=1..N} exp(s_j - m_global) · V_j.

The all-at-once result is O_all = (sum_j exp(s_j - m_global) · V_j) / sum_j exp(s_j - m_global) = O / ℓ. Same value. □

The only subtle step is the rescaling. Suppose we've processed chunks 1..k and are about to process chunk k+1. Just before processing k+1, our state is:

  • m = max(s_1, ..., s_k). Call this m_k.
  • ℓ = sum_{j ∈ first k chunks} exp(s_j - m_k).
  • O = sum_{j ∈ first k chunks} exp(s_j - m_k) · V_j.

After observing chunk k+1, the true max is m_{k+1} = max(m_k, max(s_{k+1})). To rebase to the new max:

ℓ_rebased = sum_{j ∈ first k chunks} exp(s_j - m_{k+1})
          = sum_{j ∈ first k chunks} exp(s_j - m_k + m_k - m_{k+1})
          = exp(m_k - m_{k+1}) × sum_{j ∈ first k chunks} exp(s_j - m_k)
          = α × ℓ

where α = exp(m_k - m_{k+1}) ≤ 1. Same algebra for O. Then add the chunk k+1 contributions (which are already computed relative to m_{k+1}). □

Numerical properties

Three observations:

  1. No overflow. Every exp argument is ≤ 0 by construction (we always subtract the current or new max before exponentiating). Safe in fp16.
  2. No catastrophic cancellation in the rescaling. α = exp(m_k - m_{k+1}) ≤ 1 is bounded; multiplying and O by it can underflow to zero only if m jumped by more than ~88 (fp32) or ~11 (fp16). For attention values, this is rare but possible — the Flash paper's stability discussion handles it via fp32 accumulators for m, ℓ, O.
  3. Order of chunks doesn't matter. The final result is invariant to chunk ordering (modulo round-off). This is what makes Flash work inside a tiled kernel: tiles can be processed in any order the scheduler prefers.

A worked example

Vector s = [1, 2, 3, 10], V = [[1], [1], [1], [1]] (so each chunk's V is a 1×1 vector; the answer should converge to softmax(s) @ V = 1.0 since all V entries are 1).

Process in two chunks: s_1 = [1, 2], s_2 = [3, 10].

After chunk 1: - m = 2 - p_1 = exp([1-2, 2-2]) = [exp(-1), 1] ≈ [0.368, 1] - ℓ = 1.368 - O = 0.368 + 1 = 1.368

After chunk 2: - max(s_2) = 10, so m' = max(2, 10) = 10. - α = exp(2 - 10) = exp(-8) ≈ 3.35e-4 - Rescale: ℓ = 0.000458, O = 0.000458. - p_2 = exp([3-10, 10-10]) = [exp(-7), 1] ≈ [9.12e-4, 1] - ℓ = 0.000458 + 0.000912 + 1 = 1.00137 - O = 0.000458 + 0.000912 + 1 = 1.00137

Final: O / ℓ = 1.00137 / 1.00137 = 1.0. ✓

Compare to all-at-once: softmax([1, 2, 3, 10]) ≈ [1.23e-4, 3.34e-4, 9.08e-4, 0.9985]. Dot with [1, 1, 1, 1] = 1.0. Same answer.

What this enables

Once we have the online softmax recurrence, the rest of Flash is "just" tiling. We process the attention matrix S = QK^T one tile at a time, never materializing the whole thing in HBM. For each output row of O, we maintain a tiny amount of state (m, ℓ, O) and update it as K, V tiles flow in.

The recurrence is the mathematical reason Flash is correct. The tiling (next theory file) is the algorithmic reason it's fast.

Drill problems

Solutions at phase open in solutions/01-online-softmax-ref.md. Try without code.

  1. Compute the online softmax on s = [0, 5] with V = [[2], [3]] chunked as s_1 = [0], s_2 = [5]. Show all four steps.
  2. The recurrence requires α = exp(m_k - m_{k+1}). Under what conditions does α underflow to 0 in fp16? What goes wrong in if it does? (Hint: nothing goes wrong — the math is still correct in the limit.)
  3. Show that processing chunks in reverse order produces the same final O/ℓ (modulo fp round-off). Sketch the argument; don't simulate.
  4. Suppose you parallelize the recurrence across P threads, each handling 1/P of the chunks, then reduce at the end. What's the reduction operation? Show it's associative.

One-paragraph recap

The online softmax recurrence maintains a running max m, running denominator , and running unnormalized output O while processing chunks of s and V one at a time. When a new chunk arrives, rescale the old ℓ, O by α = exp(m_old - m_new) to align them with the new max, then add the new chunk's contributions. Final answer is O / ℓ. The algebra is one line, the computation is O(N·d) without ever materializing the full softmax, and the result is exact up to floating-point round-off. This is the mathematical key that makes Flash Attention possible — without it, you couldn't process attention in tiles.

Next: theory/02-flash-attention.md.