Skip to content

English · Español

Lab 00 — Hooks de forward / backward; presupuesto de overhead

Objetivo: instrumentar el modelo con hooks de forward y backward no intrusivos que capturen estadísticas en streaming por capa. Verificar overhead ≤ 30%.

Tiempo estimado: 90-120 minutos.

Prerrequisito: bucle de entrenamiento de la Fase 18 commiteado y reproducible.


Lo que produces

Un fichero nuevo:

  • src/minitrain/inspect.py — el registro de hooks, la clase Inspector y los helpers de estadísticas en streaming.

Un segundo helper:

  • src/minitrain/per_class_loss.py — la lógica de partición regulares-vs-irregulares usada por el Panel 7.

Un test nuevo:

  • tests/minitrain/test_inspect.py.

Una nota de medición de overhead:

  • experiments/19-overhead/results.md — nota corta (≤ 1 página) que registra el overhead medido.

TODOs

Bloque A — diseñar el registro de hooks

Un "hook" es una función f(module, inputs, outputs) -> None invocada en un punto concreto. Necesitamos:

  • Forward hooks sobre cada Module: invocados después de que retorne el __call__ del módulo, con (module, args, kwargs, output).
  • Backward hooks sobre cada Parameter: invocados después de que se rellene el grad del parámetro, con (param, grad).

El enfoque de registro (preferido según la revisión del BLUEPRINT.md):

# src/minitrain/inspect.py
class HookHandle:
    def __init__(self, target, hook_fn, kind):
        self.target = target
        self.hook_fn = hook_fn
        self.kind = kind  # 'forward' | 'backward'
    def remove(self) -> None: ...

class Inspector:
    def __init__(self, model, params):
        self.handles: list[HookHandle] = []
        self.stats: dict[str, dict] = {}  # name -> streaming stats

    def register_forward(self, name: str, module) -> HookHandle: ...
    def register_backward(self, name: str, param) -> HookHandle: ...
    def snapshot(self) -> dict: ...   # current streaming stats
    def reset(self) -> None: ...
    def remove_all(self) -> None: ...

El hook para una llamada de forward calcula updates de Welford sobre la media/std/max/norma-L2 corriente del módulo. Para una llamada de backward, lo mismo pero sobre el gradiente.

Bloque B — implementar el streaming de Welford

En src/minitrain/inspect.py, helper:

def welford_update(state: dict, x: ndarray) -> None:
    """Update streaming stats with a new sample tensor x (flattened mean treatment)."""
    flat = x.ravel()
    n_old = state.get('n', 0)
    n_new = n_old + flat.size
    if n_old == 0:
        state['mean'] = flat.mean()
        state['m2'] = ((flat - state['mean'])**2).sum()
    else:
        delta = flat - state['mean']
        state['mean'] = state['mean'] + delta.sum() / n_new
        delta2 = flat - state['mean']
        state['m2'] = state['m2'] + (delta * delta2).sum()
    state['n'] = n_new
    state['max'] = max(state.get('max', -np.inf), flat.max())
    state['min'] = min(state.get('min', np.inf), flat.min())
    state['l2'] = float(np.linalg.norm(flat))

Bloque C — verificar el overhead

Escribe experiments/19-overhead/measure.py:

  1. Ejecuta el entrenamiento de la Fase 18 durante 200 pasos sin hooks. Registra wall-clock por paso.
  2. Ejecuta el entrenamiento de la Fase 18 durante 200 pasos con el Inspector adjunto (capturando las estadísticas de los seis paneles). Registra wall-clock por paso.
  3. Overhead = (t_hooked - t_baseline) / t_baseline.
  4. Guarda en results.md.

Si el overhead > 30%: - Reduce la frecuencia de estadísticas (calcula la norma espectral cada 10 pasos de logging en lugar de cada paso). - Saca los cálculos de Welford de Python (usa NumPy vectorizado). - Descarta temporalmente el panel espectral y vuelve a medir.

Si el overhead > 50% incluso tras la optimización, el diseño está mal. Consulta la pista en solutions/ al abrir la fase.

Bloque D — helper de loss por clase (Panel 7)

src/minitrain/per_class_loss.py particiona un batch en ejemplos de verbo regular e irregular y devuelve las dos medias:

REGULAR_VERBS = frozenset({"work", "play", "walk", "talk", "listen",
                           "watch", "study", "finish", "start",
                           "look", "want", "like"})
IRREGULAR_VERBS = frozenset({"be", "have", "do", "go", "come", "see",
                             "eat", "write"})

def partition_batch_loss(per_example_loss: np.ndarray,
                        verb_labels: list[str]) -> tuple[float, float]:
    """Return (mean_loss_regular, mean_loss_irregular).

    If a class has zero examples in the batch, return np.nan for that class
    (the dashboard should skip the update, not log a zero).
    """
    ...

La etiqueta de verbo de cada ejemplo es el lema — derivado una vez por ejemplo por el iterador de datos (el iterador de la Fase 18 ya lo expone como campo de metadatos a nivel de ejemplo; si no lo hace, añádelo antes de seguir).

Bloque E — cuatro tests de corrección

En tests/minitrain/test_inspect.py:

  1. test_welford_matches_numpy — alimenta a Welford con 1000 muestras aleatorias; comprueba que mean y m2/(n-1) coinciden con np.mean y np.var(ddof=1) a 1e-10.
  2. test_hook_handle_removes_cleanly — registra un forward hook, llama al modelo, comprueba que las stats se actualizan. Llama a handle.remove(), vuelve a llamar al modelo, comprueba que las stats no cambian.
  3. test_snapshot_serializable — toma un snapshot de un Inspector con tres módulos registrados; comprueba que el resultado es json.dumps-eable (no se cuelan scalars de numpy; convierte a floats de Python).
  4. test_partition_batch_loss — alimenta un batch sintético con etiquetas mixtas de regulares e irregulares; comprueba que las dos medias devueltas coinciden con una referencia calculada a mano; comprueba que un batch sólo con regulares devuelve nan para la media de irregulares.

Restricciones

  • NumPy puro. Sin hooks de PyTorch (la Fase 24 los introduce; los nuestros replican el patrón).
  • Sin estado global. Todo el estado de hooks vive en instancias de Inspector. Dos Inspectors sobre el mismo modelo producen stats independientes.
  • Presupuesto de overhead: 30%. No negociable.

Condiciones de parada

Hecho cuando:

  1. pytest tests/minitrain/test_inspect.py -v pasa los cuatro tests.
  2. experiments/19-overhead/results.md registra overhead ≤ 30%.
  3. El Inspector se puede activar o desactivar con un único flag de config en experiments/19-healthy/train.py.
  4. partition_batch_loss es importado por el renderer del dashboard (Lab 01) sin cableado adicional.

Trampas

  • NumPy devuelve arrays 0-d para .max() sobre un scalar. Castea a float de Python al guardar en el dict del snapshot, o json.dumps fallará.
  • El m2 en streaming desborda en fp32 para tensores enormes. Para los tamaños de la Fase 18 no es un problema, pero si escalas más adelante, usa un acumulador en fp64.
  • Forward hook disparándose dos veces. Si tu módulo envuelve __call__ y forward, podrías doble-registrar. Documenta la cadena de llamadas en el docstring de inspect.py.

Cuándo consultar solutions/

Tras overhead ≤ 30% y tests pasando. La solución en solutions/00-instrument-hooks-ref.md (escrita al abrir la fase) discute el truco de cachear la norma espectral que suele marcar la diferencia.


Siguiente lab: lab/01-build-dashboard.md.