English · Español
Lab 01 — DPO on the grammar-tutor¶
🇪🇸 Implementación de DPO en ~50 líneas sobre el tutor LoRA de la Fase 28. Mide la tasa de victoria contra el baseline SFT en un test set retenido.
Goal¶
Implement Direct Preference Optimization in ~50 lines, fine-tune the Phase 28 LoRA grammar-tutor on the same 200-pair dataset from Lab 00, and measure win rate against the SFT baseline on a held-out 50-pair test set. Target: > 55% win rate.
Why this is short¶
DPO has no reward model, no rollouts, no value function, no advantage estimation (theory chapter 04). The training loop is structurally identical to SFT — just with a different loss. Once you have the data loader from Lab 00, the only new code is ~50 lines for the DPO loss and the eval.
The loss¶
From theory chapter 04:
where the per-pair log-ratio differences are
The hyperparameter \(\beta = 0.1\) for this lab.
What log π(y | x) means in code¶
For a sequence \(y = (y_1, \dots, y_T)\) and a prompt \(x\):
i.e., the sum of the per-token log-probs of the response tokens only (not the prompt tokens). The implementation sums log-probs over a response-mask.
Implementation¶
# src/x3_rlhf/dpo.py
# DPO loss + training step. Total ~50 lines including utilities.
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
def sequence_logprob(
model: nn.Module,
input_ids: torch.Tensor, # (B, T)
response_mask: torch.Tensor, # (B, T), 1 for response tokens, 0 elsewhere
) -> torch.Tensor:
"""Sum of log-probs over response tokens. Shape: (B,)."""
logits = model(input_ids).logits[:, :-1, :] # (B, T-1, V)
targets = input_ids[:, 1:] # (B, T-1)
logp = F.log_softmax(logits, dim=-1)
token_logp = logp.gather(2, targets.unsqueeze(-1)).squeeze(-1) # (B, T-1)
mask = response_mask[:, 1:].float()
return (token_logp * mask).sum(dim=-1) # (B,)
def dpo_loss(
policy: nn.Module,
reference: nn.Module,
batch: dict[str, torch.Tensor],
beta: float = 0.1,
) -> tuple[torch.Tensor, dict[str, float]]:
"""DPO loss on (chosen, rejected) pairs. Reference is frozen."""
pi_w = sequence_logprob(policy, batch["chosen_ids"], batch["chosen_mask"])
pi_l = sequence_logprob(policy, batch["rejected_ids"], batch["rejected_mask"])
with torch.no_grad():
ref_w = sequence_logprob(reference, batch["chosen_ids"], batch["chosen_mask"])
ref_l = sequence_logprob(reference, batch["rejected_ids"], batch["rejected_mask"])
delta_w = pi_w - ref_w # implicit reward for chosen / β
delta_l = pi_l - ref_l # implicit reward for rejected / β
logits = beta * (delta_w - delta_l)
loss = -F.logsigmoid(logits).mean()
# Diagnostics
metrics = {
"loss": loss.item(),
"implicit_r_w": (beta * delta_w).mean().item(),
"implicit_r_l": (beta * delta_l).mean().item(),
"pair_accuracy": (logits > 0).float().mean().item(),
"kl_chosen": (pi_w - ref_w).mean().item(),
"kl_rejected": (pi_l - ref_l).mean().item(),
}
return loss, metrics
Training script (~30 more lines, in scripts/train_dpo.py):
import copy, torch
from x3_rlhf.dpo import dpo_loss
from lynx_cortex.utils import seed_everything, save_manifest
from lynx_cortex.phase17 import load_sft_model
from lynx_cortex.phase28 import attach_lora
from lynx_cortex.data import load_pref_yaml, make_loader
seed_everything(42)
# Reference = frozen SFT model. Policy = SFT + trainable LoRA.
reference = load_sft_model("checkpoints/phase28-lora.pt").eval()
for p in reference.parameters():
p.requires_grad = False
policy = attach_lora(load_sft_model("checkpoints/phase28-lora.pt"),
r=8, alpha=16, dropout=0.05)
trainable = [p for p in policy.parameters() if p.requires_grad]
train_pairs, eval_pairs = load_pref_yaml(
"data/preferences/grammar_v0.yaml", split=(150, 50)
)
train_loader = make_loader(train_pairs, batch_size=8, shuffle=True)
opt = torch.optim.AdamW(trainable, lr=1e-4, weight_decay=0.0)
for epoch in range(3):
for batch in train_loader:
loss, m = dpo_loss(policy, reference, batch, beta=0.1)
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(trainable, max_norm=1.0)
opt.step()
print(f"loss={m['loss']:.4f} acc={m['pair_accuracy']:.3f} "
f"kl_w={m['kl_chosen']:+.3f} kl_l={m['kl_rejected']:+.3f}")
torch.save({k: v for k, v in policy.state_dict().items() if "lora" in k},
"checkpoints/dpo-lora-v0.pt")
save_manifest("experiments/2026-05-23-dpo-v0/manifest.json",
{"seed": 42, "beta": 0.1, "epochs": 3, "lr": 1e-4})
That is the entire DPO training loop. Compare to what PPO-for-language requires (theory chapter 03).
Evaluation: win rate vs SFT baseline¶
For each of the 50 held-out prompts, generate one greedy response from the SFT baseline and one from the DPO model. A judge (here, the Lab 00 reward model) labels which is better; report fraction where DPO wins.
# scripts/eval_winrate.py
from x3_rlhf.reward_model import RewardModel
from lynx_cortex.phase17 import load_sft_model, generate
from lynx_cortex.phase28 import attach_lora, load_lora
sft = load_sft_model("checkpoints/phase28-lora.pt").eval()
dpo = attach_lora(load_sft_model("checkpoints/phase28-lora.pt"))
load_lora(dpo, "checkpoints/dpo-lora-v0.pt")
dpo.eval()
rm = RewardModel(load_sft_model("checkpoints/phase28-lora.pt"))
rm.head.load_state_dict(torch.load("checkpoints/rm-v0.pt"))
rm.eval()
_, eval_pairs = load_pref_yaml("data/preferences/grammar_v0.yaml", split=(150, 50))
wins = 0
for pair in eval_pairs:
x = pair["prompt"]
y_sft = generate(sft, x, max_new=32, temperature=0.0)
y_dpo = generate(dpo, x, max_new=32, temperature=0.0)
r_sft = rm.score(x, y_sft)
r_dpo = rm.score(x, y_dpo)
wins += int(r_dpo > r_sft)
print(f"DPO win rate: {wins / len(eval_pairs):.3f}")
Expected results¶
| Metric | Target | Notes |
|---|---|---|
| DPO train loss after 3 epochs | < 0.50 | Starts at \(\log 2 \approx 0.693\) |
| Train pair accuracy (final epoch) | > 0.85 | |
| KL on chosen (\(\Delta_w\), mean) | -0.5 to +1.5 | Should be positive — policy preferring chosen |
| KL on rejected (\(\Delta_l\), mean) | -2.0 to -0.5 | Should be negative — policy disprefers rejected |
| DPO win rate vs SFT (held-out 50 pairs) | > 0.55 | The headline metric |
If DPO win rate is ≤ 0.50, debug:
- Did the loss actually decrease? (Common bug: response mask off-by-one with shifted-tokens.)
- Are \(\Delta_w\) and \(\Delta_l\) moving in the right directions?
- Is the reference model actually frozen and in eval mode?
What to look at¶
- Loss curve. Should drop from \(\sim 0.69\) (random) to \(\sim 0.3\)–\(0.4\) by epoch 3. Flat or rising means a bug.
- Implicit reward gap \(\beta(\Delta_w - \Delta_l)\). Tracks the BT logit. Should rise; if it saturates fast, lower \(\beta\) or check for label noise.
- KL trajectory. \(\mathbb{E}[\Delta_w]\) and \(\mathbb{E}[\Delta_l]\) jointly. If both drift far negative, the policy is forgetting language — raise \(\beta\).
Things to break¶
- Set \(\beta = 0\). The loss becomes \(-\log\sigma(0) = \log 2\) regardless of the policy; gradient is zero. Verify training stalls.
- Set \(\beta = 10\). Sigmoid saturates almost immediately; gradient vanishes for easy pairs; learning stalls in a different way.
- Swap chosen/rejected in 10% of training pairs (label noise). Watch train accuracy plateau lower; DPO is more sensitive to label noise than SFT because it relies on the ranking.
- Use the policy itself as the reference (\(\pi_{\text{ref}} \leftarrow \pi_\theta\) at every step). The implicit reward terms become zero; loss is constant. Confirms the reference-model role.
Cross-links¶
- Theory 04 — DPO and direct methods: the four-step derivation.
- Lab 00 — Reward Model: produces the RM used as the judge here.
- Phase 28 — LoRA / QLoRA: we train LoRA params only; base model frozen.
DoD¶
-
train_dpo.pyruns end-to-end on CPU in < 15 minutes. - Loss decreases monotonically.
- DPO win rate vs SFT > 0.55.
- All four diagnostics (\(\Delta_w\), \(\Delta_l\), pair accuracy, loss) saved per epoch.
- Manifest persisted with seed, \(\beta\), lr, win rate.
- Two "break it" exercises attempted with one-paragraph writeups.