Skip to content

English · Español

Break 00 — Naive INT4 round-to-nearest with no per-channel scale

🇪🇸 Cuantizamos los pesos de atención a INT4 con round(w / global_scale) sin escala por canal y sin clipping de outliers. La precisión se desploma — esa caída es la lección.

This /break exercise targets the scale granularity decision in quantization. The bug is one line; the failure is loud and observable in PPL.

Anchors: theory/02-scales-and-zeros.md, theory/03-gptq-and-nf4.md, .claude/commands/break.md.


Hypothesis

The learner predicts: "If I quantize attention weights to INT4 using a single global scale (max-abs over the whole tensor) instead of per-channel, the dynamic range collapses around outlier rows. Most weight rows resolve to {-1, 0, +1} on a 16-level grid. PPL will explode."

The break

In src/quant/quantize.py, replace the per-channel INT8 path with a naive INT4 cast for the attention W_q, W_k, W_v, W_o matrices:

 def quantize_linear(w: Tensor, bits: int = 8, per_channel: bool = True) -> QuantTensor:
     if per_channel:
-        scales = w.abs().amax(dim=1, keepdim=True) / (2 ** (bits - 1) - 1)
+        # /break: dropped per-channel scale + dropped clipping
+        scales = w.abs().amax() / (2 ** (bits - 1) - 1)
+        scales = scales.expand(w.shape[0], 1)
-        q = (w / scales).round().clamp(-(2 ** (bits - 1)), 2 ** (bits - 1) - 1)
+        q = (w / scales).round()                 # no clamp; rely on int dtype to wrap
     else:
         ...
     return QuantTensor(q.to(torch.int8), scales)

And call it with bits=4 from the attention quant wrapper:

-Wq_q = quantize_linear(Wq, bits=8, per_channel=True)
+Wq_q = quantize_linear(Wq, bits=4, per_channel=True)   # naïvely re-using the (now broken) helper

Two-line edit. Both lines are essential — the per-channel→global scale change is the real break; the bits=4 is what makes it observable on a tiny model.

Predict, then run

A weight tensor W ∈ ℝ^(d, d) = ℝ^(64, 64) from Mini-GPT typically has max |w| ≈ 0.4 and most entries in [-0.05, +0.05]. With INT4 (16 levels) and a global scale s = 0.4 / 7 ≈ 0.057:

  • A weight at w = 0.04 maps to round(0.04 / 0.057) = round(0.70) = 1 → dequant value 0.057. Error ≈ 0.017 — larger than the original weight.
  • A weight at w = 0.002 maps to round(0.002 / 0.057) = 0 → dequant value 0. Total information loss.
  • Roughly 70–80% of the weight matrix collapses to 0.

The per-channel scale would give each row of W its own s_i ≈ max_j |w_ij| / 7, so the "typical" row scale would be ≈ 0.01 and the resolution would be 50× finer for the typical row.

Predictions

  • Final PPL on §A13 eval: > 10× baseline (e.g., from 5.2 → 50+).
  • Attention output distribution: most heads' outputs collapse to a small set of distinct values (because most weights are now 0).
  • Specific failure: the model outputs garbage tokens early in generation, often loops.
  • np.count_nonzero(q == 0) / q.numel() ≈ 0.70.

Write your predictions in learners/borja/phase-26/notes/breaks.md before running.

Observe

Run the Phase 26 eval recipe with the broken quantizer:

just exp 26-quant --variant int4-naive

Diagnostics to plot:

  1. PPL on §A13 eval set: fp32 vs int4-naive. The bar chart should have the int4 bar literally off the top.
  2. Histogram of q.flatten() for the broken W_q: should be heavily concentrated at 0.
  3. Sample three generations with prompt "I work" → expected "I work today"; observed: probably garbled.

Symptom Borja will see

  • PPL > 50 on the §A13 eval set (vs ~5.2 FP32).
  • 60% of quantized weights == 0.

  • Sample outputs incoherent.
  • The training loop is not affected — this is post-training quantization, so the weights are correct in their FP32 source.

Hidden cause (one sentence)

A single global max-abs scale across the whole weight tensor combined with INT4 (16 levels) makes the per-row resolution ~50× coarser than the per-channel default; most weight magnitudes round to zero.

Hint cascade

  1. Plot the distribution of q.flatten(). What fraction is zero? Is that plausible for a trained transformer?
  2. What is scales.shape? Trace through quantize_linear — is the scale being computed per-row or per-tensor?
  3. Compare your scales to what theory/02-scales-and-zeros.md derives as the per-channel formula. Where do they differ?

Fix diff

 def quantize_linear(w: Tensor, bits: int = 8, per_channel: bool = True) -> QuantTensor:
     if per_channel:
-        scales = w.abs().amax() / (2 ** (bits - 1) - 1)
-        scales = scales.expand(w.shape[0], 1)
+        scales = w.abs().amax(dim=1, keepdim=True) / (2 ** (bits - 1) - 1)
-        q = (w / scales).round()
+        q = (w / scales).round().clamp(-(2 ** (bits - 1)), 2 ** (bits - 1) - 1)
     ...

And keep bits=8 for the attention path — INT4 is for lab/02-quant-curve.md, not the default.

Why this teaches the concept

theory/02-scales-and-zeros.md claims per-channel scales bound the per-row error by a factor proportional to the row's max. This break makes that claim load-bearing. Without per-channel scales, an outlier-heavy row poisons every other row in the tensor. With per-channel, the bound is local. The lesson generalizes: INT8 is forgiving enough that per-tensor sometimes works (the LLM.int8() paper exploits this); INT4 is not — per-channel is mandatory. GPTQ, AWQ, NF4 are all sophisticated answers to the same question this break asks bluntly: where do you put the unit of scale?

Reference

  • Dettmers et al., LLM.int8() (NeurIPS 2022) — the outlier discussion in §3.
  • Frantar et al., GPTQ (arXiv:2210.17323) — what per-channel + Hessian-aware rounding buys you on top of this baseline.

Next: restore the per-channel scale, then run lab/02-quant-curve.md for the legitimate INT4 frontier (with group-wise scales).