Skip to content

English · Español

04 — Checkpoints (safetensors) + mlflow como envoltorio del manifest

Aquí decidimos cómo persistir el modelo y cómo recordar qué hicimos. Spoiler: safetensors (no pickle, es RCE), manifest.json sigue siendo la fuente de verdad, mlflow es una capa de navegación encima.


Dos preguntas que este fichero responde:

  1. ¿Cómo persistimos un checkpoint de modo que recargarlo sea bit-idéntico al estado vivo?
  2. ¿Cómo seguimos la pista de muchos checkpoints así a lo largo de muchos runs sin renunciar a la disciplina de manifest de la Fase 0?

Pickle es RCE. Safetensors no. Fin.

El pickle de Python es un formato de serialización que ejecuta código arbitrario al cargar. Los bytes de un fichero pickle son un programa; pickle.load(f) ejecuta ese programa. Un fichero pickle malicioso puede ejecutar os.system("rm -rf $HOME") en el instante en que lo cargas. Esto no es hipotético: el default de torch.load ha sido pickle durante años, y hay casos documentados de checkpoints de modelos descargados del registro público que secuestran el cargador.

security/supply-chain.md ya establece la regla. La Fase 18 es donde la regla se aplica por primera vez:

Borja nunca persiste pesos de modelo vía pickle. Nunca. Solo safetensors.

safetensors (Hugging Face) es un formato binario plano: una cabecera JSON describiendo formas/dtypes/offsets de tensores, seguida de los bytes en bruto de los datos del tensor. Cargarlo es parse + memmap. No hay ejecución de código, no hay invocación de metaclase, no hay trucos de __reduce__. El cargador está en Rust puro + Python mínimo; los modos de fallo están bien entendidos (truncación, magic-number no coincide).

El checkpoint de gramática verbal en la Fase 18 se ve así:

models/minigpt-phase18-<hash>.safetensors    # weights only
models/minigpt-phase18-<hash>.manifest.json  # side-car: config + seed + versions + git_sha
models/minigpt-phase18-<hash>.optim.safetensors  # optimizer state (m, v, t)

Tres ficheros por checkpoint. Escritura atómica: escribe a *.tmp, luego os.rename(tmp, final) una vez completada la escritura. Una escritura parcial nunca debe dejar un fichero *.safetensors que parezca válido; el os.rename es atómico en POSIX, así que un guardado interrumpido deja un *.tmp que puedes detectar y limpiar.

Qué va en manifest.json

Este fichero es la verdad sobre el run. Vive junto al safetensors y contiene todo lo necesario para reproducir el modelo:

{
  "schema_version": "1.0",
  "phase": 18,
  "git_sha": "abcd1234...",
  "config_hash": "sha256 of the resolved config",
  "data_manifest_hash": "sha256 of data/processed/MANIFEST.json",
  "seed": 42,
  "versions": {
    "python": "3.11.x",
    "numpy": "2.x.x",
    "safetensors": "0.4.x",
    "mlflow": "2.x.x"
  },
  "config": { ... resolved training config ... },
  "step": 2000,
  "epoch": 50,
  "metrics": {
    "train_ppl": 4.5,
    "val_ppl": 9.0,
    "ngram_baseline_val_ppl": 13.0
  },
  "hardware": {
    "cpu": "Intel i5-8250U",
    "ram_gb": 62
  },
  "mlflow_run_uri": "mlruns/0/<run_id>"
}

El data_manifest_hash es no negociable: fija qué versión del corpus de la Fase 12 entrenó este checkpoint. Si el corpus se regenera con una seed distinta, el hash cambia; los checkpoints viejos pueden ser reconocidos como entrenados sobre un dataset distinto.

El contrato de recarga

El código de recarga debe garantizar:

state_before = capture_full_state()    # weights + optim + scheduler + RNG
save_checkpoint(state_before, path)
state_after = load_checkpoint(path)

for k in state_before:
    assert np.array_equal(state_before[k], state_after[k])

Esta es la equivalencia byte a byte que pide la DoD. Nota: np.array_equal, no np.allclose. Bit a bit. Sin redondeo, sin drift fp32-vs-fp64. Si no puedes pasar este test, tu checkpoint no es un checkpoint; es un "punto de partida".

Las cinco máquinas de estado de theory/00:

  1. Pesos del modelo — safetensors.
  2. Estado del optimizador (\(m_t, v_t, t\)) — fichero safetensors separado.
  3. Estado del scheduler (\(t\)) — JSON, en el manifest.
  4. Estado del iterador de datos (epoch, posición, seed del RNG) — JSON, en el manifest.
  5. Estado del RNG de control de entrenamiento — JSON, np.random.Generator.bit_generator.state codificado en base64.

Los cinco hacen round-trip. El Lab 02 lo testea.

Un test sutil para weight tying

MiniGPT (Fase 17) comparte la matriz de embedding de entrada con la proyección de salida. Dos objetos Parameter en el grafo del Module apuntan al mismo tensor subyacente. Código naive de checkpoint hace:

for name, param in model.named_parameters():
    save_tensor(name, param.data)

Esto escribe el tensor atado dos veces, bajo dos nombres. Al recargar, ambos nombres reciben el mismo valor — bien. Pero al guardar, el fichero es de 2 MB en lugar de 1 MB. Peor, si más tarde modificas una de las dos copias durante la recarga (bug), ya no están atadas.

El arreglo: detectar el tying antes del guardado. Mantén un mapa id_to_name; si id(param.data) ya está en el mapa, guarda solo una referencia (el otro nombre). Al recargar, la referencia se resuelve al mismo objeto tensor.

El lab 02 de la Fase 18 testea esto. El tamaño esperado del fichero es 1 MB, no 2 MB. Si es 2 MB, el weight tying está roto en el round-trip del checkpoint.

mlflow — lo que te da y lo que no

mlflow es una herramienta de tracking de runs. Según A8 aterriza en la Fase 18, la primera fase que se beneficia de la navegación entre runs. Da:

  • Una UI central (mlflow ui) mostrando cada run, su config, sus métricas, sus artefactos.
  • Plots de timeline por run (loss, LR, gradient norm) según progresa el run.
  • Almacenamiento de artefactos — cualquier fichero registrado con mlflow.log_artifact(...) queda asociado al run.
  • Vistas de comparación — selecciona dos runs, ve configs y curvas de métricas en paralelo.

Lo que mlflow no reemplaza:

  • manifest.json. Sigue siendo la fuente de verdad en disco. mlflow guarda sus propios metadatos en una DB SQLite (mlruns.db); si esa DB se corrompe, el run es encontrable vía los ficheros manifest.json en models/.
  • La seed. Poner una seed dentro de mlflow.start_run() no hace el run reproducible; la seed debe estar en el manifest.
  • Garantías de determinismo. mlflow registra métricas; no fuerza reproducibilidad.

La regla del pulgar: el manifest es la verdad; mlflow es el índice. Si no coinciden, gana el manifest.

Cableado mínimo de mlflow

with mlflow.start_run() as run:
    mlflow.log_params(config.to_flat_dict())
    mlflow.log_artifact("manifest.json")  # the source of truth
    for step in range(total_steps):
        ...
        if step % 10 == 0:
            mlflow.log_metric("train_loss", loss, step=step)
            mlflow.log_metric("lr", current_lr, step=step)
            mlflow.log_metric("grad_norm", g_norm, step=step)
    mlflow.log_metric("val_ppl", val_ppl)
    mlflow.log_artifact(checkpoint_path)  # safetensors file
    mlflow.log_artifact(loss_curve_png)

Ese es todo el cableado. ~10 líneas. Sin lógica de config específica de mlflow, sin decoradores sobre el paso de entrenamiento, sin magia de auto-logging. El "envoltorio fino" del lab 04 se construye sobre esto.

Tracking URI: SQLite vs file-store

El default de mlflow es un URI file:// que escribe un fichero JSON por métrica por paso. Con 2000 pasos × 5 métricas × ~100 bytes/fichero, obtienes 1 MB repartido en 10000 ficheros en un único run de entrenamiento. El sistema de ficheros lo odia, la UI va lenta, y las escrituras concurrentes corrompen.

El arreglo: MLFLOW_TRACKING_URI=sqlite:///mlruns.db. Una base de datos SQLite, queries rápidas, sin corrupción por escrituras concurrentes. El Lab 04 lo fija.

Por checkpoint vs por run

Un run puede producir muchos checkpoints (cada epoch, más el best-val). La relación:

  • Un run = una invocación de proceso del entrenamiento. Un id de run de mlflow. Una config.
  • Muchos checkpoints = guardados por epoch + guardado best-val + guardado final.
  • Cada checkpoint tiene su propio manifest.json referenciando de vuelta al URI de mlflow del run.
  • Cada checkpoint es recargable independientemente. Reanudar desde checkpoint = restaurar el estado completo y continuar el loop.

La cadencia por defecto de la Fase 18: guardar cada epoch + guardado best-val + guardado final, con un buffer rotativo de 3 últimos checkpoints por epoch (el más viejo se desaloja para acotar el disco).

Problemas de drill

  1. Un fichero pickle se descarga de un registro público de modelos. Haces pickle.load(f). ¿Cuál es el peor caso? ¿Cuál es la mínima defensa que en realidad no resuelve el problema (es decir, la idea errónea)?
  2. Weight tying: la matriz de embedding es (V, d_model) = (512, 64) = 32768 floats = 128 KB en fp32. Si el checkpoint accidentalmente escribe el tensor atado dos veces, ¿cuál es la diferencia en tamaño de fichero?
  3. El manifest.json registra seed: 42. Recargas, pones la seed a 42, y continúas entrenando. La pérdida del siguiente paso no coincide con la pérdida del run-sin-recarga en el mismo paso. Da tres razones.
  4. El tracking URI de mlflow está puesto al file store. Tras 5 runs con 2000 pasos cada uno, el directorio mlruns/ tiene ~50000 ficheros. Da dos dolores operacionales que esto causa.

Recap de un párrafo

Persiste los pesos vía safetensors, nunca pickle (pickle es RCE). El side-car manifest.json guarda la verdad: seed, versions, git_sha, config_hash, data_manifest_hash, más métricas por paso. Escritura atómica vía *.tmp + rename. Guarda las cinco máquinas de estado (modelo, optimizador, scheduler, iterador de datos, training-RNG) para recarga byte a byte. mlflow es un índice fino encima — registra params + métricas + artefactos en un bloque with mlflow.start_run(), con MLFLOW_TRACKING_URI=sqlite:///mlruns.db. El manifest es la verdad; mlflow es el índice. El weight tying debe detectarse en el guardado para evitar almacenamiento doble.

Lo que esta sección NO cubre

  • Checkpoints encriptados. No relevante a esta escala.
  • Deduplicación de checkpoints. Fase 38 (MLOps).
  • Checkpointing distribuido. Fase 35.
  • Registro de modelos / linaje. Fase 38.

Teoría de la Fase 18 completa. Continúa con lab/00-batch-and-mask.md.