English · Español
Lab 03 — Mixed-precision drift (preview only; no training in mp)¶
Goal: measure per-layer fp16 numerical drift on a forward pass; produce a plot that informs Phase 26.
Estimated time: 45–60 minutes.
Prereq: lab 02 done (you have a reloadable Phase-18 checkpoint).
What you produce¶
src/minitrain/mp_preview.py— fp16 round-trip cast + forward-pass instrumentation.experiments/18-mp-drift/:manifest.jsondrift_results.json— per-layer relative error, argmax-change countdrift_per_layer.png— bar chart: relative error of activations at each layerargmax_flips.md— short note on which positions flipped under fp16 weights
Background you must have read¶
theory/03-mixed-precision-preview.md— fp16 vs fp32 vs bf16, the \(2^{-10}\) relative error bound, the accumulator rule.
TODOs¶
Block A — src/minitrain/mp_preview.py¶
Implement:
def cast_weights_fp32_to_fp16_back(weights: dict[str, ndarray]) -> dict[str, ndarray]:
"""For each weight tensor: cast to fp16 and back to fp32. Returns the
fp32 tensor with fp16's rounding shadow applied."""
return {k: v.astype(np.float16).astype(np.float32) for k, v in weights.items()}
- Idempotent: a second round-trip is a no-op.
- Skip embedding's integer index tensors (no
dtype.kind == 'f'→ skip).
Block B — instrumented forward¶
def forward_with_layer_outputs(model, input_ids, attn_mask) -> tuple[ndarray, dict[str, ndarray]]:
"""Returns (logits, {layer_name: activation_at_layer_output})."""
- Capture activations at:
- the embedding output (
embed_out) - each transformer block's residual output (
block_{i}_outfor \(i = 0, 1\) — Phase 17's lockedn_layers = 2) - the final LayerNorm output (
final_ln_out) - the final logits (
logits) - Total: 4 activation snapshots + the logits.
Block C — run the comparison¶
# Load Phase-18 final checkpoint
state = load_checkpoint(phase18_dir)
model_fp32 = build_minigpt(config)
apply_weights(model_fp32, state.model_weights)
# Build a copy with fp16-rounded weights
model_fp16 = build_minigpt(config)
apply_weights(model_fp16, cast_weights_fp32_to_fp16_back(state.model_weights))
# Pick a representative input: a single verb-conjugation prompt
input_ids, attn_mask = tokenize_prompt("yo trabajo / I ___") # batch of 1, 7 tokens
# Two forward passes
logits_fp32, acts_fp32 = forward_with_layer_outputs(model_fp32, input_ids, attn_mask)
logits_fp16, acts_fp16 = forward_with_layer_outputs(model_fp16, input_ids, attn_mask)
# Per-layer relative error
errors = {}
for name in acts_fp32:
a, b = acts_fp32[name], acts_fp16[name]
errors[name] = np.linalg.norm(a - b) / (np.linalg.norm(a) + 1e-12)
# Argmax-flip count
argmax_fp32 = logits_fp32.argmax(axis=-1)
argmax_fp16 = logits_fp16.argmax(axis=-1)
flipped = int((argmax_fp32 != argmax_fp16).sum())
- Run on 5 representative prompts (one per tense), average the per-layer errors.
- Record argmax flips per prompt.
Block D — the plot¶
drift_per_layer.png:
- x-axis: layer name (
embed_out,block_0_out,block_1_out, ...,final_ln_out,logits). - y-axis: relative error (
||fp16 - fp32|| / ||fp32||), log scale. -
Expected pattern: error grows monotonically with depth, from ~\(10^{-3}\) at
embed_outto ~\(10^{-2}\) atlogits. -
Annotate the theoretical bound \(2^{-10} \approx 10^{-3}\) as a horizontal dashed line.
- If your measured errors don't grow monotonically with depth, something is wrong — investigate before moving on.
Block E — the report¶
argmax_flips.md:
- For each of the 5 prompts, list which (if any) argmax positions flipped under fp16 weights.
- For each flip, list the fp32 top token, the fp16 top token, and the fp32 logit margin.
- Conclude: at this model size, fp16 weight quantization is mostly safe for argmax (≤ N flips out of M positions across 5 prompts), but the marginal flips at the last position of each prompt indicate where Phase 21's sampling will be most sensitive.
Block F — results.json¶
{
"num_prompts": 5,
"per_layer_relative_error_mean": {
"embed_out": 0.0009,
"block_0_out": 0.0014,
"block_1_out": 0.0025,
"final_ln_out": 0.0040,
"logits": 0.0051
},
"argmax_flip_count_total": 2,
"argmax_total_positions": 35,
"argmax_flip_rate": 0.057,
"monotonic_growth_with_depth": true,
"max_error_below_5pct": true
}
Constraints¶
- No backward pass. Phase 18 does not train in mp. Forward-only.
- No loss scaling. That's Phase 26.
- Pure NumPy. No PyTorch fp16 dtype.
Stop conditions¶
Done when:
drift_per_layer.pngshows monotonic per-layer error growth.drift_results.jsonis committed.argmax_flips.mdlists every flipped position with its logit margin.- You can state, in one sentence, the expected per-layer error growth rate and the theoretical bound.
Pitfalls¶
- Casting the embedding lookup tensor incorrectly. The embedding is
(V, d_model)floats; cast it. The input_ids are integers; do not cast them. - Forgetting that LayerNorm has small running stats. Phase 17's MiniGPT may or may not have RMSNorm running stats; if it does, treat them like weights.
- Errors not growing with depth. Possible causes: (a) the fp16 cast isn't being applied to all layers, (b) you're computing relative error on the residual stream which is dominated by the (unchanged) input embedding, masking deeper errors. Fix by computing relative error on the layer's output, not its accumulated residual.
- All argmaxes flip. Probably a bug; the test would imply the fp16 model is doing something qualitatively different. Check that weights are actually cast and not silently passed through fp32.
When to consult solutions/¶
After the plot is committed. Solution at solutions/03-mp-drift-ref.md (written at phase open).
Next lab: lab/04-mlflow-wiring.md.