English · Español
Lab 01 — Break naive softmax, then implement the stable version¶
Goal: see fp32 softmax overflow on a tense-classification logit vector, then implement the stable version, and prove it survives adversarial inputs.
Estimated time: 60–90 minutes.
Prereq: theory
02-softmax-stability.mdread.
What you produce¶
A directory experiments/02-softmax-stability/ containing:
naive.py— naiveexp/sumimplementation.stable.py— your stable implementation ofsoftmax,log_sum_exp,cross_entropy.compare.py— driver script that feeds a battery of adversarial inputs to both and produces a comparison table.results.json— the table.softmax_break.png— visualization of where naive softmax explodes (one row of NaN among otherwise valid outputs).manifest.json.README.md— interpretation.
No src/ module yet. Phase 2 stays in experiments/. These functions will be re-implemented in src/minigrad/numerics.py in Phase 7, when an autograd consumer exists for them.
The §A13 framing¶
Every test vector represents the model's logits for classifying the next-token tense among the five tenses defined in §A13:
Indices 0..4. The "true label" y is the integer index of the correct tense.
TODOs¶
Block A — naive implementation¶
Write naive.py with three functions:
def naive_softmax(x):
e = np.exp(x)
return e / e.sum()
def naive_log_sum_exp(x):
return np.log(np.exp(x).sum())
def naive_cross_entropy(x, y):
return -np.log(naive_softmax(x)[y])
This is the implementation to break.
Block B — stable implementation¶
Write stable.py:
def stable_softmax(x):
# TODO: apply the -max trick from theory/02
...
def log_sum_exp(x):
# TODO: stable log-sum-exp
...
def stable_cross_entropy(x, y):
# TODO: compute directly from logits via log_sum_exp(x) - x[y]
...
Each function must:
- Handle 1D input of any length ≥ 1.
- Handle 2D batched input (
x.shape = (batch, K)) — softmax over last axis. - Handle
-infentries (treat as effectively zero probability post-shift). - Not depend on
scipy; pure NumPy.
Block C — adversarial inputs¶
In compare.py, define and run all of the following:
test_cases = [
("small magnitudes", np.array([0.1, 0.2, 0.3, 0.4, 0.5])),
("mixed magnitudes", np.array([-3.0, 0.0, 1.0, 2.0, 5.0])),
("large positive", np.array([1.0, 92.0, 3.0, 0.0, 2.0])), # adversarial
("large negative", np.array([-100.0, -200.0, -300.0, -400.0, -500.0])),
("all identical", np.array([5.0, 5.0, 5.0, 5.0, 5.0])),
("masked entry", np.array([1.0, -np.inf, 3.0, 0.0, 2.0])),
("single element", np.array([42.0])),
("verb vocabulary", np.zeros(600)), # uniform over §A13 vocabulary
]
For each case, run both naive_softmax and stable_softmax. Record:
- Whether the output contains any NaN.
- The sum of the output (should be
1.0for a valid distribution). - The max element of the output.
- Element-wise relative difference between the two (where naive is valid).
For the cross_entropy battery, fix y = 1 (present simple) and run both versions on each test case.
Output the table as results.json and a markdown table in README.md.
Block D — predict before running¶
In README.md, before pasting your results.json, write your predictions for each test case:
test case predict naive NaN? predict stable NaN? predict CE? small magnitudes No No ~1.50 (compute by hand) mixed magnitudes No No ... large positive YES (NaN) No ~91 (compute via stable: log_sum_exp - x[y])large negative YES (NaN: 0/0) No ~0 (max is at y=1? no, max is at index 0, x[0]=-100; CE = -100 - (-100) = 0; check) ... ... ... ...
Then run, then compare. Where prediction and reality diverged, write a sentence explaining why. This is the highest-leverage learning step of the lab.
Block E — visualization¶
softmax_break.png: a heatmap or row-table visualization showing, for each test case, the naive and stable outputs side by side. NaN entries in red. The visual asymmetry on the "large positive" row is the headline plot of Phase 2.
Block F — gradcheck preview (optional)¶
Verify, at fp64, that stable_softmax(x) agrees with scipy.special.softmax(x) to within 1e-15 on all non-adversarial inputs, and that log_sum_exp(x) matches scipy.special.logsumexp(x). If scipy disagrees on the adversarial cases (it shouldn't — scipy is stable), note it.
Constraints¶
- Pure NumPy. No scipy except as a reference oracle in Block F.
- Predict first. Don't run before you've written down predictions. The whole point is to train the prediction muscle.
- Use a fixed seed for any random test inputs (
np.random.default_rng(42)). State inmanifest.json.
Stop conditions¶
Done when:
naive.py,stable.py,compare.pyexist and run.results.jsonshows naive NaN on at least two cases and stable NaN on zero cases.README.mdcontains your predictions table before the results table, with explanations for any divergence.softmax_break.pngis committed.- You can recite, in one sentence, why the
-maxtrick eliminates overflow without changing the mathematical result.
Pitfalls¶
-infhandling.np.exp(-np.inf) = 0correctly. But-np.inf - (-np.inf) = nan. If your-maxtrick subtractsmax = -inf(because all entries are-inf), you get NaN everywhere. Detect "max is -inf" and return a sentinel (uniform? NaN? document the choice).- Batched max.
x.max()over a 2D array gives a scalar; you wantx.max(axis=-1, keepdims=True). Lab is set up to catch this if you write a batched-aware stable softmax. np.log(np.exp(x).sum())for shifted inputs. After the-maxshift,np.exp(x_shifted).sum()includes the termexp(0) = 1, so it's≥ 1, sonp.log(...)is≥ 0. Then add backmfor the final value. If you forget to addm, yourlog_sum_expwill be wrong by exactlym.- Cross-entropy on the masked case. If the true label
ycorresponds to a-inflogit (a masked position),x[y] = -inf, andlog_sum_exp(x) - x[y] = +inf. That's correct: probability 0 of the truth means infinite loss. But it's a poison pill for training — the masking should never be on the true label. Note inREADME.md.
When to consult solutions/¶
After committing all six files. Solution at solutions/01-softmax-stability-ref.md (written at phase open).
Hint of last resort¶
If your stable softmax keeps NaN-ing on the "all -inf" edge case, the safe thing is:
Discuss this choice in README.md — it's defensible (the input has no signal, so uniform is the principled fallback) but slightly hides the input pathology.
Next lab: lab/02-summation-experiments.md.