Skip to content

English · Español

Lab 02 — Round-trip de checkpoint: guardar, recargar, forward, asegurar equivalencia byte-a-byte

Objetivo: demostrar que el formato del checkpoint basta para una reanudación byte-equivalente de las cinco máquinas de estado.

Tiempo estimado: 60–90 minutos.

Requisito previo: el lab 01 produjo un checkpoint entrenado.


Lo que produces

  • src/minitrain/checkpoint.py — guardado atómico + carga para pesos, optimizer, scheduler, iterador de datos, RNG.
  • tests/minitrain/test_checkpoint.py — tests de round-trip.
  • experiments/18-checkpoint-roundtrip/:
  • manifest.json
  • roundtrip_results.json — aserciones de equivalencia byte-a-byte por máquina de estado + mediciones de tamaño de archivo

Antecedentes que debes haber leído

  • theory/04-checkpoints-and-mlflow.md — safetensors, el manifest, weight tying, escrituras atómicas.
  • theory/00-motivation.md §"Las cinco máquinas de estado" — qué debe hacer round-trip.

TODOs

Bloque A — el writer de safetensors

def save_checkpoint(
    model, optimizer, scheduler, data_iter, train_rng,
    out_dir: Path,
    step: int, epoch: int,
    extra_metrics: dict,
) -> str:
    """Returns the checkpoint hash (sha256 of weights file)."""
  • Escribe los pesos en <out_dir>/weights.safetensors.tmp, luego haz rename atómico a weights.safetensors.
  • Detecta weight tying: si id(model.embed.weight.data) == id(model.lm_head.weight.data), guarda el tensor una sola vez (bajo embed.weight) y registra lm_head.weight: ALIAS(embed.weight) en los metadatos del header de safetensors.
  • Escribe el optimizer state (m, v, t por parámetro) en optim.safetensors.
  • Escribe el state del scheduler, del iterador de datos y del RNG en extra.json.
  • Escribe manifest.json con el volcado completo del state (seed, versions, config_hash, data_manifest_hash, git_sha, hardware, step, epoch, extra_metrics).
  • Calcula y devuelve sha256(weights.safetensors) como hash del checkpoint.

Bloque B — el loader de safetensors

def load_checkpoint(in_dir: Path) -> CheckpointState:
    """Returns a CheckpointState with model_weights, optim_state, scheduler_state,
       data_iter_state, train_rng_state, manifest."""
  • Lee manifest.json primero; valida schema_version.
  • Lee weights.safetensors. Resuelve cualquier referencia ALIAS(...) — los pesos atados apuntan al mismo objeto tensor tras la carga.
  • Lee optim.safetensors y extra.json.
  • Devuelve todo como un CheckpointState estructurado (dataclass).

Bloque C — aplicar el checkpoint a los objetos vivos

def apply_checkpoint(state: CheckpointState, model, optimizer, scheduler, data_iter, train_rng) -> None:
    """Restores every state machine in-place."""
  • model.load_state_dict(state.model_weights).
  • optimizer.load_state_dict(state.optim_state).
  • scheduler.load_state_dict(state.scheduler_state).
  • data_iter.load_state_dict(state.data_iter_state).
  • train_rng.bit_generator.state = state.train_rng_state.

Bloque D — test de round-trip

El test clave:

def test_byte_equivalent_roundtrip():
    # Live state
    model, optim, sched, data_iter, rng = build_fresh()
    train_n_steps(model, optim, sched, data_iter, rng, n=10)

    state_before = capture_full_state(model, optim, sched, data_iter, rng)
    save_checkpoint(model, optim, sched, data_iter, rng, out_dir, step=10, epoch=0, extra_metrics={})

    # Fresh objects, reload
    model2, optim2, sched2, data_iter2, rng2 = build_fresh()
    state = load_checkpoint(out_dir)
    apply_checkpoint(state, model2, optim2, sched2, data_iter2, rng2)
    state_after = capture_full_state(model2, optim2, sched2, data_iter2, rng2)

    for key in state_before:
        assert byte_equal(state_before[key], state_after[key]), f"mismatch at {key}"
  • byte_equal(a, b): para ndarrays usa np.array_equal (no np.allclose); para ints/strings usa ==; para states de RNG compara dicts.
  • Las cinco máquinas de estado deben pasar.

Bloque E — test de continuación (la comprobación más profunda)

El test de recarga real: continúa el entrenamiento, verifica que el run reanudado coincide step a step con el run que nunca se checkpointeó.

def test_resume_matches_continuous():
    # Reference: train 20 steps without checkpointing
    m_ref, o_ref, s_ref, d_ref, r_ref = build_fresh_with_seed(42)
    losses_ref = train_n_steps(m_ref, o_ref, s_ref, d_ref, r_ref, n=20, return_losses=True)

    # Test: train 10, checkpoint, fresh-load, train 10 more
    m, o, s, d, r = build_fresh_with_seed(42)
    losses_a = train_n_steps(m, o, s, d, r, n=10, return_losses=True)
    save_checkpoint(m, o, s, d, r, tmp_dir, step=10, epoch=0, extra_metrics={})
    m2, o2, s2, d2, r2 = build_fresh_with_seed(42)
    apply_checkpoint(load_checkpoint(tmp_dir), m2, o2, s2, d2, r2)
    losses_b = train_n_steps(m2, o2, s2, d2, r2, n=10, return_losses=True)

    # Concatenate and assert byte-equality
    np.testing.assert_array_equal(losses_ref, losses_a + losses_b)

Este test es el test duro. Pasar el test de igualdad de state del Bloque D es necesario pero no suficiente — el Bloque E captura máquinas de estado que son iguales en forma serializada pero que se comportan de modo distinto tras la recarga (p. ej., un RNG cuyo state se restaura a la posición 7 pero cuya siguiente extracción es desde la posición 0 porque el generator se reconstruyó con la seed en vez de con el state).

  • Si el Bloque E falla pero el Bloque D pasa, tienes un bug de restauración de máquina de estado. Causa común: reconstruir el RNG con la seed en vez de restaurar el state del bit_generator.

Bloque F — test de weight tying

def test_weight_tying_preserved():
    model = build_minigpt(tie_weights=True)
    assert id(model.embed.weight.data) == id(model.lm_head.weight.data)
    save_checkpoint(model, ..., out_dir)
    # File size sanity
    weight_file = out_dir / "weights.safetensors"
    expected_size = sum(p.numel() * 4 for p in unique_params(model))
    assert abs(weight_file.stat().st_size - expected_size) < 4096, \
        f"file size {weight_file.stat().st_size} != expected {expected_size}; tying broken?"
    # Reload and re-verify
    model2 = build_minigpt(tie_weights=True)
    apply_checkpoint(load_checkpoint(out_dir), model2, ...)
    assert id(model2.embed.weight.data) == id(model2.lm_head.weight.data)
  • El tamaño en disco debe reflejar almacenamiento atado (no duplicado).
  • Tras la recarga, ambos parámetros atados apuntan al mismo array subyacente (igualdad de id(...)).

Bloque G — test de escritura atómica

def test_atomic_write_failure_resilience():
    # Simulate an interrupted write
    with patch("os.rename", side_effect=KeyboardInterrupt):
        with pytest.raises(KeyboardInterrupt):
            save_checkpoint(model, ..., out_dir)

    # The final file should not exist
    assert not (out_dir / "weights.safetensors").exists()
    # The temp file may exist; that's fine, it's our diagnostic
    # On next save attempt, the temp must be cleaned up.
    save_checkpoint(model, ..., out_dir)  # this time os.rename works
    assert (out_dir / "weights.safetensors").exists()
    assert not (out_dir / "weights.safetensors.tmp").exists()
  • Save() detecta *.tmp residual de un save interrumpido previo y lo elimina.
  • Save() escribe a *.tmp y luego hace rename atómico — nunca deja un weights.safetensors a medio escribir.

Bloque H — registra los resultados

experiments/18-checkpoint-roundtrip/roundtrip_results.json:

{
  "byte_equivalent_state_machines": ["model", "optimizer", "scheduler", "data_iter", "train_rng"],
  "all_byte_equal": true,
  "continuation_test_passed": true,
  "weight_tying_preserved": true,
  "atomic_write_resilient": true,
  "weights_file_size_bytes": 612408,
  "expected_size_bytes_no_tying": 612408,
  "size_delta_pct": 0.0,
  "manifest_schema_version": "1.0"
}

Restricciones

  • np.array_equal, no np.allclose. Idéntico bit a bit, sin excepciones.
  • Sin pickle. Usa safetensors para ndarrays, JSON para todo lo demás.
  • Todos los saves son atómicos vía *.tmp + os.rename.

Condiciones de parada

Hecho cuando:

  1. Los seis tests pasan.
  2. roundtrip_results.json muestra all_byte_equal: true y continuation_test_passed: true.
  3. El tamaño del archivo de pesos coincide con la suma de parámetros únicos (es decir, weight tying preservado).
  4. Puedes describir la diferencia entre el Bloque D (igualdad de state) y el Bloque E (coincidencia en continuación) y por qué hacen falta ambos.

Escollos

  • np.allclose en vez de np.array_equal. Allclose oculta un error de redondeo de 1-ULP; vas a pensar que tienes recarga byte-equivalente cuando no la tienes.
  • Restaurar el RNG re-sembrando. rng.bit_generator.state = saved_state es la restauración correcta; rng = np.random.default_rng(seed) reinicia el RNG a la posición 0.
  • Serializar a JSON el state del RNG directamente. El state de np.random.Generator contiene ints de numpy que no pasan por json.dumps. Conviértelos vía el dict state y base64 los campos relevantes.
  • Olvidar guardar el step counter t del optimizer. Al recargar, \(\hat m / (1 - \beta_1^t)\) usa el t equivocado; la corrección de sesgo está mal; la actualización del siguiente step está mal.

Cuándo consultar solutions/

Después de que todos los tests pasen. Solución en solutions/02-checkpoint-roundtrip-ref.md (escrita al abrir la fase).


Siguiente lab: lab/03-mp-drift.md.