English · Español
Lab 02 — Static batching¶
🇪🇸 Junta peticiones en una cola. Cuando llegues a N (o pase un timeout), ejecuta el modelo sobre todo el batch en una sola pasada. Throughput sube; latencia de la última petición del batch sube también. Mídelo.
Objetivo¶
Implementar un scheduler de static batching: recoger peticiones, ejecutar el modelo una vez por batch. Medir throughput y tail latency vs el baseline sin batching del lab 01.
Setup¶
- Servicio FastAPI funcionando del lab 01 (variante C: async + to_thread).
src/miniserve/scheduler.py— módulo nuevo.- El script de loadtest del lab 01.
Tareas¶
- Modifica la API del modelo para aceptar input batched. El
forward()del Mini-GPT ya soporta una dimensión de batch (Fase 17). Añadeagent.correct_batch(sentences: list[str]) -> list[Correction]:
class GrammarTutorAgent:
def correct_batch(self, sentences: list[str], learner_ids: list[str | None]) -> list[Correction]:
# Run the agent loop for each sentence with a single batched model.forward.
# For now: simple — generate all responses to max_tokens with the same generation length.
...
Por simplicidad de la Fase 33, el agente generará a un max_tokens fijo para cada miembro del batch. Esto es intencional — hace visible el problema de tail latency del static batching.
- Escribe el scheduler en
src/miniserve/scheduler.py:
import asyncio
from dataclasses import dataclass
from typing import Callable
@dataclass
class PendingRequest:
payload: dict
future: asyncio.Future
class StaticBatchScheduler:
def __init__(self, batch_fn: Callable, max_batch: int, max_wait_ms: int):
self.batch_fn = batch_fn
self.max_batch = max_batch
self.max_wait_ms = max_wait_ms
self.queue: asyncio.Queue[PendingRequest] = asyncio.Queue()
self._loop_task: asyncio.Task | None = None
async def submit(self, payload: dict) -> dict:
fut = asyncio.get_event_loop().create_future()
await self.queue.put(PendingRequest(payload, fut))
return await fut
async def start(self):
self._loop_task = asyncio.create_task(self._loop())
async def _loop(self):
while True:
batch = await self._collect_batch()
if not batch:
await asyncio.sleep(0.001)
continue
# Run model in thread (CPU-bound)
results = await asyncio.to_thread(
self.batch_fn, [r.payload for r in batch]
)
for r, res in zip(batch, results):
r.future.set_result(res)
async def _collect_batch(self) -> list[PendingRequest]:
batch = []
try:
first = await asyncio.wait_for(self.queue.get(), timeout=1.0)
batch.append(first)
except asyncio.TimeoutError:
return batch
deadline = asyncio.get_event_loop().time() + self.max_wait_ms / 1000
while len(batch) < self.max_batch:
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
break
try:
batch.append(await asyncio.wait_for(self.queue.get(), timeout=remaining))
except asyncio.TimeoutError:
break
return batch
- Conéctalo a la app FastAPI:
scheduler = StaticBatchScheduler(
batch_fn=lambda payloads: agent.correct_batch(
[p["sentence"] for p in payloads],
[p.get("learner_id") for p in payloads],
),
max_batch=8,
max_wait_ms=20,
)
@app.on_event("startup") # or use lifespan
async def _start():
await scheduler.start()
@app.post("/correct")
async def correct(req: CorrectRequest) -> CorrectResponse:
result = await scheduler.submit(req.model_dump())
return CorrectResponse(**result)
- Barre los parámetros del batch. Para cada
(max_batch, max_wait_ms) ∈ {(1,0), (2,10), (4,10), (8,20), (16,20), (32,50)}: - Ejecuta loadtest con
concurrency=50, total=500. -
Registra p50, p95, p99, throughput.
-
Dibuja los resultados.
- Eje X:
max_batch. Eje Y: throughput (req/s). Una línea pormax_wait_ms. - Eje X:
max_batch. Eje Y: p95 latency. Mismo setup de líneas. -
Deberías ver el throughput subir y la p95 latency también subir — el trade-off.
-
Compara con el baseline del lab 01. En los mismos ejes, dibuja el resultado sin batching (variante C) como un punto único.
Mediciones¶
Guarda en experiments/<date>-phase-33-lab-02/:
batch_sweep.csv— una fila por(max_batch, max_wait_ms): p50, p95, p99, throughput.throughput_vs_batch.pngp95_vs_batch.pnglatency_cdf_baseline_vs_batch8.png— sin batching vsmax_batch=8.manifest.json.
Aceptación¶
- Para
max_batch ≥ 4, el throughput supera el baseline sin batching en ≥ 50%. - Para
max_batch ≥ 4, la p95 latency se degrada vs el baseline sin batching (este es el trade-off esperado, no una regresión). - La curva de throughput se satura en algún punto del barrido — aumentar
max_batchmás allá de cierto punto no ayuda. - Todas las peticiones devuelven HTTP 200 (sin timeouts) bajo la carga.
Trampas¶
- Ejecutar el modelo con
asyncio.to_threadpero dentro del_loopdel scheduler. Esto es correcto — el bucle debe ceder mientras corre el modelo. Si llamas aself.batch_fn(...)directamente (no víato_thread), el bucle se bloquea y ninguna otra petición puede encolarse durante el forward pass. - Poner
max_wait_msdemasiado alto. Si esperas 100 ms para llenar el batch, la primera petición de cada batch se come 100 ms de espera pura en cola. Encuentra el valor correcto experimentalmente. - Olvidar el padding del batch. Todas las secuencias en un batch deben tener la misma longitud (o tienes que enmascarar). En el lab 02 esquivamos esto generando a
max_tokensfijo. El desperdicio de padding es el coste. - Crecimiento de memoria bajo carga. Si la cola se llena, vas a OOM. Para el lab 02, acota la cola: en
submit(), siqueue.qsize() > MAX_QUEUE, devuelve 503. - Ruido de medición. Ejecuta cada config 3 veces, reporta la mediana de la p95.
Stretch¶
- Añade un histograma de "cómo de lleno estaba cada batch" — muchos batches serán de tamaño 1 a baja carga, tamaño N a alta carga. Esto explica la saturación.
- ¿Qué pasa si pones
max_batch=1? Debería comportarse idénticamente al baseline sin batching (módulo el overhead del scheduler). Verifícalo.
Siguiente: 03-continuous-batching.md