# SPDX-License-Identifier: Apache-2.0
"""vLLM patch for overlap-friendly token-distiller (decode-only, top-k mixing)."""

from __future__ import annotations

import os
import time
import threading
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F

from vllm.model_executor.layers.logits_processor import LogitsProcessor as VLLMLogitsProcessor
from vllm.worker.model_runner import ModelRunner

from .hidden_transition_distiller import TransitionMLP, _layer_norm_batch

TRAIN_WARMUP_STEPS = 100

_PATCH_LOCK = threading.Lock()
_PATCHED = False
_ORIG_EXECUTE_MODEL = None
_ORIG_LOGITS_FORWARD = None


def _stage_from_attn(attn_metadata) -> str:
    # Translate attn metadata counts into a coarse stage label.
    num_prefill_tokens = getattr(attn_metadata, "num_prefill_tokens", 0) or 0
    num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0) or 0
    if num_prefill_tokens == 0 and num_decode_tokens > 0:
        return "decode"
    if num_decode_tokens == 0 and num_prefill_tokens > 0:
        return "prefill"
    if num_prefill_tokens > 0 and num_decode_tokens > 0:
        return "mixed"
    return "unknown"


def _collect_decode_rows(sampling_metadata) -> List[int] | torch.Tensor:
    # Build row indices corresponding to decode tokens in the current step.
    if sampling_metadata is None:
        return []
    selected = getattr(sampling_metadata, "selected_token_indices", None)
    if selected is None:
        return []
    seq_groups = getattr(sampling_metadata, "seq_groups", None) or []
    if not seq_groups:
        return []

    def _build_positions(total: int) -> List[int]:
        positions: List[int] = []
        cursor = 0
        for group in seq_groups:
            prompt_len = len(getattr(group, "prompt_logprob_indices", []) or [])
            cursor += prompt_len
            sample_indices = getattr(group, "sample_indices", []) or []
            if not sample_indices:
                continue
            take = len(sample_indices)
            if cursor >= total:
                break
            if cursor + take > total:
                take = total - cursor
            if take <= 0:
                break
            positions.extend(range(cursor, cursor + take))
            cursor += take
        return positions

    if isinstance(selected, torch.Tensor):
        selected_flat = selected.reshape(-1)
        total = int(selected_flat.numel())
        if total <= 0:
            return []
        positions = _build_positions(total)
        if not positions:
            return []
        index = torch.tensor(positions, device=selected_flat.device, dtype=torch.long)
        return selected_flat.index_select(0, index)

    if hasattr(selected, "tolist"):
        selected_list = selected.tolist()
    else:
        selected_list = list(selected)
    if not selected_list:
        return []
    positions = _build_positions(len(selected_list))
    if not positions:
        return []
    return [int(selected_list[pos]) for pos in positions]


def _env_flag(name: str, default: bool = False) -> bool:
    """Parse 0/1 style environment flags."""
    val = os.getenv(name)
    if val is None:
        return bool(default)
    val = str(val).strip().lower()
    if val in {"1", "true", "yes", "y", "on"}:
        return True
    if val in {"0", "false", "no", "n", "off"}:
        return False
    return bool(default)


def _env_int(name: str, default: int) -> int:
    val = os.getenv(name)
    if val is None:
        return int(default)
    try:
        return int(val)
    except Exception:
        return int(default)


@dataclass
class _CapturedPredGraph:
    graph: "torch.cuda.CUDAGraph"
    src_view: torch.Tensor
    pred_view: torch.Tensor


@dataclass
class _CapturedTrainGraph:
    graph: "torch.cuda.CUDAGraph"
    src_view: torch.Tensor
    target_view: torch.Tensor
    loss_out: torch.Tensor
    cos_out: torch.Tensor
    pred_norm_out: torch.Tensor
    target_norm_out: torch.Tensor


class _DistillerRuntime:
    """Runtime state for overlap distiller (GPU buffers, events, logging)."""

    def __init__(self) -> None:
        # Config knobs (set via update_distiller_overlap_state).
        self.enabled = False
        self.beta = 0.0
        self.topk = 8
        self.buffer_slots = 2
        self.first_layer = 0
        self.mix_mode = "logits"
        self.train_enabled = False
        self.train_sync_interval = 1
        self.normalize_inputs = True
        self.reward_scale = False
        self.logits_rescale = False
        self.logit_loss_weight = 0.0
        self.curiosity_only = False
        self.mlp_hidden_dim: Optional[int] = None
        self.batched_params = False
        self.batched_index: int = -1

        # Optional CUDA Graph acceleration for the distiller.
        #
        # Enable with: VERL_OVERLAP_CUDAGRAPH=1
        #   - VERL_OVERLAP_CUDAGRAPH_PRED=1   (default: 1)
        #   - VERL_OVERLAP_CUDAGRAPH_TRAIN=1  (default: 1)
        #   - VERL_OVERLAP_CUDAGRAPH_WARMUP=3 (default: 3)
        self.cudagraph_enabled = _env_flag("VERL_OVERLAP_CUDAGRAPH", False)
        self.cudagraph_capture_pred = _env_flag("VERL_OVERLAP_CUDAGRAPH_PRED", True)
        self.cudagraph_capture_train = _env_flag("VERL_OVERLAP_CUDAGRAPH_TRAIN", True)
        self.cudagraph_warmup_iters = max(0, _env_int("VERL_OVERLAP_CUDAGRAPH_WARMUP", 3))
        self._cudagraph_pool = None
        self._pred_graphs: Dict[Tuple[int, int], _CapturedPredGraph] = {}
        self._train_graphs: Dict[Tuple[int, int], _CapturedTrainGraph] = {}
        self._cudagraph_sig: Optional[Tuple[object, ...]] = None
        self._cudagraph_lock = threading.Lock()
        self._capture_first_s: Optional[float] = None
        self._capture_last_s: Optional[float] = None
        self._capture_pred_count = 0
        self._capture_train_count = 0
        self._capture_pred_total_s = 0.0
        self._capture_train_total_s = 0.0
        self._capture_pred_last_s = 0.0
        self._capture_train_last_s = 0.0

        # Per-step state for the current decode micro-batch.
        self.step = 0
        self.current_slot: Optional[int] = None
        self.current_rows: Optional[torch.Tensor] | List[int] = None
        self.current_count = 0
        self.current_active = False
        self.current_stream: Optional[torch.cuda.Stream] = None
        self.current_req_id: int = -1

        # CUDA stream and per-slot events for overlap synchronization.
        self.stream: Optional[torch.cuda.Stream] = None
        self.events = {}

        # Ring buffers on GPU: [slots, max_rows, hidden_dim].
        self.hidden_dim: Optional[int] = None
        self.max_rows = 0
        self.h1_buf: Optional[torch.Tensor] = None
        self.hl_buf: Optional[torch.Tensor] = None
        self.pred_buf: Optional[torch.Tensor] = None
        self.logit_id_buf: Optional[torch.Tensor] = None
        self.logit_val_buf: Optional[torch.Tensor] = None
        self.logit_topk = 0
        self.logit_max_rows = 0
        self.logit_ready: List[bool] = []
        self.lm_head_weight: Optional[torch.Tensor] = None
        self.lm_head_bias: Optional[torch.Tensor] = None

        # Distiller model and optimizer live on GPU.
        self.predictor: Optional[TransitionMLP] = None
        self.optimizer: Optional[torch.optim.Optimizer] = None
        self.train_steps = 0
        self.train_tick = 0

        # In-GPU ring buffer for mix logging (steps/rows/top-k values).
        self.log_capacity = 32768
        self.log_pos = 0
        self.log_topk = 0
        self.log_step_buf: Optional[torch.Tensor] = None
        self.log_row_buf: Optional[torch.Tensor] = None
        self.log_req_buf: Optional[torch.Tensor] = None
        self.log_id_buf: Optional[torch.Tensor] = None
        self.log_pre_buf: Optional[torch.Tensor] = None
        self.log_dist_buf: Optional[torch.Tensor] = None
        self.log_post_buf: Optional[torch.Tensor] = None

        self.loss_capacity = 32768
        self.loss_pos = 0
        self.loss_step_buf: Optional[torch.Tensor] = None
        self.loss_train_step_buf: Optional[torch.Tensor] = None
        self.loss_val_buf: Optional[torch.Tensor] = None
        self.loss_cos_buf: Optional[torch.Tensor] = None
        self.loss_pred_norm_buf: Optional[torch.Tensor] = None
        self.loss_target_norm_buf: Optional[torch.Tensor] = None
        self.loss_req_buf: Optional[torch.Tensor] = None

        # Cached lm_head inverse norms for reward-scale mixing.
        self.weight_inv_norms: Optional[torch.Tensor] = None
        self.weight_inv_norms_device: Optional[torch.device] = None
        self.weight_inv_norms_dtype: Optional[torch.dtype] = None

        # Cached final norm params (copied to FP32 for stability).
        self.final_norm_weight: Optional[torch.Tensor] = None
        self.final_norm_bias: Optional[torch.Tensor] = None
        self.final_norm_eps: float = 1e-6
        self.final_norm_type: str = "rms"

    def _sync_predictor_norm(self) -> None:
        if self.predictor is None:
            return
        self.predictor.set_output_norm(
            weight=self.final_norm_weight,
            bias=self.final_norm_bias,
            eps=self.final_norm_eps,
            norm_type=self.final_norm_type,
        )

    def configure(
        self,
        *,
        enabled: bool,
        beta: float,
        topk: int,
        buffer_slots: int,
        first_layer: int,
        mix_mode: str,
        train_enabled: bool,
        train_sync_interval: int,
        normalize_inputs: bool,
        reward_scale: bool,
        logits_rescale: bool,
        logit_loss_weight: float,
        curiosity_only: bool,
    ) -> None:
        # Update runtime config and ensure stream/events exist.
        self.enabled = bool(enabled)
        self.beta = float(beta)
        self.topk = max(1, int(topk))
        self.buffer_slots = max(1, int(buffer_slots))
        self.first_layer = int(first_layer)
        mode = str(mix_mode or "logits").lower()
        if mode not in {"logits", "hidden"}:
            mode = "logits"
        self.mix_mode = mode
        self.train_enabled = bool(train_enabled)
        self.train_sync_interval = max(1, int(train_sync_interval))
        self.normalize_inputs = bool(normalize_inputs)
        self.reward_scale = bool(reward_scale)
        self.logits_rescale = bool(logits_rescale)
        self.logit_loss_weight = float(logit_loss_weight)
        self.curiosity_only = bool(curiosity_only)
        mlp_hidden = _env_int("VERL_OVERLAP_MLP_HIDDEN_DIM", 0)
        self.mlp_hidden_dim = int(mlp_hidden) if mlp_hidden and mlp_hidden > 0 else None
        if self.mix_mode != "logits":
            self.log_pos = 0

        # Refresh CUDA graph env knobs (so users can toggle without code changes).
        prev_sig = self._cudagraph_sig
        self.cudagraph_enabled = _env_flag("VERL_OVERLAP_CUDAGRAPH", self.cudagraph_enabled)
        self.cudagraph_capture_pred = _env_flag(
            "VERL_OVERLAP_CUDAGRAPH_PRED", self.cudagraph_capture_pred
        )
        self.cudagraph_capture_train = _env_flag(
            "VERL_OVERLAP_CUDAGRAPH_TRAIN", self.cudagraph_capture_train
        )
        self.cudagraph_warmup_iters = max(
            0, _env_int("VERL_OVERLAP_CUDAGRAPH_WARMUP", self.cudagraph_warmup_iters)
        )
        new_sig = self._cudagraph_signature()
        if prev_sig is None:
            self._cudagraph_sig = new_sig
        elif new_sig != prev_sig:
            self._clear_cudagraph_cache()
            self._cudagraph_sig = new_sig

        if self.stream is None and torch.cuda.is_available():
            self.stream = torch.cuda.Stream()
        self._ensure_events()
        if len(self.logit_ready) != self.buffer_slots:
            self.logit_ready = [False] * self.buffer_slots
            self.logit_id_buf = None
            self.logit_val_buf = None
            self.logit_topk = 0
            self.logit_max_rows = 0
        if self.predictor is not None:
            self.predictor.train(self.train_enabled)
        if not self.reward_scale:
            self.weight_inv_norms = None
            self.weight_inv_norms_device = None
            self.weight_inv_norms_dtype = None

    def _cudagraph_signature(self) -> Tuple[object, ...]:
        """Return a signature of settings that affect captured graphs."""
        return (
            bool(self.cudagraph_enabled),
            bool(self.cudagraph_capture_pred),
            bool(self.cudagraph_capture_train),
            int(self.cudagraph_warmup_iters),
            bool(self.normalize_inputs),
            bool(self.train_enabled),
            int(self.topk),
            str(self.final_norm_type),
            float(self.final_norm_eps),
            bool(self.final_norm_weight is not None),
            float(self.logit_loss_weight),
            bool(self.curiosity_only),
        )

    def _clear_cudagraph_cache(self) -> None:
        # Drop all captured graphs (e.g. when buffers/model/config change).
        with self._cudagraph_lock:
            self._pred_graphs.clear()
            self._train_graphs.clear()
            self._cudagraph_pool = None

    def _note_capture(self, kind: str, duration_s: float) -> None:
        now = time.perf_counter()
        if self._capture_first_s is None:
            self._capture_first_s = now
        self._capture_last_s = now
        if kind == "pred":
            self._capture_pred_count += 1
            self._capture_pred_total_s += float(duration_s)
            self._capture_pred_last_s = float(duration_s)
        elif kind == "train":
            self._capture_train_count += 1
            self._capture_train_total_s += float(duration_s)
            self._capture_train_last_s = float(duration_s)

    def get_cudagraph_stats(self) -> dict:
        return {
            "pred_captures": int(self._capture_pred_count),
            "train_captures": int(self._capture_train_count),
            "pred_capture_total_s": float(self._capture_pred_total_s),
            "train_capture_total_s": float(self._capture_train_total_s),
            "pred_capture_last_s": float(self._capture_pred_last_s),
            "train_capture_last_s": float(self._capture_train_last_s),
            "first_capture_s": self._capture_first_s,
            "last_capture_s": self._capture_last_s,
        }

    def _ensure_cudagraph_pool(self) -> None:
        if self._cudagraph_pool is None:
            self._cudagraph_pool = torch.cuda.graph_pool_handle()

    def _maybe_init_buffers_for_cudagraph(
        self,
        *,
        device: torch.device,
        max_rows: int,
    ) -> None:
        """Best-effort pre-init of buffers so we can capture graphs before first hook."""
        if self.h1_buf is not None and self.predictor is not None:
            return
        hidden_dim = self.hidden_dim
        if hidden_dim is None and isinstance(self.final_norm_weight, torch.Tensor):
            hidden_dim = int(self.final_norm_weight.numel())
        if hidden_dim is None or hidden_dim <= 0:
            return
        self._ensure_buffers(hidden_dim, max_rows, device)

    def _maybe_capture_pred_graph(self, *, slot: int, count: int, device: torch.device) -> None:
        if not (self.cudagraph_enabled and self.cudagraph_capture_pred):
            return
        if not torch.cuda.is_available():
            return
        if torch.cuda.is_current_stream_capturing():
            return
        if self.stream is None or self.predictor is None or self.h1_buf is None or self.pred_buf is None:
            return
        if count <= 0:
            return
        key = (int(slot), int(count))
        if key in self._pred_graphs:
            return
        with self._cudagraph_lock:
            if key in self._pred_graphs:
                return
            try:
                capture_start = time.perf_counter()
                self._ensure_cudagraph_pool()
                # IMPORTANT: capture requires a quiet device. We only do this once per key.
                torch.cuda.synchronize(device)

                src_view = self.h1_buf[slot, :count]
                pred_view = self.pred_buf[slot, :count]

                def _run_pred() -> None:
                    src = src_view
                    if self.normalize_inputs:
                        src = _layer_norm_batch(src)
                    pred = self.predictor(src)
                    pred_view.copy_(pred)

                # Warmup to populate allocator + kernel caches.
                with torch.cuda.stream(self.stream):
                    with torch.inference_mode(False):
                        # Give src some non-degenerate values (avoid all-zero LN corner cases).
                        src_view.copy_(torch.randn_like(src_view))
                        for _ in range(self.cudagraph_warmup_iters):
                            _run_pred()
                torch.cuda.synchronize(device)

                g = torch.cuda.CUDAGraph()
                with torch.cuda.stream(self.stream):
                    with torch.inference_mode(False):
                        with torch.cuda.graph(g, pool=self._cudagraph_pool):
                            _run_pred()
                self._pred_graphs[key] = _CapturedPredGraph(
                    graph=g,
                    src_view=src_view,
                    pred_view=pred_view,
                )
                self._note_capture("pred", time.perf_counter() - capture_start)
            except Exception as e:
                # Fall back to eager for this runtime if capture fails.
                if _env_flag("VERL_OVERLAP_CUDAGRAPH_DEBUG", False):
                    import traceback

                    print(
                        f"[verl.overlap] CUDA graph capture (pred) failed for slot={slot} "
                        f"count={count}: {e}"
                    )
                    traceback.print_exc()
                self.cudagraph_capture_pred = False
                self._cudagraph_sig = self._cudagraph_signature()

    def _maybe_capture_train_graph(self, *, slot: int, count: int, device: torch.device) -> None:
        if not (self.cudagraph_enabled and self.cudagraph_capture_train and self.train_enabled):
            return
        if not torch.cuda.is_available():
            return
        if torch.cuda.is_current_stream_capturing():
            return
        if (
            self.stream is None
            or self.predictor is None
            or self.optimizer is None
            or self.h1_buf is None
            or self.hl_buf is None
        ):
            return
        if self.logit_loss_weight > 0.0 and not isinstance(self.lm_head_weight, torch.Tensor):
            return
        if count <= 0:
            return
        key = (int(slot), int(count))
        if key in self._train_graphs:
            return
        with self._cudagraph_lock:
            if key in self._train_graphs:
                return
            try:
                capture_start = time.perf_counter()
                self._ensure_cudagraph_pool()
                torch.cuda.synchronize(device)

                src_view = self.h1_buf[slot, :count]
                target_view = self.hl_buf[slot, :count]
                use_logit_loss = self.logit_loss_weight > 0.0 and self.train_enabled
                logit_topk = 0
                logit_id_view = None
                logit_val_view = None
                weight_det = None
                bias_det = None
                if use_logit_loss:
                    weight = self.lm_head_weight
                    if not isinstance(weight, torch.Tensor):
                        use_logit_loss = False
                    else:
                        logit_topk = max(1, int(self.topk))
                        self._ensure_logit_buffers(topk=logit_topk, max_rows=count, device=device)
                        if self.logit_id_buf is None or self.logit_val_buf is None:
                            use_logit_loss = False
                        else:
                            logit_id_view = self.logit_id_buf[slot, :count, :logit_topk]
                            logit_val_view = self.logit_val_buf[slot, :count, :logit_topk]
                            weight_det = weight.detach()
                            if isinstance(self.lm_head_bias, torch.Tensor):
                                bias_det = self.lm_head_bias.detach()
                # Static outputs for metrics (written by graph, read by Python for logging).
                loss_out = torch.empty((), device=device, dtype=torch.float32)
                cos_out = torch.empty((), device=device, dtype=torch.float32)
                pred_norm_out = torch.empty((), device=device, dtype=torch.float32)
                target_norm_out = torch.empty((), device=device, dtype=torch.float32)

                def _run_train() -> None:
                    # NOTE: use set_to_none=False for graph safety.
                    self.optimizer.zero_grad(set_to_none=False)
                    src = src_view
                    if self.normalize_inputs:
                        src = _layer_norm_batch(src)
                    pred = self.predictor(src)
                    tgt = target_view
                    loss = F.mse_loss(pred, tgt)
                    if use_logit_loss and logit_id_view is not None and logit_val_view is not None:
                        flat_ids = logit_id_view.to(torch.long).reshape(-1)
                        weight_rows = weight_det.index_select(0, flat_ids).view(count, logit_topk, -1)
                        pred_logits = pred
                        if pred_logits.dtype != weight_rows.dtype:
                            pred_logits = pred_logits.to(dtype=weight_rows.dtype)
                        dist_topk = torch.bmm(weight_rows, pred_logits.unsqueeze(-1)).squeeze(-1)
                        if bias_det is not None:
                            bias_rows = bias_det.index_select(0, flat_ids).view(count, logit_topk)
                            dist_topk = dist_topk + bias_rows.to(dist_topk.dtype)
                        logit_loss = F.mse_loss(dist_topk.float(), logit_val_view.float())
                        loss = loss + self.logit_loss_weight * logit_loss

                    # Detach metrics so they don't create backward graph branches.
                    pred_det = pred.detach()
                    tgt_det = tgt.detach()
                    loss_out.copy_(loss.detach())
                    cos_out.copy_(F.cosine_similarity(pred_det, tgt_det, dim=-1).mean())
                    pred_norm_out.copy_(pred_det.norm(dim=-1).mean())
                    target_norm_out.copy_(tgt_det.norm(dim=-1).mean())

                    loss.backward()
                    self.optimizer.step()

                with torch.cuda.stream(self.stream):
                    # Capture needs autograd on even if the caller is in no_grad/inference_mode.
                    with torch.inference_mode(False), torch.enable_grad():
                        # Seed buffers to avoid NaNs in warmup metrics.
                        src_view.copy_(torch.randn_like(src_view))
                        target_view.copy_(torch.randn_like(target_view))
                        if use_logit_loss and logit_id_view is not None and logit_val_view is not None:
                            logit_id_view.zero_()
                            logit_val_view.copy_(torch.randn_like(logit_val_view))
                        for _ in range(self.cudagraph_warmup_iters):
                            _run_train()
                torch.cuda.synchronize(device)

                g = torch.cuda.CUDAGraph()
                with torch.cuda.stream(self.stream):
                    with torch.inference_mode(False), torch.enable_grad():
                        with torch.cuda.graph(g, pool=self._cudagraph_pool):
                            _run_train()

                self._train_graphs[key] = _CapturedTrainGraph(
                    graph=g,
                    src_view=src_view,
                    target_view=target_view,
                    loss_out=loss_out,
                    cos_out=cos_out,
                    pred_norm_out=pred_norm_out,
                    target_norm_out=target_norm_out,
                )
                self._note_capture("train", time.perf_counter() - capture_start)
            except Exception as e:
                if _env_flag("VERL_OVERLAP_CUDAGRAPH_DEBUG", False):
                    import traceback

                    print(
                        f"[verl.overlap] CUDA graph capture (train) failed for slot={slot} "
                        f"count={count}: {e}"
                    )
                    traceback.print_exc()
                self.cudagraph_capture_train = False
                self._cudagraph_sig = self._cudagraph_signature()

    def _ensure_events(self) -> None:
        # Create per-slot CUDA events used to synchronize buffers/streams.
        if not torch.cuda.is_available():
            return
        if self.events.get("slots") == self.buffer_slots:
            return
        # events: h1_ready/pred_ready/hl_ready/slot_free per slot.
        self.events = {
            "slots": self.buffer_slots,
            "h1_ready": [torch.cuda.Event() for _ in range(self.buffer_slots)], # hidden layer 1 ready
            "pred_ready": [torch.cuda.Event() for _ in range(self.buffer_slots)],
            "hl_ready": [torch.cuda.Event() for _ in range(self.buffer_slots)], # hidden layer Last ready
            "logits_ready": [torch.cuda.Event() for _ in range(self.buffer_slots)],
            "slot_free": [torch.cuda.Event() for _ in range(self.buffer_slots)],
        }
        for ev in self.events["slot_free"]:
            ev.record(torch.cuda.current_stream())

    def _ensure_buffers(self, hidden_dim: int, max_rows: int, device: torch.device) -> None:
        # Allocate/resize GPU buffers and distiller model as needed.
        if self.hidden_dim == hidden_dim and self.max_rows >= max_rows and self.h1_buf is not None:
            return
        # Any re-allocation invalidates captured CUDA graphs because tensor storage pointers change.
        self._clear_cudagraph_cache()
        self.hidden_dim = hidden_dim
        self.max_rows = max(max_rows, 1)
        shape = (self.buffer_slots, self.max_rows, hidden_dim)
        with torch.inference_mode(False):
            self.h1_buf = torch.empty(shape, device=device, dtype=torch.float32)
            self.hl_buf = torch.empty(shape, device=device, dtype=torch.float32)
            self.pred_buf = torch.empty(shape, device=device, dtype=torch.float32)

            if not self.batched_params:
                self.predictor = TransitionMLP(
                    hidden_dim,
                    hidden_dim,
                    mlp_hidden_dim=self.mlp_hidden_dim,
                ).to(device)
                self._sync_predictor_norm()
            self.predictor.train(self.train_enabled)
                # Use fused Adam if available for lower dispatch overhead.
                # When CUDA graph capture is enabled, prefer capturable optimizers.
                want_capturable = bool(
                    self.cudagraph_enabled and self.cudagraph_capture_train and torch.cuda.is_available()
                )
                try:
                    self.optimizer = torch.optim.Adam(
                        self.predictor.parameters(),
                        lr=5e-4,
                        fused=True,
                        capturable=want_capturable,
                    )
                except Exception:
                    try:
                        self.optimizer = torch.optim.Adam(
                            self.predictor.parameters(),
                            lr=5e-4,
                            capturable=want_capturable,
                        )
                    except Exception:
                        self.optimizer = torch.optim.Adam(self.predictor.parameters(), lr=5e-4)
            else:
                self.predictor = None
                self.optimizer = None

    def set_final_norm(
        self,
        *,
        weight: Optional[torch.Tensor],
        bias: Optional[torch.Tensor],
        eps: float,
        norm_type: str,
    ) -> None:
        # Cache final norm parameters on GPU for pred normalization.
        # Updating these parameters changes tensor storage pointers -> invalidate captured graphs.
        self._clear_cudagraph_cache()
        if weight is None:
            self.final_norm_weight = None
            self.final_norm_bias = None
            self.final_norm_eps = float(eps)
            self.final_norm_type = norm_type
            self._sync_predictor_norm()
            return
        with torch.inference_mode(False):
            self.final_norm_weight = weight.detach().clone().to(torch.float32)
            if bias is not None:
                self.final_norm_bias = bias.detach().clone().to(torch.float32)
            else:
                self.final_norm_bias = None
        self.final_norm_eps = float(eps)
        self.final_norm_type = norm_type
        self._sync_predictor_norm()

    def set_reward_scale_norms(
        self,
        *,
        weight: Optional[torch.Tensor],
        dtype: Optional[torch.dtype] = None,
    ) -> None:
        if weight is None or not self.reward_scale:
            self.weight_inv_norms = None
            self.weight_inv_norms_device = None
            self.weight_inv_norms_dtype = None
            return
        target_dtype = dtype or weight.dtype
        with torch.inference_mode(False):
            norms = weight.norm(dim=1).to(dtype=torch.float32) + 1e-6
            self.weight_inv_norms = norms.reciprocal().to(
                device=weight.device,
                dtype=target_dtype,
            )
            self.weight_inv_norms_device = weight.device
            self.weight_inv_norms_dtype = target_dtype

    def _apply_final_norm(self, hidden: torch.Tensor) -> torch.Tensor:
        # Apply cached final norm (prefer torch native kernels).
        weight = self.final_norm_weight
        if weight is None:
            return hidden
        eps = self.final_norm_eps
        weight = weight.to(device=hidden.device, dtype=hidden.dtype)
        bias = self.final_norm_bias
        if bias is not None:
            bias = bias.to(device=hidden.device, dtype=hidden.dtype)
        if self.final_norm_type == "layer":
            out = F.layer_norm(
                hidden,
                (hidden.shape[-1],),
                weight=weight,
                bias=bias,
                eps=eps,
            )
            return out
        if self.final_norm_type == "rms" and hasattr(F, "rms_norm"):
            out = F.rms_norm(
                hidden,
                (hidden.shape[-1],),
                weight=weight,
                eps=eps,
            )
            if bias is not None:
                out = out + bias
            return out
        if self.final_norm_type == "layer":
            mean = hidden.mean(dim=-1, keepdim=True)
            var = hidden.var(dim=-1, unbiased=False, keepdim=True)
            out = (hidden - mean) * torch.rsqrt(var + eps)
        else:
            var = hidden.pow(2).mean(dim=-1, keepdim=True)
            out = hidden * torch.rsqrt(var + eps)
        out = out * weight
        if bias is not None:
            out = out + bias
        return out

    def _ensure_log_buffers(self, topk: int, device: torch.device) -> None:
        # Allocate GPU ring buffer for per-step mix logs.
        if (
            self.log_id_buf is not None
            and self.log_topk == topk
            and self.log_id_buf.device == device
        ):
            return
        self.log_topk = topk
        self.log_pos = 0
        with torch.inference_mode(False):
            # Ring buffers are [log_capacity, topk] for ids/values.
            self.log_step_buf = torch.empty((self.log_capacity,), device=device, dtype=torch.int32)
            self.log_row_buf = torch.empty((self.log_capacity,), device=device, dtype=torch.int32)
            self.log_req_buf = torch.empty((self.log_capacity,), device=device, dtype=torch.int32)
            self.log_id_buf = torch.empty((self.log_capacity, topk), device=device, dtype=torch.int32)
            self.log_pre_buf = torch.empty((self.log_capacity, topk), device=device, dtype=torch.float16)
            self.log_dist_buf = torch.empty((self.log_capacity, topk), device=device, dtype=torch.float16)
            self.log_post_buf = torch.empty((self.log_capacity, topk), device=device, dtype=torch.float16)

    def _ensure_logit_buffers(self, *, topk: int, max_rows: int, device: torch.device) -> None:
        if (
            self.logit_id_buf is not None
            and self.logit_val_buf is not None
            and self.logit_topk == topk
            and self.logit_max_rows >= max_rows
            and self.logit_id_buf.device == device
        ):
            return
        self._clear_cudagraph_cache()
        self.logit_topk = int(topk)
        self.logit_max_rows = max(int(max_rows), 1)
        shape = (self.buffer_slots, self.logit_max_rows, self.logit_topk)
        with torch.inference_mode(False):
            self.logit_id_buf = torch.empty(shape, device=device, dtype=torch.int32)
            self.logit_val_buf = torch.empty(shape, device=device, dtype=torch.float32)

    def _record_logit_targets(
        self,
        *,
        top_ids: torch.Tensor,
        top_vals: torch.Tensor,
        weight: torch.Tensor,
        embedding_bias: Optional[torch.Tensor],
        stream: torch.cuda.Stream,
        device: torch.device,
    ) -> None:
        if self.logit_loss_weight <= 0.0 or not self.train_enabled:
            return
        if self.current_slot is None:
            return
        if top_ids.numel() == 0:
            return
        self._ensure_logit_buffers(topk=top_ids.size(1), max_rows=top_ids.size(0), device=device)
        if self.logit_id_buf is None or self.logit_val_buf is None:
            return
        slot = self.current_slot
        count = top_ids.size(0)
        self.logit_id_buf[slot, :count, : top_ids.size(1)].copy_(top_ids.to(torch.int32))
        self.logit_val_buf[slot, :count, : top_vals.size(1)].copy_(top_vals.to(torch.float32))
        self.logit_ready[slot] = True
        self.lm_head_weight = weight.detach()
        self.lm_head_bias = embedding_bias.detach() if isinstance(embedding_bias, torch.Tensor) else None
        self.events["logits_ready"][slot].record(stream)

    def record_logits(
        self,
        logits: torch.Tensor,
        *,
        lm_head,
        embedding_bias: Optional[torch.Tensor],
    ) -> None:
        if self.logit_loss_weight <= 0.0 or not self.train_enabled:
            return
        if not self.current_active or self.current_slot is None:
            return
        if logits is None or not isinstance(logits, torch.Tensor):
            return
        row_idx = self._row_tensor(logits.device)
        if row_idx is None or row_idx.numel() == 0:
            return
        count = row_idx.numel()
        if count <= 0:
            return
        weight = getattr(lm_head, "weight", None)
        if not isinstance(weight, torch.Tensor):
            return

        stream = torch.cuda.current_stream(logits.device)
        rows = logits.index_select(0, row_idx)
        topk = min(self.topk, rows.shape[-1])
        top_vals, top_ids = torch.topk(rows, k=topk, dim=-1)
        self._record_logit_targets(
            top_ids=top_ids,
            top_vals=top_vals,
            weight=weight,
            embedding_bias=embedding_bias,
            stream=stream,
            device=logits.device,
        )

    def _record_mix_log(
        self,
        *,
        row_idx: torch.Tensor,
        top_ids: torch.Tensor,
        pre_vals: torch.Tensor,
        dist_vals: torch.Tensor,
        post_vals: torch.Tensor,
        device: torch.device,
    ) -> None:
        # Append entries to the GPU ring buffer (no CPU sync).
        if self.log_capacity <= 0:
            return
        if self.log_step_buf is None:
            self._ensure_log_buffers(top_ids.size(1), device)
        if self.log_step_buf is None:
            return
        count = top_ids.size(0)
        if count <= 0:
            return
        offset = self.log_pos % self.log_capacity
        # Wrap-around handling: if writing contiguous block exceeds capacity, split it?
        # For simplicity, just use modulo indexing which handles wrap-around but requires index tensor.
        idx = (torch.arange(count, device=device, dtype=torch.long) + offset) % self.log_capacity
        
        # Broadcast scalar assignments
        self.log_step_buf[idx] = int(self.step)
        self.log_row_buf.index_copy_(0, idx, row_idx.to(dtype=torch.int32))
        if self.log_req_buf is not None:
            self.log_req_buf[idx] = int(self.current_req_id)
            
        self.log_id_buf.index_copy_(0, idx, top_ids.to(dtype=torch.int32))
        self.log_pre_buf.index_copy_(0, idx, pre_vals.to(dtype=torch.float16))
        self.log_dist_buf.index_copy_(0, idx, dist_vals.to(dtype=torch.float16))
        self.log_post_buf.index_copy_(0, idx, post_vals.to(dtype=torch.float16))
        self.log_pos += count

    def _ensure_loss_buffers(self, device: torch.device) -> None:
        if self.loss_val_buf is not None and self.loss_val_buf.device == device:
            return
        self.loss_pos = 0
        with torch.inference_mode(False):
            self.loss_step_buf = torch.empty((self.loss_capacity,), device=device, dtype=torch.int32)
            self.loss_train_step_buf = torch.empty((self.loss_capacity,), device=device, dtype=torch.int32)
            self.loss_val_buf = torch.empty((self.loss_capacity,), device=device, dtype=torch.float32)
            self.loss_cos_buf = torch.empty((self.loss_capacity,), device=device, dtype=torch.float32)
            self.loss_pred_norm_buf = torch.empty((self.loss_capacity,), device=device, dtype=torch.float32)
            self.loss_target_norm_buf = torch.empty((self.loss_capacity,), device=device, dtype=torch.float32)
            self.loss_req_buf = torch.empty((self.loss_capacity,), device=device, dtype=torch.int32)

    def _record_loss(
        self,
        loss: torch.Tensor,
        *,
        cos_val: Optional[torch.Tensor],
        pred_norm: Optional[torch.Tensor],
        target_norm: Optional[torch.Tensor],
        step: int,
        train_step: int,
        device: torch.device,
    ) -> None:
        if self.loss_capacity <= 0:
            return
        if self.loss_val_buf is None:
            self._ensure_loss_buffers(device)
        if self.loss_val_buf is None:
            return
        offset = self.loss_pos % self.loss_capacity
        
        # Direct scalar assignment is more efficient (avoids creating CPU tensors and copy overhead)
        self.loss_step_buf[offset] = int(step)
        self.loss_train_step_buf[offset] = int(train_step)
        if self.loss_req_buf is not None:
            self.loss_req_buf[offset] = int(self.current_req_id)
            
        self.loss_val_buf[offset] = loss.detach()
        if self.loss_cos_buf is not None and cos_val is not None:
            self.loss_cos_buf[offset] = cos_val.detach()
        if self.loss_pred_norm_buf is not None and pred_norm is not None:
            self.loss_pred_norm_buf[offset] = pred_norm.detach()
        if self.loss_target_norm_buf is not None and target_norm is not None:
            self.loss_target_norm_buf[offset] = target_norm.detach()
        self.loss_pos += 1

    def _collect_log_arrays(self):
        if self.log_step_buf is None or self.log_id_buf is None:
            return None
        entries = min(self.log_pos, self.log_capacity)
        topk = self.log_topk or int(self.log_id_buf.size(1))
        if entries <= 0:
            return {
                "entries": 0,
                "topk": topk,
                "step": [],
                "row": [],
                "req": [],
                "ids": [],
                "pre": [],
                "dist": [],
                "post": [],
            }
        device = self.log_step_buf.device
        if self.log_pos <= self.log_capacity:
            order = torch.arange(entries, device=device, dtype=torch.long)
        else:
            start = self.log_pos % self.log_capacity
            order = (torch.arange(entries, device=device, dtype=torch.long) + start) % self.log_capacity
        step_cpu = self.log_step_buf.index_select(0, order).cpu().tolist()
        row_cpu = self.log_row_buf.index_select(0, order).cpu().tolist()
        if self.log_req_buf is None:
            req_cpu = [-1] * entries
        else:
            req_cpu = self.log_req_buf.index_select(0, order).cpu().tolist()
        ids_cpu = self.log_id_buf.index_select(0, order).cpu().tolist()
        pre_cpu = self.log_pre_buf.index_select(0, order).cpu().tolist()
        dist_cpu = self.log_dist_buf.index_select(0, order).cpu().tolist()
        post_cpu = self.log_post_buf.index_select(0, order).cpu().tolist()
        return {
            "entries": entries,
            "topk": topk,
            "step": step_cpu,
            "row": row_cpu,
            "req": req_cpu,
            "ids": ids_cpu,
            "pre": pre_cpu,
            "dist": dist_cpu,
            "post": post_cpu,
        }

    def dump_log(self, file_path: str) -> int:
        # Flush ring buffer to CSV on CPU at the end of the run.
        if not file_path:
            return 0
        data = self._collect_log_arrays()
        if data is None:
            with open(file_path, "w", encoding="utf-8") as f:
                f.write("step,req_id,row\n")
            return 0
        entries = data["entries"]
        topk = data["topk"]
        if entries <= 0:
            with open(file_path, "w", encoding="utf-8") as f:
                f.write("step,req_id,row\n")
            return 0
        header = (
            ["step", "req_id", "row"]
            + [f"tid_{i}" for i in range(topk)]
            + [f"pre_{i}" for i in range(topk)]
            + [f"dist_{i}" for i in range(topk)]
            + [f"post_{i}" for i in range(topk)]
        )
        with open(file_path, "w", encoding="utf-8") as f:
            f.write(",".join(header) + "\n")
            for idx in range(entries):
                row = [data["step"][idx], data["req"][idx], data["row"][idx]]
                row.extend(data["ids"][idx])
                row.extend(data["pre"][idx])
                row.extend(data["dist"][idx])
                row.extend(data["post"][idx])
                f.write(",".join(str(v) for v in row) + "\n")
        return entries

    def _collect_loss_arrays(self):
        if self.loss_val_buf is None:
            return None
        entries = min(self.loss_pos, self.loss_capacity)
        if entries <= 0:
            return {
                "entries": 0,
                "step": [],
                "train": [],
                "req": [],
                "loss": [],
                "cos": [],
                "pred_norm": [],
                "target_norm": [],
            }
        device = self.loss_val_buf.device
        if self.loss_pos <= self.loss_capacity:
            order = torch.arange(entries, device=device, dtype=torch.long)
        else:
            start = self.loss_pos % self.loss_capacity
            order = (torch.arange(entries, device=device, dtype=torch.long) + start) % self.loss_capacity
        step_cpu = self.loss_step_buf.index_select(0, order).cpu().tolist()
        train_cpu = self.loss_train_step_buf.index_select(0, order).cpu().tolist()
        loss_cpu = self.loss_val_buf.index_select(0, order).cpu().tolist()
        if self.loss_req_buf is None:
            req_cpu = [-1] * entries
        else:
            req_cpu = self.loss_req_buf.index_select(0, order).cpu().tolist()
        if self.loss_cos_buf is None:
            cos_cpu = [0.0] * entries
        else:
            cos_cpu = self.loss_cos_buf.index_select(0, order).cpu().tolist()
        if self.loss_pred_norm_buf is None:
            pred_norm_cpu = [0.0] * entries
        else:
            pred_norm_cpu = self.loss_pred_norm_buf.index_select(0, order).cpu().tolist()
        if self.loss_target_norm_buf is None:
            target_norm_cpu = [0.0] * entries
        else:
            target_norm_cpu = self.loss_target_norm_buf.index_select(0, order).cpu().tolist()
        return {
            "entries": entries,
            "step": step_cpu,
            "train": train_cpu,
            "req": req_cpu,
            "loss": loss_cpu,
            "cos": cos_cpu,
            "pred_norm": pred_norm_cpu,
            "target_norm": target_norm_cpu,
        }

    def dump_loss_log(self, file_path: str) -> int:
        if not file_path:
            return 0
        if self.loss_val_buf is None:
            with open(file_path, "w", encoding="utf-8") as f:
                f.write("step,req_id,train_step,loss,cos,pred_norm,target_norm\n")
            return 0
        data = self._collect_loss_arrays()
        if data is None or data["entries"] <= 0:
            with open(file_path, "w", encoding="utf-8") as f:
                f.write("step,req_id,train_step,loss,cos,pred_norm,target_norm\n")
            return 0
        entries = data["entries"]
        with open(file_path, "w", encoding="utf-8") as f:
            f.write("step,req_id,train_step,loss,cos,pred_norm,target_norm\n")
            for idx in range(entries):
                f.write(
                    f"{data['step'][idx]},{data['req'][idx]},"
                    f"{data['train'][idx]},{data['loss'][idx]},"
                    f"{data['cos'][idx]},{data['pred_norm'][idx]},"
                    f"{data['target_norm'][idx]}\n"
                )
        return entries

    def start_step(
        self,
        sampling_metadata,
        attn_metadata,
        device: torch.device,
        req_id: Optional[int] = None,
    ) -> None:
        # Initialize per-step state and claim a buffer slot for decode rows.
        self.current_active = False
        self.current_rows = None
        self.current_count = 0
        self.current_slot = None
        self.current_stream = None
        self.current_req_id = -1 if req_id is None else int(req_id)

        if not self.enabled:
            return
        if attn_metadata is None:
            return
        stage = _stage_from_attn(attn_metadata)
        if stage not in {"decode", "mixed"}:
            return

        rows = _collect_decode_rows(sampling_metadata)
        if rows is None:
            return
        if isinstance(rows, torch.Tensor):
            row_count = int(rows.numel())
        else:
            row_count = len(rows)
        if row_count == 0:
            return

        self.step += 1
        slot = self.step % self.buffer_slots
        self.current_slot = slot
        self.current_rows = rows
        self.current_count = row_count
        self.current_active = True
        self.current_stream = torch.cuda.current_stream(device)
        self.current_stream.wait_event(self.events["slot_free"][slot])
        if slot < len(self.logit_ready):
            self.logit_ready[slot] = False

        if self.cudagraph_enabled:
            self._maybe_init_buffers_for_cudagraph(device=device, max_rows=row_count)
            self._maybe_capture_pred_graph(slot=slot, count=row_count, device=device)
            self._maybe_capture_train_graph(slot=slot, count=row_count, device=device)

        # Optional: pre-initialize buffers and capture CUDA graphs before the first layer hook
        # fires (hook runs during model forward, where graph capture is unsafe).
        if self.cudagraph_enabled:
            self._maybe_init_buffers_for_cudagraph(device=device, max_rows=row_count)
            self._maybe_capture_pred_graph(slot=slot, count=row_count, device=device)
            self._maybe_capture_train_graph(slot=slot, count=row_count, device=device)

    def start_step_with_rows(
        self,
        rows: torch.Tensor,
        device: torch.device,
        req_id: Optional[int] = None,
    ) -> None:
        # Initialize per-step state using explicit row indices (v1 helper).
        self.current_active = False
        self.current_rows = None
        self.current_count = 0
        self.current_slot = None
        self.current_stream = None
        self.current_req_id = -1 if req_id is None else int(req_id)

        if not self.enabled:
            return
        if rows is None:
            return
        if isinstance(rows, torch.Tensor):
            row_count = int(rows.numel())
        else:
            try:
                row_count = len(rows)
            except TypeError:
                row_count = 0
        if row_count <= 0:
            return

        self.step += 1
        slot = self.step % self.buffer_slots
        self.current_slot = slot
        self.current_rows = rows
        self.current_count = row_count
        self.current_active = True
        self.current_stream = torch.cuda.current_stream(device)
        self.current_stream.wait_event(self.events["slot_free"][slot])
        if slot < len(self.logit_ready):
            self.logit_ready[slot] = False

        if self.cudagraph_enabled:
            self._maybe_init_buffers_for_cudagraph(device=device, max_rows=row_count)
            self._maybe_capture_pred_graph(slot=slot, count=row_count, device=device)
            self._maybe_capture_train_graph(slot=slot, count=row_count, device=device)

    def _row_tensor(self, device: torch.device) -> Optional[torch.Tensor]:
        # Materialize current row indices on GPU.
        if self.current_rows is None or self.current_count == 0:
            return None
        if isinstance(self.current_rows, torch.Tensor):
            rows = self.current_rows
            if rows.device != device or rows.dtype != torch.long:
                rows = rows.to(device=device, dtype=torch.long)
            return rows
        return torch.tensor(self.current_rows, device=device, dtype=torch.long)

    def record_h1(self, output: torch.Tensor) -> None:
        # Save first-layer hidden and kick distiller infer on aux stream.
        if not self.current_active or self.current_slot is None:
            return
        if output is None:
            return
        if isinstance(output, (tuple, list)):
            output = output[0]
        if not isinstance(output, torch.Tensor):
            return

        device = output.device
        self._ensure_buffers(output.shape[-1], output.shape[0], device)
        row_idx = self._row_tensor(device)
        if row_idx is None or row_idx.numel() == 0:
            return

        slot = self.current_slot
        source = output.index_select(0, row_idx).to(torch.float32)
        self.h1_buf[slot, : source.size(0)].copy_(source)
        self.events["h1_ready"][slot].record(self.current_stream)

        if not self.stream or not self.predictor:
            return
        count = int(source.size(0))
        with torch.cuda.stream(self.stream):
            self.stream.wait_event(self.events["h1_ready"][slot])

            # Use captured CUDA graph if available (capture happens in start_step).
            if self.cudagraph_enabled and self.cudagraph_capture_pred:
                cg = self._pred_graphs.get((int(slot), int(count)))
            else:
                cg = None
            if cg is not None:
                cg.graph.replay()
            else:
                src = self.h1_buf[slot, :count]
                if self.normalize_inputs:
                    src = _layer_norm_batch(src)
                pred = self.predictor(src)
                self.pred_buf[slot, : pred.size(0)].copy_(pred)

            self.events["pred_ready"][slot].record(self.stream)

    def record_h1_from_buffer(
        self,
        source_buf: torch.Tensor,
        ready_event: Optional[torch.cuda.Event],
    ) -> None:
        # Use an externally-populated buffer (e.g. from cudagraph) as h1 input.
        if not self.current_active or self.current_slot is None:
            return
        if not isinstance(source_buf, torch.Tensor):
            return
        if source_buf.dim() != 2:
            return
        device = source_buf.device
        row_idx = self._row_tensor(device)
        if row_idx is None or row_idx.numel() == 0:
            return
        max_rows = int(source_buf.shape[0])
        if max_rows <= 0:
            return
        if row_idx.min().item() < 0 or row_idx.max().item() >= max_rows:
            row_idx = row_idx[(row_idx >= 0) & (row_idx < max_rows)]
            if row_idx.numel() == 0:
                return

        self._ensure_buffers(source_buf.shape[-1], source_buf.shape[0], device)
        slot = self.current_slot
        if not self.stream or not self.predictor:
            return
        with torch.cuda.stream(self.stream):
            if ready_event is not None:
                self.stream.wait_event(ready_event)
            source = source_buf.index_select(0, row_idx).to(torch.float32)
            self.h1_buf[slot, : source.size(0)].copy_(source)
            self.events["h1_ready"][slot].record(self.stream)
            src = self.h1_buf[slot, : source.size(0)]
            if self.normalize_inputs:
                src = _layer_norm_batch(src)
            pred = self.predictor(src)
            self.pred_buf[slot, : pred.size(0)].copy_(pred)
            self.events["pred_ready"][slot].record(self.stream)

    def record_hl(self, hidden_states: torch.Tensor) -> None:
        # Save last-layer hidden for training target.
        if not self.current_active or self.current_slot is None:
            return
        if hidden_states is None or not isinstance(hidden_states, torch.Tensor):
            return
        if isinstance(hidden_states, (tuple, list)):
            hidden_states = hidden_states[0]

        device = hidden_states.device
        row_idx = self._row_tensor(device)
        if row_idx is None or row_idx.numel() == 0:
            return
        self._ensure_buffers(hidden_states.shape[-1], hidden_states.shape[0], device)
        slot = self.current_slot
        selected = hidden_states.index_select(0, row_idx).to(torch.float32)
        count = selected.size(0)
        self.hl_buf[slot, :count].copy_(selected)
        self.events["hl_ready"][slot].record(self.current_stream)

    def maybe_train(self) -> None:
        # Run distiller training on aux stream once hl is ready.
        if not self.train_enabled or not self.predictor or not self.optimizer:
            return
        if not self.stream or self.current_slot is None or self.current_count == 0:
            return
        slot = self.current_slot
        count = self.current_count
        logit_ready = slot < len(self.logit_ready) and self.logit_ready[slot]
        use_logit_loss = (
            self.logit_loss_weight > 0.0
            and logit_ready
            and isinstance(self.lm_head_weight, torch.Tensor)
            and self.logit_id_buf is not None
            and self.logit_val_buf is not None
            and self.logit_topk > 0
        )
        with torch.cuda.stream(self.stream):
            self.stream.wait_event(self.events["hl_ready"][slot])
            if use_logit_loss:
                self.stream.wait_event(self.events["logits_ready"][slot])
            self.train_tick += 1
            if (
                self.train_tick > TRAIN_WARMUP_STEPS
                and self.train_tick % self.train_sync_interval != 0
            ):
                self.events["slot_free"][slot].record(self.stream)
                return

            # Prefer a captured training graph if available.
            cg = (
                self._train_graphs.get((int(slot), int(count)))
                if (self.cudagraph_enabled and self.cudagraph_capture_train)
                else None
            )
            if cg is not None and (self.logit_loss_weight <= 0.0 or use_logit_loss):
                cg.graph.replay()
                self._record_loss(
                    cg.loss_out,
                    cos_val=cg.cos_out,
                    pred_norm=cg.pred_norm_out,
                    target_norm=cg.target_norm_out,
                    step=self.step,
                    train_step=self.train_steps + 1,
                    device=cg.loss_out.device,
                )
                self.train_steps += 1
                self.events["slot_free"][slot].record(self.stream)
                return

            # Eager fallback.
            with torch.enable_grad():
                src = self.h1_buf[slot, :count]
                if self.normalize_inputs:
                    src = _layer_norm_batch(src)
                pred = self.predictor(src)
                target = self.hl_buf[slot, :count]
                loss = F.mse_loss(pred, target)
                if use_logit_loss:
                    topk = int(self.logit_topk)
                    top_ids = self.logit_id_buf[slot, :count, :topk]
                    top_vals = self.logit_val_buf[slot, :count, :topk]
                    weight_det = self.lm_head_weight.detach()
                    flat_ids = top_ids.to(torch.long).reshape(-1)
                    weight_rows = weight_det.index_select(0, flat_ids).view(count, topk, -1)
                    pred_logits = pred
                    if pred_logits.dtype != weight_rows.dtype:
                        pred_logits = pred_logits.to(dtype=weight_rows.dtype)
                    dist_topk = torch.bmm(weight_rows, pred_logits.unsqueeze(-1)).squeeze(-1)
                    if isinstance(self.lm_head_bias, torch.Tensor):
                        bias_rows = self.lm_head_bias.detach().index_select(0, flat_ids).view(count, topk)
                        dist_topk = dist_topk + bias_rows.to(dist_topk.dtype)
                    logit_loss = F.mse_loss(dist_topk.float(), top_vals.float())
                    loss = loss + self.logit_loss_weight * logit_loss
                cos_val = None
                pred_norm = None
                target_norm = None
                if count > 0:
                    cos_val = F.cosine_similarity(pred.detach(), target.detach(), dim=-1).mean()
                    pred_norm = pred.detach().norm(dim=-1).mean()
                    target_norm = target.detach().norm(dim=-1).mean()
                self._record_loss(
                    loss,
                    cos_val=cos_val,
                    pred_norm=pred_norm,
                    target_norm=target_norm,
                    step=self.step,
                    train_step=self.train_steps + 1,
                    device=src.device,
                )
                loss.backward()
                self.optimizer.step()
                # Keep grad tensors allocated for potential future CUDA graph capture.
                self.optimizer.zero_grad(set_to_none=False)
                self.train_steps += 1
            self.events["slot_free"][slot].record(self.stream)

    def mix_hidden(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if not self.current_active or self.current_slot is None:
            return hidden_states
        if not isinstance(hidden_states, torch.Tensor):
            return hidden_states
        if self.beta <= 0.0 or self.pred_buf is None:
            self._mark_slot_free_no_train(hidden_states.device)
            return hidden_states

        slot = self.current_slot
        row_idx = self._row_tensor(hidden_states.device)
        if row_idx is None or row_idx.numel() == 0:
            self._mark_slot_free_no_train(hidden_states.device)
            return hidden_states
        count = row_idx.numel()
        if count <= 0:
            self._mark_slot_free_no_train(hidden_states.device)
            return hidden_states

        stream = torch.cuda.current_stream(hidden_states.device)
        stream.wait_event(self.events["pred_ready"][slot])

        pred = self.pred_buf[slot, :count]
        if pred.dtype != hidden_states.dtype:
            pred = pred.to(dtype=hidden_states.dtype)
        rows = hidden_states.index_select(0, row_idx)
        if self.curiosity_only:
            mixed = rows - self.beta * pred
        else:
            mixed = (1.0 + self.beta) * rows - self.beta * pred
        hidden_states.index_copy_(0, row_idx, mixed)
        if not self.train_enabled:
            self.events["slot_free"][slot].record(stream)
        return hidden_states

    def mix_logits(self, logits: torch.Tensor, *, lm_head, embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
        # Mix distiller logits into top-k rows and record logs.
        if self.mix_mode != "logits":
            self._mark_slot_free_no_train(logits.device)
            return logits
        if not self.current_active or self.current_slot is None:
            return logits
        if not isinstance(logits, torch.Tensor):
            return logits
        if self.beta <= 0.0:
            self._mark_slot_free_no_train(logits.device)
            return logits
        if self.pred_buf is None:
            self._mark_slot_free_no_train(logits.device)
            return logits

        slot = self.current_slot
        row_idx = self._row_tensor(logits.device)
        if row_idx is None or row_idx.numel() == 0:
            self._mark_slot_free_no_train(logits.device)
            return logits
        count = row_idx.numel()
        if count <= 0:
            self._mark_slot_free_no_train(logits.device)
            return logits

        stream = torch.cuda.current_stream(logits.device)
        stream.wait_event(self.events["pred_ready"][slot])

        pred = self.pred_buf[slot, :count]
        weight = getattr(lm_head, "weight", None)
        if not isinstance(weight, torch.Tensor):
            self._mark_slot_free_no_train(logits.device)
            return logits
        target_dtype = weight.dtype
        if pred.dtype != target_dtype:
            pred = pred.to(dtype=target_dtype)

        rows = logits.index_select(0, row_idx)
        topk = min(self.topk, rows.shape[-1])
        top_vals, top_ids = torch.topk(rows, k=topk, dim=-1)
        self._record_logit_targets(
            top_ids=top_ids,
            top_vals=top_vals,
            weight=weight,
            embedding_bias=embedding_bias,
            stream=stream,
            device=logits.device,
        )
        flat_ids = top_ids.reshape(-1)
        weight_rows = weight.index_select(0, flat_ids).view(count, topk, -1)
        # Compute distiller's top-k logits via batched GEMM (usually cheaper than einsum
        # dispatch for small k).
        dist_topk = torch.bmm(weight_rows, pred.unsqueeze(-1)).squeeze(-1)
        if embedding_bias is not None:
            bias_rows = embedding_bias.index_select(0, flat_ids).view(count, topk)
            dist_topk = dist_topk + bias_rows
        if self.logits_rescale:
            pre_mean = top_vals.abs().to(torch.float32).mean()
            dist_mean = dist_topk.abs().to(torch.float32).mean()
            scale = pre_mean / (dist_mean + 1e-6)
            dist_topk = dist_topk * scale

        if self.curiosity_only:
            mixed = top_vals - self.beta * dist_topk
        elif self.reward_scale:
            if (
                self.weight_inv_norms is None
                or self.weight_inv_norms_device != weight.device
                or self.weight_inv_norms_dtype != top_vals.dtype
                or self.weight_inv_norms.numel() != weight.size(0)
            ):
                norms = weight.norm(dim=1).to(dtype=torch.float32) + 1e-6
                self.weight_inv_norms = norms.reciprocal().to(dtype=top_vals.dtype)
                self.weight_inv_norms_device = weight.device
                self.weight_inv_norms_dtype = top_vals.dtype
            inv_norms = self.weight_inv_norms.index_select(0, flat_ids).view(count, topk)
            diff = top_vals - dist_topk
            normalized_diff = diff * inv_norms
            mixed = top_vals + self.beta * normalized_diff
        else:
            mixed = (1.0 + self.beta) * top_vals - self.beta * dist_topk
        rows.scatter_(1, top_ids, mixed)
        logits.index_copy_(0, row_idx, rows)
        self._record_mix_log(
            row_idx=row_idx,
            top_ids=top_ids,
            pre_vals=top_vals,
            dist_vals=dist_topk,
            post_vals=mixed,
            device=logits.device,
        )

        if not self.train_enabled:
            self.events["slot_free"][slot].record(stream)
        return logits

    def _mark_slot_free_no_train(self, device: torch.device) -> None:
        # Release slot immediately when training is disabled.
        if self.train_enabled:
            return
        if self.current_slot is None:
            return
        stream = torch.cuda.current_stream(device)
        self.events["slot_free"][self.current_slot].record(stream)


_RUNTIME = _DistillerRuntime()


def update_distiller_overlap_state(
    *,
    enabled: bool,
    beta: float,
    topk: int,
    buffer_slots: int,
    first_layer: int,
    mix_mode: str,
    train_enabled: bool,
    train_sync_interval: int,
    normalize_inputs: bool = True,
    reward_scale: bool = False,
    logits_rescale: bool = False,
    logit_loss_weight: float = 0.0,
    curiosity_only: bool = False,
) -> None:
    # External entrypoint to configure runtime from client.
    _RUNTIME.configure(
        enabled=enabled,
        beta=beta,
        topk=topk,
        buffer_slots=buffer_slots,
        first_layer=first_layer,
        mix_mode=mix_mode,
        train_enabled=train_enabled,
        train_sync_interval=train_sync_interval,
        normalize_inputs=normalize_inputs,
        reward_scale=reward_scale,
        logits_rescale=logits_rescale,
        logit_loss_weight=logit_loss_weight,
        curiosity_only=curiosity_only,
    )


def _register_first_layer_hook(model, layer_idx: int) -> bool:
    # Attach forward hook to capture h1 from the chosen layer.
    layers = getattr(model.model, "layers", None)
    if layers is None:
        decoder = getattr(model.model, "decoder", None)
        layers = getattr(decoder, "layers", None) if decoder is not None else None
    if layers is None:
        return False

    idx = layer_idx if layer_idx >= 0 else len(layers) + layer_idx
    if idx < 0 or idx >= len(layers):
        return False
    layer = layers[idx]

    def hook(_module, _inputs, output):
        _RUNTIME.record_h1(output)

    handle = layer.register_forward_hook(hook)
    layer._distiller_overlap_hook = handle
    return True


def apply_distiller_overlap_patch() -> bool:
    # Monkey-patch vLLM execution and logits path.
    global _PATCHED, _ORIG_EXECUTE_MODEL, _ORIG_LOGITS_FORWARD
    if _PATCHED:
        return False

    with _PATCH_LOCK:
        if _PATCHED:
            return False

        _ORIG_EXECUTE_MODEL = ModelRunner.execute_model
        _ORIG_LOGITS_FORWARD = VLLMLogitsProcessor.forward

        def wrapped_execute_model(runner_self, model_input, *args, **kwargs):
            sampling_metadata = getattr(model_input, "sampling_metadata", None)
            attn_metadata = getattr(model_input, "attn_metadata", None)
            device = getattr(model_input, "input_tokens", None)
            if hasattr(device, "device"):
                device = device.device
            else:
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            _RUNTIME.start_step(sampling_metadata, attn_metadata, device)
            output = _ORIG_EXECUTE_MODEL(runner_self, model_input, *args, **kwargs)
            _RUNTIME.maybe_train()
            return output

        def patched_forward(self, lm_head, hidden_states, sampling_metadata=None, embedding_bias=None):
            if _RUNTIME.mix_mode == "hidden":
                _RUNTIME.record_hl(hidden_states)
                hidden_states = _RUNTIME.mix_hidden(hidden_states)
                logits = _ORIG_LOGITS_FORWARD(self, lm_head, hidden_states, sampling_metadata, embedding_bias)
                _RUNTIME.record_logits(logits, lm_head=lm_head, embedding_bias=embedding_bias)
                return logits

            logits = _ORIG_LOGITS_FORWARD(self, lm_head, hidden_states, sampling_metadata, embedding_bias)
            _RUNTIME.record_hl(hidden_states)
            logits = _RUNTIME.mix_logits(logits, lm_head=lm_head, embedding_bias=embedding_bias)
            return logits

        ModelRunner.execute_model = wrapped_execute_model
        VLLMLogitsProcessor.forward = patched_forward
        _PATCHED = True

    return True


def install_distiller_overlap_hooks(model_executor) -> Optional[List[bool]]:
    # Register hooks and cache final norm params in each worker.
    def register(model) -> bool:
        applied = _register_first_layer_hook(model, _RUNTIME.first_layer)
        norm = None
        for obj in (model, getattr(model, "model", None), getattr(model, "module", None)):
            candidate = getattr(obj, "norm", None)
            if candidate is not None:
                norm = candidate
                break
        if norm is not None:
            weight = getattr(norm, "weight", None)
            bias = getattr(norm, "bias", None)
            eps = getattr(norm, "eps", None)
            if eps is None:
                eps = getattr(norm, "variance_epsilon", 1e-6)
            norm_type = "other"
            if isinstance(norm, torch.nn.LayerNorm) or "LayerNorm" in type(norm).__name__:
                norm_type = "layer"
            elif "RMSNorm" in type(norm).__name__:
                norm_type = "rms"
            _RUNTIME.set_final_norm(weight=weight, bias=bias, eps=float(eps or 1e-6), norm_type=norm_type)
            applied = True
        if _RUNTIME.reward_scale:
            lm_head = getattr(model, "lm_head", None)
            weight = getattr(lm_head, "weight", None) if lm_head is not None else None
            if isinstance(weight, torch.Tensor):
                _RUNTIME.set_reward_scale_norms(weight=weight, dtype=weight.dtype)
        return applied

    return model_executor.apply_model(register)


def export_distiller_overlap_log(model_executor, file_path: str) -> Optional[List[int]]:
    # Dump log buffer to CSV from worker.
    def _dump(_model) -> int:
        return _RUNTIME.dump_log(file_path)

    return model_executor.apply_model(_dump)


def export_distiller_overlap_loss_log(model_executor, file_path: str) -> Optional[List[int]]:
    def _dump(_model) -> int:
        return _RUNTIME.dump_loss_log(file_path)

    return model_executor.apply_model(_dump)


def export_distiller_overlap_cudagraph_stats() -> dict:
    return _RUNTIME.get_cudagraph_stats()
