English · Español
Lab 00 — One-day cloud pretraining¶
🇪🇸 La corrida real. 1× A100 80GB durante ~24 h en Lambda o RunPod, 50M parámetros, ~5B tokens de FineWeb-Edu, presupuesto duro $35. El objetivo es reproducibilidad y números medidos, no batir un benchmark. Una sola configuración derivada de la teoría, una sola corrida, post-mortem si se rompe.
Goal¶
Produce one reproducible pretraining run on a real cloud GPU. Train a 50M-parameter decoder-only transformer for ~24 hours on a single A100 80GB. Predicted final val-loss: ~3.35 ± 0.2 nats/token (from the Chinchilla fit in theory/01-scaling-laws.md).
Prerequisites checklist¶
- Phase 18 training loop shipped (
src/minitrain/loop.py). The X1 trainer is its big sibling. -
mlflowworking locally. The X1 trainer logs every 10 steps. -
safetensorsavailable for checkpoint I/O. -
src/distributed/budget_guard.pyfrom Phase 35 importable. Required to launch. - A Lambda Labs or RunPod account, billing card on file, alerts at $30 set.
- SSH key registered with the chosen cloud.
Hard budget¶
| Line | $-cost | Notes |
|---|---|---|
| 1× A100 80GB spot, Lambda, ~26 h × $1.10/hr | $28.60 | Primary compute |
| Storage egress (data download) | $1 | FineWeb-Edu sample from HF Hub |
| Persistent-disk attached | $2 | 200 GB × 24 h |
| Buffer (1 restart) | $5 | If spike forces a re-launch |
| Ceiling | $35 | budget_guard.py refuses if exceeded |
If actual spot price is >$1.40/hr at launch time, do not launch. Wait 4 hours and re-check. RunPod community at $0.79/hr is an acceptable fallback.
The cluster recipe¶
Provider: Lambda Labs (primary) or RunPod (fallback)¶
- Why Lambda primary: simpler billing, good A100 80GB availability mid-week, hardware is consistent.
- Why RunPod fallback: ~30% cheaper spot, but the community-host hardware varies; the trainer should not care, but it sometimes does (older drivers, slower NVMe).
Instance specs¶
- GPU: 1× A100 80GB (SXM4 preferred; PCIe accepted with 5-10% throughput hit).
- vCPUs: ≥16.
- RAM: ≥64 GB (we hold a tokenized dataset partly in RAM).
- Disk: 500 GB NVMe attached.
- OS: Ubuntu 22.04 (Lambda default) or 20.04 (RunPod default — both work).
Docker image¶
Use the official NVIDIA PyTorch image, which ships FlashAttention-2, torch.compile-friendly Triton, and CUDA 12.x:
This image is ~14 GB, pulls in ~3-5 minutes on cloud. It includes: - PyTorch 2.5 with CUDA 12.6 - FlashAttention-2 (pre-built for SM80 / A100, SM90 / H100) - Triton - TransformerEngine (we don't use it, but it's there) - NCCL, cuDNN, cuBLAS
Pin the tag explicitly. latest will bite you when the upstream changes.
One-time host setup (~10 min)¶
# On the cloud host, after SSH:
docker pull nvcr.io/nvidia/pytorch:24.10-py3
# Verify GPU visible:
nvidia-smi
# Expected: 1× A100-SXM4-80GB or A100-PCIE-80GB
# Pre-create persistent dirs:
mkdir -p /workspace/{data,checkpoints,mlruns,logs}
The dataset: FineWeb-Edu sample¶
We use FineWeb-Edu (Penedo 2024) — the LLM-classifier-filtered subset of CommonCrawl described in theory/03.
Option A (preferred): sample-10BT subset from HF¶
# Inside the running container
huggingface-cli login # paste read-only HF token
mkdir -p /workspace/data/fineweb-edu
cd /workspace/data/fineweb-edu
# Pull the 10B-token sample shard set (~22 GB on disk)
huggingface-cli download \
HuggingFaceFW/fineweb-edu \
--repo-type dataset \
--include "sample/10BT/*.parquet" \
--local-dir .
# Verify count of shards (~96 files of ~230 MB each)
ls sample/10BT/ | wc -l
Option B (fallback if HF rate-limited): Pile-CC slice¶
# EleutherAI Pile, CommonCrawl subset; same shape, lower quality
huggingface-cli download \
monology/pile-uncopyrighted \
--include "test/*.jsonl.zst" \
--local-dir /workspace/data/pile-cc
Tokenize: GPT-2 BPE, save as uint16 binary¶
The lab uses the nanoGPT format (train.bin, val.bin as flat uint16 arrays).
cd /workspace
python -m x1_pretrain.tokenize_data \
--input-dir /workspace/data/fineweb-edu/sample/10BT \
--output-dir /workspace/data/tokenized \
--tokenizer gpt2 \
--val-fraction 0.001 \
--workers 16
Expected: ~10B tokens × 2 bytes = ~20 GB on disk. Wall time: ~30 min on 16 vCPUs.
budget_guard.py snapshot: tokenization is CPU-only; no GPU billed. Cost so far: ~$0.50 (CPU-time during this step is bundled into hourly).
The model: 50M-param decoder transformer¶
| HP | Value | Justification |
|---|---|---|
n_layer |
8 | Mid-2024 default for 50M-class |
d_model |
768 | Matches HF gpt2-base width |
n_head |
12 | Head dim 64, standard |
n_kv_head |
4 | GQA-3 (saves K/V cache, 2024 default) |
d_ff |
2048 | ~2.7× d_model (SwiGLU formula) |
seq_len |
1024 | Cheap; doubles to 2048 only at >100M |
vocab_size |
50,257 | GPT-2 BPE |
norm |
RMSNorm | Modern default |
act |
SwiGLU | Modern default |
pos_enc |
RoPE base=10000 | Modern default |
init_std |
0.02 | GPT-2 / Llama default |
| Total params (non-embed) | ~50M | computed: 8 × 12 × 768² × 4 ≈ 226M; minus embeddings; net ~50M |
(The 50M figure excludes vocab embeddings. With embeddings counted, total ~89M. We quote the non-embedding figure to match Chinchilla convention.)
Optimizer + schedule¶
| HP | Value | Source |
|---|---|---|
| Optimizer | AdamW | Loshchilov 2017 |
| LR (peak) | 3e-4 | Llama-2 7B uses 3e-4; we follow |
| LR schedule | cosine, warmup 1000 steps, decay to 3e-5 | Standard |
| β₁ | 0.9 | Default |
| β₂ | 0.95 | Lowered from 0.999; spike resistance (PaLM) |
| weight_decay | 0.1 | Modern default |
| grad_clip | 1.0 | Universal |
| Effective batch tokens | 512K = 512 batch × 1024 seq | Tuned to fit A100 80GB |
| Precision | bf16 mixed (fp32 master) | A100 native, no loss-scaling |
Compile: torch.compile(model, mode="max-autotune").
Attention kernel: flash_attn_func from flash-attn package.
Expected throughput¶
- Sustained TFLOP/s: ~150 (MFU 0.48 of A100's 312 bf16 peak).
- Tokens/s: \(1.5 \times 10^{14} / (6 \times 5 \times 10^7) \approx 500{,}000\).
- Steps/s: \(500{,}000 / 512{,}000 \approx 0.98\) → ~3,500 steps/hour.
- Total steps in 24 h: ~84,000.
- Total tokens: ~43B.
If observed throughput at hour 1 is < 300k tokens/s (MFU < 0.30), stop and diagnose. Common causes: torch.compile failed to fuse (check warnings), FlashAttention-2 not detected (check nvidia-smi for SM utilization), dataloader CPU-bound (check iostat).
Training command¶
# Inside container, with all the data and code mounted
cd /workspace
python -m x1_pretrain.train \
--config configs/x1-50m-a100.yaml \
--data-dir /workspace/data/tokenized \
--ckpt-dir /workspace/checkpoints \
--mlflow-uri file:///workspace/mlruns \
--total-steps 84000 \
--ckpt-every 1800 \
--eval-every 1800 \
--log-every 10 \
--resume-from-latest \
--budget-cap-usd 35.0 \
--budget-curr-cost-uri /workspace/budget.json
The flag --budget-cap-usd 35.0 invokes budget_guard.py from Phase 35 in periodic-check mode: every 30 minutes it consults /workspace/budget.json (kept current by a sidecar process polling the cloud billing API or its proxy), and if the projected total exceeds $35, it forces a graceful checkpoint-and-exit.
Expected loss curve¶
Reproducible bf16 + seed=42 on this config should produce:
| Hour | Steps | Tokens (cumulative) | Train loss | Val loss |
|---|---|---|---|---|
| 0 | 0 | 0 | ~10.8 | — |
| 1 | 3.5k | 1.8B | 5.1 | 5.2 |
| 6 | 21k | 11B | 3.9 | 3.95 |
| 12 | 42k | 21B | 3.5 | 3.55 |
| 18 | 63k | 32B | 3.4 | 3.42 |
| 24 | 84k | 43B | 3.32 | 3.35 |
The 3.35 final val loss matches the Hoffmann fit prediction in theory/01-scaling-laws.md to within 0.05 nats. DoD check 1 is met if final val loss is within [3.20, 3.50].
Outside that band → consult theory/04, check for spikes, write the post-mortem.
Logs to record (DoD check 1)¶
mlflow automatically logs the metrics list below every 10 steps. After the run, dump:
python -m x1_pretrain.export_run \
--mlflow-uri file:///workspace/mlruns \
--output experiments/X1-pretraining/run-cloud/
Produces:
- manifest.json — seed, versions (torch, flash-attn, transformers, numpy), config YAML hash, cluster spec, total $-spent.
- metrics.csv — long-format with columns (step, name, value).
- loss-curve.png — train + val loss vs steps.
- gradnorm-curve.png — grad-norm and param-norm vs steps.
- throughput.png — tokens/s vs wall-clock hour.
- final.safetensors — the last checkpoint, in safetensors format.
- mlflow-run-uri.txt — for cross-reference.
Loss-spike injection (DoD check 3)¶
The training script accepts --inject-spike-at-step N which:
- At step N, replaces the next 5 batches with a "rare-token" synthetic batch (sequences sampled from the bottom-1% of the unigram distribution, weighted to produce a high cross-entropy loss).
- Logs the injection clearly in
mlflow.
For lab 00, run with --inject-spike-at-step 12000 (~3.5 hours in). Observe the response (grad clip should catch it; β₂=0.95 should dampen quickly), then write the post-mortem in spike-postmortem.md.
If a real spike happens naturally before step 12000, write that post-mortem instead and skip the injection.
Watchdog and budget alarm¶
A sidecar process polls the Lambda / RunPod billing API every 5 min and writes to /workspace/budget.json:
{
"spent_usd": 14.20,
"rate_usd_per_hr": 1.10,
"hours_elapsed": 12.9,
"projected_total_usd": 28.4,
"last_update_utc": "2026-05-23T15:42:01Z"
}
The trainer reads this every 30 min and:
- if spent_usd > 30: email + Slack ping, no action yet.
- if projected_total_usd > 35: graceful checkpoint and exit.
- if spent_usd > 35: immediate exit, no checkpoint write (already over).
This is the contract with budget_guard.py. Test it before launch (--dry-run-budget flag).
Shutdown checklist¶
- Final checkpoint saved (
final.safetensors+final-optimizer.pt). -
mlflowexport toexperiments/X1-pretraining/run-cloud/done. - Final
budget.jsonconfirms spend ≤ $35. - Terminate the instance. A forgotten A100 at $1.10/hr is $26/day.
- Confirm termination in the provider dashboard (screenshot saved to
experiments/...). - Detach + delete the persistent disk if not needed for lab 01 (it is — keep it for ~3 days).
DoD checks (this lab)¶
experiments/X1-pretraining/run-cloud/manifest.jsonexists, contains seed/versions/cluster/$-spent.- Final val loss in [3.20, 3.50].
final.safetensorsexists and reloads byte-equivalently (round-trip test).- Total $-spent ≤ $35.
spike-postmortem.mdwritten (real or injected spike).- Instance terminated (screenshot proof).
Common failure modes and what to do¶
flash-attnimport error. Older driver. Eitherpip install flash-attn==2.6.3 --no-build-isolationor downgrade tonvcr.io/nvidia/pytorch:24.08-py3.torch.compilehangs at first step. Setmode="default"instead of"max-autotune". Re-run.- Throughput halves intermittently. Dataloader. Check
num_workers=4and confirmpin_memory=True. NVMe disk should sustain 1 GB/s reads — confirm withiostat. - OOM at step 0. Lower
batch_sizefrom 32 to 16, raisegrad_accum_stepsfrom 16 to 32. Effective batch is unchanged. - Loss goes NaN. Almost certainly bf16-related — check that fp32 master weights are enabled. If still NaN, check input data for token IDs out of vocab range.
- Instance pre-empted at hour 17. Watchdog auto-relaunches in <2 min. Resume picks up last checkpoint. Total downtime: ~3 min. Cost of preemption: $0.
Next: lab/01-scaling-laws-experiment.md.