Skip to content

English · Español

03 — Tied embeddings y la cabeza LM

🇪🇸 La cabeza del modelo lingüístico reutiliza la misma matriz que el embedding de entrada. Un truco de un solo carácter en el código que ahorra \(|V| \cdot d_\text{model}\) parámetros y unifica la entrada y salida en un mismo espacio. Aquí no salvamos mucha memoria — el corpus es pequeño — pero entendemos el principio.

La cabeza LM

Tras el LayerNorm final, el residual stream queda en \(\mathbb{R}^{T \times d_\text{model}}\). Para convertir esto en una distribución de probabilidad sobre los tokens del vocabulario en cada posición, aplicamos una proyección lineal al tamaño del vocabulario, luego softmax:

\[\text{logits}_t = h_t \cdot W_\text{LM}^\top \quad \in \mathbb{R}^{|V|}$$ $$p_t = \text{softmax}(\text{logits}_t)\]

donde \(W_\text{LM} \in \mathbb{R}^{|V| \times d_\text{model}}\). Ingenuamente, esto es una matriz aprendible nueva con \(|V| \cdot d_\text{model}\) parámetros.

Pero fíjate: el embedding de entrada \(E \in \mathbb{R}^{|V| \times d_\text{model}}\) tiene la misma forma. Un vector "embebe" un token id; la transpuesta "desembebe" un residual de vuelta a un token id. Las dos operaciones son inversas la una de la otra en un sentido significativo.

Atado de pesos (weight tying)

Tied embeddings (Press & Wolf 2017, "Using the Output Embedding to Improve Language Models"; concurrente en Inan et al. 2017): fijar \(W_\text{LM} = E\). La matriz de embedding de entrada es la matriz de proyección de salida. El forward queda:

\[\text{logits}_t = h_t \cdot E^\top\]

Eso es todo — la misma matriz, usada dos veces. Patrón:

class MiniGPT:
    def __init__(self, vocab_size, d_model, ...):
        self.E = Parameter(np.random.randn(vocab_size, d_model) * 0.02)
        # ... bloques, LNs, etc.
        # NO self.W_LM. La cabeza LM reutiliza self.E.

    def forward(self, tokens):
        h = self.E[tokens]                # (T, d_model)  — token embed
        for block in self.blocks:
            h = block(h)
        h = self.ln_final(h)
        logits = h @ self.E.T             # (T, vocab_size) — cabeza LM, atada
        return logits

Dos beneficios

1. Ahorro de parámetros

Atar elimina una matriz \(|V| \cdot d_\text{model}\). Para Mini-GPT (\(|V| = 64, d_\text{model} = 64\)), eso son 4096 parámetros — poco. Pero para GPT-2 (\(|V| = 50257, d_\text{model} = 768\)), son 38,5M de parámetros — de un total de 124M, ~31% del modelo. A escala LLaMA-2 (\(|V| = 32000, d_\text{model} = 4096\)), son 131M de parámetros. Los ahorros son sustanciales a escala real.

Para Mini-GPT, los ahorros son pedagógicos, no prácticos. Atamos de todas formas porque:

  • Enseña el principio.
  • Hace significativa la dirección de la matriz de embedding — las direcciones que mejoran el embedding de entrada también mejoran la proyección de salida.
  • Es lo que hace cada transformer moderno.

2. Simetría conceptual

El atado dice: "la misma noción de qué es un token (el embedding) determina lo que un estado oculto significa sobre un token (el unembedding)". Hay un único espacio de vocabulario, usado dos veces. Esto es satisfactorio conceptualmente y resulta ser cierto empíricamente — los embeddings de entrada y salida entrenados de forma independiente acaban cerca entre ellos de todos modos. Atar simplemente impone ese prior.

La interpretabilidad mecanística descansa en esto: la dirección de unembed para el token \(w\) es la fila \(E[w]\), y puedes preguntar "¿qué direcciones residuales apuntan a \(E[w]\)?" — esta es la técnica de "logit lens" (nostalgebraist 2020). El atado hace la lente bien definida.

El forward completo de Mini-GPT, con cabeza atada

tokens (T,)   →   E[tokens]              shape (T, d_model)
             block_0                     shape (T, d_model)
             block_1                     shape (T, d_model)
             LN_final                    shape (T, d_model)
             @ E.T                       shape (T, vocab_size)
             logits → (softmax en la pérdida; no forma parte del modelo propiamente)

¿Por qué no hay softmax dentro de MiniGPT.forward?

El forward devuelve logits, no probabilidades. El softmax ocurre:

  • Dentro de la función de pérdida (Fase 18) — fusionado con la pérdida por estabilidad numérica (el truco cross_entropy_from_logits de la Fase 05).
  • Dentro del muestreo (Fase 21) — posiblemente con escalado de temperatura.

Desacoplar forward del softmax significa:

  • Estabilidad numérica: la pérdida puede usar el truco log-sum-exp en lugar de calcular log(softmax(...)) directamente.
  • Flexibilidad: en inferencia, puedes aplicar temperatura, top-k, top-p sin recomputar.

.forward() de Mini-GPT devuelve logits: (T, V). Lo que venga después elige qué hacer con ellos.

Un punto sutil: flujo de gradiente con atado

Cuando atas \(W_\text{LM} = E\), el gradiente \(\partial \mathcal{L} / \partial E\) tiene dos contribuciones: una del lookup del embedding de entrada, otra de la proyección de salida. El autograd lo gestiona automáticamente si cableas \(E\) como un único Parameter referenciado dos veces — exactamente lo que hace el código de arriba. Si por accidente haces dos copias (p. ej., self.W_LM = E.copy()), las has desatado y has perdido la propiedad.

La Fase 18 examinará el flujo de gradiente hacia el \(E\) atado como comprobación de cordura; las dos contribuciones se suman y la dirección de actualización resultante es significativa.

Inicialización

Para tied embeddings, inicializa \(E\) una vez con la inicialización estilo embedding (gaussiana pequeña, típicamente \(\mathcal{N}(0, 0{,}02^2)\)). La proyección de salida hereda esta inicialización — no hace falta inicialización separada. Este es el default de GPT-2.

Para cabezas LM no atadas, típicamente inicializarías la cabeza con la misma escala, ya que efectivamente es otra capa lineal.

El softmax final — dónde vive en inferencia

En la Fase 21 (muestreo), los logits del modelo se convierten en probabilidades mediante:

\[q_t = \text{softmax}(\text{logits}_t / \tau)\]

donde \(\tau\) es la temperatura (\(\tau = 1\) es "crudo", \(\tau < 1\) afila, \(\tau > 1\) aplana). La Fase 21 cubre temperatura, top-k, top-p, nucleus. Nada de esto forma parte de la Fase 17. El trabajo de la Fase 17 termina en los logits.

Qué NO cubre este archivo

  • Estrategias de muestreo. Fase 21.
  • La pérdida de entropía cruzada usando logits. Fase 18 (con el truco de estabilidad numérica de la Fase 05).
  • Cabezas LM no atadas. Mencionadas por completitud; no se usan.

Siguiente: ../lab/00-block-by-hand.md