English · Español
01 — The Online Softmax Recurrence¶
🇪🇸 La clave matemática para Flash: poder calcular
softmax(s) @ Vpor trozos sin haber visto todosantes. Mantienes el máximo corriente y la suma corriente, y al añadir un trozo nuevo, reescalas lo anterior porexp(m_viejo - m_nuevo). Una línea de álgebra, todo el resto del fenómeno depende de ella.
The classical softmax¶
For a vector s ∈ ℝ^N, the numerically stable softmax is:
The subtraction of m keeps exp arguments ≤ 0, so no overflow. This is the numerically stable formulation; the unstable version (just exp(s) / sum(exp(s))) overflows for any s_i > log(float_max) ≈ 88 (fp32) or > log(half_max) ≈ 11 (fp16).
For attention with Q, K of head dim d, a single pre-softmax value s_i = (Q[i, :] · K[j, :]) / √d can easily exceed 11 for fp16. Stable softmax (i.e., subtracting the max) is mandatory, not optional, in fp16 attention.
The catch: the standard formulation requires the full vector s to compute m = max(s). That precludes streaming computation.
The setup: streaming attention¶
In Flash Attention, we don't have the full row of s = Q[i, :] @ K^T ∈ ℝ^N at once. We have it tile-by-tile: chunks s_1, s_2, ..., s_{N/B_c} each of length B_c. For each chunk s_k, we want to update a running O[i, :] ∈ ℝ^d (the partial attention output) such that, after all chunks are consumed, O[i, :] = softmax(s) @ V.
The question: can we update O correctly using only the current chunk and a small amount of running state?
Yes. Here's how.
The recurrence¶
Maintain three pieces of running state for output row i:
m ∈ ℝ— running maximum ofsseen so far.ℓ ∈ ℝ— running denominatorsum(exp(s_seen - m)).O ∈ ℝ^d— running unnormalized outputsum_j exp(s_j - m) · V_j.
When chunk s_new ∈ ℝ^{B_c} (with corresponding V_new ∈ ℝ^{B_c × d}) arrives:
-
New max: $$ m' = \max(m, \max(s_{\text{new}})) $$
-
Rescale the old state to the new max: $$ \alpha = \exp(m - m') $$ The old
ℓandOwere computed relative to the oldm. To put them on the same footing as the new chunk (which we'll compute relative tom'), multiply both byα: $$ ℓ \leftarrow \alpha \cdot ℓ \qquad O \leftarrow \alpha \cdot O $$ -
Add the new chunk's contribution: $$ p_{\text{new}} = \exp(s_{\text{new}} - m') \in \mathbb{R}^{B_c} $$ $$ ℓ \leftarrow ℓ + \sum p_{\text{new}} $$ $$ O \leftarrow O + p_{\text{new}} \cdot V_{\text{new}} \quad \text{(matrix-vector multiply over } B_c \text{ terms)} $$
-
Update
m: $$ m \leftarrow m' $$
At the end, divide once:
That's the entire recurrence. Six lines of pseudo-code; an O(N·d) running computation; mathematically identical to the all-at-once softmax up to fp round-off.
Proof of correctness¶
Claim: after processing all chunks, O / ℓ = softmax(s) @ V.
Let s = [s_1, ..., s_K] be the concatenated chunks (each of length B_c). Let m_global = max(s). By construction, after all chunks, m = m_global (running max accumulates correctly).
After all chunks:
- ℓ = sum_{j=1..N} exp(s_j - m_global).
- O = sum_{j=1..N} exp(s_j - m_global) · V_j.
The all-at-once result is O_all = (sum_j exp(s_j - m_global) · V_j) / sum_j exp(s_j - m_global) = O / ℓ. Same value. □
The only subtle step is the rescaling. Suppose we've processed chunks 1..k and are about to process chunk k+1. Just before processing k+1, our state is:
m = max(s_1, ..., s_k). Call thism_k.ℓ = sum_{j ∈ first k chunks} exp(s_j - m_k).O = sum_{j ∈ first k chunks} exp(s_j - m_k) · V_j.
After observing chunk k+1, the true max is m_{k+1} = max(m_k, max(s_{k+1})). To rebase ℓ to the new max:
ℓ_rebased = sum_{j ∈ first k chunks} exp(s_j - m_{k+1})
= sum_{j ∈ first k chunks} exp(s_j - m_k + m_k - m_{k+1})
= exp(m_k - m_{k+1}) × sum_{j ∈ first k chunks} exp(s_j - m_k)
= α × ℓ
where α = exp(m_k - m_{k+1}) ≤ 1. Same algebra for O. Then add the chunk k+1 contributions (which are already computed relative to m_{k+1}). □
Numerical properties¶
Three observations:
- No overflow. Every
expargument is ≤ 0 by construction (we always subtract the current or new max before exponentiating). Safe in fp16. - No catastrophic cancellation in the rescaling.
α = exp(m_k - m_{k+1}) ≤ 1is bounded; multiplyingℓandOby it can underflow to zero only ifmjumped by more than~88(fp32) or~11(fp16). For attention values, this is rare but possible — the Flash paper's stability discussion handles it via fp32 accumulators form, ℓ, O. - Order of chunks doesn't matter. The final result is invariant to chunk ordering (modulo round-off). This is what makes Flash work inside a tiled kernel: tiles can be processed in any order the scheduler prefers.
A worked example¶
Vector s = [1, 2, 3, 10], V = [[1], [1], [1], [1]] (so each chunk's V is a 1×1 vector; the answer should converge to softmax(s) @ V = 1.0 since all V entries are 1).
Process in two chunks: s_1 = [1, 2], s_2 = [3, 10].
After chunk 1:
- m = 2
- p_1 = exp([1-2, 2-2]) = [exp(-1), 1] ≈ [0.368, 1]
- ℓ = 1.368
- O = 0.368 + 1 = 1.368
After chunk 2:
- max(s_2) = 10, so m' = max(2, 10) = 10.
- α = exp(2 - 10) = exp(-8) ≈ 3.35e-4
- Rescale: ℓ = 0.000458, O = 0.000458.
- p_2 = exp([3-10, 10-10]) = [exp(-7), 1] ≈ [9.12e-4, 1]
- ℓ = 0.000458 + 0.000912 + 1 = 1.00137
- O = 0.000458 + 0.000912 + 1 = 1.00137
Final: O / ℓ = 1.00137 / 1.00137 = 1.0. ✓
Compare to all-at-once: softmax([1, 2, 3, 10]) ≈ [1.23e-4, 3.34e-4, 9.08e-4, 0.9985]. Dot with [1, 1, 1, 1] = 1.0. Same answer.
What this enables¶
Once we have the online softmax recurrence, the rest of Flash is "just" tiling. We process the attention matrix S = QK^T one tile at a time, never materializing the whole thing in HBM. For each output row of O, we maintain a tiny amount of state (m, ℓ, O) and update it as K, V tiles flow in.
The recurrence is the mathematical reason Flash is correct. The tiling (next theory file) is the algorithmic reason it's fast.
Drill problems¶
Solutions at phase open in solutions/01-online-softmax-ref.md. Try without code.
- Compute the online softmax on
s = [0, 5]withV = [[2], [3]]chunked ass_1 = [0],s_2 = [5]. Show all four steps. - The recurrence requires
α = exp(m_k - m_{k+1}). Under what conditions doesαunderflow to 0 in fp16? What goes wrong inℓif it does? (Hint: nothing goes wrong — the math is still correct in the limit.) - Show that processing chunks in reverse order produces the same final
O/ℓ(modulo fp round-off). Sketch the argument; don't simulate. - Suppose you parallelize the recurrence across
Pthreads, each handling1/Pof the chunks, then reduce at the end. What's the reduction operation? Show it's associative.
One-paragraph recap¶
The online softmax recurrence maintains a running max m, running denominator ℓ, and running unnormalized output O while processing chunks of s and V one at a time. When a new chunk arrives, rescale the old ℓ, O by α = exp(m_old - m_new) to align them with the new max, then add the new chunk's contributions. Final answer is O / ℓ. The algebra is one line, the computation is O(N·d) without ever materializing the full softmax, and the result is exact up to floating-point round-off. This is the mathematical key that makes Flash Attention possible — without it, you couldn't process attention in tiles.
Next: theory/02-flash-attention.md.