Skip to content

English · Español

Lab 00 — Cuantización post-entrenamiento INT8 sobre MiniGPT

Objetivo: implementar PTQ INT8 solo-pesos (por-tensor y por-canal) y medir perplejidad vs FP32 sobre MiniGPT.

Tiempo estimado: 4–6 horas.

Prerrequisito: MiniGPT de Fase 17 con un forward pass model.eval() funcional; PyTorch de Fase 24; src/miniquant/BLUEPRINT.md leído.


Lo que produces

Un directorio experiments/26-int8-ptq/ que contiene:

  • quantize_minigpt.py — tu script (lo escribes tú).
  • results.json — mediciones (PPL FP32, PPL INT8 por-tensor, PPL INT8 por-canal, bytes en disco para cada uno, tamaño de calibration).
  • ppl_table.png — gráfico de barras o imagen de tabla.
  • manifest.json{seed, versions, config, hardware} según LYNX_CORTEX.md §5.
  • README.md — interpretación breve (2–4 párrafos).

También commiteas a src/miniquant/:

  • quantize.py — los cuantizadores simétricos por-tensor y por-canal y un módulo QuantizedLinear. Los tests pasan.

El kernel

El "kernel" de este lab es envolver cada nn.Linear de MiniGPT con un QuantizedLinear cuyo forward sea:

Linear(x) = (s_W * W_int8.float()) @ x + b

donde W_int8 = quantize_symmetric(W, scheme) se calcula una vez en calibration y se almacena como INT8 con una escala s_W (FP32) por-tensor o por-canal.

Esto es fake-quant: almacenamos los valores INT8 pero la matmul sigue ocurriendo en FP32. El objetivo es medir el efecto numérico de la cuantización, no la velocidad. (La velocidad requiere kernels INT8, que no tenemos en AVX2-sin-VNNI.)

TODOs

Bloque A — implementa el cuantizador en src/miniquant/quantize.py

El BLUEPRINT lista la API. Recapitulación:

  • quantize_symmetric_per_tensor(W: Tensor, bits: int = 8) -> (Tensor[int8], float). Devuelve (W_int, scale) con W_int ∈ [-127, 127] y scale = max(|W|) / 127.
  • quantize_symmetric_per_channel(W: Tensor, bits: int = 8, dim: int = 0) -> (Tensor[int8], Tensor[float]). Escalas por fila.
  • dequantize(W_int: Tensor[int8], scale: Tensor) -> Tensor. Hace broadcast correctamente de la escala.
  • QuantizedLinear(nn.Module). El constructor toma un nn.Linear existente + scheme; el forward hace matmul fake-quant; preserva el bias en FP32.
  • Tests en tests/test_quantization.py (Claude scaffolds los failing tests; Borja los hace pasar).

Bloque B — envuelve MiniGPT

  • Carga el MiniGPT de Fase 17 en modo eval.
  • Recorre el árbol de módulos, reemplaza cada nn.Linear con QuantizedLinear(orig_linear, scheme). Nota: no cuantices el embedding (es un nn.Embedding, no un nn.Linear; y cuantizar embeddings daña más que los pesos, ver theory 02).
  • Opcional: salta también el lm_head final (coincide con la convención de LLM.int8() — la capa de readout es sensible). Mide con y sin saltar; reporta ambos.

Bloque C — calibración

Para la cuantización solo-pesos por-canal, no hace falta calibración (los pesos son estáticos). Para cuantización de activaciones sí harías falta calibración — la saltamos en este lab y solo cuantizamos pesos.

  • Confirma: tu forward de QuantizedLinear pasa un tensor de la misma shape y dtype que el Linear original. Añade un assert en el test.

Bloque D — evalúa perplejidad

  • Usa la misma evaluación de perplejidad held-out de Fase 17 (scripts/eval_minigpt_ppl.py). Ejecuta sobre:
  • FP32 (baseline).
  • INT8 por-tensor.
  • INT8 por-canal.
  • Registra bytes en disco tras cada cuantización (suma de numel × dtype_size sobre todos los pesos, incluidas las escalas).

Bloque E — results.json

{
  "experiment": "26-int8-ptq",
  "date": "YYYY-MM-DD",
  "model": "minigpt-phase17",
  "model_params": null,
  "schemes": {
    "fp32":              { "ppl": null, "bytes": null },
    "int8_per_tensor":   { "ppl": null, "bytes": null, "ppl_gap_pct": null },
    "int8_per_channel":  { "ppl": null, "bytes": null, "ppl_gap_pct": null }
  },
  "notes": "..."
}

Bloque F — interpreta en README.md

Tres preguntas:

  1. ¿Cuál es la diferencia de PPL por-tensor vs por-canal? Por-canal debería ser ≤ la mitad de la diferencia por-tensor. Si no lo es, tu modelo es inusualmente libre de outliers o tus pesos ya están de algún modo en baja precisión.
  2. ¿Dónde está la mayor parte del ahorro de bytes? Calcula el % de bytes totales atribuible a (a) pesos de Linears, (b) la tabla de embedding, © parámetros de layer-norm. Las tablas de embedding suelen dominar la cuenta de bytes en modelos pequeños.
  3. ¿La cuantización de qué capa duele más? Pista: re-ejecuta con solo una capa cuantizada cada vez, mide la PPL cada vez, dibuja un gráfico de barras. Habitualmente la proyección de salida de attention o el lm_head final es la peor.

Restricciones

  • Nada de los wrappers de alto nivel de torch.quantization. Puedes usar utilidades de bajo nivel (torch.int8, tensor.to(torch.int8)), pero la matemática de cuantización es tuya.
  • Nada de bitsandbytes. Mismo motivo: caja negra.
  • Reproducibilidad: seed_everything(42) arriba de cada script.
  • Solo CPU. No hace falta gate de CUDA; asume que el dataset de calibration es suficientemente pequeño como para que los forward passes en FP32 terminen en minutos.

Condiciones de parada

Terminado cuando:

  1. Los tests en tests/test_quantization.py pasan todos.
  2. experiments/26-int8-ptq/ tiene los cinco archivos.
  3. La diferencia de PPL INT8 por-canal es < 5% (el umbral del DoD); si no, depura siguiendo las notas de pitfalls/ antes de consultar solutions.
  4. README.md responde las tres preguntas del Bloque F.

Trampas

  • Diferencia de PPL > 20%. Probablemente olvidaste hacer dequant antes de la matmul, o el broadcast de la escala está mal (la escala por-canal necesita shape (out, 1) no (out,) cuando multiplica una matriz (out, in)).
  • Diferencia de PPL sospechosamente pequeña (< 0.1%). Quizá hayas mantenido por accidente los pesos FP32 originales cacheados en el módulo. Imprime model.layers[0].mlp.fc1.weight.dtype tras envolver; debería ser FP32 (el resultado dequantizado), y model.layers[0].mlp.fc1.W_int8.dtype debería ser int8.
  • La memoria revienta. Estás manteniendo copias INT8 y FP32 a la vez. El peso dequantizado debería calcularse al vuelo por forward, no cachearse.
  • NaN en la salida. Escala por-tensor donde max(|W|) = 0 para alguna fila patológica. Añade un guard max(s, 1e-9).

Cuándo consultar solutions/

Tras commitear los cinco archivos y cumplir el umbral del DoD. La referencia en solutions/00-int8-ptq-ref.md (escrita al abrir la fase) compara tus números y la estructura de QuantizedLinear.


Siguiente lab: lab/01-gptq-toy.md.