Skip to content

English · Español

Break 00 — Redondeo INT4 naif al más cercano sin escala por-canal

Cuantizamos los pesos de attention a INT4 con round(w / global_scale) sin escala por canal y sin clipping de outliers. La precisión se desploma — esa caída es la lección.

Este ejercicio /break apunta a la decisión de granularidad de escala en cuantización. El bug es una línea; el fallo es ruidoso y observable en PPL.

Anclas: theory/02-scales-and-zeros.md, theory/03-gptq-and-nf4.md, .claude/commands/break.md.


Hipótesis

El learner predice: "Si cuantizo los pesos de attention a INT4 usando una única escala global (max-abs sobre todo el tensor) en vez de por-canal, el rango dinámico colapsa alrededor de las filas con outliers. La mayoría de las filas de pesos resuelven a {-1, 0, +1} sobre un grid de 16 niveles. La PPL explotará."

El break

En src/quant/quantize.py, reemplaza la ruta INT8 por-canal con un cast INT4 naif para las matrices W_q, W_k, W_v, W_o de attention:

 def quantize_linear(w: Tensor, bits: int = 8, per_channel: bool = True) -> QuantTensor:
     if per_channel:
-        scales = w.abs().amax(dim=1, keepdim=True) / (2 ** (bits - 1) - 1)
+        # /break: dropped per-channel scale + dropped clipping
+        scales = w.abs().amax() / (2 ** (bits - 1) - 1)
+        scales = scales.expand(w.shape[0], 1)
-        q = (w / scales).round().clamp(-(2 ** (bits - 1)), 2 ** (bits - 1) - 1)
+        q = (w / scales).round()                 # no clamp; rely on int dtype to wrap
     else:
         ...
     return QuantTensor(q.to(torch.int8), scales)

Y llámalo con bits=4 desde el wrapper de cuantización de attention:

-Wq_q = quantize_linear(Wq, bits=8, per_channel=True)
+Wq_q = quantize_linear(Wq, bits=4, per_channel=True)   # naïvely re-using the (now broken) helper

Edición de dos líneas. Las dos son esenciales — el cambio de escala por-canal→global es el break real; el bits=4 es lo que lo hace observable sobre un modelo diminuto.

Predice, luego ejecuta

Un tensor de pesos W ∈ ℝ^(d, d) = ℝ^(64, 64) de Mini-GPT típicamente tiene max |w| ≈ 0.4 y la mayoría de entradas en [-0.05, +0.05]. Con INT4 (16 niveles) y una escala global s = 0.4 / 7 ≈ 0.057:

  • Un peso en w = 0.04 mapea a round(0.04 / 0.057) = round(0.70) = 1 → valor dequantizado 0.057. Error ≈ 0.017 — mayor que el peso original.
  • Un peso en w = 0.002 mapea a round(0.002 / 0.057) = 0 → valor dequantizado 0. Pérdida total de información.
  • Aproximadamente el 70–80% de la matriz de pesos colapsa a 0.

La escala por-canal daría a cada fila de W su propia s_i ≈ max_j |w_ij| / 7, así que la escala "típica" de fila sería ≈ 0.01 y la resolución sería 50× más fina para la fila típica.

Predicciones

  • PPL final en el eval §A13: > 10× del baseline (p. ej., de 5.2 → 50+).
  • Distribución de salida de attention: la mayoría de las salidas de las heads colapsan a un pequeño conjunto de valores distintos (porque la mayoría de los pesos ahora son 0).
  • Fallo específico: el modelo emite tokens basura temprano en la generación, a menudo entra en bucle.
  • np.count_nonzero(q == 0) / q.numel() ≈ 0.70.

Escribe tus predicciones en learners/borja/phase-26/notes/breaks.md antes de ejecutar.

Observa

Ejecuta la receta de evaluación de Fase 26 con el cuantizador roto:

just exp 26-quant --variant int4-naive

Diagnósticos a dibujar:

  1. PPL en el eval set §A13: fp32 vs int4-naive. El gráfico de barras debería tener la barra de int4 literalmente fuera del tope.
  2. Histograma de q.flatten() para el W_q roto: debería estar fuertemente concentrado en 0.
  3. Muestra tres generaciones con prompt "I work" → esperada "I work today"; observada: probablemente incoherente.

Síntoma que verá Borja

  • PPL > 50 en el eval set §A13 (vs ~5.2 FP32).
  • 60% de los pesos cuantizados == 0.

  • Salidas de muestra incoherentes.
  • El bucle de entrenamiento no se ve afectado — esto es cuantización post-entrenamiento, así que los pesos son correctos en su origen FP32.

Causa oculta (una frase)

Una única escala global max-abs sobre todo el tensor de pesos combinada con INT4 (16 niveles) hace que la resolución por fila sea ~50× más gruesa que el default por-canal; la mayoría de las magnitudes de peso redondean a cero.

Cascada de pistas

  1. Dibuja la distribución de q.flatten(). ¿Qué fracción es cero? ¿Es plausible para un transformer entrenado?
  2. ¿Cuál es scales.shape? Traza por quantize_linear — ¿se está calculando la escala por-fila o por-tensor?
  3. Compara tus scales con lo que theory/02-scales-and-zeros.md deriva como la fórmula por-canal. ¿Dónde difieren?

Diff del fix

 def quantize_linear(w: Tensor, bits: int = 8, per_channel: bool = True) -> QuantTensor:
     if per_channel:
-        scales = w.abs().amax() / (2 ** (bits - 1) - 1)
-        scales = scales.expand(w.shape[0], 1)
+        scales = w.abs().amax(dim=1, keepdim=True) / (2 ** (bits - 1) - 1)
-        q = (w / scales).round()
+        q = (w / scales).round().clamp(-(2 ** (bits - 1)), 2 ** (bits - 1) - 1)
     ...

Y mantén bits=8 para la ruta de attention — INT4 es para lab/02-quant-curve.md, no el default.

Por qué esto enseña el concepto

theory/02-scales-and-zeros.md afirma que las escalas por-canal acotan el error por-fila por un factor proporcional al máximo de la fila. Este break hace esa afirmación load-bearing. Sin escalas por-canal, una fila pesada en outliers envenena a todas las demás filas del tensor. Con por-canal, la cota es local. La lección generaliza: INT8 es lo suficientemente tolerante como para que por-tensor a veces funcione (el paper LLM.int8() explota esto); INT4 no lo es — por-canal es obligatorio. GPTQ, AWQ, NF4 son todas respuestas sofisticadas a la misma pregunta que este break plantea sin rodeos: ¿dónde pones la unidad de escala?

Referencia

  • Dettmers et al., LLM.int8() (NeurIPS 2022) — la discusión sobre outliers en §3.
  • Frantar et al., GPTQ (arXiv:2210.17323) — lo que por-canal + redondeo consciente de la Hessiana te compra encima de este baseline.

Siguiente: restaura la escala por-canal, luego ejecuta lab/02-quant-curve.md para la frontera INT4 legítima (con escalas por-grupo).