Skip to content

English · Español

Lab 01 — Break naive softmax, then implement the stable version

Goal: see fp32 softmax overflow on a tense-classification logit vector, then implement the stable version, and prove it survives adversarial inputs.

Estimated time: 60–90 minutes.

Prereq: theory 02-softmax-stability.md read.


What you produce

A directory experiments/02-softmax-stability/ containing:

  • naive.py — naive exp/sum implementation.
  • stable.py — your stable implementation of softmax, log_sum_exp, cross_entropy.
  • compare.py — driver script that feeds a battery of adversarial inputs to both and produces a comparison table.
  • results.json — the table.
  • softmax_break.png — visualization of where naive softmax explodes (one row of NaN among otherwise valid outputs).
  • manifest.json.
  • README.md — interpretation.

No src/ module yet. Phase 2 stays in experiments/. These functions will be re-implemented in src/minigrad/numerics.py in Phase 7, when an autograd consumer exists for them.

The §A13 framing

Every test vector represents the model's logits for classifying the next-token tense among the five tenses defined in §A13:

[infinitive, present simple, past simple, past participle, simple future]

Indices 0..4. The "true label" y is the integer index of the correct tense.

TODOs

Block A — naive implementation

Write naive.py with three functions:

def naive_softmax(x):
    e = np.exp(x)
    return e / e.sum()

def naive_log_sum_exp(x):
    return np.log(np.exp(x).sum())

def naive_cross_entropy(x, y):
    return -np.log(naive_softmax(x)[y])

This is the implementation to break.

Block B — stable implementation

Write stable.py:

def stable_softmax(x):
    # TODO: apply the -max trick from theory/02
    ...

def log_sum_exp(x):
    # TODO: stable log-sum-exp
    ...

def stable_cross_entropy(x, y):
    # TODO: compute directly from logits via log_sum_exp(x) - x[y]
    ...

Each function must:

  • Handle 1D input of any length ≥ 1.
  • Handle 2D batched input (x.shape = (batch, K)) — softmax over last axis.
  • Handle -inf entries (treat as effectively zero probability post-shift).
  • Not depend on scipy; pure NumPy.

Block C — adversarial inputs

In compare.py, define and run all of the following:

test_cases = [
    ("small magnitudes",     np.array([0.1, 0.2, 0.3, 0.4, 0.5])),
    ("mixed magnitudes",     np.array([-3.0, 0.0, 1.0, 2.0, 5.0])),
    ("large positive",       np.array([1.0, 92.0, 3.0, 0.0, 2.0])),  # adversarial
    ("large negative",       np.array([-100.0, -200.0, -300.0, -400.0, -500.0])),
    ("all identical",        np.array([5.0, 5.0, 5.0, 5.0, 5.0])),
    ("masked entry",         np.array([1.0, -np.inf, 3.0, 0.0, 2.0])),
    ("single element",       np.array([42.0])),
    ("verb vocabulary",      np.zeros(600)),  # uniform over §A13 vocabulary
]

For each case, run both naive_softmax and stable_softmax. Record:

  • Whether the output contains any NaN.
  • The sum of the output (should be 1.0 for a valid distribution).
  • The max element of the output.
  • Element-wise relative difference between the two (where naive is valid).

For the cross_entropy battery, fix y = 1 (present simple) and run both versions on each test case.

Output the table as results.json and a markdown table in README.md.

Block D — predict before running

In README.md, before pasting your results.json, write your predictions for each test case:

test case predict naive NaN? predict stable NaN? predict CE?
small magnitudes No No ~1.50 (compute by hand)
mixed magnitudes No No ...
large positive YES (NaN) No ~91 (compute via stable: log_sum_exp - x[y])
large negative YES (NaN: 0/0) No ~0 (max is at y=1? no, max is at index 0, x[0]=-100; CE = -100 - (-100) = 0; check)
... ... ... ...

Then run, then compare. Where prediction and reality diverged, write a sentence explaining why. This is the highest-leverage learning step of the lab.

Block E — visualization

softmax_break.png: a heatmap or row-table visualization showing, for each test case, the naive and stable outputs side by side. NaN entries in red. The visual asymmetry on the "large positive" row is the headline plot of Phase 2.

Block F — gradcheck preview (optional)

Verify, at fp64, that stable_softmax(x) agrees with scipy.special.softmax(x) to within 1e-15 on all non-adversarial inputs, and that log_sum_exp(x) matches scipy.special.logsumexp(x). If scipy disagrees on the adversarial cases (it shouldn't — scipy is stable), note it.

Constraints

  • Pure NumPy. No scipy except as a reference oracle in Block F.
  • Predict first. Don't run before you've written down predictions. The whole point is to train the prediction muscle.
  • Use a fixed seed for any random test inputs (np.random.default_rng(42)). State in manifest.json.

Stop conditions

Done when:

  1. naive.py, stable.py, compare.py exist and run.
  2. results.json shows naive NaN on at least two cases and stable NaN on zero cases.
  3. README.md contains your predictions table before the results table, with explanations for any divergence.
  4. softmax_break.png is committed.
  5. You can recite, in one sentence, why the -max trick eliminates overflow without changing the mathematical result.

Pitfalls

  • -inf handling. np.exp(-np.inf) = 0 correctly. But -np.inf - (-np.inf) = nan. If your -max trick subtracts max = -inf (because all entries are -inf), you get NaN everywhere. Detect "max is -inf" and return a sentinel (uniform? NaN? document the choice).
  • Batched max. x.max() over a 2D array gives a scalar; you want x.max(axis=-1, keepdims=True). Lab is set up to catch this if you write a batched-aware stable softmax.
  • np.log(np.exp(x).sum()) for shifted inputs. After the -max shift, np.exp(x_shifted).sum() includes the term exp(0) = 1, so it's ≥ 1, so np.log(...) is ≥ 0. Then add back m for the final value. If you forget to add m, your log_sum_exp will be wrong by exactly m.
  • Cross-entropy on the masked case. If the true label y corresponds to a -inf logit (a masked position), x[y] = -inf, and log_sum_exp(x) - x[y] = +inf. That's correct: probability 0 of the truth means infinite loss. But it's a poison pill for training — the masking should never be on the true label. Note in README.md.

When to consult solutions/

After committing all six files. Solution at solutions/01-softmax-stability-ref.md (written at phase open).

Hint of last resort

If your stable softmax keeps NaN-ing on the "all -inf" edge case, the safe thing is:

if np.isneginf(m):
    return np.full_like(x, 1.0 / x.size)  # uniform fallback

Discuss this choice in README.md — it's defensible (the input has no signal, so uniform is the principled fallback) but slightly hides the input pathology.


Next lab: lab/02-summation-experiments.md.