Skip to content

English · Español

Lab 03 — Continuous batching: scheduling a nivel de iteración

🇪🇸 La diferencia con static batching: en lugar de "ejecuta todo el decode hasta que todos terminen", ejecuta un paso del decode por iteración. Las peticiones rápidas salen pronto; las nuevas se unen entre pasos. Es la técnica que usa vLLM, TGI, Triton.

Objetivo

Reemplazar el static batcher del lab 02 con un batcher continuous (a nivel de iteración). Mostrar que la p95 latency cae ≥ 30% en una workload de longitud mixta, con throughput igual o mejor.

Setup

  • Scheduler del lab 02 (src/miniserve/scheduler.py).
  • KV cache de la Fase 22 (src/minimodel/kv_cache.py) — estrictamente requerido: continuous batching sin KV cache hace que cada paso sea cuadrático. Si la Fase 22 todavía no está hecha, posponer este lab.
  • Una workload de longitud mixta: prompts donde algunas respuestas son de 1-2 tokens, otras de 8-12 tokens.

Tareas

  1. Extiende el agente con una API por pasos:
class GrammarTutorAgent:
    def begin_correction(self, sentence: str, learner_id: str | None) -> "InflightCorrection":
        """Run the prefill, return a handle with the initial KV cache and first token."""
        ...

    def step_corrections(self, inflight: list["InflightCorrection"]) -> list[tuple[int, bool]]:
        """One decode step on a batch of in-flight corrections. Returns [(new_token_id, is_done), ...]"""
        ...

Esbozo de implementación para step_corrections: - Recoger el "último token" de cada in-flight → batch de tamaño \(B\). - Ejecutar model.forward_one_step(tokens, kv_caches) — un paso de decode. - Para cada uno: muestrear el siguiente token, comprobar si es EOS o alcanza max_tokens. - Añadir el nuevo token al estado de cada in-flight.

  1. Escribe el continuous batcher en src/miniserve/scheduler.py:
class ContinuousBatchScheduler:
    def __init__(self, agent, max_inflight: int, max_queue: int):
        self.agent = agent
        self.max_inflight = max_inflight
        self.max_queue = max_queue
        self.ready: asyncio.Queue[PendingRequest] = asyncio.Queue(maxsize=max_queue)
        self.inflight: list[InflightCorrection] = []

    async def submit(self, payload):
        fut = asyncio.get_event_loop().create_future()
        try:
            self.ready.put_nowait(PendingRequest(payload, fut))
        except asyncio.QueueFull:
            raise HTTPException(503, "Server queue full")
        return await fut

    async def _loop(self):
        while True:
            # Admit
            while len(self.inflight) < self.max_inflight and not self.ready.empty():
                req = self.ready.get_nowait()
                inflight_corr = await asyncio.to_thread(
                    self.agent.begin_correction, req.payload["sentence"], req.payload.get("learner_id")
                )
                inflight_corr.future = req.future
                self.inflight.append(inflight_corr)

            if not self.inflight:
                await asyncio.sleep(0.001)
                continue

            # Step
            step_results = await asyncio.to_thread(
                self.agent.step_corrections, self.inflight
            )

            # Reap finished
            still_inflight = []
            for corr, (tok, done) in zip(self.inflight, step_results):
                if done:
                    corr.future.set_result(corr.to_response_dict())
                else:
                    still_inflight.append(corr)
            self.inflight = still_inflight
  1. Conéctalo. Reemplaza el static scheduler en app.py:
scheduler = ContinuousBatchScheduler(
    agent=agent,
    max_inflight=8,
    max_queue=200,
)
  1. Diseña la workload de longitud mixta. Desde el corpus de verbos, construye 200 prompts:
  2. 60% son cortos: "He" → 1-2 tokens esperados (solo la forma del verbo).
  3. 30% son medios: "Yesterday I" → 2-3 tokens.
  4. 10% son largos: prompts que provocan futuros con "going to" o explicaciones completas — 6-10 tokens.

Mezcla. Guarda como data/mixed_workload.json.

  1. Haz load-test con los tres schedulers sobre esta workload, todos con concurrency=50, total=500:
  2. A: Sin scheduler (lab 01, variante C).
  3. B: Static batching (max_batch=8, max_wait_ms=20).
  4. C: Continuous batching (max_inflight=8).

  5. Dibuja la comparación:

  6. CDF de latencia: tres curvas en los mismos ejes.
  7. Gráfico de barras p50, p95, p99: tres grupos de tres barras.
  8. Gráfico de barras de throughput (req/s).

  9. Verifica la propiedad:

  10. La p95 de C es ≥ 30% mejor que la p95 de B.
  11. La p50 de C es similar o mejor que la p50 de B.
  12. El throughput de C es similar o mejor que B (continuous batching es throughput-neutral en el caso fácil; este es el win en latencia).

  13. Barre max_inflight ∈ {1, 2, 4, 8, 16} para el continuous scheduler. Dibuja throughput vs max_inflight y p95 vs max_inflight. Forma esperada: el throughput sube y se satura; la p95 se queda aproximadamente plana y luego sube en el extremo alto (la espera en cola entra en juego).

Mediciones

Guarda en experiments/<date>-phase-33-lab-03/:

  • latencies_nobatch.json, latencies_static.json, latencies_continuous.json.
  • latency_cdf_3way.png.
  • pNN_bars.png.
  • inflight_sweep.csv e inflight_sweep.png.
  • manifest.json — composición de la workload (el split 60/30/10), seeds, checkpoint del modelo.

Aceptación (DoD relevante)

  • p95 de continuous batching ≥ 30% mejor que static batching en la workload mixta. (DoD de la Fase 33.)
  • Throughput de continuous batching ≥ static batching (dentro del 10%).
  • Todas las peticiones devuelven HTTP 200 O HTTP 503 (sin errores 500, sin timeouts).
  • Bajo carga sostenida por encima de capacidad (concurrency=200), el scheduler rechaza con 503 cuando se excede max_queue — no OOMs ni stalls.

Trampas

  • Bookkeeping del KV cache por petición. Cada in-flight tiene su propio cache. No compartas buffers entre peticiones — tienen longitudes distintas. (La PagedAttention de la Fase 27 resuelve la versión "desperdicio de memoria" de esto; para nuestro lab, simplemente asigna por petición.)
  • Hacer batch del step, no del forward completo. Es el punto entero. Si accidentalmente haces batch de una generación multi-token, has reinventado static batching.
  • Determinismo del tokenizer. Cuando muestras un token en el paso t, tiene que retroalimentarse como input del paso t+1. Si tu forward batched tiene outputs sutilmente dependientes del padding, obtendrás tokens distintos que la versión sin batch. Añade un property test: el output de continuous-batch para una sola petición debe coincidir con el output sin batch (módulo reordenamiento en coma flotante).
  • Cancelación. Si el cliente se desconecta, la petición debería salir del conjunto in-flight. FastAPI expone request.is_disconnected() — manéjalo. Si no, desperdicias cómputo en un cliente muerto.
  • Equidad. Admisión FIFO es lo más sencillo. Con políticas más sofisticadas (prioridades, fairness), puedes matar de hambre algunas peticiones. Deja esto para la Fase 34.

Stretch (fuera de scope para el DoD)

  • Continuous batching + prefill batching. Los sistemas reales hacen batch del prefill por separado porque el prefill es caro. Investígalo y haz un gesto hacia ello en el lab 04.
  • Attention de longitud variable en un solo forward. Implementa masking para que peticiones con cached-lengths distintos compartan el kernel del step. O: padding hasta la longitud más larga (más sencillo, lo que hace el lab 03).

Siguiente: 04-vllm-and-tgi-survey.md