Skip to content

English · Español

Lab 00 — Reward model from preferences

🇪🇸 Entrena un modelo de recompensa diminuto sobre 200 pares de preferencias de correcciones gramaticales con la pérdida Bradley-Terry.

Goal

Train a tiny reward model \(r_\phi(x, y)\) on 200 hand-curated pairwise preferences over grammar-tutor responses from §A13. Verify the Bradley-Terry loss converges, the held-out pair accuracy exceeds 70%, and inspect reward histograms for the over-optimization warning sign.

Prerequisites

  • Phase 17 mini-GPT trained on §A13 corpus.
  • Phase 18 SFT model on the grammar-tutor prompt format (input sentence → correction).
  • Phase 28 LoRA adapters fitted to the SFT model (we initialize the RM from this).

Recipe

The reward model is the SFT model body + a single linear head on the last-token hidden state:

\[ r_\phi(x, y) = w^\top h_{|y|}(x, y) + b \]

with \(w \in \mathbb{R}^{128}, b \in \mathbb{R}\). The body is frozen; only \((w, b)\) are trained. This is the cheapest possible RM — and it is what InstructGPT did at the architectural level (a scalar head on the SFT body).

Bradley-Terry loss

\[ \mathcal{L}_{\text{RM}}(\phi) = -\,\mathbb{E}_{(x, y_w, y_l)}\!\left[\log \sigma\!\left( r_\phi(x, y_w) - r_\phi(x, y_l) \right)\right] \]

Data: the 200-pair preference set

The 200 pairs are inlined here as YAML — drop into data/preferences/grammar_v0.yaml. The split is 160 train / 40 eval. Each entry is prompt / chosen / rejected. The chosen responses follow the 3 principles from theory chapter 05 (correct, concise, honest-when-no-error).

# data/preferences/grammar_v0.yaml
# 200 pairwise preferences over §A13 grammar-tutor responses.
# Format: { prompt, chosen, rejected }
# Train: indices 0..159 ; Eval: indices 160..199

# --- Block A: regular verb past simple (40 pairs) ---
- prompt: "I works yesterday."
  chosen: "Use *worked*: past simple of *work* is *worked* (regular -ed)."
  rejected: "Use *worked*. The past simple form of the regular verb *to work*, which is one of the twelve regular verbs in the curriculum, is *worked*, formed by appending the *-ed* suffix to the bare infinitive *work*."
- prompt: "She play piano yesterday."
  chosen: "Use *played*: past simple, regular -ed."
  rejected: "Use *play*. She play piano yesterday is grammatically correct."
- prompt: "He walk to school yesterday."
  chosen: "Use *walked*: past simple of *walk* is *walked* (regular -ed)."
  rejected: "Yes, *He walk to school yesterday* is fine."
- prompt: "You talk to him last week."
  chosen: "Use *talked*: past simple, regular -ed."
  rejected: "Use *talks*. Always add -s to verbs."
- prompt: "I listen to music yesterday."
  chosen: "Use *listened*: past simple, regular -ed."
  rejected: "Use *listening*."
# ... 35 more pairs in this block, covering: watch/watched, study/studied,
# finish/finished, start/started, look/looked, want/wanted, like/liked,
# applied to all three persons (I, you, he/she/it).

# --- Block B: irregular verb past simple (40 pairs) ---
- prompt: "I goed home yesterday."
  chosen: "Use *went*: *go* is irregular; past simple is *went*."
  rejected: "Use *goed*. Add -ed for past."
- prompt: "She eated breakfast."
  chosen: "Use *ate*: *eat* is irregular; past simple is *ate*."
  rejected: "Use *eated*. The regular -ed rule applies."
- prompt: "He seed the movie."
  chosen: "Use *saw*: *see* is irregular; past simple is *saw*."
  rejected: "Use *seed*. -ed makes the past."
- prompt: "I writed a letter."
  chosen: "Use *wrote*: *write* is irregular; past simple is *wrote*."
  rejected: "Use *writed*."
- prompt: "You comed early."
  chosen: "Use *came*: *come* is irregular; past simple is *came*."
  rejected: "Use *comed*. Add -ed."
# ... 35 more pairs covering be/was|were, have/had, do/did and all persons.

# --- Block C: 3rd-person singular -s in present simple (30 pairs) ---
- prompt: "He work every day."
  chosen: "Use *works*: 3rd-person singular present simple takes -s."
  rejected: "Use *work*. Present simple is the bare infinitive."
- prompt: "She study at night."
  chosen: "Use *studies*: 3rd-person singular, *y*  *ies*."
  rejected: "Use *studys*."
- prompt: "It watch TV."
  chosen: "Use *watches*: 3rd-person singular adds -es after -ch."
  rejected: "Use *watch*. Subjects don't change verbs."
- prompt: "He have a car."
  chosen: "Use *has*: *have* is irregular in 3rd-person singular  *has*."
  rejected: "Use *haves*."
- prompt: "She do her homework."
  chosen: "Use *does*: *do* is irregular in 3rd-person singular  *does*."
  rejected: "Use *dos*."
# ... 25 more pairs.

# --- Block D: future tense (will / going to) (30 pairs) ---
- prompt: "I will to go tomorrow."
  chosen: "Drop *to*: *will* + bare infinitive  *I will go tomorrow*."
  rejected: "Use *willed*. Past of *will* is *willed*."
- prompt: "She going to study."
  chosen: "Add *is*: *going to* needs a form of *be*  *She is going to study*."
  rejected: "*She going to study* is correct."
- prompt: "He will works tomorrow."
  chosen: "Drop the -s: after *will*, use bare infinitive  *He will work tomorrow*."
  rejected: "Use *will worked*."
- prompt: "I am going eat dinner."
  chosen: "Add *to*: *going to* + bare infinitive  *I am going to eat dinner*."
  rejected: "Sentence is fine."
- prompt: "You will to come?"
  chosen: "Drop *to*:  *Will you come?*"
  rejected: "Use *willed* in questions."
# ... 25 more pairs.

# --- Block E: past participle vs past simple (20 pairs) ---
- prompt: "I have went to Spain."
  chosen: "Use *gone* (past participle): *have* + past participle  *I have gone*."
  rejected: "Use *went*. Past simple after *have*."
- prompt: "She has ate already."
  chosen: "Use *eaten*: *have/has* + past participle of *eat* is *eaten*."
  rejected: "Use *ate*. *Ate* is correct anywhere in past."
- prompt: "He has wrote a book."
  chosen: "Use *written*: past participle of *write* is *written*."
  rejected: "Sentence is fine."
- prompt: "I have did my work."
  chosen: "Use *done*: past participle of *do* is *done*."
  rejected: "Use *did*. Past forms are interchangeable."
# ... 16 more pairs.

# --- Block F: honesty (input is correct; refuse to invent a fix) (20 pairs) ---
- prompt: "I worked yesterday."
  chosen: "No error: *I worked yesterday* is correct (past simple)."
  rejected: "Use *I am worked*. Always use a form of *be*."
- prompt: "She is going to study tonight."
  chosen: "No error: *going to* + infinitive is correct here."
  rejected: "Use *She is going study tonight*. Drop *to*."
- prompt: "He has eaten breakfast."
  chosen: "No error: *has* + past participle *eaten* is correct."
  rejected: "Use *He has ate breakfast*."
- prompt: "You will come tomorrow."
  chosen: "No error: *will* + bare infinitive is correct."
  rejected: "Use *You will to come tomorrow*."
# ... 16 more pairs. This block trains the model NOT to hallucinate errors.

# --- Block G: distractor / cross-tense (20 pairs) ---
- prompt: "He goes to the store now."
  chosen: "No error: *goes* is correct 3rd-person singular present simple."
  rejected: "Use *went*. Always prefer past."
- prompt: "I am writing a letter."
  chosen: "No error (present continuous  outside §A13 scope but valid English)."
  rejected: "Use *I writes a letter*."
# ... 18 more pairs.

Note: the abbreviated entries above are illustrative. The full 200-pair file is generated by scripts/build_grammar_v0.py (see Phase 12 corpus tools) by enumerating template × verb × person and pairing each correct response with a controlled mistake. The script is short (~40 lines) and lives in the lab's solutions/ folder; learners write their own version first as part of this lab.

Implementation

# src/x3_rlhf/reward_model.py
# Bradley-Terry reward model on §A13 preferences. ~30 lines.

from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F


class RewardModel(nn.Module):
    """Scalar reward head on a frozen SFT body."""

    def __init__(self, sft_model: nn.Module, hidden_dim: int = 128) -> None:
        super().__init__()
        self.body = sft_model
        for p in self.body.parameters():
            p.requires_grad = False
        self.head = nn.Linear(hidden_dim, 1)

    def forward(self, input_ids: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
        # body returns (B, T, hidden); we want the last non-pad token per row.
        h = self.body.hidden_states(input_ids, attn_mask)        # (B, T, H)
        last_idx = attn_mask.sum(dim=1) - 1                       # (B,)
        last_h = h[torch.arange(h.size(0)), last_idx]             # (B, H)
        return self.head(last_h).squeeze(-1)                      # (B,)


def bt_loss(rm: RewardModel, batch: dict[str, torch.Tensor]) -> torch.Tensor:
    """Bradley-Terry NLL on (chosen, rejected) pairs."""
    r_w = rm(batch["chosen_ids"], batch["chosen_mask"])
    r_l = rm(batch["rejected_ids"], batch["rejected_mask"])
    return -F.logsigmoid(r_w - r_l).mean()


def pair_accuracy(rm: RewardModel, batch: dict[str, torch.Tensor]) -> float:
    """Fraction of pairs where r(chosen) > r(rejected)."""
    with torch.no_grad():
        r_w = rm(batch["chosen_ids"], batch["chosen_mask"])
        r_l = rm(batch["rejected_ids"], batch["rejected_mask"])
        return (r_w > r_l).float().mean().item()

Training loop (in scripts/train_rm.py, also ~30 lines):

from x3_rlhf.reward_model import RewardModel, bt_loss, pair_accuracy
from lynx_cortex.utils import seed_everything, save_manifest
from lynx_cortex.phase17 import load_sft_model
from lynx_cortex.data import load_pref_yaml, make_loader

seed_everything(42)
sft = load_sft_model("checkpoints/phase28-lora.pt")
rm = RewardModel(sft, hidden_dim=128)

train_pairs, eval_pairs = load_pref_yaml(
    "data/preferences/grammar_v0.yaml", split=(160, 40)
)
train_loader = make_loader(train_pairs, batch_size=16, shuffle=True)
eval_loader  = make_loader(eval_pairs,  batch_size=16, shuffle=False)

opt = torch.optim.AdamW(rm.head.parameters(), lr=1e-3, weight_decay=0.01)

for epoch in range(10):
    for batch in train_loader:
        loss = bt_loss(rm, batch)
        opt.zero_grad(); loss.backward(); opt.step()
    eval_acc = sum(pair_accuracy(rm, b) for b in eval_loader) / len(eval_loader)
    print(f"epoch={epoch}  train_loss={loss:.4f}  eval_acc={eval_acc:.3f}")

save_manifest("experiments/2026-05-23-rm-v0/manifest.json",
              {"seed": 42, "epochs": 10, "lr": 1e-3, "eval_acc": eval_acc})
torch.save(rm.head.state_dict(), "checkpoints/rm-v0.pt")

Expected results

Metric Target
Train BT loss after 10 epochs < 0.30
Train pair accuracy > 0.90
Held-out pair accuracy (40 pairs) > 0.70
Mean \(r(y_w) - r(y_l)\) on held-out > 1.0 (well-separated)

Diagnostic plots

After training, plot:

  1. BT loss curve (train + eval) — should monotonically decrease, then flatten.
  2. Reward histograms — overlay \(r_\phi(x, y_w)\) and \(r_\phi(x, y_l)\) on held-out. Should be visibly bimodal.
  3. Per-block accuracy — break out held-out accuracy by the 7 blocks (A-G). Block F (honesty) is the hardest; expect lower accuracy there.

Things to break (per CLAUDE.md §0.2 — "break intentionally")

These are exercises for the learner. Each one elicits a specific failure-mode lesson from chapter 02.

  1. Set \(w\) to be the embedding of "the" instead of training. Show that the RM scores responses by how often they contain "the" — a perfect length-bias caricature.
  2. Train on a duplicated dataset where Block F (honesty) is removed. Show that the RM happily rewards hallucinated corrections on already-correct inputs — sycophancy materialized.
  3. Score 50 best-of-N samples from the SFT model as N grows from 1 to 32. Plot RM score vs. held-out gold accuracy. Look for the U-curve (Gao 2022).

DoD

  • train_rm.py runs end-to-end on CPU in < 5 minutes.
  • Held-out pair accuracy > 0.70.
  • Reward histograms saved as PNG in experiments/2026-05-23-rm-v0/.
  • Manifest persisted with seed, lr, eval_acc, lib versions.
  • Three "break it" exercises attempted with one-paragraph writeup of what each revealed.