English · Español
Break 00 — Naive INT4 round-to-nearest with no per-channel scale¶
🇪🇸 Cuantizamos los pesos de atención a INT4 con
round(w / global_scale)sin escala por canal y sin clipping de outliers. La precisión se desploma — esa caída es la lección.
This /break exercise targets the scale granularity decision in quantization. The bug is one line; the failure is loud and observable in PPL.
Anchors: theory/02-scales-and-zeros.md, theory/03-gptq-and-nf4.md, .claude/commands/break.md.
Hypothesis¶
The learner predicts: "If I quantize attention weights to INT4 using a single global scale (max-abs over the whole tensor) instead of per-channel, the dynamic range collapses around outlier rows. Most weight rows resolve to {-1, 0, +1} on a 16-level grid. PPL will explode."
The break¶
In src/quant/quantize.py, replace the per-channel INT8 path with a naive INT4 cast for the attention W_q, W_k, W_v, W_o matrices:
def quantize_linear(w: Tensor, bits: int = 8, per_channel: bool = True) -> QuantTensor:
if per_channel:
- scales = w.abs().amax(dim=1, keepdim=True) / (2 ** (bits - 1) - 1)
+ # /break: dropped per-channel scale + dropped clipping
+ scales = w.abs().amax() / (2 ** (bits - 1) - 1)
+ scales = scales.expand(w.shape[0], 1)
- q = (w / scales).round().clamp(-(2 ** (bits - 1)), 2 ** (bits - 1) - 1)
+ q = (w / scales).round() # no clamp; rely on int dtype to wrap
else:
...
return QuantTensor(q.to(torch.int8), scales)
And call it with bits=4 from the attention quant wrapper:
-Wq_q = quantize_linear(Wq, bits=8, per_channel=True)
+Wq_q = quantize_linear(Wq, bits=4, per_channel=True) # naïvely re-using the (now broken) helper
Two-line edit. Both lines are essential — the per-channel→global scale change is the real break; the bits=4 is what makes it observable on a tiny model.
Predict, then run¶
A weight tensor W ∈ ℝ^(d, d) = ℝ^(64, 64) from Mini-GPT typically has max |w| ≈ 0.4 and most entries in [-0.05, +0.05]. With INT4 (16 levels) and a global scale s = 0.4 / 7 ≈ 0.057:
- A weight at
w = 0.04maps toround(0.04 / 0.057) = round(0.70) = 1→ dequant value0.057. Error ≈ 0.017 — larger than the original weight. - A weight at
w = 0.002maps toround(0.002 / 0.057) = 0→ dequant value0. Total information loss. - Roughly 70–80% of the weight matrix collapses to 0.
The per-channel scale would give each row of W its own s_i ≈ max_j |w_ij| / 7, so the "typical" row scale would be ≈ 0.01 and the resolution would be 50× finer for the typical row.
Predictions¶
- Final PPL on §A13 eval: > 10× baseline (e.g., from 5.2 → 50+).
- Attention output distribution: most heads' outputs collapse to a small set of distinct values (because most weights are now 0).
- Specific failure: the model outputs garbage tokens early in generation, often loops.
np.count_nonzero(q == 0) / q.numel()≈ 0.70.
Write your predictions in learners/borja/phase-26/notes/breaks.md before running.
Observe¶
Run the Phase 26 eval recipe with the broken quantizer:
Diagnostics to plot:
- PPL on §A13 eval set:
fp32vsint4-naive. The bar chart should have the int4 bar literally off the top. - Histogram of
q.flatten()for the brokenW_q: should be heavily concentrated at 0. - Sample three generations with prompt "I work" → expected "I work today"; observed: probably garbled.
Symptom Borja will see¶
- PPL > 50 on the §A13 eval set (vs ~5.2 FP32).
-
60% of quantized weights == 0.
- Sample outputs incoherent.
- The training loop is not affected — this is post-training quantization, so the weights are correct in their FP32 source.
Hidden cause (one sentence)¶
A single global max-abs scale across the whole weight tensor combined with INT4 (16 levels) makes the per-row resolution ~50× coarser than the per-channel default; most weight magnitudes round to zero.
Hint cascade¶
- Plot the distribution of
q.flatten(). What fraction is zero? Is that plausible for a trained transformer? - What is
scales.shape? Trace throughquantize_linear— is the scale being computed per-row or per-tensor? - Compare your
scalesto whattheory/02-scales-and-zeros.mdderives as the per-channel formula. Where do they differ?
Fix diff¶
def quantize_linear(w: Tensor, bits: int = 8, per_channel: bool = True) -> QuantTensor:
if per_channel:
- scales = w.abs().amax() / (2 ** (bits - 1) - 1)
- scales = scales.expand(w.shape[0], 1)
+ scales = w.abs().amax(dim=1, keepdim=True) / (2 ** (bits - 1) - 1)
- q = (w / scales).round()
+ q = (w / scales).round().clamp(-(2 ** (bits - 1)), 2 ** (bits - 1) - 1)
...
And keep bits=8 for the attention path — INT4 is for lab/02-quant-curve.md, not the default.
Why this teaches the concept¶
theory/02-scales-and-zeros.md claims per-channel scales bound the per-row error by a factor proportional to the row's max. This break makes that claim load-bearing. Without per-channel scales, an outlier-heavy row poisons every other row in the tensor. With per-channel, the bound is local. The lesson generalizes: INT8 is forgiving enough that per-tensor sometimes works (the LLM.int8() paper exploits this); INT4 is not — per-channel is mandatory. GPTQ, AWQ, NF4 are all sophisticated answers to the same question this break asks bluntly: where do you put the unit of scale?
Reference¶
- Dettmers et al., LLM.int8() (NeurIPS 2022) — the outlier discussion in §3.
- Frantar et al., GPTQ (arXiv:2210.17323) — what per-channel + Hessian-aware rounding buys you on top of this baseline.
Next: restore the per-channel scale, then run lab/02-quant-curve.md for the legitimate INT4 frontier (with group-wise scales).