English · Español
Lab 02 — Round-trip de checkpoint: guardar, recargar, forward, asegurar equivalencia byte-a-byte¶
Objetivo: demostrar que el formato del checkpoint basta para una reanudación byte-equivalente de las cinco máquinas de estado.
Tiempo estimado: 60–90 minutos.
Requisito previo: el lab 01 produjo un checkpoint entrenado.
Lo que produces¶
src/minitrain/checkpoint.py— guardado atómico + carga para pesos, optimizer, scheduler, iterador de datos, RNG.tests/minitrain/test_checkpoint.py— tests de round-trip.experiments/18-checkpoint-roundtrip/:manifest.jsonroundtrip_results.json— aserciones de equivalencia byte-a-byte por máquina de estado + mediciones de tamaño de archivo
Antecedentes que debes haber leído¶
theory/04-checkpoints-and-mlflow.md— safetensors, el manifest, weight tying, escrituras atómicas.theory/00-motivation.md§"Las cinco máquinas de estado" — qué debe hacer round-trip.
TODOs¶
Bloque A — el writer de safetensors¶
def save_checkpoint(
model, optimizer, scheduler, data_iter, train_rng,
out_dir: Path,
step: int, epoch: int,
extra_metrics: dict,
) -> str:
"""Returns the checkpoint hash (sha256 of weights file)."""
- Escribe los pesos en
<out_dir>/weights.safetensors.tmp, luego haz rename atómico aweights.safetensors. - Detecta weight tying: si
id(model.embed.weight.data) == id(model.lm_head.weight.data), guarda el tensor una sola vez (bajoembed.weight) y registralm_head.weight: ALIAS(embed.weight)en los metadatos del header de safetensors. - Escribe el optimizer state (m, v, t por parámetro) en
optim.safetensors. - Escribe el state del scheduler, del iterador de datos y del RNG en
extra.json. - Escribe
manifest.jsoncon el volcado completo del state (seed, versions, config_hash, data_manifest_hash, git_sha, hardware, step, epoch, extra_metrics). - Calcula y devuelve
sha256(weights.safetensors)como hash del checkpoint.
Bloque B — el loader de safetensors¶
def load_checkpoint(in_dir: Path) -> CheckpointState:
"""Returns a CheckpointState with model_weights, optim_state, scheduler_state,
data_iter_state, train_rng_state, manifest."""
- Lee
manifest.jsonprimero; validaschema_version. - Lee
weights.safetensors. Resuelve cualquier referenciaALIAS(...)— los pesos atados apuntan al mismo objeto tensor tras la carga. - Lee
optim.safetensorsyextra.json. - Devuelve todo como un
CheckpointStateestructurado (dataclass).
Bloque C — aplicar el checkpoint a los objetos vivos¶
def apply_checkpoint(state: CheckpointState, model, optimizer, scheduler, data_iter, train_rng) -> None:
"""Restores every state machine in-place."""
-
model.load_state_dict(state.model_weights). -
optimizer.load_state_dict(state.optim_state). -
scheduler.load_state_dict(state.scheduler_state). -
data_iter.load_state_dict(state.data_iter_state). -
train_rng.bit_generator.state = state.train_rng_state.
Bloque D — test de round-trip¶
El test clave:
def test_byte_equivalent_roundtrip():
# Live state
model, optim, sched, data_iter, rng = build_fresh()
train_n_steps(model, optim, sched, data_iter, rng, n=10)
state_before = capture_full_state(model, optim, sched, data_iter, rng)
save_checkpoint(model, optim, sched, data_iter, rng, out_dir, step=10, epoch=0, extra_metrics={})
# Fresh objects, reload
model2, optim2, sched2, data_iter2, rng2 = build_fresh()
state = load_checkpoint(out_dir)
apply_checkpoint(state, model2, optim2, sched2, data_iter2, rng2)
state_after = capture_full_state(model2, optim2, sched2, data_iter2, rng2)
for key in state_before:
assert byte_equal(state_before[key], state_after[key]), f"mismatch at {key}"
-
byte_equal(a, b): para ndarrays usanp.array_equal(nonp.allclose); para ints/strings usa==; para states de RNG compara dicts. - Las cinco máquinas de estado deben pasar.
Bloque E — test de continuación (la comprobación más profunda)¶
El test de recarga real: continúa el entrenamiento, verifica que el run reanudado coincide step a step con el run que nunca se checkpointeó.
def test_resume_matches_continuous():
# Reference: train 20 steps without checkpointing
m_ref, o_ref, s_ref, d_ref, r_ref = build_fresh_with_seed(42)
losses_ref = train_n_steps(m_ref, o_ref, s_ref, d_ref, r_ref, n=20, return_losses=True)
# Test: train 10, checkpoint, fresh-load, train 10 more
m, o, s, d, r = build_fresh_with_seed(42)
losses_a = train_n_steps(m, o, s, d, r, n=10, return_losses=True)
save_checkpoint(m, o, s, d, r, tmp_dir, step=10, epoch=0, extra_metrics={})
m2, o2, s2, d2, r2 = build_fresh_with_seed(42)
apply_checkpoint(load_checkpoint(tmp_dir), m2, o2, s2, d2, r2)
losses_b = train_n_steps(m2, o2, s2, d2, r2, n=10, return_losses=True)
# Concatenate and assert byte-equality
np.testing.assert_array_equal(losses_ref, losses_a + losses_b)
Este test es el test duro. Pasar el test de igualdad de state del Bloque D es necesario pero no suficiente — el Bloque E captura máquinas de estado que son iguales en forma serializada pero que se comportan de modo distinto tras la recarga (p. ej., un RNG cuyo state se restaura a la posición 7 pero cuya siguiente extracción es desde la posición 0 porque el generator se reconstruyó con la seed en vez de con el state).
- Si el Bloque E falla pero el Bloque D pasa, tienes un bug de restauración de máquina de estado. Causa común: reconstruir el RNG con la seed en vez de restaurar el state del bit_generator.
Bloque F — test de weight tying¶
def test_weight_tying_preserved():
model = build_minigpt(tie_weights=True)
assert id(model.embed.weight.data) == id(model.lm_head.weight.data)
save_checkpoint(model, ..., out_dir)
# File size sanity
weight_file = out_dir / "weights.safetensors"
expected_size = sum(p.numel() * 4 for p in unique_params(model))
assert abs(weight_file.stat().st_size - expected_size) < 4096, \
f"file size {weight_file.stat().st_size} != expected {expected_size}; tying broken?"
# Reload and re-verify
model2 = build_minigpt(tie_weights=True)
apply_checkpoint(load_checkpoint(out_dir), model2, ...)
assert id(model2.embed.weight.data) == id(model2.lm_head.weight.data)
- El tamaño en disco debe reflejar almacenamiento atado (no duplicado).
- Tras la recarga, ambos parámetros atados apuntan al mismo array subyacente (igualdad de
id(...)).
Bloque G — test de escritura atómica¶
def test_atomic_write_failure_resilience():
# Simulate an interrupted write
with patch("os.rename", side_effect=KeyboardInterrupt):
with pytest.raises(KeyboardInterrupt):
save_checkpoint(model, ..., out_dir)
# The final file should not exist
assert not (out_dir / "weights.safetensors").exists()
# The temp file may exist; that's fine, it's our diagnostic
# On next save attempt, the temp must be cleaned up.
save_checkpoint(model, ..., out_dir) # this time os.rename works
assert (out_dir / "weights.safetensors").exists()
assert not (out_dir / "weights.safetensors.tmp").exists()
- Save() detecta
*.tmpresidual de un save interrumpido previo y lo elimina. - Save() escribe a
*.tmpy luego hace rename atómico — nunca deja unweights.safetensorsa medio escribir.
Bloque H — registra los resultados¶
experiments/18-checkpoint-roundtrip/roundtrip_results.json:
{
"byte_equivalent_state_machines": ["model", "optimizer", "scheduler", "data_iter", "train_rng"],
"all_byte_equal": true,
"continuation_test_passed": true,
"weight_tying_preserved": true,
"atomic_write_resilient": true,
"weights_file_size_bytes": 612408,
"expected_size_bytes_no_tying": 612408,
"size_delta_pct": 0.0,
"manifest_schema_version": "1.0"
}
Restricciones¶
np.array_equal, nonp.allclose. Idéntico bit a bit, sin excepciones.- Sin pickle. Usa
safetensorspara ndarrays, JSON para todo lo demás. - Todos los saves son atómicos vía
*.tmp+os.rename.
Condiciones de parada¶
Hecho cuando:
- Los seis tests pasan.
roundtrip_results.jsonmuestraall_byte_equal: trueycontinuation_test_passed: true.- El tamaño del archivo de pesos coincide con la suma de parámetros únicos (es decir, weight tying preservado).
- Puedes describir la diferencia entre el Bloque D (igualdad de state) y el Bloque E (coincidencia en continuación) y por qué hacen falta ambos.
Escollos¶
np.allcloseen vez denp.array_equal. Allclose oculta un error de redondeo de 1-ULP; vas a pensar que tienes recarga byte-equivalente cuando no la tienes.- Restaurar el RNG re-sembrando.
rng.bit_generator.state = saved_statees la restauración correcta;rng = np.random.default_rng(seed)reinicia el RNG a la posición 0. - Serializar a JSON el state del RNG directamente. El state de
np.random.Generatorcontiene ints de numpy que no pasan porjson.dumps. Conviértelos vía el dictstatey base64 los campos relevantes. - Olvidar guardar el step counter
tdel optimizer. Al recargar, \(\hat m / (1 - \beta_1^t)\) usa eltequivocado; la corrección de sesgo está mal; la actualización del siguiente step está mal.
Cuándo consultar solutions/¶
Después de que todos los tests pasen. Solución en solutions/02-checkpoint-roundtrip-ref.md (escrita al abrir la fase).
Siguiente lab: lab/03-mp-drift.md.