English · Español
Break 00 — Redondeo INT4 naif al más cercano sin escala por-canal¶
Cuantizamos los pesos de attention a INT4 con
round(w / global_scale)sin escala por canal y sin clipping de outliers. La precisión se desploma — esa caída es la lección.
Este ejercicio /break apunta a la decisión de granularidad de escala en cuantización. El bug es una línea; el fallo es ruidoso y observable en PPL.
Anclas: theory/02-scales-and-zeros.md, theory/03-gptq-and-nf4.md, .claude/commands/break.md.
Hipótesis¶
El learner predice: "Si cuantizo los pesos de attention a INT4 usando una única escala global (max-abs sobre todo el tensor) en vez de por-canal, el rango dinámico colapsa alrededor de las filas con outliers. La mayoría de las filas de pesos resuelven a {-1, 0, +1} sobre un grid de 16 niveles. La PPL explotará."
El break¶
En src/quant/quantize.py, reemplaza la ruta INT8 por-canal con un cast INT4 naif para las matrices W_q, W_k, W_v, W_o de attention:
def quantize_linear(w: Tensor, bits: int = 8, per_channel: bool = True) -> QuantTensor:
if per_channel:
- scales = w.abs().amax(dim=1, keepdim=True) / (2 ** (bits - 1) - 1)
+ # /break: dropped per-channel scale + dropped clipping
+ scales = w.abs().amax() / (2 ** (bits - 1) - 1)
+ scales = scales.expand(w.shape[0], 1)
- q = (w / scales).round().clamp(-(2 ** (bits - 1)), 2 ** (bits - 1) - 1)
+ q = (w / scales).round() # no clamp; rely on int dtype to wrap
else:
...
return QuantTensor(q.to(torch.int8), scales)
Y llámalo con bits=4 desde el wrapper de cuantización de attention:
-Wq_q = quantize_linear(Wq, bits=8, per_channel=True)
+Wq_q = quantize_linear(Wq, bits=4, per_channel=True) # naïvely re-using the (now broken) helper
Edición de dos líneas. Las dos son esenciales — el cambio de escala por-canal→global es el break real; el bits=4 es lo que lo hace observable sobre un modelo diminuto.
Predice, luego ejecuta¶
Un tensor de pesos W ∈ ℝ^(d, d) = ℝ^(64, 64) de Mini-GPT típicamente tiene max |w| ≈ 0.4 y la mayoría de entradas en [-0.05, +0.05]. Con INT4 (16 niveles) y una escala global s = 0.4 / 7 ≈ 0.057:
- Un peso en
w = 0.04mapea around(0.04 / 0.057) = round(0.70) = 1→ valor dequantizado0.057. Error ≈ 0.017 — mayor que el peso original. - Un peso en
w = 0.002mapea around(0.002 / 0.057) = 0→ valor dequantizado0. Pérdida total de información. - Aproximadamente el 70–80% de la matriz de pesos colapsa a 0.
La escala por-canal daría a cada fila de W su propia s_i ≈ max_j |w_ij| / 7, así que la escala "típica" de fila sería ≈ 0.01 y la resolución sería 50× más fina para la fila típica.
Predicciones¶
- PPL final en el eval §A13: > 10× del baseline (p. ej., de 5.2 → 50+).
- Distribución de salida de attention: la mayoría de las salidas de las heads colapsan a un pequeño conjunto de valores distintos (porque la mayoría de los pesos ahora son 0).
- Fallo específico: el modelo emite tokens basura temprano en la generación, a menudo entra en bucle.
np.count_nonzero(q == 0) / q.numel()≈ 0.70.
Escribe tus predicciones en learners/borja/phase-26/notes/breaks.md antes de ejecutar.
Observa¶
Ejecuta la receta de evaluación de Fase 26 con el cuantizador roto:
Diagnósticos a dibujar:
- PPL en el eval set §A13:
fp32vsint4-naive. El gráfico de barras debería tener la barra de int4 literalmente fuera del tope. - Histograma de
q.flatten()para elW_qroto: debería estar fuertemente concentrado en 0. - Muestra tres generaciones con prompt "I work" → esperada "I work today"; observada: probablemente incoherente.
Síntoma que verá Borja¶
- PPL > 50 en el eval set §A13 (vs ~5.2 FP32).
-
60% de los pesos cuantizados == 0.
- Salidas de muestra incoherentes.
- El bucle de entrenamiento no se ve afectado — esto es cuantización post-entrenamiento, así que los pesos son correctos en su origen FP32.
Causa oculta (una frase)¶
Una única escala global max-abs sobre todo el tensor de pesos combinada con INT4 (16 niveles) hace que la resolución por fila sea ~50× más gruesa que el default por-canal; la mayoría de las magnitudes de peso redondean a cero.
Cascada de pistas¶
- Dibuja la distribución de
q.flatten(). ¿Qué fracción es cero? ¿Es plausible para un transformer entrenado? - ¿Cuál es
scales.shape? Traza porquantize_linear— ¿se está calculando la escala por-fila o por-tensor? - Compara tus
scalescon lo quetheory/02-scales-and-zeros.mdderiva como la fórmula por-canal. ¿Dónde difieren?
Diff del fix¶
def quantize_linear(w: Tensor, bits: int = 8, per_channel: bool = True) -> QuantTensor:
if per_channel:
- scales = w.abs().amax() / (2 ** (bits - 1) - 1)
- scales = scales.expand(w.shape[0], 1)
+ scales = w.abs().amax(dim=1, keepdim=True) / (2 ** (bits - 1) - 1)
- q = (w / scales).round()
+ q = (w / scales).round().clamp(-(2 ** (bits - 1)), 2 ** (bits - 1) - 1)
...
Y mantén bits=8 para la ruta de attention — INT4 es para lab/02-quant-curve.md, no el default.
Por qué esto enseña el concepto¶
theory/02-scales-and-zeros.md afirma que las escalas por-canal acotan el error por-fila por un factor proporcional al máximo de la fila. Este break hace esa afirmación load-bearing. Sin escalas por-canal, una fila pesada en outliers envenena a todas las demás filas del tensor. Con por-canal, la cota es local. La lección generaliza: INT8 es lo suficientemente tolerante como para que por-tensor a veces funcione (el paper LLM.int8() explota esto); INT4 no lo es — por-canal es obligatorio. GPTQ, AWQ, NF4 son todas respuestas sofisticadas a la misma pregunta que este break plantea sin rodeos: ¿dónde pones la unidad de escala?
Referencia¶
- Dettmers et al., LLM.int8() (NeurIPS 2022) — la discusión sobre outliers en §3.
- Frantar et al., GPTQ (arXiv:2210.17323) — lo que por-canal + redondeo consciente de la Hessiana te compra encima de este baseline.
Siguiente: restaura la escala por-canal, luego ejecuta lab/02-quant-curve.md para la frontera INT4 legítima (con escalas por-grupo).