English · Español
Lab 03 — Deriva en precisión mixta (solo preview; sin entrenar en mp)¶
Objetivo: medir la deriva numérica fp16 por capa en un forward pass; producir un gráfico que informe la Fase 26.
Tiempo estimado: 45–60 minutos.
Requisito previo: lab 02 hecho (tienes un checkpoint de la Fase 18 recargable).
Lo que produces¶
src/minitrain/mp_preview.py— round-trip de cast fp16 + instrumentación del forward pass.experiments/18-mp-drift/:manifest.jsondrift_results.json— error relativo por capa, conteo de cambios de argmaxdrift_per_layer.png— gráfico de barras: error relativo de las activaciones en cada capaargmax_flips.md— nota corta sobre qué posiciones cambiaron bajo pesos fp16
Antecedentes que debes haber leído¶
theory/03-mixed-precision-preview.md— fp16 vs fp32 vs bf16, el límite de error relativo \(2^{-10}\), la regla del acumulador.
TODOs¶
Bloque A — src/minitrain/mp_preview.py¶
Implementa:
def cast_weights_fp32_to_fp16_back(weights: dict[str, ndarray]) -> dict[str, ndarray]:
"""For each weight tensor: cast to fp16 and back to fp32. Returns the
fp32 tensor with fp16's rounding shadow applied."""
return {k: v.astype(np.float16).astype(np.float32) for k, v in weights.items()}
- Idempotente: un segundo round-trip es un no-op.
- Salta los tensores enteros de índice del embedding (sin
dtype.kind == 'f'→ salta).
Bloque B — forward instrumentado¶
def forward_with_layer_outputs(model, input_ids, attn_mask) -> tuple[ndarray, dict[str, ndarray]]:
"""Returns (logits, {layer_name: activation_at_layer_output})."""
- Captura las activaciones en:
- la salida del embedding (
embed_out) - la salida residual de cada bloque transformer (
block_{i}_outpara \(i = 0, 1\) —n_layers = 2fijado en la Fase 17) - la salida final de LayerNorm (
final_ln_out) - los logits finales (
logits) - Total: 4 snapshots de activación + los logits.
Bloque C — ejecuta la comparación¶
# Load Phase-18 final checkpoint
state = load_checkpoint(phase18_dir)
model_fp32 = build_minigpt(config)
apply_weights(model_fp32, state.model_weights)
# Build a copy with fp16-rounded weights
model_fp16 = build_minigpt(config)
apply_weights(model_fp16, cast_weights_fp32_to_fp16_back(state.model_weights))
# Pick a representative input: a single verb-conjugation prompt
input_ids, attn_mask = tokenize_prompt("yo trabajo / I ___") # batch of 1, 7 tokens
# Two forward passes
logits_fp32, acts_fp32 = forward_with_layer_outputs(model_fp32, input_ids, attn_mask)
logits_fp16, acts_fp16 = forward_with_layer_outputs(model_fp16, input_ids, attn_mask)
# Per-layer relative error
errors = {}
for name in acts_fp32:
a, b = acts_fp32[name], acts_fp16[name]
errors[name] = np.linalg.norm(a - b) / (np.linalg.norm(a) + 1e-12)
# Argmax-flip count
argmax_fp32 = logits_fp32.argmax(axis=-1)
argmax_fp16 = logits_fp16.argmax(axis=-1)
flipped = int((argmax_fp32 != argmax_fp16).sum())
- Ejecuta sobre 5 prompts representativos (uno por tiempo verbal), promedia los errores por capa.
- Registra los flips de argmax por prompt.
Bloque D — el gráfico¶
drift_per_layer.png:
- eje x: nombre de la capa (
embed_out,block_0_out,block_1_out, ...,final_ln_out,logits). - eje y: error relativo (
||fp16 - fp32|| / ||fp32||), escala logarítmica. -
Patrón esperado: el error crece de forma monótona con la profundidad, de ~\(10^{-3}\) en
embed_outa ~\(10^{-2}\) enlogits. -
Anota el límite teórico \(2^{-10} \approx 10^{-3}\) como línea horizontal discontinua.
- Si tus errores medidos no crecen monótonamente con la profundidad, algo está mal — investiga antes de seguir.
Bloque E — el informe¶
argmax_flips.md:
- Para cada uno de los 5 prompts, lista qué posiciones de argmax (si las hay) cambiaron bajo pesos fp16.
- Para cada flip, lista el top token fp32, el top token fp16 y el margen del logit fp32.
- Conclusión: a este tamaño de modelo, la cuantización de pesos fp16 es en su mayoría segura para argmax (≤ N flips de M posiciones a lo largo de 5 prompts), pero los flips marginales en la última posición de cada prompt indican dónde el sampling de la Fase 21 será más sensible.
Bloque F — results.json¶
{
"num_prompts": 5,
"per_layer_relative_error_mean": {
"embed_out": 0.0009,
"block_0_out": 0.0014,
"block_1_out": 0.0025,
"final_ln_out": 0.0040,
"logits": 0.0051
},
"argmax_flip_count_total": 2,
"argmax_total_positions": 35,
"argmax_flip_rate": 0.057,
"monotonic_growth_with_depth": true,
"max_error_below_5pct": true
}
Restricciones¶
- Sin backward pass. La Fase 18 no entrena en mp. Solo forward.
- Sin loss scaling. Eso es la Fase 26.
- NumPy puro. Sin el dtype fp16 de PyTorch.
Condiciones de parada¶
Hecho cuando:
drift_per_layer.pngmuestra crecimiento monótono del error por capa.drift_results.jsonestá commiteado.argmax_flips.mdlista cada posición flipped con su margen de logit.- Puedes formular, en una frase, la tasa esperada de crecimiento del error por capa y el límite teórico.
Escollos¶
- Castear incorrectamente el tensor de lookup del embedding. El embedding es
(V, d_model)floats; castéalo. Los input_ids son enteros; no los castees. - Olvidar que LayerNorm tiene pequeñas estadísticas running. El MiniGPT de la Fase 17 puede o no tener stats running de RMSNorm; si las tiene, trátalas como pesos.
- Los errores no crecen con la profundidad. Causas posibles: (a) el cast fp16 no se está aplicando a todas las capas, (b) estás calculando el error relativo sobre el residual stream que está dominado por el embedding de entrada (no cambiado), enmascarando errores más profundos. Arregla calculando el error relativo sobre la salida de la capa, no sobre el residual acumulado.
- Todos los argmaxes cambian. Probablemente un bug; el test implicaría que el modelo fp16 está haciendo algo cualitativamente diferente. Comprueba que los pesos están realmente casteados y no se están pasando silenciosamente como fp32.
Cuándo consultar solutions/¶
Después de que el gráfico esté commiteado. Solución en solutions/03-mp-drift-ref.md (escrita al abrir la fase).
Siguiente lab: lab/04-mlflow-wiring.md.