Skip to content

English · Español

Lab 02 — Corrección: con-caché igual a sin-caché, byte a byte

Objetivo: demostrar que generate(prompt, cache=True) produce los mismos tokens exactos que generate(prompt, cache=False) para prompts y semillas arbitrarios. Los bugs sutiles de caché son silenciosos; solo un test de igualdad exacta los hace aflorar.

Tiempo estimado: 2–4 horas.

Prerrequisito: lab/01-implement-cache.md completo. src/miniinfer/generate.py de la Fase 21 en su sitio.


Lo que produces

Un directorio experiments/22-cache-correctness/ que contenga:

  • property_test.py — tu test runner de propiedad.
  • results.json — pass/fail por prompt, paso de divergencia si lo hay.
  • manifest.json.
  • README.md — 2–3 párrafos. Si algún test falló, documenta el bug que encontraste y cómo lo arreglaste.

Un segundo directorio experiments/22-yesterday-worked/ que contenga el volcado emblemático a nivel de slot:

  • dump.py — script que prefilla "Yesterday I", decodifica un token, luego por separado ejecuta un recálculo completo sobre "Yesterday I worked" (o la forma de pasado simple que el modelo emitiera), y vuelca la fila K y la fila V para el slot de la posición de "I" desde ambas ejecuciones.
  • slots.npz — las filas K, V volcadas de ambos caminos.
  • report.md — afirmación: cada byte de la fila "I" del camino cacheado es igual a cada byte de la fila "I" del camino recalculado, para K y para V, para cada capa y cabeza.
  • manifest.json.

La propiedad

Para un modelo fijo (MiniGPT, Fase 17) y una semilla de sampling fija, lo siguiente debe cumplirse:

seed_everything(42)
tokens_cached = generate(prompt, max_new_tokens=64, cache=True)

seed_everything(42)
tokens_uncached = generate(prompt, max_new_tokens=64, cache=False)

assert tokens_cached == tokens_uncached  # byte-identical token sequence

Para 50 prompts distintos muestreados de una distribución fija (defínela en tu property_test.py).

Nota de determinismo: seed_everything debe re-aplicarse antes de cada camino porque el sampling consume el RNG. Si cache=True llama al modelo menos veces (lo hace — ese es el punto), el estado del RNG diverge salvo que se reinicie. Esta es la fuente más común de reportes falsos positivos de "bug de corrección"; construye el test para manejarla desde el principio.

TODOs

Bloque A — escribir el test runner

  • Cargar los pesos del MiniGPT de la Fase 17 una vez. El modelo está entrenado en el corpus de gramática de verbos §A13; sus tokens son formas verbales en inglés (y español).
  • Muestrear 50 prompts: elegir de la distribución natural del corpus de gramática. Mezcla sugerida: (a) 20 prompts de la forma "<adverbial-de-tiempo> <pronombre>" (p. ej. "Yesterday I", "Tomorrow he", "Now you"), (b) 20 prompts de longitud 4–8 que sean frases parciales válidas (p. ej. "I am going to"), © 10 prompts más largos que mezclen tiempos para forzar el enmascaramiento causal. Pon semilla al muestreador de prompts — semilla distinta de la de generación.
  • Para cada prompt:
  • seed_everything(gen_seed_for_this_prompt)
  • t_cached = generate(prompt, max_new_tokens=64, cache=True)
  • seed_everything(gen_seed_for_this_prompt)
  • t_uncached = generate(prompt, max_new_tokens=64, cache=False)
  • Si t_cached != t_uncached: registrar el primer índice de divergencia.
  • Cuenta pass/fail. Escribe results.json.

Bloque B — interpretar fallos

Si algún prompt diverge, el test solo te dice dónde (índice de token) pero no por qué. Tu trabajo en este bloque:

  1. Re-ejecutar ese prompt con cache=True y volcar las salidas de atención por capa en el paso de divergencia.
  2. Re-ejecutar el mismo prompt con cache=False, mismos volcados.
  3. Comparar: encontrar la capa (¿y cabeza?) donde difieren primero.
  4. Trazarlo de vuelta al código del caché. Culpables comunes:
  5. Off-by-one del cursor (guardar K, V del token actual antes de calcular la atención).
  6. Forma de máscara incorrecta para decode q_len=1 (no hace falta máscara, pero el código de la Fase 15 podría seguir aplicando una).
  7. Índice de capa intercambiado (usar cache.read(layer=0) en todas partes).
  8. Desajuste de dtype (caché guardado en fp32, pero las lecturas hacen cast a fp64 a mitad de atención).

Documenta el bug + fix en README.md.

Bloque C — extender el test

Una vez pasen los 50 prompts:

  • Probar una generación más larga: 256 tokens nuevos. ¿Sigue siendo byte a byte idéntico? (Nota: el modelo entrenado con vocabulario de 600 formas empezará a ciclar / repetir bastante antes de 256 tokens — eso está bien. La propiedad de equivalencia es lo que se testea.)
  • Probar batch=4 secuencias paralelas. Cada una debe independientemente producir lo mismo con/sin caché. (Esto caza bugs de dim-batch que los tests de un solo stream se pierden.)
  • Probar q_len > 1 (reanudación de prefill multi-token). Caso límite: si alguna vez haces "warm-start decode desde un prompt largo + 5 tokens", ¿el camino de prefill usa el caché correctamente?

Bloque C-emblemático — el volcado a nivel de slot de "Yesterday I worked"

Este es el artefacto humano-visible que une §A13 con la mecánica del KV cache. Prodúcelo en experiments/22-yesterday-worked/:

  • Ejecutar camino A: prefill("Yesterday I") puebla el caché para los slots 0 y 1. Decodificar un token nuevo; registrar cuál fue (probablemente "worked" / "played" / etc.). Guardar cache.read(layer=ℓ)[..., :2, :] para cada capa.
  • Ejecutar camino B: desde cero, correr el modelo sobre la secuencia completa de 3 tokens "Yesterday I <decoded_token>", sin caché, tomando las proyecciones K y V en las posiciones 0 y 1.
  • Afirmar: la fila K del slot 1 del camino A == la fila K de la posición 1 del camino B, byte a byte idéntica. Igual para V. Igual para el slot 0. Repetir para todas las capas.
  • Si algún byte difiere: es una fuga de codificación posicional (fase RoPE incorrecta en el camino de decode), o un bug de orden de layer-norm, o un off-by-one del cursor. El volcado localiza el bug a una tripla (capa, slot, cabeza).
  • Guardar los volcados de K, V en slots.npz. Escribir un report.md de 1 página.

Bloque D — manifest

{
  "experiment": "22-cache-correctness",
  "date": "YYYY-MM-DD",
  "seed_prompt_sampler": 1,
  "seed_generation_per_prompt": "deterministic_from_prompt_idx",
  "versions": {"python": "3.11.x", "numpy": "X.Y.Z"},
  "config": {
    "model": "miniGPT-phase17",
    "n_prompts": 50,
    "prompt_len_range": [8, 32],
    "max_new_tokens": 64,
    "batch_size": 1
  },
  "results_summary": {
    "passed": null,
    "failed": null,
    "first_divergence_step_min": null,
    "first_divergence_step_max": null
  }
}

Restricciones

  • Sin fuzz. Los tests deben ser deterministas. Misma semilla → mismos prompts → mismas salidas.
  • Sin try/except para "saltar fallos". Cada divergencia es un bug. Hazlos aflorar todos.
  • Reiniciar el RNG entre caminos. Como se notó arriba.

Condiciones de parada

Hecho cuando:

  1. 50/50 prompts pasan byte a byte idénticamente sobre 64 tokens nuevos, un solo stream.
  2. Los tests extendidos (256 tokens, batch=4) pasan.
  3. manifest.json commiteado con passed: 50, failed: 0.
  4. README.md documenta o bien "no bugs encontrados" o el bug + fix.

Si algún test sigue fallando tras 4 horas de depuración, escribe el síntoma y para en /phase-checkpoint — no machaques.

Escollos (leer antes de depurar)

  • "Off by one en el token 1". Casi siempre: guardar K, V para el token actual antes de calcular la atención, así que atiende a sí mismo con fuerza completa. Añade K, V después del cálculo de atención (o usa una máscara que excluya la fila actual — pero entonces el caché tiene bytes muertos; mejor añade después).
  • "Off by hundreds en el token 30". Deriva lenta — error numérico acumulado de aritmética fp no asociativa. Aceptable hasta ~1e-6, pero causa una divergencia eventualmente cuando el sampling cruza un límite de token. O bien: (a) iguala el orden exacto de operaciones en los caminos con caché vs sin caché, o (b) acepta la divergencia en horizontes largos y documenta la cota.
  • "Diverge solo con batch>1". La forma de cache.read() de la capa es (B, H, S, d_h). El broadcasting en el matmul es implacable; doble check de los ejes.

Cuándo consultar solutions/

Después de que pasen 50/50. La referencia en solutions/02-correctness-test-ref.md documenta los bugs canónicos encontrados durante el desarrollo de la implementación de referencia.


Siguiente lab: lab/03-cost-curves.md.