English · Español
05 — Matemáticas del routing en MoE + pérdida de balanceo de carga; intuición de memoria constante en Mamba¶
🇪🇸 Mixture-of-Experts es un router + N expertos: cada token elige top-k expertos. Si el router colapsa (todo a un experto), pierdes la ventaja entera. El "auxiliary loss" es un regulador suave que mantiene los expertos cargados de forma pareja. Mamba/SSM: la memoria por paso es constante en la longitud de secuencia porque mantiene un estado oculto de tamaño fijo, no una caché que crece.
Parte 1 — Routing de Mixture-of-Experts¶
La arquitectura en una frase¶
Una capa MoE (Mixture of Experts) reemplaza un único bloque feed-forward por \(E\) experts paralelos más un router que, para cada token, elige los top-\(k\) experts a evaluar (típicamente \(k = 1\) o \(k = 2\)). El router es un lineal diminuto: gates = softmax(W_r · x), shape (seq_len, E).
Para cada token \(t\) con hidden \(x_t\), la salida de la capa es:
La ganancia: solo \(k\) de \(E\) experts corren por token, así que el cómputo es \(k/E\) del denso, mientras que el conteo de parámetros es \(\sim E \cdot\) denso. Más parámetros, FLOPs similares. Ese intercambio es el juego entero de los MoE de frontera (Mixtral, Switch, GShard).
Las matemáticas del routing, explícitas¶
Para tamaño de batch \(B\) y longitud de secuencia \(L\), con \(T = B \cdot L\) tokens:
- Logits. \(G \in \mathbb{R}^{T \times E} = X W_r\) donde \(W_r \in \mathbb{R}^{d \times E}\).
- Softmax del gate. \(\hat{G} = \text{softmax}(G)\) a lo largo del eje de experts.
- Máscara top-k. Para cada fila \(t\), conservar las \(k\) entradas más grandes; poner el resto a cero. Renormalizar (típicamente no re-softmax, solo reescalar para que sumen 1 sobre las entradas conservadas; esta es la convención "Switch", pero hay variantes).
- Dispatch. Agrupar tokens por expert: el expert \(e\) recibe todos los tokens cuya máscara no es cero en la columna \(e\).
- Procesar. Cada expert corre su FFN sobre sus tokens asignados.
- Combinar. Scatter de las salidas hacia atrás, ponderar por el valor del gate, sumar.
El par dispatch/combine es donde vive la ingeniería de sistemas: implementarlo como una permutación + ungather es el colectivo MoE_all-to-all estándar en los MoE distribuidos (Switch Transformer §4).
La patología: colapso del router¶
El router es una función aprendida. No hay nada estructural que le impida aprender a elegir siempre el expert 0. Si eso ocurre:
- Un expert ve el 100% de los tokens (desborda su capacidad, puede dropear tokens).
- \(E - 1\) experts ven 0 tokens, reciben gradiente cero, nunca entrenan.
- El conteo efectivo de parámetros colapsa al equivalente de una sola FFN densa.
Esto es colapso del router y es el modo de fallo que existe la pérdida auxiliar para prevenir.
La pérdida auxiliar¶
Para cada expert \(e\), definir dos estadísticas sobre el batch:
- \(f_e\) = fracción de tokens routeados a \(e\) vía top-k (una medida sobre máscara discreta).
- \(P_e\) = media de \(\hat{G}_{:,e}\) a lo largo del batch (una medida soft del gate).
La pérdida auxiliar del Switch Transformer es:
Intuitivamente, \(f_e \cdot P_e\) se minimiza cuando ambos términos son iguales a \(1/E\) (routing uniforme), y la suma \(\sum_e f_e P_e \ge 1/E\) por Cauchy-Schwarz con igualdad en uniforme. El coeficiente \(\alpha\) es pequeño (~0.01); justo lo suficiente para romper el colapso del router sin dominar la pérdida principal.
¿Por qué tanto \(f\) como \(P\)? \(P\) solo es diferenciable (softmax del gate), pero es gameable por el modelo — puede mantener \(P\) uniforme mientras \(f\) se mantiene puntiagudo. \(f\) solo es no diferenciable (es un indicador top-k). El producto penaliza la correlación entre asignación de expert y confianza del gate — exactamente el modo de fallo del colapso del router.
Capacidad de expert y tokens dropeados¶
Cada expert tiene una capacidad finita \(C = \lceil \kappa \cdot T \cdot k / E \rceil\) donde \(\kappa\) es el "capacity factor" (~1.0 a 1.25). Si se routean más de \(C\) tokens al expert \(e\), los sobrantes se dropean — se saltan esta capa MoE (la conexión residual sigue preservando el input). Esto evita latencia ilimitada de straggler en el dispatch distribuido.
La capacidad es una restricción de sistema, no aprendida. Poner \(\kappa\) demasiado bajo pierde tokens; demasiado alto desperdicia memoria. La receta estándar es \(\kappa = 1.25\) para entrenamiento, \(\kappa = 2.0\) para inferencia (sin tolerancia a stragglers).
Intuición a escala §A13 (no es un MoE real)¶
No añadimos un MoE real al tutor de gramática (el modelo tiene 500k parámetros; la arquitectura estaría sobreingeniería). Pero el lab 00-moe-on-grammar-tutor.md recorre un MoE de juguete de 4 experts sobre el corpus de 600 formas para hacer las matemáticas del routing viscerales. El modo de fallo del colapso del router es el ejercicio /break.
Parte 2 — Mamba / SSMs y la afirmación de memoria constante¶
Por qué la attention escala O(N²) en memoria¶
La self-attention computa una matriz de attention \(N \times N\); incluso con la KV cache que convierte la inferencia en \(O(N)\) en cómputo por paso, la cache en sí crece linealmente: en el paso \(n\) almacenas \(n\) keys y \(n\) values por capa por head. Para contextos largos, la cache domina la memoria de GPU.
Modelos de espacio de estados en un diagrama¶
Un modelo de espacio de estados (SSM) mantiene un estado oculto de tamaño fijo \(h_t \in \mathbb{R}^{d_\text{state}}\) y lo evoluciona vía una recurrencia lineal:
La salida en el paso \(t\) depende solo de \(x_t\) y \(h_{t-1}\) — sin historia más allá del estado actual. Coste de memoria: \(O(d_\text{state})\) independientemente de la longitud de la secuencia. Esa es la afirmación de memoria constante.
¿Qué les pasaba a las RNN clásicas?¶
Las RNN clásicas tenían la misma forma recurrente pero: (a) la recurrencia era no lineal (tanh), por lo que no podía paralelizarse en el tiempo; (b) la matriz \(A\) no tenía estructura, así que las dependencias a largo plazo se desvanecían/explotaban.
Mamba (y S4) arreglan ambas:
- La recurrencia es lineal; la aplicación secuencial es matemáticamente equivalente a una convolución, que tiene algoritmos paralelos rápidos (scan asociativo, FFT para la forma de coeficientes fijos).
- \(A\) se parametriza para que sus eigenvalues se porten bien (inicialización HiPPO, estructura diagonal-más-bajo-rango de S4). Este es el "selective scan" que hace Mamba — y la matemática es lo que hace que las dependencias de largo alcance realmente entrenen.
Comparación de memoria de inferencia¶
Para un modelo con \(L\) capas, \(d_\text{state} = 16\) (default de Mamba), \(d_\text{model} = 1024\):
| Arquitectura | Memoria por paso (cache) | A seq_len = 8192 (bytes, fp16) |
|---|---|---|
| Transformer (KV) | \(L \cdot 2 \cdot d_\text{model} \cdot n\) tokens | \(L \cdot 2 \cdot 1024 \cdot 8192 \cdot 2 = 33L\) MB |
| Mamba (estado SSM) | \(L \cdot d_\text{state} \cdot d_\text{model}\) | \(L \cdot 16 \cdot 1024 \cdot 2 = 32L\) KB |
Tres órdenes de magnitud de diferencia a esta longitud de contexto. Eso es lo que compra "memoria constante en la longitud de secuencia".
¿Qué cede Mamba?¶
No es un free lunch:
- El acceso aleatorio a la historia se va. Un transformer puede atender a la posición 17 en detalle; Mamba solo puede ver lo que el estado oculto haya codificado de la posición 17. Para tareas que requieren recall preciso (needle-in-haystack), Mamba rinde peor; para tareas que solo necesitan contexto comprimido (modelado de lenguaje), iguala o supera a los transformers por FLOP.
- El in-context learning es más débil. El few-shot prompting depende en parte de la capacidad de recall preciso que tienen los transformers.
- Las arquitecturas híbridas (Mamba + capas de attention, p. ej. Jamba) conservan lo mejor de ambos — la mayoría de las capas son Mamba (baratas), unas pocas son attention (recall preciso).
Intuición a escala §A13¶
Para el tutor de gramática, los contextos son ≤ 64 tokens; el argumento de memoria constante es irrelevante. Lo cubrimos igualmente porque la decisión arquitectónica (cuándo usar qué) es parte del vocabulario del ingeniero, y el lab del survey recorre un pequeño bloque de Mamba para hacer concreta la recurrencia.
Lo que este capítulo NO cubre¶
- Implementaciones de kernel all-to-all de MoE (Tutel, FasterMoE). Territorio de producción / X1.
- MLA de DeepSeek (Multi-head Latent Attention). Cubierto por separado en
02-mla.md. - La teoría HiPPO detrás de la matriz \(A\) de S4. La derivación cerrada vive en el paper de S4.
- Scheduling híbrido Mamba+Attention. El paper de Jamba lo cubre.
Referencias¶
- Fedus et al., "Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity" (JMLR 2022). La fórmula de la pérdida auxiliar y la disciplina del capacity factor vienen de aquí.
- Gu & Dao, "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" (2023). La derivación del selective-scan que hace competitivos a los SSM.
- Lieber et al., "Jamba: A Hybrid Transformer-Mamba Language Model" (2024). La receta Mamba+attention en la práctica.
Siguiente: ../lab/00-moe-on-grammar-tutor.md o ../lab/02-mamba-walkthrough.md.