Skip to content

English · Español

04 — Checkpoints (safetensors) + mlflow as a manifest wrapper

🇪🇸 Aquí decidimos cómo persistir el modelo y cómo recordar qué hicimos. Spoiler: safetensors (no pickle, es RCE), manifest.json sigue siendo la fuente de verdad, mlflow es una capa de navegación encima.


Two questions this file answers:

  1. How do we persist a checkpoint such that reloading is bit-identical to the live state?
  2. How do we track many such checkpoints across many runs without giving up the manifest discipline of Phase 0?

Pickle is RCE. Safetensors is not. End.

Python's pickle is a serialization format that executes arbitrary code on load. The bytes of a pickle file are a program; pickle.load(f) runs that program. A malicious pickle file can os.system("rm -rf $HOME") the instant you load it. This is not a hypothetical: the torch.load default has been pickle for years, and there are documented cases of model checkpoints downloaded from the public registry that hijack the loader.

security/supply-chain.md already states the rule. Phase 18 is where the rule first applies:

Borja never persists model weights via pickle. Never. Safetensors only.

safetensors (Hugging Face) is a flat binary format: a JSON header describing tensor shapes/dtypes/offsets, followed by raw bytes of tensor data. Loading it is a parse + memmap. There is no code execution, no metaclass invocation, no __reduce__ shenanigans. The loader is in pure Rust + minimal Python; the failure modes are well-understood (truncation, magic-number mismatch).

The verb-grammar checkpoint at Phase 18 looks like:

models/minigpt-phase18-<hash>.safetensors    # weights only
models/minigpt-phase18-<hash>.manifest.json  # side-car: config + seed + versions + git_sha
models/minigpt-phase18-<hash>.optim.safetensors  # optimizer state (m, v, t)

Three files per checkpoint. Atomic write: write to *.tmp, then os.rename(tmp, final) once writing completes. A partial write must never leave a *.safetensors file that looks valid; the os.rename is atomic on POSIX, so an interrupted save leaves a *.tmp you can detect and clean up.

What goes in manifest.json

This file is the truth about the run. It lives next to the safetensors and contains everything needed to reproduce the model:

{
  "schema_version": "1.0",
  "phase": 18,
  "git_sha": "abcd1234...",
  "config_hash": "sha256 of the resolved config",
  "data_manifest_hash": "sha256 of data/processed/MANIFEST.json",
  "seed": 42,
  "versions": {
    "python": "3.11.x",
    "numpy": "2.x.x",
    "safetensors": "0.4.x",
    "mlflow": "2.x.x"
  },
  "config": { ... resolved training config ... },
  "step": 2000,
  "epoch": 50,
  "metrics": {
    "train_ppl": 4.5,
    "val_ppl": 9.0,
    "ngram_baseline_val_ppl": 13.0
  },
  "hardware": {
    "cpu": "Intel i5-8250U",
    "ram_gb": 62
  },
  "mlflow_run_uri": "mlruns/0/<run_id>"
}

The data_manifest_hash is non-negotiable: it pins which version of the Phase-12 corpus this checkpoint was trained on. If the corpus regenerates with a different seed, the hash changes; old checkpoints can be recognized as trained on a different dataset.

The reload contract

The reload code must guarantee:

state_before = capture_full_state()    # weights + optim + scheduler + RNG
save_checkpoint(state_before, path)
state_after = load_checkpoint(path)

for k in state_before:
    assert np.array_equal(state_before[k], state_after[k])

This is the byte-equivalence that the DoD calls for. Note: np.array_equal, not np.allclose. Bit-for-bit. No rounding, no fp32-vs-fp64 drift. If you cannot pass this test, your checkpoint is not a checkpoint; it's a "starting point".

The five state machines from theory/00:

  1. Model weights — safetensors.
  2. Optimizer state (\(m_t, v_t, t\)) — separate safetensors file.
  3. Scheduler state (\(t\)) — JSON, in the manifest.
  4. Data-iterator state (epoch, position, RNG seed) — JSON, in the manifest.
  5. Training-control RNG state — JSON, base64-encoded np.random.Generator.bit_generator.state.

All five round-trip. Lab 02 tests it.

A subtle test for weight tying

MiniGPT (Phase 17) shares the input embedding matrix with the output projection. Two Parameter objects in the Module graph point to the same underlying tensor. Naive checkpoint code does:

for name, param in model.named_parameters():
    save_tensor(name, param.data)

This writes the tied tensor twice, under two names. On reload, both names get assigned the same value — fine. But on save, the file is 2 MB instead of 1 MB. Worse, if you later modify one of the two copies during reload (bug), they're no longer tied.

The fix: detect tying before save. Maintain a id_to_name map; if id(param.data) is already in the map, save only a reference (the other name). On reload, the reference is resolved to the same tensor object.

The Phase 18 lab 02 tests for this. The expected file size is 1 MB, not 2 MB. If it's 2 MB, weight tying is broken in the checkpoint round-trip.

mlflow — what it gives you and what it doesn't

mlflow is a run-tracking tool. Per A8 it lands in Phase 18, the first phase that benefits from cross-run navigation. It gives:

  • A central UI (mlflow ui) showing every run, its config, its metrics, its artifacts.
  • Per-run timeline plots (loss, LR, gradient norm) as the run progresses.
  • Artifact storage — any file logged with mlflow.log_artifact(...) is associated with the run.
  • Comparison views — select two runs, see side-by-side configs and metric curves.

What mlflow does not replace:

  • manifest.json. It's still the source of truth on disk. mlflow stores its own metadata in a SQLite DB (mlruns.db); if that DB is corrupted, the run is findable via manifest.json files in models/.
  • The seed. Setting a seed inside mlflow.start_run() doesn't make the run reproducible; the seed must be in the manifest.
  • Determinism guarantees. mlflow records metrics; it doesn't enforce reproducibility.

The thumb rule: manifest is the truth; mlflow is the index. If they disagree, manifest wins.

Minimal mlflow wiring

with mlflow.start_run() as run:
    mlflow.log_params(config.to_flat_dict())
    mlflow.log_artifact("manifest.json")  # the source of truth
    for step in range(total_steps):
        ...
        if step % 10 == 0:
            mlflow.log_metric("train_loss", loss, step=step)
            mlflow.log_metric("lr", current_lr, step=step)
            mlflow.log_metric("grad_norm", g_norm, step=step)
    mlflow.log_metric("val_ppl", val_ppl)
    mlflow.log_artifact(checkpoint_path)  # safetensors file
    mlflow.log_artifact(loss_curve_png)

That's the entire wiring. ~10 lines. No mlflow-specific config logic, no decorators on the training step, no auto-logging magic. The "thin wrapper" in lab 04 builds on this.

Tracking URI: SQLite vs file-store

mlflow's default is a file:// URI that writes one JSON file per metric per step. With 2000 steps × 5 metrics × ~100 bytes/file, you get 1 MB across 10000 files in a single training run. The filesystem hates this, the UI is slow, and concurrent writes corrupt.

The fix: MLFLOW_TRACKING_URI=sqlite:///mlruns.db. One SQLite database, fast queries, no concurrent-write corruption. Lab 04 pins this.

Per-checkpoint vs per-run

A run can produce many checkpoints (every epoch, plus best-val). The relationship:

  • One run = one process invocation of training. One mlflow run id. One config.
  • Many checkpoints = per-epoch saves + best-val save + final save.
  • Each checkpoint has its own manifest.json referring back to the run's mlflow URI.
  • Each checkpoint is independently reloadable. Resuming from checkpoint = restoring the full state and continuing the loop.

Phase 18's default cadence: save every epoch + best-val save + final save, with a rolling buffer of 3 last-epoch checkpoints (oldest evicted to bound disk).

Drill problems

  1. A pickle file is downloaded from a public model registry. You pickle.load(f). What's the worst case? What's the minimum defense that doesn't actually solve the problem (i.e., the misconception)?
  2. Weight tying: the embedding matrix is (V, d_model) = (512, 64) = 32768 floats = 128 KB at fp32. If the checkpoint accidentally writes the tied tensor twice, what's the file-size delta?
  3. The manifest.json records seed: 42. You reload, set the seed to 42, and continue training. The next step's loss does not match the run-without-reload's loss at the same step. Name three reasons.
  4. mlflow's tracking URI is set to the file store. After 5 runs with 2000 steps each, the mlruns/ directory has ~50000 files. Name two operational pains this causes.

One-paragraph recap

Persist weights via safetensors, never pickle (pickle is RCE). Side-car manifest.json holds the truth: seed, versions, git_sha, config_hash, data_manifest_hash, plus per-step metrics. Atomic write via *.tmp + rename. Save all five state machines (model, optimizer, scheduler, data-iterator, training-RNG) for byte-equivalent reload. mlflow is a thin index on top — log params + metrics + artifacts in a with mlflow.start_run() block, with MLFLOW_TRACKING_URI=sqlite:///mlruns.db. Manifest is the truth; mlflow is the index. Weight tying must be detected at save-time to avoid double-storage.

What this section does NOT cover

  • Encrypted checkpoints. Not relevant at this scale.
  • Checkpoint deduplication. Phase 38 (MLOps).
  • Distributed checkpointing. Phase 35.
  • Model registry / lineage. Phase 38.

Phase 18 theory complete. Continue with lab/00-batch-and-mask.md.