English · Español
Lab 02 — Corrección: con-caché igual a sin-caché, byte a byte¶
Objetivo: demostrar que
generate(prompt, cache=True)produce los mismos tokens exactos quegenerate(prompt, cache=False)para prompts y semillas arbitrarios. Los bugs sutiles de caché son silenciosos; solo un test de igualdad exacta los hace aflorar.Tiempo estimado: 2–4 horas.
Prerrequisito:
lab/01-implement-cache.mdcompleto.src/miniinfer/generate.pyde la Fase 21 en su sitio.
Lo que produces¶
Un directorio experiments/22-cache-correctness/ que contenga:
property_test.py— tu test runner de propiedad.results.json— pass/fail por prompt, paso de divergencia si lo hay.manifest.json.README.md— 2–3 párrafos. Si algún test falló, documenta el bug que encontraste y cómo lo arreglaste.
Un segundo directorio experiments/22-yesterday-worked/ que contenga el volcado emblemático a nivel de slot:
dump.py— script que prefilla"Yesterday I", decodifica un token, luego por separado ejecuta un recálculo completo sobre"Yesterday I worked"(o la forma de pasado simple que el modelo emitiera), y vuelca la fila K y la fila V para el slot de la posición de"I"desde ambas ejecuciones.slots.npz— las filas K, V volcadas de ambos caminos.report.md— afirmación: cada byte de la fila"I"del camino cacheado es igual a cada byte de la fila"I"del camino recalculado, para K y para V, para cada capa y cabeza.manifest.json.
La propiedad¶
Para un modelo fijo (MiniGPT, Fase 17) y una semilla de sampling fija, lo siguiente debe cumplirse:
seed_everything(42)
tokens_cached = generate(prompt, max_new_tokens=64, cache=True)
seed_everything(42)
tokens_uncached = generate(prompt, max_new_tokens=64, cache=False)
assert tokens_cached == tokens_uncached # byte-identical token sequence
Para 50 prompts distintos muestreados de una distribución fija (defínela en tu property_test.py).
Nota de determinismo: seed_everything debe re-aplicarse antes de cada camino porque el sampling consume el RNG. Si cache=True llama al modelo menos veces (lo hace — ese es el punto), el estado del RNG diverge salvo que se reinicie. Esta es la fuente más común de reportes falsos positivos de "bug de corrección"; construye el test para manejarla desde el principio.
TODOs¶
Bloque A — escribir el test runner¶
- Cargar los pesos del MiniGPT de la Fase 17 una vez. El modelo está entrenado en el corpus de gramática de verbos §A13; sus tokens son formas verbales en inglés (y español).
- Muestrear 50 prompts: elegir de la distribución natural del corpus de gramática. Mezcla sugerida: (a) 20 prompts de la forma
"<adverbial-de-tiempo> <pronombre>"(p. ej."Yesterday I","Tomorrow he","Now you"), (b) 20 prompts de longitud 4–8 que sean frases parciales válidas (p. ej."I am going to"), © 10 prompts más largos que mezclen tiempos para forzar el enmascaramiento causal. Pon semilla al muestreador de prompts — semilla distinta de la de generación. - Para cada prompt:
seed_everything(gen_seed_for_this_prompt)t_cached = generate(prompt, max_new_tokens=64, cache=True)seed_everything(gen_seed_for_this_prompt)t_uncached = generate(prompt, max_new_tokens=64, cache=False)- Si
t_cached != t_uncached: registrar el primer índice de divergencia. - Cuenta pass/fail. Escribe
results.json.
Bloque B — interpretar fallos¶
Si algún prompt diverge, el test solo te dice dónde (índice de token) pero no por qué. Tu trabajo en este bloque:
- Re-ejecutar ese prompt con
cache=Truey volcar las salidas de atención por capa en el paso de divergencia. - Re-ejecutar el mismo prompt con
cache=False, mismos volcados. - Comparar: encontrar la capa (¿y cabeza?) donde difieren primero.
- Trazarlo de vuelta al código del caché. Culpables comunes:
- Off-by-one del cursor (guardar K, V del token actual antes de calcular la atención).
- Forma de máscara incorrecta para decode
q_len=1(no hace falta máscara, pero el código de la Fase 15 podría seguir aplicando una). - Índice de capa intercambiado (usar
cache.read(layer=0)en todas partes). - Desajuste de dtype (caché guardado en fp32, pero las lecturas hacen cast a fp64 a mitad de atención).
Documenta el bug + fix en README.md.
Bloque C — extender el test¶
Una vez pasen los 50 prompts:
- Probar una generación más larga: 256 tokens nuevos. ¿Sigue siendo byte a byte idéntico? (Nota: el modelo entrenado con vocabulario de 600 formas empezará a ciclar / repetir bastante antes de 256 tokens — eso está bien. La propiedad de equivalencia es lo que se testea.)
- Probar
batch=4secuencias paralelas. Cada una debe independientemente producir lo mismo con/sin caché. (Esto caza bugs de dim-batch que los tests de un solo stream se pierden.) - Probar
q_len > 1(reanudación de prefill multi-token). Caso límite: si alguna vez haces "warm-start decode desde un prompt largo + 5 tokens", ¿el camino de prefill usa el caché correctamente?
Bloque C-emblemático — el volcado a nivel de slot de "Yesterday I worked"¶
Este es el artefacto humano-visible que une §A13 con la mecánica del KV cache. Prodúcelo en experiments/22-yesterday-worked/:
- Ejecutar camino A:
prefill("Yesterday I")puebla el caché para los slots 0 y 1. Decodificar un token nuevo; registrar cuál fue (probablemente"worked"/"played"/ etc.). Guardarcache.read(layer=ℓ)[..., :2, :]para cada capa. - Ejecutar camino B: desde cero, correr el modelo sobre la secuencia completa de 3 tokens
"Yesterday I <decoded_token>", sin caché, tomando las proyecciones K y V en las posiciones 0 y 1. - Afirmar: la fila K del slot 1 del camino A == la fila K de la posición 1 del camino B, byte a byte idéntica. Igual para V. Igual para el slot 0. Repetir para todas las capas.
- Si algún byte difiere: es una fuga de codificación posicional (fase RoPE incorrecta en el camino de decode), o un bug de orden de layer-norm, o un off-by-one del cursor. El volcado localiza el bug a una tripla (capa, slot, cabeza).
- Guardar los volcados de K, V en
slots.npz. Escribir unreport.mdde 1 página.
Bloque D — manifest¶
{
"experiment": "22-cache-correctness",
"date": "YYYY-MM-DD",
"seed_prompt_sampler": 1,
"seed_generation_per_prompt": "deterministic_from_prompt_idx",
"versions": {"python": "3.11.x", "numpy": "X.Y.Z"},
"config": {
"model": "miniGPT-phase17",
"n_prompts": 50,
"prompt_len_range": [8, 32],
"max_new_tokens": 64,
"batch_size": 1
},
"results_summary": {
"passed": null,
"failed": null,
"first_divergence_step_min": null,
"first_divergence_step_max": null
}
}
Restricciones¶
- Sin fuzz. Los tests deben ser deterministas. Misma semilla → mismos prompts → mismas salidas.
- Sin
try/exceptpara "saltar fallos". Cada divergencia es un bug. Hazlos aflorar todos. - Reiniciar el RNG entre caminos. Como se notó arriba.
Condiciones de parada¶
Hecho cuando:
- 50/50 prompts pasan byte a byte idénticamente sobre 64 tokens nuevos, un solo stream.
- Los tests extendidos (256 tokens, batch=4) pasan.
manifest.jsoncommiteado conpassed: 50, failed: 0.README.mddocumenta o bien "no bugs encontrados" o el bug + fix.
Si algún test sigue fallando tras 4 horas de depuración, escribe el síntoma y para en /phase-checkpoint — no machaques.
Escollos (leer antes de depurar)¶
- "Off by one en el token 1". Casi siempre: guardar K, V para el token actual antes de calcular la atención, así que atiende a sí mismo con fuerza completa. Añade K, V después del cálculo de atención (o usa una máscara que excluya la fila actual — pero entonces el caché tiene bytes muertos; mejor añade después).
- "Off by hundreds en el token 30". Deriva lenta — error numérico acumulado de aritmética fp no asociativa. Aceptable hasta ~1e-6, pero causa una divergencia eventualmente cuando el sampling cruza un límite de token. O bien: (a) iguala el orden exacto de operaciones en los caminos con caché vs sin caché, o (b) acepta la divergencia en horizontes largos y documenta la cota.
- "Diverge solo con batch>1". La forma de
cache.read()de la capa es(B, H, S, d_h). El broadcasting en el matmul es implacable; doble check de los ejes.
Cuándo consultar solutions/¶
Después de que pasen 50/50. La referencia en solutions/02-correctness-test-ref.md documenta los bugs canónicos encontrados durante el desarrollo de la implementación de referencia.
Siguiente lab: lab/03-cost-curves.md.