Skip to content

English · Español

Lab 00 — Forward / backward hooks; overhead budget

Goal: instrument the model with non-intrusive forward and backward hooks that capture per-layer streaming statistics. Verify overhead ≤ 30%.

Estimated time: 90-120 minutes.

Prereq: Phase 18 training loop committed and reproducible.


What you produce

A new file:

  • src/minitrain/inspect.py — the hook registry, the Inspector class, and the streaming-statistics helpers.

A second helper:

  • src/minitrain/per_class_loss.py — the regular-vs-irregular partition logic used by Panel 7.

A new test:

  • tests/minitrain/test_inspect.py.

An overhead-measurement note:

  • experiments/19-overhead/results.md — short note (≤ 1 page) recording the measured overhead.

TODOs

Block A — design the hook registry

A "hook" is a function f(module, inputs, outputs) -> None invoked at a specific point. We need:

  • Forward hooks on each Module: called after the module's __call__ returns, with (module, args, kwargs, output).
  • Backward hooks on each Parameter: called after the parameter's grad is filled, with (param, grad).

The registry approach (preferred per BLUEPRINT.md revision):

# src/minitrain/inspect.py
class HookHandle:
    def __init__(self, target, hook_fn, kind):
        self.target = target
        self.hook_fn = hook_fn
        self.kind = kind  # 'forward' | 'backward'
    def remove(self) -> None: ...

class Inspector:
    def __init__(self, model, params):
        self.handles: list[HookHandle] = []
        self.stats: dict[str, dict] = {}  # name -> streaming stats

    def register_forward(self, name: str, module) -> HookHandle: ...
    def register_backward(self, name: str, param) -> HookHandle: ...
    def snapshot(self) -> dict: ...   # current streaming stats
    def reset(self) -> None: ...
    def remove_all(self) -> None: ...

The hook for a forward call computes Welford-updates to the module's running mean/std/max/L2-norm. For a backward call, the same but on the gradient.

Block B — implement Welford streaming

In src/minitrain/inspect.py, helper:

def welford_update(state: dict, x: ndarray) -> None:
    """Update streaming stats with a new sample tensor x (flattened mean treatment)."""
    flat = x.ravel()
    n_old = state.get('n', 0)
    n_new = n_old + flat.size
    if n_old == 0:
        state['mean'] = flat.mean()
        state['m2'] = ((flat - state['mean'])**2).sum()
    else:
        delta = flat - state['mean']
        state['mean'] = state['mean'] + delta.sum() / n_new
        delta2 = flat - state['mean']
        state['m2'] = state['m2'] + (delta * delta2).sum()
    state['n'] = n_new
    state['max'] = max(state.get('max', -np.inf), flat.max())
    state['min'] = min(state.get('min', np.inf), flat.min())
    state['l2'] = float(np.linalg.norm(flat))

Block C — verify overhead

Write experiments/19-overhead/measure.py:

  1. Run Phase-18 training for 200 steps without hooks. Record wall-clock per step.
  2. Run Phase-18 training for 200 steps with the Inspector attached (capturing all six panels' stats). Record wall-clock per step.
  3. Overhead = (t_hooked - t_baseline) / t_baseline.
  4. Save to results.md.

If overhead > 30%: - Reduce statistics frequency (compute spectral norm every 10 logging steps instead of every step). - Move Welford computations out of Python (use NumPy vectorized). - Drop the spectral panel temporarily and re-measure.

If overhead > 50% even after optimization, the design is wrong. Consult the solutions hint at phase open.

Block D — per-class loss helper (Panel 7)

src/minitrain/per_class_loss.py partitions a batch into regular-verb and irregular-verb examples and returns the two means:

REGULAR_VERBS = frozenset({"work", "play", "walk", "talk", "listen",
                           "watch", "study", "finish", "start",
                           "look", "want", "like"})
IRREGULAR_VERBS = frozenset({"be", "have", "do", "go", "come", "see",
                             "eat", "write"})

def partition_batch_loss(per_example_loss: np.ndarray,
                        verb_labels: list[str]) -> tuple[float, float]:
    """Return (mean_loss_regular, mean_loss_irregular).

    If a class has zero examples in the batch, return np.nan for that class
    (the dashboard should skip the update, not log a zero).
    """
    ...

The verb label of each example is the lemma — derived once per example by the data iterator (Phase 18's iterator already exposes it as an example-level metadata field; if it doesn't, add it before continuing).

Block E — four correctness tests

In tests/minitrain/test_inspect.py:

  1. test_welford_matches_numpy — feed Welford 1000 random samples; assert mean and m2/(n-1) match np.mean and np.var(ddof=1) to 1e-10.
  2. test_hook_handle_removes_cleanly — register a forward hook, call the model, assert stats updated. Call handle.remove(), call the model again, assert stats unchanged.
  3. test_snapshot_serializable — take a snapshot from an Inspector with three modules registered; assert the result is json.dumps-able (no numpy scalars leak through; convert to Python floats).
  4. test_partition_batch_loss — feed a synthetic batch with mixed regular and irregular labels; assert the two returned means match a hand-computed reference; assert that a batch with all-regulars returns nan for the irregular mean.

Constraints

  • Pure NumPy. No PyTorch hooks (Phase 24 introduces those; ours mirror the pattern).
  • No global state. All hook state lives on Inspector instances. Two Inspectors on the same model produce independent stats.
  • Overhead budget: 30%. Non-negotiable.

Stop conditions

Done when:

  1. pytest tests/minitrain/test_inspect.py -v passes all four tests.
  2. experiments/19-overhead/results.md records overhead ≤ 30%.
  3. The Inspector can be enabled or disabled with one config flag in experiments/19-healthy/train.py.
  4. partition_batch_loss is imported by the dashboard renderer (Lab 01) without further wiring needed.

Pitfalls

  • NumPy returns 0-d arrays for .max() on a scalar. Cast to Python float when storing in the snapshot dict, or json.dumps will fail.
  • Streaming m2 overflows in fp32 for huge tensors. For Phase-18 sizes this isn't a concern, but if you scale up later, use fp64 accumulator.
  • Forward hook firing twice. If your module wraps __call__ and forward, you might double-register. Lay out the call chain in inspect.py's docstring.

When to consult solutions/

After overhead ≤ 30% and tests pass. The solution at solutions/00-instrument-hooks-ref.md (written at phase open) discusses the spectral-norm caching trick that often makes the difference.


Next lab: lab/01-build-dashboard.md.