English · Español
Lab 00 — Hooks de forward / backward; presupuesto de overhead¶
Objetivo: instrumentar el modelo con hooks de forward y backward no intrusivos que capturen estadísticas en streaming por capa. Verificar overhead ≤ 30%.
Tiempo estimado: 90-120 minutos.
Prerrequisito: bucle de entrenamiento de la Fase 18 commiteado y reproducible.
Lo que produces¶
Un fichero nuevo:
src/minitrain/inspect.py— el registro de hooks, la claseInspectory los helpers de estadísticas en streaming.
Un segundo helper:
src/minitrain/per_class_loss.py— la lógica de partición regulares-vs-irregulares usada por el Panel 7.
Un test nuevo:
tests/minitrain/test_inspect.py.
Una nota de medición de overhead:
experiments/19-overhead/results.md— nota corta (≤ 1 página) que registra el overhead medido.
TODOs¶
Bloque A — diseñar el registro de hooks¶
Un "hook" es una función f(module, inputs, outputs) -> None invocada en un punto concreto. Necesitamos:
- Forward hooks sobre cada
Module: invocados después de que retorne el__call__del módulo, con(module, args, kwargs, output). - Backward hooks sobre cada
Parameter: invocados después de que se rellene el grad del parámetro, con(param, grad).
El enfoque de registro (preferido según la revisión del BLUEPRINT.md):
# src/minitrain/inspect.py
class HookHandle:
def __init__(self, target, hook_fn, kind):
self.target = target
self.hook_fn = hook_fn
self.kind = kind # 'forward' | 'backward'
def remove(self) -> None: ...
class Inspector:
def __init__(self, model, params):
self.handles: list[HookHandle] = []
self.stats: dict[str, dict] = {} # name -> streaming stats
def register_forward(self, name: str, module) -> HookHandle: ...
def register_backward(self, name: str, param) -> HookHandle: ...
def snapshot(self) -> dict: ... # current streaming stats
def reset(self) -> None: ...
def remove_all(self) -> None: ...
El hook para una llamada de forward calcula updates de Welford sobre la media/std/max/norma-L2 corriente del módulo. Para una llamada de backward, lo mismo pero sobre el gradiente.
Bloque B — implementar el streaming de Welford¶
En src/minitrain/inspect.py, helper:
def welford_update(state: dict, x: ndarray) -> None:
"""Update streaming stats with a new sample tensor x (flattened mean treatment)."""
flat = x.ravel()
n_old = state.get('n', 0)
n_new = n_old + flat.size
if n_old == 0:
state['mean'] = flat.mean()
state['m2'] = ((flat - state['mean'])**2).sum()
else:
delta = flat - state['mean']
state['mean'] = state['mean'] + delta.sum() / n_new
delta2 = flat - state['mean']
state['m2'] = state['m2'] + (delta * delta2).sum()
state['n'] = n_new
state['max'] = max(state.get('max', -np.inf), flat.max())
state['min'] = min(state.get('min', np.inf), flat.min())
state['l2'] = float(np.linalg.norm(flat))
Bloque C — verificar el overhead¶
Escribe experiments/19-overhead/measure.py:
- Ejecuta el entrenamiento de la Fase 18 durante 200 pasos sin hooks. Registra wall-clock por paso.
- Ejecuta el entrenamiento de la Fase 18 durante 200 pasos con el Inspector adjunto (capturando las estadísticas de los seis paneles). Registra wall-clock por paso.
- Overhead =
(t_hooked - t_baseline) / t_baseline. - Guarda en
results.md.
Si el overhead > 30%: - Reduce la frecuencia de estadísticas (calcula la norma espectral cada 10 pasos de logging en lugar de cada paso). - Saca los cálculos de Welford de Python (usa NumPy vectorizado). - Descarta temporalmente el panel espectral y vuelve a medir.
Si el overhead > 50% incluso tras la optimización, el diseño está mal. Consulta la pista en solutions/ al abrir la fase.
Bloque D — helper de loss por clase (Panel 7)¶
src/minitrain/per_class_loss.py particiona un batch en ejemplos de verbo regular e irregular y devuelve las dos medias:
REGULAR_VERBS = frozenset({"work", "play", "walk", "talk", "listen",
"watch", "study", "finish", "start",
"look", "want", "like"})
IRREGULAR_VERBS = frozenset({"be", "have", "do", "go", "come", "see",
"eat", "write"})
def partition_batch_loss(per_example_loss: np.ndarray,
verb_labels: list[str]) -> tuple[float, float]:
"""Return (mean_loss_regular, mean_loss_irregular).
If a class has zero examples in the batch, return np.nan for that class
(the dashboard should skip the update, not log a zero).
"""
...
La etiqueta de verbo de cada ejemplo es el lema — derivado una vez por ejemplo por el iterador de datos (el iterador de la Fase 18 ya lo expone como campo de metadatos a nivel de ejemplo; si no lo hace, añádelo antes de seguir).
Bloque E — cuatro tests de corrección¶
En tests/minitrain/test_inspect.py:
test_welford_matches_numpy— alimenta a Welford con 1000 muestras aleatorias; comprueba quemeanym2/(n-1)coinciden connp.meanynp.var(ddof=1)a 1e-10.test_hook_handle_removes_cleanly— registra un forward hook, llama al modelo, comprueba que las stats se actualizan. Llama ahandle.remove(), vuelve a llamar al modelo, comprueba que las stats no cambian.test_snapshot_serializable— toma un snapshot de un Inspector con tres módulos registrados; comprueba que el resultado esjson.dumps-eable (no se cuelan scalars de numpy; convierte a floats de Python).test_partition_batch_loss— alimenta un batch sintético con etiquetas mixtas de regulares e irregulares; comprueba que las dos medias devueltas coinciden con una referencia calculada a mano; comprueba que un batch sólo con regulares devuelvenanpara la media de irregulares.
Restricciones¶
- NumPy puro. Sin hooks de PyTorch (la Fase 24 los introduce; los nuestros replican el patrón).
- Sin estado global. Todo el estado de hooks vive en instancias de
Inspector. Dos Inspectors sobre el mismo modelo producen stats independientes. - Presupuesto de overhead: 30%. No negociable.
Condiciones de parada¶
Hecho cuando:
pytest tests/minitrain/test_inspect.py -vpasa los cuatro tests.experiments/19-overhead/results.mdregistra overhead ≤ 30%.- El
Inspectorse puede activar o desactivar con un único flag de config enexperiments/19-healthy/train.py. partition_batch_losses importado por el renderer del dashboard (Lab 01) sin cableado adicional.
Trampas¶
- NumPy devuelve arrays 0-d para
.max()sobre un scalar. Castea afloatde Python al guardar en el dict del snapshot, ojson.dumpsfallará. - El
m2en streaming desborda en fp32 para tensores enormes. Para los tamaños de la Fase 18 no es un problema, pero si escalas más adelante, usa un acumulador en fp64. - Forward hook disparándose dos veces. Si tu módulo envuelve
__call__yforward, podrías doble-registrar. Documenta la cadena de llamadas en el docstring deinspect.py.
Cuándo consultar solutions/¶
Tras overhead ≤ 30% y tests pasando. La solución en solutions/00-instrument-hooks-ref.md (escrita al abrir la fase) discute el truco de cachear la norma espectral que suele marcar la diferencia.
Siguiente lab: lab/01-build-dashboard.md.