English · Español
Lab 02 — Checkpoint roundtrip: save, reload, forward, assert byte-equivalence¶
Goal: prove the checkpoint format is sufficient for byte-equivalent resume of all five state machines.
Estimated time: 60–90 minutes.
Prereq: lab 01 produced a trained checkpoint.
What you produce¶
src/minitrain/checkpoint.py— atomic save + load for weights, optimizer, scheduler, data iterator, RNG.tests/minitrain/test_checkpoint.py— round-trip tests.experiments/18-checkpoint-roundtrip/:manifest.jsonroundtrip_results.json— byte-equivalence assertions per state machine + file size measurements
Background you must have read¶
theory/04-checkpoints-and-mlflow.md— safetensors, the manifest, weight tying, atomic writes.theory/00-motivation.md§"The five state machines" — what must round-trip.
TODOs¶
Block A — the safetensors writer¶
def save_checkpoint(
model, optimizer, scheduler, data_iter, train_rng,
out_dir: Path,
step: int, epoch: int,
extra_metrics: dict,
) -> str:
"""Returns the checkpoint hash (sha256 of weights file)."""
- Write weights to
<out_dir>/weights.safetensors.tmp, then atomic-rename toweights.safetensors. - Detect weight tying: if
id(model.embed.weight.data) == id(model.lm_head.weight.data), save the tensor only once (underembed.weight) and recordlm_head.weight: ALIAS(embed.weight)in the safetensors header metadata. - Write optimizer state (m, v, t per param) to
optim.safetensors. - Write scheduler state, data-iterator state, and RNG state to
extra.json. - Write
manifest.jsonwith the full state dump (seed, versions, config_hash, data_manifest_hash, git_sha, hardware, step, epoch, extra_metrics). - Compute and return
sha256(weights.safetensors)for the checkpoint hash.
Block B — the safetensors loader¶
def load_checkpoint(in_dir: Path) -> CheckpointState:
"""Returns a CheckpointState with model_weights, optim_state, scheduler_state,
data_iter_state, train_rng_state, manifest."""
- Read
manifest.jsonfirst; validateschema_version. - Read
weights.safetensors. Resolve anyALIAS(...)references — the tied weights point to the same tensor object after load. - Read
optim.safetensorsandextra.json. - Return everything as a structured
CheckpointState(dataclass).
Block C — apply checkpoint to live objects¶
def apply_checkpoint(state: CheckpointState, model, optimizer, scheduler, data_iter, train_rng) -> None:
"""Restores every state machine in-place."""
-
model.load_state_dict(state.model_weights). -
optimizer.load_state_dict(state.optim_state). -
scheduler.load_state_dict(state.scheduler_state). -
data_iter.load_state_dict(state.data_iter_state). -
train_rng.bit_generator.state = state.train_rng_state.
Block D — round-trip test¶
The key test:
def test_byte_equivalent_roundtrip():
# Live state
model, optim, sched, data_iter, rng = build_fresh()
train_n_steps(model, optim, sched, data_iter, rng, n=10)
state_before = capture_full_state(model, optim, sched, data_iter, rng)
save_checkpoint(model, optim, sched, data_iter, rng, out_dir, step=10, epoch=0, extra_metrics={})
# Fresh objects, reload
model2, optim2, sched2, data_iter2, rng2 = build_fresh()
state = load_checkpoint(out_dir)
apply_checkpoint(state, model2, optim2, sched2, data_iter2, rng2)
state_after = capture_full_state(model2, optim2, sched2, data_iter2, rng2)
for key in state_before:
assert byte_equal(state_before[key], state_after[key]), f"mismatch at {key}"
-
byte_equal(a, b): for ndarrays usesnp.array_equal(notnp.allclose); for ints/strings uses==; for RNG states compares dicts. - All five state machines must pass.
Block E — continuation test (the deeper check)¶
The real reload test: continue training, verify the resumed run matches the never-checkpointed run step-for-step.
def test_resume_matches_continuous():
# Reference: train 20 steps without checkpointing
m_ref, o_ref, s_ref, d_ref, r_ref = build_fresh_with_seed(42)
losses_ref = train_n_steps(m_ref, o_ref, s_ref, d_ref, r_ref, n=20, return_losses=True)
# Test: train 10, checkpoint, fresh-load, train 10 more
m, o, s, d, r = build_fresh_with_seed(42)
losses_a = train_n_steps(m, o, s, d, r, n=10, return_losses=True)
save_checkpoint(m, o, s, d, r, tmp_dir, step=10, epoch=0, extra_metrics={})
m2, o2, s2, d2, r2 = build_fresh_with_seed(42)
apply_checkpoint(load_checkpoint(tmp_dir), m2, o2, s2, d2, r2)
losses_b = train_n_steps(m2, o2, s2, d2, r2, n=10, return_losses=True)
# Concatenate and assert byte-equality
np.testing.assert_array_equal(losses_ref, losses_a + losses_b)
This test is the hard test. Getting Block D's state-equality test to pass is necessary but not sufficient — Block E catches state machines that are equal in serialized form but behave differently after reload (e.g., an RNG whose state is restored to position 7 but whose next draw is from position 0 because the generator was rebuilt with the seed instead of the state).
- If Block E fails but Block D passes, you have a state-machine restore bug. Common cause: rebuilding the RNG with the seed instead of restoring the bit_generator state.
Block F — weight tying test¶
def test_weight_tying_preserved():
model = build_minigpt(tie_weights=True)
assert id(model.embed.weight.data) == id(model.lm_head.weight.data)
save_checkpoint(model, ..., out_dir)
# File size sanity
weight_file = out_dir / "weights.safetensors"
expected_size = sum(p.numel() * 4 for p in unique_params(model))
assert abs(weight_file.stat().st_size - expected_size) < 4096, \
f"file size {weight_file.stat().st_size} != expected {expected_size}; tying broken?"
# Reload and re-verify
model2 = build_minigpt(tie_weights=True)
apply_checkpoint(load_checkpoint(out_dir), model2, ...)
assert id(model2.embed.weight.data) == id(model2.lm_head.weight.data)
- On-disk size must reflect tied (not double) storage.
- After reload, both tied parameters point to the same underlying array (
id(...)equality).
Block G — atomic write test¶
def test_atomic_write_failure_resilience():
# Simulate an interrupted write
with patch("os.rename", side_effect=KeyboardInterrupt):
with pytest.raises(KeyboardInterrupt):
save_checkpoint(model, ..., out_dir)
# The final file should not exist
assert not (out_dir / "weights.safetensors").exists()
# The temp file may exist; that's fine, it's our diagnostic
# On next save attempt, the temp must be cleaned up.
save_checkpoint(model, ..., out_dir) # this time os.rename works
assert (out_dir / "weights.safetensors").exists()
assert not (out_dir / "weights.safetensors.tmp").exists()
- Save() detects leftover
*.tmpfrom a previous interrupted save and removes it. - Save() writes to
*.tmpthen atomic-renames — never leaves a half-writtenweights.safetensors.
Block H — record results¶
experiments/18-checkpoint-roundtrip/roundtrip_results.json:
{
"byte_equivalent_state_machines": ["model", "optimizer", "scheduler", "data_iter", "train_rng"],
"all_byte_equal": true,
"continuation_test_passed": true,
"weight_tying_preserved": true,
"atomic_write_resilient": true,
"weights_file_size_bytes": 612408,
"expected_size_bytes_no_tying": 612408,
"size_delta_pct": 0.0,
"manifest_schema_version": "1.0"
}
Constraints¶
np.array_equal, notnp.allclose. Bit-identical only.- No pickle. Use
safetensorsfor ndarrays, JSON for everything else. - All saves are atomic via
*.tmp+os.rename.
Stop conditions¶
Done when:
- All six tests pass.
roundtrip_results.jsonshowsall_byte_equal: trueandcontinuation_test_passed: true.- The weights file size matches the unique-param sum (i.e., weight tying preserved).
- You can describe the difference between Block D (state equality) and Block E (continuation match) and why both are needed.
Pitfalls¶
np.allcloseinstead ofnp.array_equal. Allclose hides a 1-ULP rounding error; you'll think you have byte-equivalent reload but you don't.- Restoring the RNG by reseeding.
rng.bit_generator.state = saved_stateis the correct restore;rng = np.random.default_rng(seed)resets the RNG to position 0. - JSON-serializing the RNG state directly.
np.random.Generatorstate contains numpy ints that can't go throughjson.dumps. Convert via thestatedict and base64 the relevant fields. - Forgetting to save the optimizer step counter
t. On reload, \(\hat m / (1 - \beta_1^t)\) uses the wrongt; bias correction is wrong; next step's update is wrong.
When to consult solutions/¶
After all tests pass. Solution at solutions/02-checkpoint-roundtrip-ref.md (written at phase open).
Next lab: lab/03-mp-drift.md.