English · Español
Lab 00 — Forward / backward hooks; overhead budget¶
Goal: instrument the model with non-intrusive forward and backward hooks that capture per-layer streaming statistics. Verify overhead ≤ 30%.
Estimated time: 90-120 minutes.
Prereq: Phase 18 training loop committed and reproducible.
What you produce¶
A new file:
src/minitrain/inspect.py— the hook registry, theInspectorclass, and the streaming-statistics helpers.
A second helper:
src/minitrain/per_class_loss.py— the regular-vs-irregular partition logic used by Panel 7.
A new test:
tests/minitrain/test_inspect.py.
An overhead-measurement note:
experiments/19-overhead/results.md— short note (≤ 1 page) recording the measured overhead.
TODOs¶
Block A — design the hook registry¶
A "hook" is a function f(module, inputs, outputs) -> None invoked at a specific point. We need:
- Forward hooks on each
Module: called after the module's__call__returns, with(module, args, kwargs, output). - Backward hooks on each
Parameter: called after the parameter's grad is filled, with(param, grad).
The registry approach (preferred per BLUEPRINT.md revision):
# src/minitrain/inspect.py
class HookHandle:
def __init__(self, target, hook_fn, kind):
self.target = target
self.hook_fn = hook_fn
self.kind = kind # 'forward' | 'backward'
def remove(self) -> None: ...
class Inspector:
def __init__(self, model, params):
self.handles: list[HookHandle] = []
self.stats: dict[str, dict] = {} # name -> streaming stats
def register_forward(self, name: str, module) -> HookHandle: ...
def register_backward(self, name: str, param) -> HookHandle: ...
def snapshot(self) -> dict: ... # current streaming stats
def reset(self) -> None: ...
def remove_all(self) -> None: ...
The hook for a forward call computes Welford-updates to the module's running mean/std/max/L2-norm. For a backward call, the same but on the gradient.
Block B — implement Welford streaming¶
In src/minitrain/inspect.py, helper:
def welford_update(state: dict, x: ndarray) -> None:
"""Update streaming stats with a new sample tensor x (flattened mean treatment)."""
flat = x.ravel()
n_old = state.get('n', 0)
n_new = n_old + flat.size
if n_old == 0:
state['mean'] = flat.mean()
state['m2'] = ((flat - state['mean'])**2).sum()
else:
delta = flat - state['mean']
state['mean'] = state['mean'] + delta.sum() / n_new
delta2 = flat - state['mean']
state['m2'] = state['m2'] + (delta * delta2).sum()
state['n'] = n_new
state['max'] = max(state.get('max', -np.inf), flat.max())
state['min'] = min(state.get('min', np.inf), flat.min())
state['l2'] = float(np.linalg.norm(flat))
Block C — verify overhead¶
Write experiments/19-overhead/measure.py:
- Run Phase-18 training for 200 steps without hooks. Record wall-clock per step.
- Run Phase-18 training for 200 steps with the Inspector attached (capturing all six panels' stats). Record wall-clock per step.
- Overhead =
(t_hooked - t_baseline) / t_baseline. - Save to
results.md.
If overhead > 30%: - Reduce statistics frequency (compute spectral norm every 10 logging steps instead of every step). - Move Welford computations out of Python (use NumPy vectorized). - Drop the spectral panel temporarily and re-measure.
If overhead > 50% even after optimization, the design is wrong. Consult the solutions hint at phase open.
Block D — per-class loss helper (Panel 7)¶
src/minitrain/per_class_loss.py partitions a batch into regular-verb and irregular-verb examples and returns the two means:
REGULAR_VERBS = frozenset({"work", "play", "walk", "talk", "listen",
"watch", "study", "finish", "start",
"look", "want", "like"})
IRREGULAR_VERBS = frozenset({"be", "have", "do", "go", "come", "see",
"eat", "write"})
def partition_batch_loss(per_example_loss: np.ndarray,
verb_labels: list[str]) -> tuple[float, float]:
"""Return (mean_loss_regular, mean_loss_irregular).
If a class has zero examples in the batch, return np.nan for that class
(the dashboard should skip the update, not log a zero).
"""
...
The verb label of each example is the lemma — derived once per example by the data iterator (Phase 18's iterator already exposes it as an example-level metadata field; if it doesn't, add it before continuing).
Block E — four correctness tests¶
In tests/minitrain/test_inspect.py:
test_welford_matches_numpy— feed Welford 1000 random samples; assertmeanandm2/(n-1)matchnp.meanandnp.var(ddof=1)to 1e-10.test_hook_handle_removes_cleanly— register a forward hook, call the model, assert stats updated. Callhandle.remove(), call the model again, assert stats unchanged.test_snapshot_serializable— take a snapshot from an Inspector with three modules registered; assert the result isjson.dumps-able (no numpy scalars leak through; convert to Python floats).test_partition_batch_loss— feed a synthetic batch with mixed regular and irregular labels; assert the two returned means match a hand-computed reference; assert that a batch with all-regulars returnsnanfor the irregular mean.
Constraints¶
- Pure NumPy. No PyTorch hooks (Phase 24 introduces those; ours mirror the pattern).
- No global state. All hook state lives on
Inspectorinstances. Two Inspectors on the same model produce independent stats. - Overhead budget: 30%. Non-negotiable.
Stop conditions¶
Done when:
pytest tests/minitrain/test_inspect.py -vpasses all four tests.experiments/19-overhead/results.mdrecords overhead ≤ 30%.- The
Inspectorcan be enabled or disabled with one config flag inexperiments/19-healthy/train.py. partition_batch_lossis imported by the dashboard renderer (Lab 01) without further wiring needed.
Pitfalls¶
- NumPy returns 0-d arrays for
.max()on a scalar. Cast to Pythonfloatwhen storing in the snapshot dict, orjson.dumpswill fail. - Streaming
m2overflows in fp32 for huge tensors. For Phase-18 sizes this isn't a concern, but if you scale up later, use fp64 accumulator. - Forward hook firing twice. If your module wraps
__call__andforward, you might double-register. Lay out the call chain ininspect.py's docstring.
When to consult solutions/¶
After overhead ≤ 30% and tests pass. The solution at solutions/00-instrument-hooks-ref.md (written at phase open) discusses the spectral-norm caching trick that often makes the difference.
Next lab: lab/01-build-dashboard.md.