Skip to content

English · Español

Lab 02 — Checkpoint roundtrip: save, reload, forward, assert byte-equivalence

Goal: prove the checkpoint format is sufficient for byte-equivalent resume of all five state machines.

Estimated time: 60–90 minutes.

Prereq: lab 01 produced a trained checkpoint.


What you produce

  • src/minitrain/checkpoint.py — atomic save + load for weights, optimizer, scheduler, data iterator, RNG.
  • tests/minitrain/test_checkpoint.py — round-trip tests.
  • experiments/18-checkpoint-roundtrip/:
  • manifest.json
  • roundtrip_results.json — byte-equivalence assertions per state machine + file size measurements

Background you must have read

  • theory/04-checkpoints-and-mlflow.md — safetensors, the manifest, weight tying, atomic writes.
  • theory/00-motivation.md §"The five state machines" — what must round-trip.

TODOs

Block A — the safetensors writer

def save_checkpoint(
    model, optimizer, scheduler, data_iter, train_rng,
    out_dir: Path,
    step: int, epoch: int,
    extra_metrics: dict,
) -> str:
    """Returns the checkpoint hash (sha256 of weights file)."""
  • Write weights to <out_dir>/weights.safetensors.tmp, then atomic-rename to weights.safetensors.
  • Detect weight tying: if id(model.embed.weight.data) == id(model.lm_head.weight.data), save the tensor only once (under embed.weight) and record lm_head.weight: ALIAS(embed.weight) in the safetensors header metadata.
  • Write optimizer state (m, v, t per param) to optim.safetensors.
  • Write scheduler state, data-iterator state, and RNG state to extra.json.
  • Write manifest.json with the full state dump (seed, versions, config_hash, data_manifest_hash, git_sha, hardware, step, epoch, extra_metrics).
  • Compute and return sha256(weights.safetensors) for the checkpoint hash.

Block B — the safetensors loader

def load_checkpoint(in_dir: Path) -> CheckpointState:
    """Returns a CheckpointState with model_weights, optim_state, scheduler_state,
       data_iter_state, train_rng_state, manifest."""
  • Read manifest.json first; validate schema_version.
  • Read weights.safetensors. Resolve any ALIAS(...) references — the tied weights point to the same tensor object after load.
  • Read optim.safetensors and extra.json.
  • Return everything as a structured CheckpointState (dataclass).

Block C — apply checkpoint to live objects

def apply_checkpoint(state: CheckpointState, model, optimizer, scheduler, data_iter, train_rng) -> None:
    """Restores every state machine in-place."""
  • model.load_state_dict(state.model_weights).
  • optimizer.load_state_dict(state.optim_state).
  • scheduler.load_state_dict(state.scheduler_state).
  • data_iter.load_state_dict(state.data_iter_state).
  • train_rng.bit_generator.state = state.train_rng_state.

Block D — round-trip test

The key test:

def test_byte_equivalent_roundtrip():
    # Live state
    model, optim, sched, data_iter, rng = build_fresh()
    train_n_steps(model, optim, sched, data_iter, rng, n=10)

    state_before = capture_full_state(model, optim, sched, data_iter, rng)
    save_checkpoint(model, optim, sched, data_iter, rng, out_dir, step=10, epoch=0, extra_metrics={})

    # Fresh objects, reload
    model2, optim2, sched2, data_iter2, rng2 = build_fresh()
    state = load_checkpoint(out_dir)
    apply_checkpoint(state, model2, optim2, sched2, data_iter2, rng2)
    state_after = capture_full_state(model2, optim2, sched2, data_iter2, rng2)

    for key in state_before:
        assert byte_equal(state_before[key], state_after[key]), f"mismatch at {key}"
  • byte_equal(a, b): for ndarrays uses np.array_equal (not np.allclose); for ints/strings uses ==; for RNG states compares dicts.
  • All five state machines must pass.

Block E — continuation test (the deeper check)

The real reload test: continue training, verify the resumed run matches the never-checkpointed run step-for-step.

def test_resume_matches_continuous():
    # Reference: train 20 steps without checkpointing
    m_ref, o_ref, s_ref, d_ref, r_ref = build_fresh_with_seed(42)
    losses_ref = train_n_steps(m_ref, o_ref, s_ref, d_ref, r_ref, n=20, return_losses=True)

    # Test: train 10, checkpoint, fresh-load, train 10 more
    m, o, s, d, r = build_fresh_with_seed(42)
    losses_a = train_n_steps(m, o, s, d, r, n=10, return_losses=True)
    save_checkpoint(m, o, s, d, r, tmp_dir, step=10, epoch=0, extra_metrics={})
    m2, o2, s2, d2, r2 = build_fresh_with_seed(42)
    apply_checkpoint(load_checkpoint(tmp_dir), m2, o2, s2, d2, r2)
    losses_b = train_n_steps(m2, o2, s2, d2, r2, n=10, return_losses=True)

    # Concatenate and assert byte-equality
    np.testing.assert_array_equal(losses_ref, losses_a + losses_b)

This test is the hard test. Getting Block D's state-equality test to pass is necessary but not sufficient — Block E catches state machines that are equal in serialized form but behave differently after reload (e.g., an RNG whose state is restored to position 7 but whose next draw is from position 0 because the generator was rebuilt with the seed instead of the state).

  • If Block E fails but Block D passes, you have a state-machine restore bug. Common cause: rebuilding the RNG with the seed instead of restoring the bit_generator state.

Block F — weight tying test

def test_weight_tying_preserved():
    model = build_minigpt(tie_weights=True)
    assert id(model.embed.weight.data) == id(model.lm_head.weight.data)
    save_checkpoint(model, ..., out_dir)
    # File size sanity
    weight_file = out_dir / "weights.safetensors"
    expected_size = sum(p.numel() * 4 for p in unique_params(model))
    assert abs(weight_file.stat().st_size - expected_size) < 4096, \
        f"file size {weight_file.stat().st_size} != expected {expected_size}; tying broken?"
    # Reload and re-verify
    model2 = build_minigpt(tie_weights=True)
    apply_checkpoint(load_checkpoint(out_dir), model2, ...)
    assert id(model2.embed.weight.data) == id(model2.lm_head.weight.data)
  • On-disk size must reflect tied (not double) storage.
  • After reload, both tied parameters point to the same underlying array (id(...) equality).

Block G — atomic write test

def test_atomic_write_failure_resilience():
    # Simulate an interrupted write
    with patch("os.rename", side_effect=KeyboardInterrupt):
        with pytest.raises(KeyboardInterrupt):
            save_checkpoint(model, ..., out_dir)

    # The final file should not exist
    assert not (out_dir / "weights.safetensors").exists()
    # The temp file may exist; that's fine, it's our diagnostic
    # On next save attempt, the temp must be cleaned up.
    save_checkpoint(model, ..., out_dir)  # this time os.rename works
    assert (out_dir / "weights.safetensors").exists()
    assert not (out_dir / "weights.safetensors.tmp").exists()
  • Save() detects leftover *.tmp from a previous interrupted save and removes it.
  • Save() writes to *.tmp then atomic-renames — never leaves a half-written weights.safetensors.

Block H — record results

experiments/18-checkpoint-roundtrip/roundtrip_results.json:

{
  "byte_equivalent_state_machines": ["model", "optimizer", "scheduler", "data_iter", "train_rng"],
  "all_byte_equal": true,
  "continuation_test_passed": true,
  "weight_tying_preserved": true,
  "atomic_write_resilient": true,
  "weights_file_size_bytes": 612408,
  "expected_size_bytes_no_tying": 612408,
  "size_delta_pct": 0.0,
  "manifest_schema_version": "1.0"
}

Constraints

  • np.array_equal, not np.allclose. Bit-identical only.
  • No pickle. Use safetensors for ndarrays, JSON for everything else.
  • All saves are atomic via *.tmp + os.rename.

Stop conditions

Done when:

  1. All six tests pass.
  2. roundtrip_results.json shows all_byte_equal: true and continuation_test_passed: true.
  3. The weights file size matches the unique-param sum (i.e., weight tying preserved).
  4. You can describe the difference between Block D (state equality) and Block E (continuation match) and why both are needed.

Pitfalls

  • np.allclose instead of np.array_equal. Allclose hides a 1-ULP rounding error; you'll think you have byte-equivalent reload but you don't.
  • Restoring the RNG by reseeding. rng.bit_generator.state = saved_state is the correct restore; rng = np.random.default_rng(seed) resets the RNG to position 0.
  • JSON-serializing the RNG state directly. np.random.Generator state contains numpy ints that can't go through json.dumps. Convert via the state dict and base64 the relevant fields.
  • Forgetting to save the optimizer step counter t. On reload, \(\hat m / (1 - \beta_1^t)\) uses the wrong t; bias correction is wrong; next step's update is wrong.

When to consult solutions/

After all tests pass. Solution at solutions/02-checkpoint-roundtrip-ref.md (written at phase open).


Next lab: lab/03-mp-drift.md.