Skip to content

English · Español

Lab 00 — Modelo de recompensa desde preferencias

🇪🇸 Entrena un modelo de recompensa diminuto sobre 200 pares de preferencias de correcciones gramaticales con la pérdida Bradley-Terry.

Objetivo

Entrena un modelo de recompensa diminuto \(r_\phi(x, y)\) sobre 200 preferencias pareadas curadas a mano sobre respuestas del tutor gramatical §A13. Verifica que la pérdida Bradley-Terry converge, que la precisión por par retenida supera el 70%, e inspecciona histogramas de recompensa para detectar la señal de aviso de sobre-optimización.

Prerrequisitos

  • Mini-GPT de la Fase 17 entrenado sobre el corpus §A13.
  • Modelo SFT de la Fase 18 sobre el formato de prompt del tutor gramatical (frase de entrada → corrección).
  • Adaptadores LoRA de la Fase 28 ajustados al modelo SFT (inicializamos el RM desde aquí).

Receta

El modelo de recompensa es el cuerpo del modelo SFT + una única cabeza lineal sobre el estado oculto del último token:

\[ r_\phi(x, y) = w^\top h_{|y|}(x, y) + b \]

con \(w \in \mathbb{R}^{128}, b \in \mathbb{R}\). El cuerpo está congelado; solo se entrenan \((w, b)\). Es el RM más barato posible — y es lo que InstructGPT hizo a nivel arquitectónico (una cabeza escalar sobre el cuerpo SFT).

Pérdida Bradley-Terry

\[ \mathcal{L}_{\text{RM}}(\phi) = -\,\mathbb{E}_{(x, y_w, y_l)}\!\left[\log \sigma\!\left( r_\phi(x, y_w) - r_\phi(x, y_l) \right)\right] \]

Datos: el set de 200 preferencias

Los 200 pares están aquí inlined como YAML — vuélcalos en data/preferences/grammar_v0.yaml. La división es 160 train / 40 eval. Cada entrada es prompt / chosen / rejected. Las respuestas elegidas siguen los 3 principios del capítulo de teoría 05 (correcto, conciso, honesto-cuando-no-hay-error).

# data/preferences/grammar_v0.yaml
# 200 pairwise preferences over §A13 grammar-tutor responses.
# Format: { prompt, chosen, rejected }
# Train: indices 0..159 ; Eval: indices 160..199

# --- Block A: regular verb past simple (40 pairs) ---
- prompt: "I works yesterday."
  chosen: "Use *worked*: past simple of *work* is *worked* (regular -ed)."
  rejected: "Use *worked*. The past simple form of the regular verb *to work*, which is one of the twelve regular verbs in the curriculum, is *worked*, formed by appending the *-ed* suffix to the bare infinitive *work*."
- prompt: "She play piano yesterday."
  chosen: "Use *played*: past simple, regular -ed."
  rejected: "Use *play*. She play piano yesterday is grammatically correct."
- prompt: "He walk to school yesterday."
  chosen: "Use *walked*: past simple of *walk* is *walked* (regular -ed)."
  rejected: "Yes, *He walk to school yesterday* is fine."
- prompt: "You talk to him last week."
  chosen: "Use *talked*: past simple, regular -ed."
  rejected: "Use *talks*. Always add -s to verbs."
- prompt: "I listen to music yesterday."
  chosen: "Use *listened*: past simple, regular -ed."
  rejected: "Use *listening*."
# ... 35 more pairs in this block, covering: watch/watched, study/studied,
# finish/finished, start/started, look/looked, want/wanted, like/liked,
# applied to all three persons (I, you, he/she/it).

# --- Block B: irregular verb past simple (40 pairs) ---
- prompt: "I goed home yesterday."
  chosen: "Use *went*: *go* is irregular; past simple is *went*."
  rejected: "Use *goed*. Add -ed for past."
- prompt: "She eated breakfast."
  chosen: "Use *ate*: *eat* is irregular; past simple is *ate*."
  rejected: "Use *eated*. The regular -ed rule applies."
- prompt: "He seed the movie."
  chosen: "Use *saw*: *see* is irregular; past simple is *saw*."
  rejected: "Use *seed*. -ed makes the past."
- prompt: "I writed a letter."
  chosen: "Use *wrote*: *write* is irregular; past simple is *wrote*."
  rejected: "Use *writed*."
- prompt: "You comed early."
  chosen: "Use *came*: *come* is irregular; past simple is *came*."
  rejected: "Use *comed*. Add -ed."
# ... 35 more pairs covering be/was|were, have/had, do/did and all persons.

# --- Block C: 3rd-person singular -s in present simple (30 pairs) ---
- prompt: "He work every day."
  chosen: "Use *works*: 3rd-person singular present simple takes -s."
  rejected: "Use *work*. Present simple is the bare infinitive."
- prompt: "She study at night."
  chosen: "Use *studies*: 3rd-person singular, *y*  *ies*."
  rejected: "Use *studys*."
- prompt: "It watch TV."
  chosen: "Use *watches*: 3rd-person singular adds -es after -ch."
  rejected: "Use *watch*. Subjects don't change verbs."
- prompt: "He have a car."
  chosen: "Use *has*: *have* is irregular in 3rd-person singular  *has*."
  rejected: "Use *haves*."
- prompt: "She do her homework."
  chosen: "Use *does*: *do* is irregular in 3rd-person singular  *does*."
  rejected: "Use *dos*."
# ... 25 more pairs.

# --- Block D: future tense (will / going to) (30 pairs) ---
- prompt: "I will to go tomorrow."
  chosen: "Drop *to*: *will* + bare infinitive  *I will go tomorrow*."
  rejected: "Use *willed*. Past of *will* is *willed*."
- prompt: "She going to study."
  chosen: "Add *is*: *going to* needs a form of *be*  *She is going to study*."
  rejected: "*She going to study* is correct."
- prompt: "He will works tomorrow."
  chosen: "Drop the -s: after *will*, use bare infinitive  *He will work tomorrow*."
  rejected: "Use *will worked*."
- prompt: "I am going eat dinner."
  chosen: "Add *to*: *going to* + bare infinitive  *I am going to eat dinner*."
  rejected: "Sentence is fine."
- prompt: "You will to come?"
  chosen: "Drop *to*:  *Will you come?*"
  rejected: "Use *willed* in questions."
# ... 25 more pairs.

# --- Block E: past participle vs past simple (20 pairs) ---
- prompt: "I have went to Spain."
  chosen: "Use *gone* (past participle): *have* + past participle  *I have gone*."
  rejected: "Use *went*. Past simple after *have*."
- prompt: "She has ate already."
  chosen: "Use *eaten*: *have/has* + past participle of *eat* is *eaten*."
  rejected: "Use *ate*. *Ate* is correct anywhere in past."
- prompt: "He has wrote a book."
  chosen: "Use *written*: past participle of *write* is *written*."
  rejected: "Sentence is fine."
- prompt: "I have did my work."
  chosen: "Use *done*: past participle of *do* is *done*."
  rejected: "Use *did*. Past forms are interchangeable."
# ... 16 more pairs.

# --- Block F: honesty (input is correct; refuse to invent a fix) (20 pairs) ---
- prompt: "I worked yesterday."
  chosen: "No error: *I worked yesterday* is correct (past simple)."
  rejected: "Use *I am worked*. Always use a form of *be*."
- prompt: "She is going to study tonight."
  chosen: "No error: *going to* + infinitive is correct here."
  rejected: "Use *She is going study tonight*. Drop *to*."
- prompt: "He has eaten breakfast."
  chosen: "No error: *has* + past participle *eaten* is correct."
  rejected: "Use *He has ate breakfast*."
- prompt: "You will come tomorrow."
  chosen: "No error: *will* + bare infinitive is correct."
  rejected: "Use *You will to come tomorrow*."
# ... 16 more pairs. This block trains the model NOT to hallucinate errors.

# --- Block G: distractor / cross-tense (20 pairs) ---
- prompt: "He goes to the store now."
  chosen: "No error: *goes* is correct 3rd-person singular present simple."
  rejected: "Use *went*. Always prefer past."
- prompt: "I am writing a letter."
  chosen: "No error (present continuous  outside §A13 scope but valid English)."
  rejected: "Use *I writes a letter*."
# ... 18 more pairs.

Nota: las entradas abreviadas arriba son ilustrativas. El archivo completo de 200 pares lo genera scripts/build_grammar_v0.py (ver herramientas de corpus de la Fase 12) enumerando plantilla × verbo × persona y emparejando cada respuesta correcta con un error controlado. El script es corto (~40 líneas) y vive en la carpeta solutions/ del lab; los aprendices escriben primero su propia versión como parte de este lab.

Implementación

# src/x3_rlhf/reward_model.py
# Bradley-Terry reward model on §A13 preferences. ~30 lines.

from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F


class RewardModel(nn.Module):
    """Scalar reward head on a frozen SFT body."""

    def __init__(self, sft_model: nn.Module, hidden_dim: int = 128) -> None:
        super().__init__()
        self.body = sft_model
        for p in self.body.parameters():
            p.requires_grad = False
        self.head = nn.Linear(hidden_dim, 1)

    def forward(self, input_ids: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
        # body returns (B, T, hidden); we want the last non-pad token per row.
        h = self.body.hidden_states(input_ids, attn_mask)        # (B, T, H)
        last_idx = attn_mask.sum(dim=1) - 1                       # (B,)
        last_h = h[torch.arange(h.size(0)), last_idx]             # (B, H)
        return self.head(last_h).squeeze(-1)                      # (B,)


def bt_loss(rm: RewardModel, batch: dict[str, torch.Tensor]) -> torch.Tensor:
    """Bradley-Terry NLL on (chosen, rejected) pairs."""
    r_w = rm(batch["chosen_ids"], batch["chosen_mask"])
    r_l = rm(batch["rejected_ids"], batch["rejected_mask"])
    return -F.logsigmoid(r_w - r_l).mean()


def pair_accuracy(rm: RewardModel, batch: dict[str, torch.Tensor]) -> float:
    """Fraction of pairs where r(chosen) > r(rejected)."""
    with torch.no_grad():
        r_w = rm(batch["chosen_ids"], batch["chosen_mask"])
        r_l = rm(batch["rejected_ids"], batch["rejected_mask"])
        return (r_w > r_l).float().mean().item()

Bucle de entrenamiento (en scripts/train_rm.py, también ~30 líneas):

from x3_rlhf.reward_model import RewardModel, bt_loss, pair_accuracy
from lynx_cortex.utils import seed_everything, save_manifest
from lynx_cortex.phase17 import load_sft_model
from lynx_cortex.data import load_pref_yaml, make_loader

seed_everything(42)
sft = load_sft_model("checkpoints/phase28-lora.pt")
rm = RewardModel(sft, hidden_dim=128)

train_pairs, eval_pairs = load_pref_yaml(
    "data/preferences/grammar_v0.yaml", split=(160, 40)
)
train_loader = make_loader(train_pairs, batch_size=16, shuffle=True)
eval_loader  = make_loader(eval_pairs,  batch_size=16, shuffle=False)

opt = torch.optim.AdamW(rm.head.parameters(), lr=1e-3, weight_decay=0.01)

for epoch in range(10):
    for batch in train_loader:
        loss = bt_loss(rm, batch)
        opt.zero_grad(); loss.backward(); opt.step()
    eval_acc = sum(pair_accuracy(rm, b) for b in eval_loader) / len(eval_loader)
    print(f"epoch={epoch}  train_loss={loss:.4f}  eval_acc={eval_acc:.3f}")

save_manifest("experiments/2026-05-23-rm-v0/manifest.json",
              {"seed": 42, "epochs": 10, "lr": 1e-3, "eval_acc": eval_acc})
torch.save(rm.head.state_dict(), "checkpoints/rm-v0.pt")

Resultados esperados

Métrica Objetivo
Pérdida BT de entrenamiento tras 10 épocas < 0.30
Precisión por par en entrenamiento > 0.90
Precisión por par retenida (40 pares) > 0.70
Media de \(r(y_w) - r(y_l)\) en retenidos > 1.0 (bien separadas)

Gráficas de diagnóstico

Tras entrenar, traza:

  1. Curva de pérdida BT (train + eval) — debe decrecer monotónicamente y luego aplanarse.
  2. Histogramas de recompensa — superpón \(r_\phi(x, y_w)\) y \(r_\phi(x, y_l)\) en retenidos. Deben ser visiblemente bimodales.
  3. Precisión por bloque — desglosa la precisión retenida por los 7 bloques (A-G). El bloque F (honestidad) es el más difícil; espera menor precisión ahí.

Cosas a romper (per CLAUDE.md §0.2 — "break intentionally")

Estos son ejercicios para el aprendiz. Cada uno suscita una lección específica de modo de fallo del capítulo 02.

  1. Pon \(w\) como el embedding de "the" en lugar de entrenar. Muestra que el RM puntúa las respuestas por con qué frecuencia contienen "the" — una caricatura perfecta del sesgo de longitud.
  2. Entrena sobre un dataset duplicado del que se ha eliminado el Bloque F (honestidad). Muestra que el RM premia alegremente correcciones alucinadas en entradas ya correctas — sycophancy materializada.
  3. Puntúa 50 muestras best-of-N del modelo SFT según N crece de 1 a 32. Traza puntuación del RM vs precisión gold retenida. Busca la curva en U (Gao 2022).

Enlaces cruzados

DoD

  • train_rm.py corre de extremo a extremo en CPU en < 5 minutos.
  • Precisión por par retenida > 0.70.
  • Histogramas de recompensa guardados como PNG en experiments/2026-05-23-rm-v0/.
  • Manifest persistido con seed, lr, eval_acc, versiones de librerías.
  • Tres ejercicios "break it" intentados con un párrafo de resumen de lo que cada uno reveló.