Skip to content

English · Español

Lab 00 — Pretraining en la nube de un día

🇪🇸 La corrida real. 1× A100 80GB durante ~24 h en Lambda o RunPod, 50M parámetros, ~5B tokens de FineWeb-Edu, presupuesto duro $35. El objetivo es reproducibilidad y números medidos, no batir un benchmark. Una sola configuración derivada de la teoría, una sola corrida, post-mortem si se rompe.

Objetivo

Producir una corrida reproducible de pretraining (pretraining) sobre una GPU real en la nube. Entrenar un transformer decoder-only de 50M parámetros durante ~24 horas en un único A100 80GB. Val-loss final predicho: ~3.35 ± 0.2 nats/token (del ajuste Chinchilla en theory/01-scaling-laws.md).

Checklist de prerrequisitos

  • Training loop de la Fase 18 entregado (src/minitrain/loop.py). El trainer de X1 es su hermano mayor.
  • mlflow funcionando localmente. El trainer de X1 loguea cada 10 steps.
  • safetensors disponible para I/O de checkpoint.
  • src/distributed/budget_guard.py de la Fase 35 importable. Requerido para lanzar.
  • Una cuenta de Lambda Labs o RunPod, tarjeta de facturación, alertas a $30 configuradas.
  • Clave SSH registrada con la nube elegida.

Presupuesto duro

Línea $-coste Notas
1× A100 80GB spot, Lambda, ~26 h × $1.10/h $28.60 Cómputo primario
Egress de almacenamiento (descarga de datos) $1 Muestra de FineWeb-Edu desde HF Hub
Disco persistente adjunto $2 200 GB × 24 h
Buffer (1 reinicio) $5 Si un spike fuerza un re-lanzamiento
Tope $35 budget_guard.py rechaza si se excede

Si el precio spot real es >$1.40/h en el momento del lanzamiento, no lances. Espera 4 horas y re-comprueba. RunPod community a $0.79/h es un fallback aceptable.

La receta del cluster

Proveedor: Lambda Labs (primario) o RunPod (fallback)

  • Por qué Lambda primario: facturación más simple, buena disponibilidad de A100 80GB a mitad de semana, el hardware es consistente.
  • Por qué RunPod fallback: ~30% más barato en spot, pero el hardware community-host varía; al trainer no debería importarle, pero a veces le importa (drivers más viejos, NVMe más lento).

Especificaciones de la instancia

  • GPU: 1× A100 80GB (SXM4 preferida; PCIe aceptable con golpe de throughput del 5-10%).
  • vCPUs: ≥16.
  • RAM: ≥64 GB (mantenemos un dataset tokenizado parcialmente en RAM).
  • Disco: 500 GB NVMe adjunto.
  • OS: Ubuntu 22.04 (default de Lambda) o 20.04 (default de RunPod — ambos funcionan).

Imagen Docker

Usa la imagen oficial de NVIDIA PyTorch, que trae FlashAttention-2, Triton compatible con torch.compile, y CUDA 12.x:

nvcr.io/nvidia/pytorch:24.10-py3

Esta imagen es ~14 GB, descarga en ~3-5 minutos en la nube. Incluye: - PyTorch 2.5 con CUDA 12.6 - FlashAttention-2 (pre-construido para SM80 / A100, SM90 / H100) - Triton - TransformerEngine (no lo usamos, pero está ahí) - NCCL, cuDNN, cuBLAS

Fija el tag explícitamente. latest te morderá cuando el upstream cambie.

Setup de host único (~10 min)

# En el host de la nube, tras SSH:
docker pull nvcr.io/nvidia/pytorch:24.10-py3

# Verificar GPU visible:
nvidia-smi
# Esperado: 1× A100-SXM4-80GB o A100-PCIE-80GB

# Pre-crear dirs persistentes:
mkdir -p /workspace/{data,checkpoints,mlruns,logs}

El dataset: muestra de FineWeb-Edu

Usamos FineWeb-Edu (Penedo 2024) — el subset de CommonCrawl filtrado por clasificador LLM descrito en theory/03.

Opción A (preferida): subset sample-10BT de HF

# Dentro del contenedor en ejecución
huggingface-cli login  # pega el token HF de solo lectura
mkdir -p /workspace/data/fineweb-edu
cd /workspace/data/fineweb-edu

# Descarga el set de shards de muestra de 10B-tokens (~22 GB en disco)
huggingface-cli download \
  HuggingFaceFW/fineweb-edu \
  --repo-type dataset \
  --include "sample/10BT/*.parquet" \
  --local-dir .

# Verifica el conteo de shards (~96 ficheros de ~230 MB cada uno)
ls sample/10BT/ | wc -l

Opción B (fallback si HF rate-limita): slice de Pile-CC

# Pile de EleutherAI, subset CommonCrawl; misma forma, calidad más baja
huggingface-cli download \
  monology/pile-uncopyrighted \
  --include "test/*.jsonl.zst" \
  --local-dir /workspace/data/pile-cc

Tokenizar: BPE de GPT-2, guardar como binario uint16

El lab usa el formato nanoGPT (train.bin, val.bin como arrays planos uint16).

cd /workspace
python -m x1_pretrain.tokenize_data \
  --input-dir /workspace/data/fineweb-edu/sample/10BT \
  --output-dir /workspace/data/tokenized \
  --tokenizer gpt2 \
  --val-fraction 0.001 \
  --workers 16

Esperado: ~10B tokens × 2 bytes = ~20 GB en disco. Tiempo de reloj: ~30 min en 16 vCPUs.

Snapshot de budget_guard.py: la tokenización es solo-CPU; no se factura GPU. Coste hasta ahora: ~$0.50 (el tiempo de CPU durante este paso está incluido en el horario).

El modelo: transformer decoder de 50M parámetros

HP Valor Justificación
n_layer 8 Default mid-2024 para clase 50M
d_model 768 Coincide con el ancho de HF gpt2-base
n_head 12 Dim de cabeza 64, estándar
n_kv_head 4 GQA-3 (ahorra K/V cache, default 2024)
d_ff 2048 ~2.7× d_model (fórmula SwiGLU)
seq_len 1024 Barato; se dobla a 2048 solo a >100M
vocab_size 50,257 BPE de GPT-2
norm RMSNorm Default moderno
act SwiGLU Default moderno
pos_enc RoPE base=10000 Default moderno
init_std 0.02 Default GPT-2 / Llama
Parámetros totales (no-embed) ~50M calculado: 8 × 12 × 768² × 4 ≈ 226M; menos embeddings; neto ~50M

(La cifra de 50M excluye embeddings de vocab. Con embeddings contados, total ~89M. Citamos la cifra no-embedding para coincidir con la convención de Chinchilla.)

Optimizador + schedule

HP Valor Fuente
Optimizador AdamW Loshchilov 2017
LR (pico) 3e-4 Llama-2 7B usa 3e-4; seguimos
Schedule de LR cosine, warmup 1000 steps, decae a 3e-5 Estándar
β₁ 0.9 Default
β₂ 0.95 Bajado desde 0.999; resistencia a spike (PaLM)
weight_decay 0.1 Default moderno
grad_clip 1.0 Universal
Tokens efectivos de batch 512K = 512 batch × 1024 seq Afinado para caber en A100 80GB
Precisión bf16 mixto (maestro fp32) Nativo de A100, sin loss-scaling

Compile: torch.compile(model, mode="max-autotune").

Kernel de atención (attention): flash_attn_func del paquete flash-attn.

Throughput esperado

  • TFLOP/s sostenidos: ~150 (MFU 0.48 del pico bf16 de 312 del A100).
  • Tokens/s: \(1.5 \times 10^{14} / (6 \times 5 \times 10^7) \approx 500{,}000\).
  • Steps/s: \(500{,}000 / 512{,}000 \approx 0.98\)~3,500 steps/hora.
  • Steps totales en 24 h: ~84,000.
  • Tokens totales: ~43B.

Si el throughput observado en la hora 1 es < 300k tokens/s (MFU < 0.30), para y diagnostica. Causas comunes: torch.compile falló al hacer fuse (revisa warnings), FlashAttention-2 no detectado (revisa nvidia-smi para utilización SM), dataloader CPU-bound (revisa iostat).

Comando de entrenamiento

# Dentro del contenedor, con todos los datos y código montados
cd /workspace
python -m x1_pretrain.train \
  --config configs/x1-50m-a100.yaml \
  --data-dir /workspace/data/tokenized \
  --ckpt-dir /workspace/checkpoints \
  --mlflow-uri file:///workspace/mlruns \
  --total-steps 84000 \
  --ckpt-every 1800 \
  --eval-every 1800 \
  --log-every 10 \
  --resume-from-latest \
  --budget-cap-usd 35.0 \
  --budget-curr-cost-uri /workspace/budget.json

El flag --budget-cap-usd 35.0 invoca budget_guard.py de la Fase 35 en modo periodic-check: cada 30 minutos consulta /workspace/budget.json (mantenido al día por un proceso sidecar que sondea la API de facturación de la nube o su proxy), y si el total proyectado excede $35, fuerza un checkpoint-and-exit limpio.

Curva de loss esperada

Reproducible bf16 + seed=42 sobre esta config debería producir:

Hora Steps Tokens (acumulado) Train loss Val loss
0 0 0 ~10.8
1 3.5k 1.8B 5.1 5.2
6 21k 11B 3.9 3.95
12 42k 21B 3.5 3.55
18 63k 32B 3.4 3.42
24 84k 43B 3.32 3.35

El val loss final de 3.35 coincide con la predicción del ajuste de Hoffmann en theory/01-scaling-laws.md dentro de 0.05 nats. El check 1 de DoD se cumple si el val loss final está dentro de [3.20, 3.50].

Fuera de esa banda → consulta theory/04, comprueba spikes, escribe el post-mortem.

Logs a registrar (DoD check 1)

mlflow loguea automáticamente la lista de métricas de abajo cada 10 steps. Tras la corrida, vuelca:

python -m x1_pretrain.export_run \
  --mlflow-uri file:///workspace/mlruns \
  --output experiments/X1-pretraining/run-cloud/

Produce: - manifest.json — seed, versiones (torch, flash-attn, transformers, numpy), hash del YAML de config, especificación del cluster, $-gastados totales. - metrics.csv — formato largo con columnas (step, name, value). - loss-curve.png — train + val loss vs steps. - gradnorm-curve.png — grad-norm y param-norm vs steps. - throughput.png — tokens/s vs hora de reloj. - final.safetensors — el último checkpoint, en formato safetensors. - mlflow-run-uri.txt — para referencia cruzada.

Inyección de loss-spike (DoD check 3)

El script de entrenamiento acepta --inject-spike-at-step N que:

  1. En el step N, reemplaza los siguientes 5 batches con un batch sintético de "token raro" (secuencias muestreadas del 1% inferior de la distribución unigrama, ponderadas para producir un loss alto de cross-entropy).
  2. Loguea la inyección claramente en mlflow.

Para el lab 00, ejecuta con --inject-spike-at-step 12000 (~3.5 horas dentro). Observa la respuesta (grad clip debería atraparlo; β₂=0.95 debería amortiguarlo rápido), luego escribe el post-mortem en spike-postmortem.md.

Si ocurre un spike real naturalmente antes del step 12000, escribe ese post-mortem en su lugar y salta la inyección.

Watchdog y alarma de presupuesto

Un proceso sidecar sondea la API de facturación de Lambda / RunPod cada 5 min y escribe a /workspace/budget.json:

{
  "spent_usd": 14.20,
  "rate_usd_per_hr": 1.10,
  "hours_elapsed": 12.9,
  "projected_total_usd": 28.4,
  "last_update_utc": "2026-05-23T15:42:01Z"
}

El trainer lee esto cada 30 min y: - si spent_usd > 30: email + ping de Slack, sin acción todavía. - si projected_total_usd > 35: checkpoint limpio y exit. - si spent_usd > 35: exit inmediato, sin escritura de checkpoint (ya está por encima).

Este es el contrato con budget_guard.py. Pruébalo antes del lanzamiento (flag --dry-run-budget).

Checklist de apagado

  • Checkpoint final guardado (final.safetensors + final-optimizer.pt).
  • Export de mlflow a experiments/X1-pretraining/run-cloud/ hecho.
  • budget.json final confirma gasto ≤ $35.
  • Termina la instancia. Un A100 olvidado a $1.10/h son $26/día.
  • Confirma terminación en el dashboard del proveedor (screenshot guardado en experiments/...).
  • Desadjunta + borra el disco persistente si no se necesita para el lab 01 (lo necesita — guárdalo por ~3 días).

Checks de DoD (este lab)

  1. experiments/X1-pretraining/run-cloud/manifest.json existe, contiene seed/versiones/cluster/$-gastados.
  2. Val loss final en [3.20, 3.50].
  3. final.safetensors existe y recarga byte-equivalentemente (test de round-trip).
  4. Total $-gastado ≤ $35.
  5. spike-postmortem.md escrito (spike real o inyectado).
  6. Instancia terminada (prueba con screenshot).

Modos de fallo comunes y qué hacer

  • Error de import de flash-attn. Driver más viejo. O bien pip install flash-attn==2.6.3 --no-build-isolation o baja a nvcr.io/nvidia/pytorch:24.08-py3.
  • torch.compile se cuelga en el primer step. Pon mode="default" en vez de "max-autotune". Re-ejecuta.
  • El throughput se reduce a la mitad intermitentemente. Dataloader. Comprueba num_workers=4 y confirma pin_memory=True. El disco NVMe debería sostener 1 GB/s de reads — confirma con iostat.
  • OOM en el step 0. Baja batch_size de 32 a 16, sube grad_accum_steps de 16 a 32. El batch efectivo no cambia.
  • El loss se va a NaN. Casi seguro relacionado con bf16 — comprueba que los pesos maestros fp32 están habilitados. Si sigue NaN, comprueba los datos de entrada por IDs de token fuera del rango de vocab.
  • Instancia desalojada en la hora 17. El watchdog auto-relanza en <2 min. El resume recoge el último checkpoint. Downtime total: ~3 min. Coste de la preemption: $0.

Siguiente: lab/01-scaling-laws-experiment.md.