# SPDX-License-Identifier: Apache-2.0
"""vLLM v1 patch for overlap distiller (batch-friendly, eager-friendly)."""

from __future__ import annotations

import copy
import os
import threading
from typing import Dict, List, Optional, Union

import torch
import torch.nn.functional as F
import vllm.envs as envs
from vllm.config import CUDAGraphMode
from vllm.distributed.kv_transfer import has_kv_transfer_group
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.sequence import IntermediateTensors
from vllm.utils import round_up
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.worker.gpu_worker import Worker
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

from .vllm_overlap_patch import _DistillerRuntime, _RUNTIME as _BASE_RUNTIME
from .hidden_transition_distiller import (
    BatchedTransitionMLP,
    TransitionMLP,
    batched_transition_forward,
    batched_transition_forward_stacked,
    stack_transition_weights,
)

_PATCH_LOCK = threading.Lock()
_PATCHED = False
_ORIG_EXECUTE_MODEL = None
_ORIG_WORKER_EXECUTE = None


def _env_flag(name: str, default: bool = False) -> bool:
    val = os.getenv(name)
    if val is None:
        return bool(default)
    val = str(val).strip().lower()
    if val in {"1", "true", "yes", "y", "on"}:
        return True
    if val in {"0", "false", "no", "n", "off"}:
        return False
    return bool(default)


def _env_str(name: str, default: str) -> str:
    val = os.getenv(name)
    if val is None:
        return str(default)
    return str(val).strip()


def _env_int(name: str, default: int) -> int:
    val = os.getenv(name)
    if val is None:
        return int(default)
    try:
        return int(val)
    except Exception:
        return int(default)


class _OverlapRuntimeManager:
    def __init__(self) -> None:
        self._runtimes: Dict[str, _DistillerRuntime] = {}
        self._req_id_map: Dict[str, int] = {}
        self._next_req_id = 0
        self._active_groups: List[str] = []
        self._active_logit_rows: Dict[str, List[int]] = {}
        self._config_sig: Optional[tuple] = None
        self._final_norm_weight: Optional[torch.Tensor] = None
        self._final_norm_bias: Optional[torch.Tensor] = None
        self._final_norm_eps = 1e-6
        self._final_norm_type = "rms"
        self._reward_scale_inv_norms: Optional[torch.Tensor] = None
        self._reward_scale_device: Optional[torch.device] = None
        self._reward_scale_dtype: Optional[torch.dtype] = None
        self._lock = threading.Lock()
        self._batched_stream: Optional[torch.cuda.Stream] = None
        self._batched_weight_cache: Optional[dict] = None
        self._batched_weight_ids: Optional[List[int]] = None
        self._batched_weight_steps: Optional[List[int]] = None
        self._batched_train_stream: Optional[torch.cuda.Stream] = None
        self._batched_predictor: Optional[BatchedTransitionMLP] = None
        self._batched_optimizer: Optional[torch.optim.Optimizer] = None
        self._batched_group_map: Dict[str, int] = {}
        self._batched_param_sig: Optional[tuple] = None

    def _config_signature(self) -> tuple:
        return (
            _BASE_RUNTIME.enabled,
            _BASE_RUNTIME.beta,
            _BASE_RUNTIME.topk,
            _BASE_RUNTIME.buffer_slots,
            _BASE_RUNTIME.first_layer,
            _BASE_RUNTIME.mix_mode,
            _BASE_RUNTIME.train_enabled,
            _BASE_RUNTIME.train_sync_interval,
            _BASE_RUNTIME.normalize_inputs,
            _BASE_RUNTIME.reward_scale,
            _BASE_RUNTIME.logits_rescale,
            _BASE_RUNTIME.logit_loss_weight,
            _BASE_RUNTIME.curiosity_only,
        )

    def _batched_enabled(self) -> bool:
        return _env_flag("ES_OVERLAP_BATCHED_MLP", False)

    def _batched_params_enabled(self) -> bool:
        return _env_flag("ES_OVERLAP_BATCHED_PARAMS", False)

    def _batched_max_groups(self) -> int:
        return max(1, _env_int("ES_OVERLAP_MAX_GROUPS", 64))

    def _get_batched_stream(self, device: torch.device) -> torch.cuda.Stream:
        if self._batched_stream is None:
            self._batched_stream = torch.cuda.Stream(device=device)
        return self._batched_stream

    def _get_batched_train_stream(self, device: torch.device) -> torch.cuda.Stream:
        if self._batched_train_stream is None:
            self._batched_train_stream = torch.cuda.Stream(device=device)
        return self._batched_train_stream

    def _batched_predictor_signature(
        self,
        *,
        hidden_dim: int,
        mlp_hidden_dim: Optional[int],
        num_groups: int,
        device: torch.device,
    ) -> tuple:
        return (int(hidden_dim), int(mlp_hidden_dim or 0), int(num_groups), device)

    def _ensure_batched_predictor(
        self,
        *,
        hidden_dim: int,
        mlp_hidden_dim: Optional[int],
        num_groups: int,
        device: torch.device,
    ) -> None:
        sig = self._batched_predictor_signature(
            hidden_dim=hidden_dim,
            mlp_hidden_dim=mlp_hidden_dim,
            num_groups=num_groups,
            device=device,
        )
        if self._batched_predictor is not None and self._batched_param_sig == sig:
            return
        self._batched_predictor = BatchedTransitionMLP(
            num_groups=num_groups,
            input_dim=hidden_dim,
            output_dim=hidden_dim,
            mlp_hidden_dim=mlp_hidden_dim,
        ).to(device)
        self._batched_predictor.train(_BASE_RUNTIME.train_enabled)
        try:
            self._batched_optimizer = torch.optim.Adam(
                self._batched_predictor.parameters(),
                lr=5e-4,
                fused=True,
            )
        except Exception:
            self._batched_optimizer = torch.optim.Adam(
                self._batched_predictor.parameters(),
                lr=5e-4,
            )
        self._batched_param_sig = sig

    def _get_or_assign_batched_index(self, group_id: str) -> Optional[int]:
        idx = self._batched_group_map.get(group_id)
        if idx is not None:
            return idx
        max_groups = self._batched_max_groups()
        if len(self._batched_group_map) >= max_groups:
            return None
        idx = len(self._batched_group_map)
        self._batched_group_map[group_id] = idx
        return idx

    def _collect_active_runtimes(self) -> List[_DistillerRuntime]:
        runtimes: List[_DistillerRuntime] = []
        for group_id in list(self._active_groups):
            runtime = self._runtimes.get(group_id)
            if runtime is None or not runtime.current_active:
                continue
            if runtime.current_slot is None or runtime.current_count <= 0:
                continue
            if self._batched_params_enabled() and getattr(runtime, "batched_index", -1) < 0:
                continue
            runtimes.append(runtime)
        return runtimes

    def _sync_config(self) -> None:
        sig = self._config_signature()
        if sig == self._config_sig:
            return
        for runtime in self._runtimes.values():
            runtime.configure(
                enabled=_BASE_RUNTIME.enabled,
                beta=_BASE_RUNTIME.beta,
                topk=_BASE_RUNTIME.topk,
                buffer_slots=_BASE_RUNTIME.buffer_slots,
                first_layer=_BASE_RUNTIME.first_layer,
                mix_mode=_BASE_RUNTIME.mix_mode,
                train_enabled=_BASE_RUNTIME.train_enabled,
                train_sync_interval=_BASE_RUNTIME.train_sync_interval,
                normalize_inputs=_BASE_RUNTIME.normalize_inputs,
                reward_scale=_BASE_RUNTIME.reward_scale,
                logits_rescale=_BASE_RUNTIME.logits_rescale,
                logit_loss_weight=_BASE_RUNTIME.logit_loss_weight,
                curiosity_only=_BASE_RUNTIME.curiosity_only,
            )
            if self._batched_params_enabled():
                runtime.batched_params = True
        self._config_sig = sig

    def _apply_shared_cache(self, runtime: _DistillerRuntime) -> None:
        # If cached norm pointers change, invalidate any captured CUDA graphs in the runtime.
        if (
            runtime.final_norm_weight is not self._final_norm_weight
            or runtime.final_norm_bias is not self._final_norm_bias
            or runtime.final_norm_eps != self._final_norm_eps
            or runtime.final_norm_type != self._final_norm_type
        ):
            try:
                runtime._clear_cudagraph_cache()  # type: ignore[attr-defined]
                runtime._cudagraph_sig = runtime._cudagraph_signature()  # type: ignore[attr-defined]
            except Exception:
                pass
            self._batched_weight_cache = None
            self._batched_weight_ids = None
            self._batched_weight_steps = None

        runtime.final_norm_weight = self._final_norm_weight
        runtime.final_norm_bias = self._final_norm_bias
        runtime.final_norm_eps = self._final_norm_eps
        runtime.final_norm_type = self._final_norm_type
        if hasattr(runtime, "_sync_predictor_norm"):
            runtime._sync_predictor_norm()
        runtime.weight_inv_norms = self._reward_scale_inv_norms
        runtime.weight_inv_norms_device = self._reward_scale_device
        runtime.weight_inv_norms_dtype = self._reward_scale_dtype

    def _group_req_id(self, req_id: str) -> str:
        if not isinstance(req_id, str):
            return str(req_id)
        prefix, sep, rest = req_id.partition("_")
        if sep and prefix.isdigit() and rest:
            return rest
        return req_id

    def _get_req_id(self, req_id: str) -> int:
        if req_id not in self._req_id_map:
            self._req_id_map[req_id] = self._next_req_id
            self._next_req_id += 1
        return self._req_id_map[req_id]

    def _get_runtime(self, req_id: str) -> Optional[_DistillerRuntime]:
        runtime = self._runtimes.get(req_id)
        if runtime is None:
            runtime = _DistillerRuntime()
            runtime.configure(
                enabled=_BASE_RUNTIME.enabled,
                beta=_BASE_RUNTIME.beta,
                topk=_BASE_RUNTIME.topk,
                buffer_slots=_BASE_RUNTIME.buffer_slots,
                first_layer=_BASE_RUNTIME.first_layer,
                mix_mode=_BASE_RUNTIME.mix_mode,
                train_enabled=_BASE_RUNTIME.train_enabled,
                train_sync_interval=_BASE_RUNTIME.train_sync_interval,
                normalize_inputs=_BASE_RUNTIME.normalize_inputs,
                reward_scale=_BASE_RUNTIME.reward_scale,
                logits_rescale=_BASE_RUNTIME.logits_rescale,
                logit_loss_weight=_BASE_RUNTIME.logit_loss_weight,
                curiosity_only=_BASE_RUNTIME.curiosity_only,
            )
            if self._batched_params_enabled():
                idx = self._get_or_assign_batched_index(req_id)
                if idx is None:
                    return None
                runtime.batched_params = True
                runtime.batched_index = int(idx)
            self._apply_shared_cache(runtime)
            self._runtimes[req_id] = runtime
        return runtime

    def sync_norm_cache_from_base(self) -> None:
        self._final_norm_weight = _BASE_RUNTIME.final_norm_weight
        self._final_norm_bias = _BASE_RUNTIME.final_norm_bias
        self._final_norm_eps = _BASE_RUNTIME.final_norm_eps
        self._final_norm_type = _BASE_RUNTIME.final_norm_type
        self._reward_scale_inv_norms = _BASE_RUNTIME.weight_inv_norms
        self._reward_scale_device = _BASE_RUNTIME.weight_inv_norms_device
        self._reward_scale_dtype = _BASE_RUNTIME.weight_inv_norms_dtype
        for runtime in self._runtimes.values():
            self._apply_shared_cache(runtime)

    def clear_active(self) -> None:
        self._active_groups = []
        self._active_logit_rows = {}

    def start_step(
        self,
        *,
        req_ids: List[str],
        logits_indices: Optional[torch.Tensor],
        scheduler_output,
        device: torch.device,
    ) -> None:
        self.clear_active()
        if not _BASE_RUNTIME.enabled:
            return
        if logits_indices is None or not isinstance(logits_indices, torch.Tensor):
            return
        num_reqs = len(req_ids)
        if num_reqs == 0 or logits_indices.numel() < num_reqs:
            return

        self._sync_config()
        group_req_indices: Dict[str, List[int]] = {}
        group_logit_rows: Dict[str, List[int]] = {}
        for req_idx, req_id in enumerate(req_ids):
            if scheduler_output.num_scheduled_tokens.get(req_id, 0) <= 0:
                continue
            group_id = self._group_req_id(req_id)
            group_req_indices.setdefault(group_id, []).append(req_idx)
            group_logit_rows.setdefault(group_id, []).append(req_idx)

        for group_id, req_indices in group_req_indices.items():
            runtime = self._get_runtime(group_id)
            if runtime is None:
                continue
            req_int = self._get_req_id(group_id)
            if not req_indices:
                continue
            idx_tensor = torch.tensor(req_indices, device=logits_indices.device, dtype=torch.long)
            hidden_rows = logits_indices.index_select(0, idx_tensor)
            runtime.start_step_with_rows(hidden_rows, device, req_id=req_int)
            self._active_groups.append(group_id)
            self._active_logit_rows[group_id] = group_logit_rows.get(group_id, [])

    def has_active(self) -> bool:
        return bool(self._active_groups)

    def record_h1(self, output: torch.Tensor) -> None:
        if not self._active_groups:
            return
        if not isinstance(output, torch.Tensor):
            return
        if isinstance(output, (tuple, list)):
            output = output[0]
        if not isinstance(output, torch.Tensor):
            return
        if self._batched_params_enabled() and torch.cuda.is_available():
            runtimes = self._collect_active_runtimes()
            if not runtimes:
                return

            device = output.device
            batched_runtimes: List[_DistillerRuntime] = []
            row_indices: List[torch.Tensor] = []
            slots: List[int] = []
            counts: List[int] = []
            group_indices: List[int] = []
            max_rows = int(output.shape[0])
            for runtime in runtimes:
                row_idx = runtime._row_tensor(device)
                if row_idx is None or row_idx.numel() == 0:
                    continue
                if max_rows <= 0:
                    continue
                row_idx = row_idx[(row_idx >= 0) & (row_idx < max_rows)]
                if row_idx.numel() == 0:
                    continue
                runtime._ensure_buffers(output.shape[-1], output.shape[0], device)
                batched_runtimes.append(runtime)
                row_indices.append(row_idx)
                slots.append(int(runtime.current_slot))
                counts.append(int(row_idx.numel()))
                group_indices.append(int(runtime.batched_index))

            if not batched_runtimes:
                return

            stream = self._get_batched_stream(device)
            current_stream = torch.cuda.current_stream(device)
            with torch.cuda.stream(stream):
                stream.wait_stream(current_stream)
                sources: List[torch.Tensor] = []
                max_count = max(counts)
                hidden_dim = int(output.shape[-1])
                for runtime, row_idx, slot in zip(batched_runtimes, row_indices, slots):
                    source = output.index_select(0, row_idx).to(torch.float32)
                    runtime.h1_buf[slot, : source.size(0)].copy_(source)
                    runtime.events["h1_ready"][slot].record(stream)
                    sources.append(source)

                batch_src = torch.zeros(
                    (len(sources), max_count, hidden_dim),
                    device=device,
                    dtype=torch.float32,
                )
                for idx, src in enumerate(sources):
                    batch_src[idx, : src.size(0)].copy_(src)
                if batched_runtimes[0].normalize_inputs:
                    batch_src = F.layer_norm(batch_src, (batch_src.shape[-1],))

                group_idx = torch.tensor(group_indices, device=device, dtype=torch.long)
                max_groups = self._batched_max_groups()
                mlp_hidden_dim = batched_runtimes[0].mlp_hidden_dim
                self._ensure_batched_predictor(
                    hidden_dim=hidden_dim,
                    mlp_hidden_dim=mlp_hidden_dim,
                    num_groups=max_groups,
                    device=device,
                )
                if self._batched_predictor is None:
                    return
                pred_batch = self._batched_predictor(batch_src, group_indices=group_idx)
                pred_batch = batched_runtimes[0]._apply_final_norm(pred_batch)

                for idx, (runtime, slot, count) in enumerate(
                    zip(batched_runtimes, slots, counts)
                ):
                    runtime.pred_buf[slot, :count].copy_(pred_batch[idx, :count])
                    runtime.events["pred_ready"][slot].record(stream)
            return

        if not self._batched_enabled() or not torch.cuda.is_available():
            for group_id in list(self._active_groups):
                runtime = self._runtimes.get(group_id)
                if runtime is not None:
                    runtime.record_h1(output)
            return

        runtimes = self._collect_active_runtimes()
        if len(runtimes) <= 1:
            for runtime in runtimes:
                runtime.record_h1(output)
            return

        device = output.device
        batched_runtimes: List[_DistillerRuntime] = []
        predictors: List[TransitionMLP] = []
        row_indices: List[torch.Tensor] = []
        slots: List[int] = []
        for runtime in runtimes:
            row_idx = runtime._row_tensor(device)
            if row_idx is None or row_idx.numel() == 0:
                continue
            max_rows = int(output.shape[0])
            if max_rows <= 0:
                continue
            row_idx = row_idx[(row_idx >= 0) & (row_idx < max_rows)]
            if row_idx.numel() == 0:
                continue
            runtime._ensure_buffers(output.shape[-1], output.shape[0], device)
            if runtime.predictor is None or runtime.h1_buf is None or runtime.pred_buf is None:
                continue
            batched_runtimes.append(runtime)
            predictors.append(runtime.predictor)
            row_indices.append(row_idx)
            slots.append(int(runtime.current_slot))

        if len(predictors) <= 1:
            for runtime in runtimes:
                runtime.record_h1(output)
            return

        stream = self._get_batched_stream(device)
        current_stream = torch.cuda.current_stream(device)
        with torch.cuda.stream(stream):
            stream.wait_stream(current_stream)
            sources: List[torch.Tensor] = []
            counts: List[int] = []
            for runtime, row_idx, slot in zip(batched_runtimes, row_indices, slots):
                source = output.index_select(0, row_idx).to(torch.float32)
                runtime.h1_buf[slot, : source.size(0)].copy_(source)
                runtime.events["h1_ready"][slot].record(stream)
                sources.append(source)
                counts.append(int(source.size(0)))

            if not sources:
                return
            max_count = max(counts)
            hidden_dim = int(sources[0].size(1))
            batch_src = torch.zeros(
                (len(sources), max_count, hidden_dim),
                device=device,
                dtype=torch.float32,
            )
            for idx, src in enumerate(sources):
                batch_src[idx, : src.size(0)].copy_(src)

            if runtimes[0].normalize_inputs:
                batch_src = F.layer_norm(batch_src, (batch_src.shape[-1],))

            weight_ids = [id(p) for p in predictors]
            weight_steps = [int(rt.train_steps) for rt in batched_runtimes]
            if (
                self._batched_weight_cache is None
                or self._batched_weight_ids != weight_ids
                or self._batched_weight_steps != weight_steps
            ):
                self._batched_weight_cache = stack_transition_weights(predictors)
                self._batched_weight_ids = weight_ids
                self._batched_weight_steps = weight_steps
            pred_batch = batched_transition_forward_stacked(batch_src, self._batched_weight_cache)

            for idx, (runtime, slot, count) in enumerate(zip(batched_runtimes, slots, counts)):
                runtime.pred_buf[slot, :count].copy_(pred_batch[idx, :count])
                runtime.events["pred_ready"][slot].record(stream)

    def record_h1_from_buffer(
        self,
        source_buf: torch.Tensor,
        ready_event: Optional[torch.cuda.Event],
    ) -> None:
        if not self._active_groups:
            return
        if not isinstance(source_buf, torch.Tensor) or source_buf.dim() != 2:
            return
        if self._batched_params_enabled() and torch.cuda.is_available():
            runtimes = self._collect_active_runtimes()
            if not runtimes:
                return

            device = source_buf.device
            batched_runtimes: List[_DistillerRuntime] = []
            row_indices: List[torch.Tensor] = []
            slots: List[int] = []
            counts: List[int] = []
            group_indices: List[int] = []
            max_rows = int(source_buf.shape[0])
            for runtime in runtimes:
                row_idx = runtime._row_tensor(device)
                if row_idx is None or row_idx.numel() == 0:
                    continue
                if max_rows <= 0:
                    continue
                row_idx = row_idx[(row_idx >= 0) & (row_idx < max_rows)]
                if row_idx.numel() == 0:
                    continue
                runtime._ensure_buffers(source_buf.shape[-1], source_buf.shape[0], device)
                batched_runtimes.append(runtime)
                row_indices.append(row_idx)
                slots.append(int(runtime.current_slot))
                counts.append(int(row_idx.numel()))
                group_indices.append(int(runtime.batched_index))

            if not batched_runtimes:
                return

            stream = self._get_batched_stream(device)
            current_stream = torch.cuda.current_stream(device)
            with torch.cuda.stream(stream):
                stream.wait_stream(current_stream)
                if ready_event is not None:
                    stream.wait_event(ready_event)
                sources: List[torch.Tensor] = []
                max_count = max(counts)
                hidden_dim = int(source_buf.shape[-1])
                for runtime, row_idx, slot in zip(batched_runtimes, row_indices, slots):
                    source = source_buf.index_select(0, row_idx).to(torch.float32)
                    runtime.h1_buf[slot, : source.size(0)].copy_(source)
                    runtime.events["h1_ready"][slot].record(stream)
                    sources.append(source)

                batch_src = torch.zeros(
                    (len(sources), max_count, hidden_dim),
                    device=device,
                    dtype=torch.float32,
                )
                for idx, src in enumerate(sources):
                    batch_src[idx, : src.size(0)].copy_(src)
                if batched_runtimes[0].normalize_inputs:
                    batch_src = F.layer_norm(batch_src, (batch_src.shape[-1],))

                group_idx = torch.tensor(group_indices, device=device, dtype=torch.long)
                max_groups = self._batched_max_groups()
                mlp_hidden_dim = batched_runtimes[0].mlp_hidden_dim
                self._ensure_batched_predictor(
                    hidden_dim=hidden_dim,
                    mlp_hidden_dim=mlp_hidden_dim,
                    num_groups=max_groups,
                    device=device,
                )
                if self._batched_predictor is None:
                    return
                pred_batch = self._batched_predictor(batch_src, group_indices=group_idx)
                pred_batch = batched_runtimes[0]._apply_final_norm(pred_batch)

                for idx, (runtime, slot, count) in enumerate(
                    zip(batched_runtimes, slots, counts)
                ):
                    runtime.pred_buf[slot, :count].copy_(pred_batch[idx, :count])
                    runtime.events["pred_ready"][slot].record(stream)
            return

        if not self._batched_enabled() or not torch.cuda.is_available():
            for group_id in list(self._active_groups):
                runtime = self._runtimes.get(group_id)
                if runtime is not None:
                    runtime.record_h1_from_buffer(source_buf, ready_event)
            return

        runtimes = self._collect_active_runtimes()
        if len(runtimes) <= 1:
            for runtime in runtimes:
                runtime.record_h1_from_buffer(source_buf, ready_event)
            return

        device = source_buf.device
        batched_runtimes: List[_DistillerRuntime] = []
        predictors: List[TransitionMLP] = []
        row_indices: List[torch.Tensor] = []
        slots: List[int] = []
        max_rows = int(source_buf.shape[0])
        for runtime in runtimes:
            row_idx = runtime._row_tensor(device)
            if row_idx is None or row_idx.numel() == 0:
                continue
            if max_rows <= 0:
                continue
            row_idx = row_idx[(row_idx >= 0) & (row_idx < max_rows)]
            if row_idx.numel() == 0:
                continue
            runtime._ensure_buffers(source_buf.shape[-1], source_buf.shape[0], device)
            if runtime.predictor is None or runtime.h1_buf is None or runtime.pred_buf is None:
                continue
            batched_runtimes.append(runtime)
            predictors.append(runtime.predictor)
            row_indices.append(row_idx)
            slots.append(int(runtime.current_slot))

        if len(predictors) <= 1:
            for runtime in runtimes:
                runtime.record_h1_from_buffer(source_buf, ready_event)
            return

        stream = self._get_batched_stream(device)
        current_stream = torch.cuda.current_stream(device)
        with torch.cuda.stream(stream):
            stream.wait_stream(current_stream)
            if ready_event is not None:
                stream.wait_event(ready_event)
            sources: List[torch.Tensor] = []
            counts: List[int] = []
            for runtime, row_idx, slot in zip(batched_runtimes, row_indices, slots):
                source = source_buf.index_select(0, row_idx).to(torch.float32)
                runtime.h1_buf[slot, : source.size(0)].copy_(source)
                runtime.events["h1_ready"][slot].record(stream)
                sources.append(source)
                counts.append(int(source.size(0)))

            if not sources:
                return
            max_count = max(counts)
            hidden_dim = int(sources[0].size(1))
            batch_src = torch.zeros(
                (len(sources), max_count, hidden_dim),
                device=device,
                dtype=torch.float32,
            )
            for idx, src in enumerate(sources):
                batch_src[idx, : src.size(0)].copy_(src)

            if runtimes[0].normalize_inputs:
                batch_src = F.layer_norm(batch_src, (batch_src.shape[-1],))

            weight_ids = [id(p) for p in predictors]
            weight_steps = [int(rt.train_steps) for rt in batched_runtimes]
            if (
                self._batched_weight_cache is None
                or self._batched_weight_ids != weight_ids
                or self._batched_weight_steps != weight_steps
            ):
                self._batched_weight_cache = stack_transition_weights(predictors)
                self._batched_weight_ids = weight_ids
                self._batched_weight_steps = weight_steps
            pred_batch = batched_transition_forward_stacked(batch_src, self._batched_weight_cache)

            for idx, (runtime, slot, count) in enumerate(zip(batched_runtimes, slots, counts)):
                runtime.pred_buf[slot, :count].copy_(pred_batch[idx, :count])
                runtime.events["pred_ready"][slot].record(stream)

    def record_hl(self, hidden_states: torch.Tensor) -> None:
        if not self._active_groups:
            return
        for group_id in list(self._active_groups):
            runtime = self._runtimes.get(group_id)
            if runtime is not None:
                runtime.record_hl(hidden_states)

    def mix_hidden(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if not self._active_groups:
            return hidden_states
        for group_id in list(self._active_groups):
            runtime = self._runtimes.get(group_id)
            if runtime is not None:
                hidden_states = runtime.mix_hidden(hidden_states)
        return hidden_states

    def record_logits(
        self,
        logits: torch.Tensor,
        *,
        lm_head,
        embedding_bias: Optional[torch.Tensor],
    ) -> None:
        if not self._active_groups:
            return
        if logits is None or not isinstance(logits, torch.Tensor):
            return
        for group_id in list(self._active_groups):
            runtime = self._runtimes.get(group_id)
            if runtime is None:
                continue
            logit_rows = self._active_logit_rows.get(group_id)
            if not logit_rows:
                continue
            runtime.current_rows = torch.tensor(logit_rows, device=logits.device, dtype=torch.long)
            runtime.current_count = len(logit_rows)
            runtime.record_logits(logits, lm_head=lm_head, embedding_bias=embedding_bias)

    def mix_logits(self, logits: torch.Tensor, *, lm_head, embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
        if not self._active_groups:
            return logits
        for group_id in list(self._active_groups):
            runtime = self._runtimes.get(group_id)
            if runtime is None:
                continue
            logit_rows = self._active_logit_rows.get(group_id)
            if not logit_rows:
                continue
            runtime.current_rows = torch.tensor(logit_rows, device=logits.device, dtype=torch.long)
            runtime.current_count = len(logit_rows)
            logits = runtime.mix_logits(logits, lm_head=lm_head, embedding_bias=embedding_bias)
        return logits

    def maybe_train(self) -> None:
        if not self._active_groups:
            return
        if self._batched_params_enabled() and torch.cuda.is_available():
            self._batched_maybe_train()
            return
        for group_id in list(self._active_groups):
            runtime = self._runtimes.get(group_id)
            if runtime is not None:
                runtime.maybe_train()

    def _batched_maybe_train(self) -> None:
        runtimes = self._collect_active_runtimes()
        if not runtimes:
            return
        train_runtimes: List[_DistillerRuntime] = []
        skip_runtimes: List[_DistillerRuntime] = []
        slots: List[int] = []
        counts: List[int] = []
        group_indices: List[int] = []
        for runtime in runtimes:
            if not runtime.train_enabled:
                if runtime.current_slot is not None:
                    skip_runtimes.append(runtime)
                continue
            if runtime.current_slot is None or runtime.current_count <= 0:
                continue
            runtime.train_tick += 1
            if runtime.train_tick % runtime.train_sync_interval != 0:
                skip_runtimes.append(runtime)
                continue
            train_runtimes.append(runtime)
            slots.append(int(runtime.current_slot))
            counts.append(int(runtime.current_count))
            group_indices.append(int(runtime.batched_index))

        if not train_runtimes and not skip_runtimes:
            return

        device = None
        for runtime in train_runtimes:
            if runtime.hl_buf is not None:
                device = runtime.hl_buf.device
                break
        if device is None:
            for runtime in skip_runtimes:
                if runtime.hl_buf is not None:
                    device = runtime.hl_buf.device
                    break
        if device is None:
            return
        stream = self._get_batched_train_stream(device)
        current_stream = torch.cuda.current_stream(device)
        with torch.cuda.stream(stream):
            stream.wait_stream(current_stream)
            for runtime, slot in zip(train_runtimes, slots):
                stream.wait_event(runtime.events["pred_ready"][slot])
                stream.wait_event(runtime.events["hl_ready"][slot])
            for runtime in skip_runtimes:
                slot = runtime.current_slot
                if slot is None:
                    continue
                stream.wait_event(runtime.events["pred_ready"][slot])
                stream.wait_event(runtime.events["hl_ready"][slot])
                runtime.events["slot_free"][slot].record(stream)

            if not train_runtimes:
                return

            with torch.inference_mode(False), torch.enable_grad():
                max_count = max(counts)
                hidden_dim = int(
                    train_runtimes[0].hidden_dim or train_runtimes[0].hl_buf.size(-1)
                )
                batch_src = torch.zeros(
                    (len(train_runtimes), max_count, hidden_dim),
                    device=device,
                    dtype=torch.float32,
                )
                batch_tgt = torch.zeros_like(batch_src)
                for idx, (runtime, slot, count) in enumerate(zip(train_runtimes, slots, counts)):
                    batch_src[idx, :count].copy_(runtime.h1_buf[slot, :count])
                    batch_tgt[idx, :count].copy_(runtime.hl_buf[slot, :count])

                if train_runtimes[0].normalize_inputs:
                    batch_src = F.layer_norm(batch_src, (batch_src.shape[-1],))

                group_idx = torch.tensor(group_indices, device=device, dtype=torch.long)
                max_groups = self._batched_max_groups()
                mlp_hidden_dim = train_runtimes[0].mlp_hidden_dim
                self._ensure_batched_predictor(
                    hidden_dim=hidden_dim,
                    mlp_hidden_dim=mlp_hidden_dim,
                    num_groups=max_groups,
                    device=device,
                )
                if self._batched_predictor is None or self._batched_optimizer is None:
                    return

                self._batched_optimizer.zero_grad(set_to_none=False)
                pred = self._batched_predictor(batch_src, group_indices=group_idx)
                pred = train_runtimes[0]._apply_final_norm(pred)

                counts_i = torch.tensor(counts, device=device, dtype=torch.long)
                counts_f = counts_i.to(dtype=torch.float32)
                mask = (
                    torch.arange(max_count, device=device).unsqueeze(0)
                    < counts_i.unsqueeze(1)
                )
                mask = mask.unsqueeze(-1)
                diff = (pred - batch_tgt) * mask
                denom = counts_f * float(hidden_dim)
                denom = torch.clamp(denom, min=1.0)
                loss_per = diff.pow(2).sum(dim=(1, 2)) / denom
                loss = loss_per.mean()

                # Log per-group metrics before backward.
                for idx, runtime in enumerate(train_runtimes):
                    if counts[idx] <= 0:
                        continue
                    pred_slice = pred[idx, : counts[idx]]
                    tgt_slice = batch_tgt[idx, : counts[idx]]
                    cos_val = F.cosine_similarity(
                        pred_slice.detach(), tgt_slice.detach(), dim=-1
                    ).mean()
                    pred_norm = pred_slice.detach().norm(dim=-1).mean()
                    tgt_norm = tgt_slice.detach().norm(dim=-1).mean()
                    runtime._record_loss(
                        loss_per[idx],
                        cos_val=cos_val,
                        pred_norm=pred_norm,
                        target_norm=tgt_norm,
                        step=runtime.step,
                        train_step=runtime.train_steps + 1,
                        device=device,
                    )

                loss.backward()
                self._batched_optimizer.step()
                self._batched_optimizer.zero_grad(set_to_none=False)

            for runtime in train_runtimes:
                runtime.train_steps += 1
            for runtime, slot in zip(train_runtimes, slots):
                runtime.events["slot_free"][slot].record(stream)

    def dump_log(self, file_path: str) -> Optional[int]:
        if not file_path:
            return None
        rows: List[List[object]] = []
        topk = int(_BASE_RUNTIME.topk)
        for runtime in self._runtimes.values():
            data = runtime._collect_log_arrays()
            if data is None or data["entries"] <= 0:
                continue
            topk = data["topk"]
            entries = data["entries"]
            for idx in range(entries):
                row = [data["step"][idx], data["req"][idx], data["row"][idx]]
                row.extend(data["ids"][idx])
                row.extend(data["pre"][idx])
                row.extend(data["dist"][idx])
                row.extend(data["post"][idx])
                rows.append(row)
        header = (
            ["step", "req_id", "row"]
            + [f"tid_{i}" for i in range(topk)]
            + [f"pre_{i}" for i in range(topk)]
            + [f"dist_{i}" for i in range(topk)]
            + [f"post_{i}" for i in range(topk)]
        )
        rows.sort(key=lambda item: (item[0], item[1], item[2]))
        with open(file_path, "w", encoding="utf-8") as f:
            f.write(",".join(header) + "\n")
            for row in rows:
                f.write(",".join(str(v) for v in row) + "\n")
        return len(rows)

    def dump_loss_log(self, file_path: str) -> Optional[int]:
        if not file_path:
            return None
        rows: List[List[object]] = []
        for runtime in self._runtimes.values():
            data = runtime._collect_loss_arrays()
            if data is None or data["entries"] <= 0:
                continue
            entries = data["entries"]
            for idx in range(entries):
                row = [
                    data["step"][idx],
                    data["req"][idx],
                    data["train"][idx],
                    data["loss"][idx],
                    data["cos"][idx],
                    data["pred_norm"][idx],
                    data["target_norm"][idx],
                ]
                rows.append(row)
        rows.sort(key=lambda item: (item[0], item[1]))
        with open(file_path, "w", encoding="utf-8") as f:
            f.write("step,req_id,train_step,loss,cos,pred_norm,target_norm\n")
            for row in rows:
                f.write(",".join(str(v) for v in row) + "\n")
        return len(rows)

    def get_cudagraph_stats(self) -> dict:
        stats_list = []
        for runtime in self._runtimes.values():
            if hasattr(runtime, "get_cudagraph_stats"):
                stats_list.append(runtime.get_cudagraph_stats())

        if not stats_list:
            return {
                "pred_captures": 0,
                "train_captures": 0,
                "pred_capture_total_s": 0.0,
                "train_capture_total_s": 0.0,
                "pred_capture_last_s": 0.0,
                "train_capture_last_s": 0.0,
                "first_capture_s": None,
                "last_capture_s": None,
            }

        pred_captures = sum(s.get("pred_captures", 0) for s in stats_list)
        train_captures = sum(s.get("train_captures", 0) for s in stats_list)
        pred_total = sum(s.get("pred_capture_total_s", 0.0) for s in stats_list)
        train_total = sum(s.get("train_capture_total_s", 0.0) for s in stats_list)
        pred_last = max(s.get("pred_capture_last_s", 0.0) for s in stats_list)
        train_last = max(s.get("train_capture_last_s", 0.0) for s in stats_list)
        first_list = [s.get("first_capture_s") for s in stats_list if s.get("first_capture_s") is not None]
        last_list = [s.get("last_capture_s") for s in stats_list if s.get("last_capture_s") is not None]
        first_capture = min(first_list) if first_list else None
        last_capture = max(last_list) if last_list else None
        return {
            "pred_captures": int(pred_captures),
            "train_captures": int(train_captures),
            "pred_capture_total_s": float(pred_total),
            "train_capture_total_s": float(train_total),
            "pred_capture_last_s": float(pred_last),
            "train_capture_last_s": float(train_last),
            "first_capture_s": first_capture,
            "last_capture_s": last_capture,
        }


_RUNTIME_GROUP = _OverlapRuntimeManager()


def _register_first_layer_hook_v1(model, layer_idx: int) -> bool:
    layers = getattr(model.model, "layers", None)
    if layers is None:
        decoder = getattr(model.model, "decoder", None)
        layers = getattr(decoder, "layers", None) if decoder is not None else None
    if layers is None:
        return False

    idx = layer_idx if layer_idx >= 0 else len(layers) + layer_idx
    if idx < 0 or idx >= len(layers):
        return False
    layer = layers[idx]

    def hook(_module, _inputs, output):
        _RUNTIME_GROUP.record_h1(output)

    handle = layer.register_forward_hook(hook)
    layer._distiller_overlap_v1_hook = handle
    return True


def _get_model_hidden_dim(model) -> Optional[int]:
    config = getattr(model, "config", None)
    hidden_dim = getattr(config, "hidden_size", None)
    if isinstance(hidden_dim, int) and hidden_dim > 0:
        return hidden_dim
    embed = getattr(model, "embed_tokens", None)
    if embed is None:
        embed = getattr(getattr(model, "model", None), "embed_tokens", None)
    weight = getattr(embed, "weight", None) if embed is not None else None
    if isinstance(weight, torch.Tensor) and weight.dim() == 2:
        return int(weight.shape[1])
    return None


def _ensure_first_layer_capture_v1(model, layer_idx: int) -> bool:
    if getattr(model, "_distiller_overlap_v1_capture_patched", False):
        return True
    layers = getattr(model.model, "layers", None)
    if layers is None:
        decoder = getattr(model.model, "decoder", None)
        layers = getattr(decoder, "layers", None) if decoder is not None else None
    if layers is None:
        return False

    idx = layer_idx if layer_idx >= 0 else len(layers) + layer_idx
    if idx < 0 or idx >= len(layers):
        return False
    layer = layers[idx]
    orig_forward = layer.forward

    def wrapped_forward(*args, **kwargs):
        output = orig_forward(*args, **kwargs)
        buf = getattr(model, "_distiller_h1_active_buf", None)
        if isinstance(buf, torch.Tensor):
            out_tensor = output[0] if isinstance(output, (tuple, list)) else output
            if isinstance(out_tensor, torch.Tensor) and buf.shape == out_tensor.shape:
                buf.copy_(out_tensor)
                ev = getattr(model, "_distiller_h1_ready_event", None)
                if isinstance(ev, torch.cuda.Event):
                    ev.record(torch.cuda.current_stream())
        return output

    layer.forward = wrapped_forward
    layer._distiller_overlap_v1_capture_forward = orig_forward
    model._distiller_overlap_v1_capture_patched = True
    return True


def _ensure_h1_buffer_v1(model, num_tokens: int, device: torch.device) -> Optional[torch.Tensor]:
    hidden_dim = _get_model_hidden_dim(model)
    if hidden_dim is None:
        return None
    buffers = getattr(model, "_distiller_h1_buffers", None)
    if buffers is None:
        buffers = {}
        model._distiller_h1_buffers = buffers
    key = int(num_tokens)
    buf = buffers.get(key)
    if buf is None or buf.shape != (num_tokens, hidden_dim) or buf.device != device:
        with torch.inference_mode(False):
            buf = torch.empty((num_tokens, hidden_dim), device=device, dtype=torch.float32)
        buffers[key] = buf
    model._distiller_h1_active_buf = buf
    if not hasattr(model, "_distiller_h1_ready_event"):
        model._distiller_h1_ready_event = torch.cuda.Event()
    return buf


def _ensure_overlap_hook(model, *, register_hook: bool = True) -> None:
    if getattr(model, "_distiller_overlap_v1_hooked", False):
        return
    applied = False
    if register_hook:
        applied = _register_first_layer_hook_v1(model, _BASE_RUNTIME.first_layer)
    norm = None
    for obj in (model, getattr(model, "model", None), getattr(model, "module", None)):
        candidate = getattr(obj, "norm", None)
        if candidate is not None:
            norm = candidate
            break
    if norm is not None:
        weight = getattr(norm, "weight", None)
        bias = getattr(norm, "bias", None)
        eps = getattr(norm, "eps", None)
        if eps is None:
            eps = getattr(norm, "variance_epsilon", 1e-6)
        norm_type = "other"
        if isinstance(norm, torch.nn.LayerNorm) or "LayerNorm" in type(norm).__name__:
            norm_type = "layer"
        elif "RMSNorm" in type(norm).__name__:
            norm_type = "rms"
        _BASE_RUNTIME.set_final_norm(weight=weight, bias=bias, eps=float(eps or 1e-6), norm_type=norm_type)
        applied = True
    if _BASE_RUNTIME.reward_scale:
        lm_head = getattr(model, "lm_head", None)
        weight = getattr(lm_head, "weight", None) if lm_head is not None else None
        if isinstance(weight, torch.Tensor):
            _BASE_RUNTIME.set_reward_scale_norms(weight=weight, dtype=weight.dtype)
    _RUNTIME_GROUP.sync_norm_cache_from_base()
    model._distiller_overlap_v1_hooked = applied or not register_hook


def apply_distiller_overlap_patch_v1() -> bool:
    """Monkey-patch vLLM v1 GPUModelRunner to enable overlap mixing."""
    global _PATCHED, _ORIG_EXECUTE_MODEL, _ORIG_WORKER_EXECUTE
    if _PATCHED:
        return False

    with _PATCH_LOCK:
        if _PATCHED:
            return False

        _ORIG_EXECUTE_MODEL = GPUModelRunner.execute_model
        _ORIG_WORKER_EXECUTE = Worker.execute_model

        def worker_execute_model(self, scheduler_output):
            def _run():
                intermediate_tensors = None
                if not get_pp_group().is_first_rank:
                    intermediate_tensors = IntermediateTensors(
                        get_pp_group().recv_tensor_dict(
                            all_gather_group=get_tp_group()))

                output = self.model_runner.execute_model(
                    scheduler_output, intermediate_tensors)

                parallel_config = self.vllm_config.parallel_config
                if (parallel_config.distributed_executor_backend
                        != "external_launcher"
                        and not get_pp_group().is_last_rank):
                    assert isinstance(output, IntermediateTensors)
                    get_pp_group().send_tensor_dict(
                        output.tensors, all_gather_group=get_tp_group())

                    kv_connector_output = output.kv_connector_output
                    if not kv_connector_output:
                        return None
                    if (not kv_connector_output.finished_sending
                            and not kv_connector_output.finished_recving):
                        return EMPTY_MODEL_RUNNER_OUTPUT

                    passthrough = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
                    passthrough.kv_connector_output = kv_connector_output
                    return passthrough

                assert isinstance(output, ModelRunnerOutput)
                return output

            if _BASE_RUNTIME.enabled and _BASE_RUNTIME.train_enabled:
                with torch.no_grad():
                    return _run()
            with torch.inference_mode():
                return _run()

        def wrapped_execute_model(
            self,
            scheduler_output,
            intermediate_tensors: Optional[IntermediateTensors] = None,
        ) -> Union[ModelRunnerOutput, IntermediateTensors]:
            self._update_states(scheduler_output)
            if not scheduler_output.total_num_scheduled_tokens:
                if not has_kv_transfer_group():
                    return EMPTY_MODEL_RUNNER_OUTPUT
                return self.kv_connector_no_forward(scheduler_output,
                                                    self.vllm_config)

            (attn_metadata, logits_indices, spec_decode_metadata,
             num_scheduled_tokens_np, spec_decode_common_attn_metadata,
             max_query_len) = (self._prepare_inputs(scheduler_output))

            num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
            if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
                    and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
                num_input_tokens = self.vllm_config.pad_for_cudagraph(
                    num_scheduled_tokens)
            else:
                tp_size = self.vllm_config.parallel_config.tensor_parallel_size
                if self.compilation_config.pass_config. \
                    enable_sequence_parallelism and tp_size > 1:
                    num_input_tokens = round_up(num_scheduled_tokens, tp_size)
                else:
                    num_input_tokens = num_scheduled_tokens

            num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
            num_input_tokens += num_pad

            if self.supports_mm_inputs:
                self._execute_mm_encoder(scheduler_output)
                mm_embeds = self._gather_mm_embeddings(scheduler_output)
            else:
                mm_embeds = []

            if self.supports_mm_inputs and get_pp_group().is_first_rank:
                inputs_embeds_scheduled = self.model.get_input_embeddings(
                    input_ids=self.input_ids[:num_scheduled_tokens],
                    multimodal_embeddings=mm_embeds or None,
                )
                self.inputs_embeds[:num_scheduled_tokens].copy_(
                    inputs_embeds_scheduled)
                input_ids = None
                inputs_embeds = self.inputs_embeds[:num_input_tokens]
                model_kwargs = {
                    **self._init_model_kwargs(num_scheduled_tokens),
                    **self._extract_mm_kwargs(scheduler_output),
                }
            else:
                input_ids = self.input_ids[:num_input_tokens]
                inputs_embeds = None
                model_kwargs = self._init_model_kwargs(num_input_tokens)
            if self.uses_mrope:
                positions = self.mrope_positions[:, :num_input_tokens]
            else:
                positions = self.positions[:num_input_tokens]

            if get_pp_group().is_first_rank:
                intermediate_tensors = None
            else:
                intermediate_tensors = self.sync_and_slice_intermediate_tensors(
                    num_input_tokens, intermediate_tensors, True)

            uniform_decode = (
                max_query_len == self.uniform_decode_query_len
            ) and (num_scheduled_tokens
                   == self.input_batch.num_reqs * max_query_len)
            batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
                                               uniform_decode=uniform_decode)
            cudagraph_runtime_mode, batch_descriptor = \
                self.cudagraph_dispatcher.dispatch(batch_descriptor)

            use_embed_h1 = False
            use_first_layer_buffer = False
            record_h1_after = True
            overlap_mode = _env_str("ES_OVERLAP_H1_MODE", "event").lower()
            if overlap_mode not in {"event", "post"}:
                overlap_mode = "event"
            if _BASE_RUNTIME.enabled:
                use_embed_h1 = _env_flag("ES_OVERLAP_EMBED_H1", False)
                use_first_layer_buffer = (
                    self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
                    and not use_embed_h1
                )
                if use_first_layer_buffer:
                    _ensure_first_layer_capture_v1(self.model, _BASE_RUNTIME.first_layer)
                    _ensure_h1_buffer_v1(self.model, num_input_tokens, self.device)
                    record_h1_after = overlap_mode == "post"
                if _env_flag("ES_OVERLAP_CUDAGRAPH", False):
                    # Avoid nested CUDA graph capture when main model uses cudagraphs.
                    os.environ["ES_OVERLAP_CUDAGRAPH"] = "0"
                _ensure_overlap_hook(
                    self.model,
                    register_hook=not use_embed_h1 and not use_first_layer_buffer,
                )
                if max_query_len == 1 and spec_decode_metadata is None:
                    _RUNTIME_GROUP.start_step(
                        req_ids=self.input_batch.req_ids,
                        logits_indices=logits_indices,
                        scheduler_output=scheduler_output,
                        device=self.device,
                    )
                    if (use_first_layer_buffer
                            and overlap_mode == "event"
                            and get_pp_group().is_first_rank):
                        buf = getattr(self.model, "_distiller_h1_active_buf", None)
                        ev = getattr(self.model, "_distiller_h1_ready_event", None)
                        if isinstance(buf, torch.Tensor):
                            _RUNTIME_GROUP.record_h1_from_buffer(buf, ev)
                else:
                    _RUNTIME_GROUP.clear_active()
            else:
                _RUNTIME_GROUP.clear_active()

            with torch.inference_mode():
                with set_forward_context(
                        attn_metadata,
                        self.vllm_config,
                        num_tokens=num_input_tokens,
                        num_tokens_across_dp=num_tokens_across_dp,
                        cudagraph_runtime_mode=cudagraph_runtime_mode,
                        batch_descriptor=batch_descriptor,
                ), self.maybe_get_kv_connector_output(
                        scheduler_output) as kv_connector_output:
                    model_output = self.model(
                        input_ids=input_ids,
                        positions=positions,
                        intermediate_tensors=intermediate_tensors,
                        inputs_embeds=inputs_embeds,
                        **model_kwargs,
                    )

                if self.use_aux_hidden_state_outputs:
                    hidden_states, aux_hidden_states = model_output
                else:
                    hidden_states = model_output
                    aux_hidden_states = None

                broadcast_pp_output = \
                    self.parallel_config.distributed_executor_backend \
                    == "external_launcher" and len(get_pp_group().ranks) > 0
                if not get_pp_group().is_last_rank:
                    assert isinstance(hidden_states, IntermediateTensors)
                    if not broadcast_pp_output:
                        hidden_states.kv_connector_output = kv_connector_output
                        return hidden_states
                    get_pp_group().send_tensor_dict(hidden_states.tensors,
                                                    all_gather_group=get_tp_group())
                    logits = None
                else:
                    if self.input_batch.pooling_params:
                        return self._pool(hidden_states, num_scheduled_tokens,
                                          num_scheduled_tokens_np,
                                          kv_connector_output)

                    if _BASE_RUNTIME.enabled and _RUNTIME_GROUP.has_active():
                        if record_h1_after and get_pp_group().is_first_rank:
                            h1_source = None
                            if use_embed_h1:
                                embed_source = inputs_embeds
                                if embed_source is None and isinstance(input_ids, torch.Tensor):
                                    try:
                                        embed_source = self.model.get_input_embeddings(
                                            input_ids=input_ids,
                                            multimodal_embeddings=mm_embeds or None,
                                        )
                                    except Exception:
                                        embed_source = None
                                h1_source = embed_source
                            else:
                                h1_source = getattr(self.model, "_distiller_h1_active_buf", None)
                            if isinstance(h1_source, torch.Tensor):
                                _RUNTIME_GROUP.record_h1(h1_source)
                        _RUNTIME_GROUP.record_hl(hidden_states)
                        if _BASE_RUNTIME.mix_mode == "hidden":
                            hidden_states = _RUNTIME_GROUP.mix_hidden(hidden_states)

                    sample_hidden_states = hidden_states[logits_indices]
                    logits = self.model.compute_logits(sample_hidden_states, None)

                    if _BASE_RUNTIME.enabled and _RUNTIME_GROUP.has_active():
                        embedding_bias = getattr(self.model, "embedding_bias", None)
                        if _BASE_RUNTIME.mix_mode == "logits":
                            logits = _RUNTIME_GROUP.mix_logits(
                                logits,
                                lm_head=self.model.lm_head,
                                embedding_bias=embedding_bias,
                            )
                        else:
                            _RUNTIME_GROUP.record_logits(
                                logits,
                                lm_head=self.model.lm_head,
                                embedding_bias=embedding_bias,
                            )

                if broadcast_pp_output:
                    model_output_broadcast_data = {
                        "logits": logits.contiguous(),
                    } if logits is not None else {}
                    model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
                        model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
                    assert model_output_broadcast_data is not None
                    logits = model_output_broadcast_data["logits"]

                if scheduler_output.grammar_bitmask is not None:
                    self.apply_grammar_bitmask(scheduler_output, logits)

                sampling_metadata = self.input_batch.sampling_metadata
                if spec_decode_metadata is None:
                    sampler_output = self.sampler(
                        logits=logits,
                        sampling_metadata=sampling_metadata,
                    )
                else:
                    assert logits is not None
                    bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
                    sampler_output = self.sampler(
                        logits=bonus_logits,
                        sampling_metadata=sampling_metadata,
                    )
                    bonus_token_ids = sampler_output.sampled_token_ids
                    target_logits = logits[spec_decode_metadata.target_logits_indices]
                    output_token_ids = self.rejection_sampler(
                        spec_decode_metadata,
                        None,
                        target_logits,
                        bonus_token_ids,
                        sampling_metadata,
                    )
                    sampler_output.sampled_token_ids = output_token_ids

                num_nans_in_logits = {}
                if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
                    num_nans_in_logits = self._get_nans_in_logits(logits)

                discard_sampled_tokens_req_indices = []
                for i, req_id in enumerate(self.input_batch.req_ids):
                    req_state = self.requests[req_id]
                    seq_len = (req_state.num_computed_tokens +
                               scheduler_output.num_scheduled_tokens[req_id])
                    if seq_len < req_state.num_tokens:
                        generator = self.input_batch.generators.get(i)
                        if generator is not None:
                            generator.set_offset(generator.get_offset() - 4)
                        discard_sampled_tokens_req_indices.append(i)

                logprobs_tensors = sampler_output.logprobs_tensors
                logprobs_lists = logprobs_tensors.tolists() \
                    if logprobs_tensors is not None else None

                prompt_logprobs_dict = self._get_prompt_logprobs_dict(
                    hidden_states[:num_scheduled_tokens],
                    scheduler_output,
                )

                sampled_token_ids = sampler_output.sampled_token_ids
                max_gen_len = sampled_token_ids.shape[-1]
                if max_gen_len == 1:
                    valid_sampled_token_ids = sampled_token_ids.tolist()
                else:
                    valid_sampled_token_ids = self.rejection_sampler.parse_output(
                        sampled_token_ids,
                        self.input_batch.vocab_size,
                    )
                for i in discard_sampled_tokens_req_indices:
                    valid_sampled_token_ids[i].clear()

                for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
                    if not sampled_ids:
                        continue
                    start_idx = self.input_batch.num_tokens_no_spec[req_idx]
                    end_idx = start_idx + len(sampled_ids)
                    assert end_idx <= self.max_model_len, (
                        "Sampled token IDs exceed the max model length. "
                        f"Total number of tokens: {end_idx} > max_model_len: "
                        f"{self.max_model_len}")
                    self.input_batch.token_ids_cpu[req_idx,
                                                   start_idx:end_idx] = sampled_ids
                    self.input_batch.num_tokens_no_spec[req_idx] = end_idx
                    self.input_batch.num_tokens[req_idx] = end_idx
                    req_id = self.input_batch.req_ids[req_idx]
                    req_state = self.requests[req_id]
                    req_state.output_token_ids.extend(sampled_ids)

                if not self.speculative_config:
                    spec_token_ids = None
                else:
                    assert spec_decode_common_attn_metadata is not None
                    spec_token_ids = self.propose_draft_token_ids(
                        scheduler_output,
                        valid_sampled_token_ids,
                        sampling_metadata,
                        hidden_states,
                        sample_hidden_states,
                        aux_hidden_states,
                        spec_decode_metadata,
                        spec_decode_common_attn_metadata,
                    )

                self.eplb_step()

                model_runner_output = ModelRunnerOutput(
                    req_ids=self.input_batch.req_ids,
                    req_id_to_index=self.input_batch.req_id_to_index,
                    sampled_token_ids=valid_sampled_token_ids,
                    spec_token_ids=spec_token_ids,
                    logprobs=logprobs_lists,
                    prompt_logprobs_dict=prompt_logprobs_dict,
                    pooler_output=[],
                    kv_connector_output=kv_connector_output,
                    num_nans_in_logits=num_nans_in_logits,
                )

            if _BASE_RUNTIME.enabled and _RUNTIME_GROUP.has_active():
                _RUNTIME_GROUP.maybe_train()

            return model_runner_output

        GPUModelRunner.execute_model = wrapped_execute_model
        Worker.execute_model = worker_execute_model
        _PATCHED = True

    return True


def export_distiller_overlap_log_v1(file_path: str) -> Optional[int]:
    return _RUNTIME_GROUP.dump_log(file_path)


def export_distiller_overlap_loss_log_v1(file_path: str) -> Optional[int]:
    return _RUNTIME_GROUP.dump_loss_log(file_path)


def export_distiller_overlap_cudagraph_stats_v1() -> dict:
    return _RUNTIME_GROUP.get_cudagraph_stats()
