English · Español
Lab 01 — DPO sobre el tutor gramatical¶
🇪🇸 Implementación de DPO en ~50 líneas sobre el tutor LoRA de la Fase 28. Mide la tasa de victoria contra el baseline SFT en un test set retenido.
Objetivo¶
Implementa Direct Preference Optimization en ~50 líneas, ajusta finamente el tutor gramatical LoRA de la Fase 28 sobre el mismo dataset de 200 pares del Lab 00, y mide la tasa de victoria frente al baseline SFT en un test retenido de 50 pares. Objetivo: > 55% de tasa de victoria.
Por qué es corto¶
DPO no tiene modelo de recompensa, ni rollouts, ni función de valor, ni estimación de ventaja (capítulo 04 de teoría). El bucle de entrenamiento es estructuralmente idéntico al SFT — solo con una pérdida distinta. Una vez que tienes el data loader del Lab 00, el único código nuevo son ~50 líneas para la pérdida DPO y la evaluación.
La pérdida¶
Del capítulo 04 de teoría:
donde las diferencias de log-ratio por par son
El hiperparámetro \(\beta = 0{,}1\) para este lab.
Qué significa log π(y | x) en código¶
Para una secuencia \(y = (y_1, \dots, y_T)\) y un prompt \(x\):
es decir, la suma de las log-probs por token de los tokens de la respuesta solamente (no los tokens del prompt). La implementación suma las log-probs sobre una máscara de respuesta.
Implementación¶
# src/x3_rlhf/dpo.py
# DPO loss + training step. Total ~50 lines including utilities.
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
def sequence_logprob(
model: nn.Module,
input_ids: torch.Tensor, # (B, T)
response_mask: torch.Tensor, # (B, T), 1 for response tokens, 0 elsewhere
) -> torch.Tensor:
"""Sum of log-probs over response tokens. Shape: (B,)."""
logits = model(input_ids).logits[:, :-1, :] # (B, T-1, V)
targets = input_ids[:, 1:] # (B, T-1)
logp = F.log_softmax(logits, dim=-1)
token_logp = logp.gather(2, targets.unsqueeze(-1)).squeeze(-1) # (B, T-1)
mask = response_mask[:, 1:].float()
return (token_logp * mask).sum(dim=-1) # (B,)
def dpo_loss(
policy: nn.Module,
reference: nn.Module,
batch: dict[str, torch.Tensor],
beta: float = 0.1,
) -> tuple[torch.Tensor, dict[str, float]]:
"""DPO loss on (chosen, rejected) pairs. Reference is frozen."""
pi_w = sequence_logprob(policy, batch["chosen_ids"], batch["chosen_mask"])
pi_l = sequence_logprob(policy, batch["rejected_ids"], batch["rejected_mask"])
with torch.no_grad():
ref_w = sequence_logprob(reference, batch["chosen_ids"], batch["chosen_mask"])
ref_l = sequence_logprob(reference, batch["rejected_ids"], batch["rejected_mask"])
delta_w = pi_w - ref_w # implicit reward for chosen / β
delta_l = pi_l - ref_l # implicit reward for rejected / β
logits = beta * (delta_w - delta_l)
loss = -F.logsigmoid(logits).mean()
# Diagnostics
metrics = {
"loss": loss.item(),
"implicit_r_w": (beta * delta_w).mean().item(),
"implicit_r_l": (beta * delta_l).mean().item(),
"pair_accuracy": (logits > 0).float().mean().item(),
"kl_chosen": (pi_w - ref_w).mean().item(),
"kl_rejected": (pi_l - ref_l).mean().item(),
}
return loss, metrics
Script de entrenamiento (~30 líneas más, en scripts/train_dpo.py):
import copy, torch
from x3_rlhf.dpo import dpo_loss
from lynx_cortex.utils import seed_everything, save_manifest
from lynx_cortex.phase17 import load_sft_model
from lynx_cortex.phase28 import attach_lora
from lynx_cortex.data import load_pref_yaml, make_loader
seed_everything(42)
# Reference = frozen SFT model. Policy = SFT + trainable LoRA.
reference = load_sft_model("checkpoints/phase28-lora.pt").eval()
for p in reference.parameters():
p.requires_grad = False
policy = attach_lora(load_sft_model("checkpoints/phase28-lora.pt"),
r=8, alpha=16, dropout=0.05)
trainable = [p for p in policy.parameters() if p.requires_grad]
train_pairs, eval_pairs = load_pref_yaml(
"data/preferences/grammar_v0.yaml", split=(150, 50)
)
train_loader = make_loader(train_pairs, batch_size=8, shuffle=True)
opt = torch.optim.AdamW(trainable, lr=1e-4, weight_decay=0.0)
for epoch in range(3):
for batch in train_loader:
loss, m = dpo_loss(policy, reference, batch, beta=0.1)
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(trainable, max_norm=1.0)
opt.step()
print(f"loss={m['loss']:.4f} acc={m['pair_accuracy']:.3f} "
f"kl_w={m['kl_chosen']:+.3f} kl_l={m['kl_rejected']:+.3f}")
torch.save({k: v for k, v in policy.state_dict().items() if "lora" in k},
"checkpoints/dpo-lora-v0.pt")
save_manifest("experiments/2026-05-23-dpo-v0/manifest.json",
{"seed": 42, "beta": 0.1, "epochs": 3, "lr": 1e-4})
Ese es el bucle de entrenamiento DPO completo. Compáralo con lo que requiere PPO para lenguaje (capítulo 03 de teoría).
Evaluación: tasa de victoria vs baseline SFT¶
Para cada uno de los 50 prompts retenidos, genera una respuesta greedy del baseline SFT y otra del modelo DPO. Un juez (aquí, el modelo de recompensa del Lab 00) etiqueta cuál es mejor; reporta la fracción donde DPO gana.
# scripts/eval_winrate.py
from x3_rlhf.reward_model import RewardModel
from lynx_cortex.phase17 import load_sft_model, generate
from lynx_cortex.phase28 import attach_lora, load_lora
sft = load_sft_model("checkpoints/phase28-lora.pt").eval()
dpo = attach_lora(load_sft_model("checkpoints/phase28-lora.pt"))
load_lora(dpo, "checkpoints/dpo-lora-v0.pt")
dpo.eval()
rm = RewardModel(load_sft_model("checkpoints/phase28-lora.pt"))
rm.head.load_state_dict(torch.load("checkpoints/rm-v0.pt"))
rm.eval()
_, eval_pairs = load_pref_yaml("data/preferences/grammar_v0.yaml", split=(150, 50))
wins = 0
for pair in eval_pairs:
x = pair["prompt"]
y_sft = generate(sft, x, max_new=32, temperature=0.0)
y_dpo = generate(dpo, x, max_new=32, temperature=0.0)
r_sft = rm.score(x, y_sft)
r_dpo = rm.score(x, y_dpo)
wins += int(r_dpo > r_sft)
print(f"DPO win rate: {wins / len(eval_pairs):.3f}")
Resultados esperados¶
| Métrica | Objetivo | Notas |
|---|---|---|
| Pérdida DPO de entrenamiento tras 3 épocas | < 0.50 | Empieza en \(\log 2 \approx 0.693\) |
| Precisión por par en entrenamiento (época final) | > 0.85 | |
| KL en elegidos (\(\Delta_w\), media) | -0.5 a +1.5 | Debe ser positivo — política prefiriendo los elegidos |
| KL en rechazados (\(\Delta_l\), media) | -2.0 a -0.5 | Debe ser negativo — política rechazando los rechazados |
| Tasa de victoria DPO vs SFT (50 pares retenidos) | > 0.55 | La métrica titular |
Si la tasa de victoria DPO es ≤ 0.50, depura:
- ¿La pérdida realmente decreció? (Bug común: máscara de respuesta con off-by-one en tokens shifteados.)
- ¿Se mueven \(\Delta_w\) y \(\Delta_l\) en las direcciones correctas?
- ¿Está el modelo de referencia realmente congelado y en modo eval?
Qué observar¶
- Curva de pérdida. Debe caer desde \(\sim 0.69\) (aleatoria) a \(\sim 0.3\)–\(0.4\) en la época 3. Plana o subiendo indica un bug.
- Brecha de recompensa implícita \(\beta(\Delta_w - \Delta_l)\). Sigue el logit BT. Debe subir; si se satura rápido, baja \(\beta\) o revisa ruido en etiquetas.
- Trayectoria KL. \(\mathbb{E}[\Delta_w]\) y \(\mathbb{E}[\Delta_l]\) conjuntamente. Si ambas derivan muy negativas, la política está olvidando el lenguaje — sube \(\beta\).
Cosas a romper¶
- Pon \(\beta = 0\). La pérdida queda \(-\log\sigma(0) = \log 2\) independientemente de la política; el gradiente es cero. Verifica que el entrenamiento se atasca.
- Pon \(\beta = 10\). La sigmoide se satura casi de inmediato; el gradiente se desvanece para pares fáciles; el aprendizaje se atasca de otra forma.
- Intercambia chosen/rejected en el 10% de los pares de entrenamiento (ruido en etiquetas). Observa cómo la precisión de entrenamiento se estanca en un valor más bajo; DPO es más sensible al ruido en etiquetas que el SFT porque depende del ranking.
- Usa la propia política como referencia (\(\pi_{\text{ref}} \leftarrow \pi_\theta\) en cada paso). Los términos de recompensa implícita se hacen cero; la pérdida es constante. Confirma el rol del modelo de referencia.
Enlaces cruzados¶
- Teoría 04 — DPO y métodos directos: la derivación en cuatro pasos.
- Lab 00 — Modelo de recompensa: produce el RM usado como juez aquí.
- Fase 28 — LoRA / QLoRA: entrenamos solo los parámetros LoRA; modelo base congelado.
DoD¶
-
train_dpo.pycorre de extremo a extremo en CPU en < 15 minutos. - La pérdida decrece monotónicamente.
- Tasa de victoria DPO vs SFT > 0.55.
- Los cuatro diagnósticos (\(\Delta_w\), \(\Delta_l\), precisión por par, pérdida) guardados por época.
- Manifest persistido con seed, \(\beta\), lr, tasa de victoria.
- Dos ejercicios "break it" intentados con párrafo de resumen.