Skip to content

English · Español

Lab 02 — Static batching

🇪🇸 Junta peticiones en una cola. Cuando llegues a N (o pase un timeout), ejecuta el modelo sobre todo el batch en una sola pasada. Throughput sube; latencia de la última petición del batch sube también. Mídelo.

Objetivo

Implementar un scheduler de static batching: recoger peticiones, ejecutar el modelo una vez por batch. Medir throughput y tail latency vs el baseline sin batching del lab 01.

Setup

  • Servicio FastAPI funcionando del lab 01 (variante C: async + to_thread).
  • src/miniserve/scheduler.py — módulo nuevo.
  • El script de loadtest del lab 01.

Tareas

  1. Modifica la API del modelo para aceptar input batched. El forward() del Mini-GPT ya soporta una dimensión de batch (Fase 17). Añade agent.correct_batch(sentences: list[str]) -> list[Correction]:
class GrammarTutorAgent:
    def correct_batch(self, sentences: list[str], learner_ids: list[str | None]) -> list[Correction]:
        # Run the agent loop for each sentence with a single batched model.forward.
        # For now: simple — generate all responses to max_tokens with the same generation length.
        ...

Por simplicidad de la Fase 33, el agente generará a un max_tokens fijo para cada miembro del batch. Esto es intencional — hace visible el problema de tail latency del static batching.

  1. Escribe el scheduler en src/miniserve/scheduler.py:
import asyncio
from dataclasses import dataclass
from typing import Callable

@dataclass
class PendingRequest:
    payload: dict
    future: asyncio.Future

class StaticBatchScheduler:
    def __init__(self, batch_fn: Callable, max_batch: int, max_wait_ms: int):
        self.batch_fn = batch_fn
        self.max_batch = max_batch
        self.max_wait_ms = max_wait_ms
        self.queue: asyncio.Queue[PendingRequest] = asyncio.Queue()
        self._loop_task: asyncio.Task | None = None

    async def submit(self, payload: dict) -> dict:
        fut = asyncio.get_event_loop().create_future()
        await self.queue.put(PendingRequest(payload, fut))
        return await fut

    async def start(self):
        self._loop_task = asyncio.create_task(self._loop())

    async def _loop(self):
        while True:
            batch = await self._collect_batch()
            if not batch:
                await asyncio.sleep(0.001)
                continue
            # Run model in thread (CPU-bound)
            results = await asyncio.to_thread(
                self.batch_fn, [r.payload for r in batch]
            )
            for r, res in zip(batch, results):
                r.future.set_result(res)

    async def _collect_batch(self) -> list[PendingRequest]:
        batch = []
        try:
            first = await asyncio.wait_for(self.queue.get(), timeout=1.0)
            batch.append(first)
        except asyncio.TimeoutError:
            return batch
        deadline = asyncio.get_event_loop().time() + self.max_wait_ms / 1000
        while len(batch) < self.max_batch:
            remaining = deadline - asyncio.get_event_loop().time()
            if remaining <= 0:
                break
            try:
                batch.append(await asyncio.wait_for(self.queue.get(), timeout=remaining))
            except asyncio.TimeoutError:
                break
        return batch
  1. Conéctalo a la app FastAPI:
scheduler = StaticBatchScheduler(
    batch_fn=lambda payloads: agent.correct_batch(
        [p["sentence"] for p in payloads],
        [p.get("learner_id") for p in payloads],
    ),
    max_batch=8,
    max_wait_ms=20,
)

@app.on_event("startup")  # or use lifespan
async def _start():
    await scheduler.start()

@app.post("/correct")
async def correct(req: CorrectRequest) -> CorrectResponse:
    result = await scheduler.submit(req.model_dump())
    return CorrectResponse(**result)
  1. Barre los parámetros del batch. Para cada (max_batch, max_wait_ms) ∈ {(1,0), (2,10), (4,10), (8,20), (16,20), (32,50)}:
  2. Ejecuta loadtest con concurrency=50, total=500.
  3. Registra p50, p95, p99, throughput.

  4. Dibuja los resultados.

  5. Eje X: max_batch. Eje Y: throughput (req/s). Una línea por max_wait_ms.
  6. Eje X: max_batch. Eje Y: p95 latency. Mismo setup de líneas.
  7. Deberías ver el throughput subir y la p95 latency también subir — el trade-off.

  8. Compara con el baseline del lab 01. En los mismos ejes, dibuja el resultado sin batching (variante C) como un punto único.

Mediciones

Guarda en experiments/<date>-phase-33-lab-02/:

  • batch_sweep.csv — una fila por (max_batch, max_wait_ms): p50, p95, p99, throughput.
  • throughput_vs_batch.png
  • p95_vs_batch.png
  • latency_cdf_baseline_vs_batch8.png — sin batching vs max_batch=8.
  • manifest.json.

Aceptación

  • Para max_batch ≥ 4, el throughput supera el baseline sin batching en ≥ 50%.
  • Para max_batch ≥ 4, la p95 latency se degrada vs el baseline sin batching (este es el trade-off esperado, no una regresión).
  • La curva de throughput se satura en algún punto del barrido — aumentar max_batch más allá de cierto punto no ayuda.
  • Todas las peticiones devuelven HTTP 200 (sin timeouts) bajo la carga.

Trampas

  • Ejecutar el modelo con asyncio.to_thread pero dentro del _loop del scheduler. Esto es correcto — el bucle debe ceder mientras corre el modelo. Si llamas a self.batch_fn(...) directamente (no vía to_thread), el bucle se bloquea y ninguna otra petición puede encolarse durante el forward pass.
  • Poner max_wait_ms demasiado alto. Si esperas 100 ms para llenar el batch, la primera petición de cada batch se come 100 ms de espera pura en cola. Encuentra el valor correcto experimentalmente.
  • Olvidar el padding del batch. Todas las secuencias en un batch deben tener la misma longitud (o tienes que enmascarar). En el lab 02 esquivamos esto generando a max_tokens fijo. El desperdicio de padding es el coste.
  • Crecimiento de memoria bajo carga. Si la cola se llena, vas a OOM. Para el lab 02, acota la cola: en submit(), si queue.qsize() > MAX_QUEUE, devuelve 503.
  • Ruido de medición. Ejecuta cada config 3 veces, reporta la mediana de la p95.

Stretch

  • Añade un histograma de "cómo de lleno estaba cada batch" — muchos batches serán de tamaño 1 a baja carga, tamaño N a alta carga. Esto explica la saturación.
  • ¿Qué pasa si pones max_batch=1? Debería comportarse idénticamente al baseline sin batching (módulo el overhead del scheduler). Verifícalo.

Siguiente: 03-continuous-batching.md