Skip to content

English · Español

Lab 00 — Online Softmax in Pure Python

Goal: implement the online softmax recurrence and verify it matches batched softmax on synthetic data, at both fp32 and fp16.

Estimated time: 2–3 hours.

Prereq: theory 01 read and the recurrence re-derivable from memory.


What you produce

A directory experiments/27-online-softmax/ containing:

  • online_softmax.py — pure-Python (NumPy) implementation. Two functions: softmax_batched (the classical formulation) and softmax_online_chunked(chunks, V_chunks, m_init=-inf, ℓ_init=0, O_init=0).
  • test_equivalence.py — verifies the two produce identical outputs on random inputs.
  • results.json — measurements of max-abs-error across {fp32, fp16, bf16} and across chunk sizes {1, 4, 16, 64}.
  • manifest.json.
  • README.md — interpretation.

No src/ deliverable for this lab; the implementation is intentionally throwaway pedagogical code.

TODOs

Block A — implement batched softmax for reference

  • softmax_batched(s: ndarray, V: ndarray) -> ndarray: returns softmax(s) @ V (1D s, 2D V of shape (N, d)). Use the numerically stable form.

Block B — implement the online version

  • Signature: softmax_online(chunks: list[ndarray], V_chunks: list[ndarray]) -> ndarray.
  • Loop over chunks; maintain m, ℓ, O per the recurrence from theory 01.
  • Return final O / ℓ.

Block C — verify equivalence

For each (N, d, dtype, chunk_size) in a small grid:

  • Generate random s ∈ ℝ^N, V ∈ ℝ^{N × d}.
  • Compute batched and online versions.
  • Record max_abs_error = max|O_batched - O_online|.

Expected ranges: - fp32: < 1e-6 (round-off only). - fp16: < 1e-3. - bf16: < 5e-3.

Block D — chunk-size sensitivity

For fp16:

  • Sweep chunk_size ∈ {1, 4, 16, 64, 256}. Plot max_abs_error vs chunk_size.
  • Expected: error is roughly constant across chunk sizes (the recurrence is exact algebraically; only round-off matters, and round-off is bounded by O(N) regardless of chunk size).
  • If error varies sharply with chunk_size, your α rescaling has a bug.

Block E — pathological inputs

Test the recurrence on:

  • s = [60, 0, 0, 0] (one huge value followed by tiny ones). At fp16, exp(60) overflows — but online softmax should handle it because of the running-max subtraction.
  • s = [-100, -100, -100, 0] (one near-zero, rest very negative). Verify softmax_online gives O ≈ V[3].
  • s = all zeros. Uniform attention; output is mean(V).

Block E' — verb-corpus realistic inputs

The verb corpus's vocabulary is small (~600 forms), so attention logits over a 64-token sequence have a very peaked distribution after a few training epochs (the model is confident about each verb in context).

  • Generate s by sampling from N(0, 5) (a peaked-but-not-pathological distribution that mimics post-training attention logits on a small vocab).
  • Set V shape (64, 64) (N = 64 matches the verb-corpus sequence length; d = 64 is a typical head dim for MiniGPT).
  • Verify softmax_online at chunk_size=16 matches batched softmax within fp16 tolerance.
  • Comment in README.md: how does the peaked distribution help or hurt the online recurrence's stability?

Block F — interpret in README.md

Three questions:

  1. What's the worst-case fp16 error you observed? Is it under 1e-3? If not, where did the extra error come from?
  2. Does chunk size affect accuracy? If you observed sensitivity, why?
  3. What happens at fp16 when s contains a value > 11? (exp(11) ≈ half_max.) Does the online recurrence handle it, or do you need fp32 accumulators for m, ℓ, O?

Stop conditions

  • All five files committed.
  • fp32 max_abs_error < 1e-6; fp16 < 1e-3.
  • README answers the three Block F questions.

Pitfalls

  • Off-by-one in chunk loop. If chunks don't perfectly tile s, the last chunk handling can drop entries. Use Python list-of-chunks rather than index arithmetic for clarity.
  • α rescales applied to wrong things. Both and O need α. If you forget one, errors accumulate quadratically with N.
  • exp(m - m') underflows. When m_new ≫ m_old, α can underflow. This is mathematically fine (α → 0 just means old contributions are dwarfed by the new), but if you're computing ℓ * α in fp16 and was small, you may lose it entirely. Use fp32 accumulators for the running state in fp16 implementations.

When to consult solutions/

After all stop conditions met. solutions/00-online-softmax-ref.md (phase open) compares structure and numbers.


Next lab: lab/01-flash-bytes.md.