English · Español
Lab 00 — Online Softmax in Pure Python¶
Goal: implement the online softmax recurrence and verify it matches batched softmax on synthetic data, at both fp32 and fp16.
Estimated time: 2–3 hours.
Prereq: theory 01 read and the recurrence re-derivable from memory.
What you produce¶
A directory experiments/27-online-softmax/ containing:
online_softmax.py— pure-Python (NumPy) implementation. Two functions:softmax_batched(the classical formulation) andsoftmax_online_chunked(chunks, V_chunks, m_init=-inf, ℓ_init=0, O_init=0).test_equivalence.py— verifies the two produce identical outputs on random inputs.results.json— measurements of max-abs-error across {fp32, fp16, bf16} and across chunk sizes {1, 4, 16, 64}.manifest.json.README.md— interpretation.
No src/ deliverable for this lab; the implementation is intentionally throwaway pedagogical code.
TODOs¶
Block A — implement batched softmax for reference¶
-
softmax_batched(s: ndarray, V: ndarray) -> ndarray: returnssoftmax(s) @ V(1Ds, 2DVof shape(N, d)). Use the numerically stable form.
Block B — implement the online version¶
- Signature:
softmax_online(chunks: list[ndarray], V_chunks: list[ndarray]) -> ndarray. - Loop over chunks; maintain
m, ℓ, Oper the recurrence from theory 01. - Return final
O / ℓ.
Block C — verify equivalence¶
For each (N, d, dtype, chunk_size) in a small grid:
- Generate random
s ∈ ℝ^N,V ∈ ℝ^{N × d}. - Compute batched and online versions.
- Record
max_abs_error = max|O_batched - O_online|.
Expected ranges:
- fp32: < 1e-6 (round-off only).
- fp16: < 1e-3.
- bf16: < 5e-3.
Block D — chunk-size sensitivity¶
For fp16:
- Sweep chunk_size ∈ {1, 4, 16, 64, 256}. Plot max_abs_error vs chunk_size.
- Expected: error is roughly constant across chunk sizes (the recurrence is exact algebraically; only round-off matters, and round-off is bounded by
O(N)regardless of chunk size). - If error varies sharply with chunk_size, your
αrescaling has a bug.
Block E — pathological inputs¶
Test the recurrence on:
-
s = [60, 0, 0, 0](one huge value followed by tiny ones). At fp16,exp(60)overflows — but online softmax should handle it because of the running-max subtraction. -
s = [-100, -100, -100, 0](one near-zero, rest very negative). Verifysoftmax_onlinegivesO ≈ V[3]. -
s = all zeros. Uniform attention; output ismean(V).
Block E' — verb-corpus realistic inputs¶
The verb corpus's vocabulary is small (~600 forms), so attention logits over a 64-token sequence have a very peaked distribution after a few training epochs (the model is confident about each verb in context).
- Generate
sby sampling fromN(0, 5)(a peaked-but-not-pathological distribution that mimics post-training attention logits on a small vocab). - Set
Vshape(64, 64)(N = 64matches the verb-corpus sequence length;d = 64is a typical head dim for MiniGPT). - Verify
softmax_onlineat chunk_size=16 matches batched softmax within fp16 tolerance. - Comment in
README.md: how does the peaked distribution help or hurt the online recurrence's stability?
Block F — interpret in README.md¶
Three questions:
- What's the worst-case fp16 error you observed? Is it under
1e-3? If not, where did the extra error come from? - Does chunk size affect accuracy? If you observed sensitivity, why?
- What happens at fp16 when
scontains a value > 11? (exp(11) ≈ half_max.) Does the online recurrence handle it, or do you need fp32 accumulators form, ℓ, O?
Stop conditions¶
- All five files committed.
- fp32 max_abs_error <
1e-6; fp16 <1e-3. - README answers the three Block F questions.
Pitfalls¶
- Off-by-one in chunk loop. If chunks don't perfectly tile
s, the last chunk handling can drop entries. Use Python list-of-chunks rather than index arithmetic for clarity. αrescales applied to wrong things. BothℓandOneedα. If you forget one, errors accumulate quadratically with N.exp(m - m')underflows. Whenm_new ≫ m_old,αcan underflow. This is mathematically fine (α → 0just means old contributions are dwarfed by the new), but if you're computingℓ * αin fp16 andℓwas small, you may lose it entirely. Use fp32 accumulators for the running state in fp16 implementations.
When to consult solutions/¶
After all stop conditions met. solutions/00-online-softmax-ref.md (phase open) compares structure and numbers.
Next lab: lab/01-flash-bytes.md.