Skip to content

English · Español

Lab 03 — Deriva en precisión mixta (solo preview; sin entrenar en mp)

Objetivo: medir la deriva numérica fp16 por capa en un forward pass; producir un gráfico que informe la Fase 26.

Tiempo estimado: 45–60 minutos.

Requisito previo: lab 02 hecho (tienes un checkpoint de la Fase 18 recargable).


Lo que produces

  • src/minitrain/mp_preview.py — round-trip de cast fp16 + instrumentación del forward pass.
  • experiments/18-mp-drift/:
  • manifest.json
  • drift_results.json — error relativo por capa, conteo de cambios de argmax
  • drift_per_layer.png — gráfico de barras: error relativo de las activaciones en cada capa
  • argmax_flips.md — nota corta sobre qué posiciones cambiaron bajo pesos fp16

Antecedentes que debes haber leído

  • theory/03-mixed-precision-preview.md — fp16 vs fp32 vs bf16, el límite de error relativo \(2^{-10}\), la regla del acumulador.

TODOs

Bloque A — src/minitrain/mp_preview.py

Implementa:

def cast_weights_fp32_to_fp16_back(weights: dict[str, ndarray]) -> dict[str, ndarray]:
    """For each weight tensor: cast to fp16 and back to fp32. Returns the
    fp32 tensor with fp16's rounding shadow applied."""
    return {k: v.astype(np.float16).astype(np.float32) for k, v in weights.items()}
  • Idempotente: un segundo round-trip es un no-op.
  • Salta los tensores enteros de índice del embedding (sin dtype.kind == 'f' → salta).

Bloque B — forward instrumentado

def forward_with_layer_outputs(model, input_ids, attn_mask) -> tuple[ndarray, dict[str, ndarray]]:
    """Returns (logits, {layer_name: activation_at_layer_output})."""
  • Captura las activaciones en:
  • la salida del embedding (embed_out)
  • la salida residual de cada bloque transformer (block_{i}_out para \(i = 0, 1\)n_layers = 2 fijado en la Fase 17)
  • la salida final de LayerNorm (final_ln_out)
  • los logits finales (logits)
  • Total: 4 snapshots de activación + los logits.

Bloque C — ejecuta la comparación

# Load Phase-18 final checkpoint
state = load_checkpoint(phase18_dir)
model_fp32 = build_minigpt(config)
apply_weights(model_fp32, state.model_weights)

# Build a copy with fp16-rounded weights
model_fp16 = build_minigpt(config)
apply_weights(model_fp16, cast_weights_fp32_to_fp16_back(state.model_weights))

# Pick a representative input: a single verb-conjugation prompt
input_ids, attn_mask = tokenize_prompt("yo trabajo / I ___")  # batch of 1, 7 tokens

# Two forward passes
logits_fp32, acts_fp32 = forward_with_layer_outputs(model_fp32, input_ids, attn_mask)
logits_fp16, acts_fp16 = forward_with_layer_outputs(model_fp16, input_ids, attn_mask)

# Per-layer relative error
errors = {}
for name in acts_fp32:
    a, b = acts_fp32[name], acts_fp16[name]
    errors[name] = np.linalg.norm(a - b) / (np.linalg.norm(a) + 1e-12)

# Argmax-flip count
argmax_fp32 = logits_fp32.argmax(axis=-1)
argmax_fp16 = logits_fp16.argmax(axis=-1)
flipped = int((argmax_fp32 != argmax_fp16).sum())
  • Ejecuta sobre 5 prompts representativos (uno por tiempo verbal), promedia los errores por capa.
  • Registra los flips de argmax por prompt.

Bloque D — el gráfico

drift_per_layer.png:

  • eje x: nombre de la capa (embed_out, block_0_out, block_1_out, ..., final_ln_out, logits).
  • eje y: error relativo (||fp16 - fp32|| / ||fp32||), escala logarítmica.
  • Patrón esperado: el error crece de forma monótona con la profundidad, de ~\(10^{-3}\) en embed_out a ~\(10^{-2}\) en logits.

  • Anota el límite teórico \(2^{-10} \approx 10^{-3}\) como línea horizontal discontinua.

  • Si tus errores medidos no crecen monótonamente con la profundidad, algo está mal — investiga antes de seguir.

Bloque E — el informe

argmax_flips.md:

  • Para cada uno de los 5 prompts, lista qué posiciones de argmax (si las hay) cambiaron bajo pesos fp16.
  • Para cada flip, lista el top token fp32, el top token fp16 y el margen del logit fp32.
  • Conclusión: a este tamaño de modelo, la cuantización de pesos fp16 es en su mayoría segura para argmax (≤ N flips de M posiciones a lo largo de 5 prompts), pero los flips marginales en la última posición de cada prompt indican dónde el sampling de la Fase 21 será más sensible.

Bloque F — results.json

{
  "num_prompts": 5,
  "per_layer_relative_error_mean": {
    "embed_out": 0.0009,
    "block_0_out": 0.0014,
    "block_1_out": 0.0025,
    "final_ln_out": 0.0040,
    "logits": 0.0051
  },
  "argmax_flip_count_total": 2,
  "argmax_total_positions": 35,
  "argmax_flip_rate": 0.057,
  "monotonic_growth_with_depth": true,
  "max_error_below_5pct": true
}

Restricciones

  • Sin backward pass. La Fase 18 no entrena en mp. Solo forward.
  • Sin loss scaling. Eso es la Fase 26.
  • NumPy puro. Sin el dtype fp16 de PyTorch.

Condiciones de parada

Hecho cuando:

  1. drift_per_layer.png muestra crecimiento monótono del error por capa.
  2. drift_results.json está commiteado.
  3. argmax_flips.md lista cada posición flipped con su margen de logit.
  4. Puedes formular, en una frase, la tasa esperada de crecimiento del error por capa y el límite teórico.

Escollos

  • Castear incorrectamente el tensor de lookup del embedding. El embedding es (V, d_model) floats; castéalo. Los input_ids son enteros; no los castees.
  • Olvidar que LayerNorm tiene pequeñas estadísticas running. El MiniGPT de la Fase 17 puede o no tener stats running de RMSNorm; si las tiene, trátalas como pesos.
  • Los errores no crecen con la profundidad. Causas posibles: (a) el cast fp16 no se está aplicando a todas las capas, (b) estás calculando el error relativo sobre el residual stream que está dominado por el embedding de entrada (no cambiado), enmascarando errores más profundos. Arregla calculando el error relativo sobre la salida de la capa, no sobre el residual acumulado.
  • Todos los argmaxes cambian. Probablemente un bug; el test implicaría que el modelo fp16 está haciendo algo cualitativamente diferente. Comprueba que los pesos están realmente casteados y no se están pasando silenciosamente como fp32.

Cuándo consultar solutions/

Después de que el gráfico esté commiteado. Solución en solutions/03-mp-drift-ref.md (escrita al abrir la fase).


Siguiente lab: lab/04-mlflow-wiring.md.