"""SDPA-based inference for TabularACE with frozen embedder.

This module provides a decoder-only inference path for TabularACE that:
- Runs the tabular embedder (ISAB + row encoder) once at inference time to
  produce per-row embeddings for context and targets, then freezes it.
- Builds per-layer context K/V once using the trained Transformer backbone
  with dense self-attention restricted to context rows.
- Supports three decode modes using scaled dot-product attention (SDPA):
  1) Predict all targets independently (no target-target attention)
  2) Autoregressive K-batch decode without target self-attend for current query
  3) Autoregressive re-encode per step (refresh context K/V after each sample)

Note: This implementation focuses first on Mode A (predict all targets at once)
and scaffolds the APIs for Modes B/C. The attention backend uses
torch.nn.functional.scaled_dot_product_attention and can switch between
math, mem-efficient, and flash kernels where available.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Literal, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from src.models.utils import expand_kv_heads
from src.models.tabular_embedder import TabularACE
from src.utils import DataAttr, LossAttr


SDPABackend = Literal["auto", "flash", "mem", "math"]
DecodeMode = Literal["all", "ar_kbatch", "ar_reencode"]


def _split_heads(x: torch.Tensor, n_heads: int) -> torch.Tensor:
    """[B, L, D] -> [B, H, L, Dh]."""
    B, L, D = x.shape
    assert D % n_heads == 0, f"D={D} must be divisible by n_heads={n_heads}"
    Dh = D // n_heads
    return x.view(B, L, n_heads, Dh).transpose(1, 2).contiguous()


def _combine_heads(y: torch.Tensor) -> torch.Tensor:
    """[B, H, L, Dh] -> [B, L, D]."""
    B, H, L, Dh = y.shape
    return y.transpose(1, 2).contiguous().view(B, L, H * Dh)


@dataclass
class _CtxKV:
    """Per-layer context K/V caches (unexpanded for batch)."""

    Kc: list[torch.Tensor]  # each [1, H, Nc, Dh]
    Vc: list[torch.Tensor]  # each [1, H, Nc, Dh]


class SDPAKernel:
    """Lightweight context manager to select SDPA backend.

    On CUDA, toggles flash/mem/math kernels as requested. On CPU, falls back to math.
    """

    def __init__(self, backend: SDPABackend = "auto") -> None:
        self.backend = backend
        self.prev = None

    def __enter__(self):
        # Only CUDA supports alternative kernels
        if not torch.cuda.is_available():
            return self
        self.prev = (
            torch.backends.cuda.sdp_kernel.is_flash_enabled(),
            torch.backends.cuda.sdp_kernel.is_math_enabled(),
            torch.backends.cuda.sdp_kernel.is_mem_efficient_enabled(),
        )
        if self.backend == "auto":
            # Leave as default
            return self
        enable_flash = self.backend == "flash"
        enable_mem = self.backend == "mem"
        enable_math = self.backend == "math"
        torch.backends.cuda.sdp_kernel.enable_flash(enable_flash)
        torch.backends.cuda.sdp_kernel.enable_mem_efficient(enable_mem)
        torch.backends.cuda.sdp_kernel.enable_math(enable_math)
        return self

    def __exit__(self, exc_type, exc, tb):
        if not torch.cuda.is_available():
            return False
        if self.prev is None:
            return False
        flash, math, mem = self.prev
        torch.backends.cuda.sdp_kernel.enable_flash(flash)
        torch.backends.cuda.sdp_kernel.enable_math(math)
        torch.backends.cuda.sdp_kernel.enable_mem_efficient(mem)
        return False


class TabularInferenceSDPA(nn.Module):
    """SDPA-based inference engine for TabularACE with frozen embedder.

    Workflow:
    - precompute_rows(batch): run TabularEmbedder once on [context|targets]
      and store per-row embeddings without labels/tokens.
    - build_context_kv(): add target labels to context rows, run the backbone
      with dense self-attn over context only, and capture per-layer Kc/Vc.
    - Decode using one of the supported modes (all, ar_kbatch, ar_reencode).
    """

    def __init__(
        self,
        model: TabularACE,
        *,
        backend: SDPABackend = "auto",
    ) -> None:
        super().__init__()
        self.model = model
        self.embedder = model.tabular_embedder
        self.backbone = model.backbone
        self.head = model.head
        self.backend: SDPABackend = backend

        # Cached, frozen row embeddings
        self._emb_ctx_rows: Optional[torch.Tensor] = None  # [B,Nc,D]
        self._emb_tgt_rows: Optional[torch.Tensor] = None  # [B,Nt,D]
        # Per-layer context K/V (shape [1,H,Nc,Dh], unexpanded for batch)
        self._ctx_kv: Optional[_CtxKV] = None

        # Convenience flags
        self.concat_cls: bool = bool(getattr(model, "concat_cls", False))
        self.num_cls_tokens: int = int(getattr(model, "num_cls_tokens", 1))

        # Pre-expanded AR tokens if concat_cls=True
        with torch.no_grad():
            ar = self.embedder.ar_tokens  # [Nb_max, E]
            if self.concat_cls:
                ar = ar.repeat_interleave(self.num_cls_tokens, dim=-1)
            self._ar_tokens_expanded = ar  # [Nb_max, D]

    @classmethod
    def from_trained_model(
        cls, model: TabularACE, backend: SDPABackend = "auto"
    ) -> "TabularInferenceSDPA":
        return cls(model, backend=backend)

    # ------------------------- Precompute & Context KV -------------------------
    @torch.no_grad()
    def precompute_rows(self, batch: DataAttr) -> None:
        """Run tabular embedder once over [context | targets] and cache rows.

        Expects batch.xc, batch.yc, and batch.xt to be provided (Nb=0 at inference).
        """
        assert batch.xc is not None and batch.yc is not None and batch.xt is not None
        B, Nc = batch.xc.shape[0], batch.xc.shape[1]
        Nt = batch.xt.shape[1]

        # Build a minimal batch for the embedder: Nb=0
        tmp = DataAttr(xc=batch.xc, yc=batch.yc, xb=batch.xt.new_zeros(B, 0, batch.xt.shape[-1]), yb=batch.yc.new_zeros(B, 0, 1), xt=batch.xt, yt=None)

        # Row embeddings from TabularEmbedder (ISAB + row encoder)
        rows = self.embedder(tmp)  # [B, Nc+Nt, D]
        self._emb_ctx_rows = rows[:, :Nc, :].contiguous()
        self._emb_tgt_rows = rows[:, Nc:, :].contiguous()
        # Cache context labels for K/V build
        self.cache_context_labels(batch.yc)

    @torch.no_grad()
    def build_context_kv(self) -> None:
        """Build per-layer K/V for context-only self-attention.

        Adds target labels to context rows and runs the backbone as an encoder
        with dense attention restricted to the context prefix. Captures per-layer
        Kc/Vc tensors in [1, H, Nc, Dh] form (unexpanded for batch).
        """
        assert self._emb_ctx_rows is not None, "Call precompute_rows() first"
        B, Nc, D = self._emb_ctx_rows.shape
        device = self._emb_ctx_rows.device

        # Add target labels to context rows, matching model.dim_model shape
        yc_enc = self.embedder.target_encoder(
            # Context labels are 1D; we need only the slice matching context size
            # The caller must ensure the same batch used in precompute_rows
            self._get_context_y()  # [B, Nc, 1]
        )  # [B, Nc, E]
        if self.concat_cls:
            yc_enc = yc_enc.repeat_interleave(self.num_cls_tokens, dim=-1)  # [B,Nc,D]

        h = self._emb_ctx_rows + yc_enc  # [B,Nc,D]

        Kc_list: list[torch.Tensor] = []
        Vc_list: list[torch.Tensor] = []

        with SDPAKernel(self.backend):
            for lyr in self.backbone.layers:
                H = lyr.attn.num_heads
                Dh = lyr.attn.head_dim

                y = lyr.norm1(h)
                q = lyr.attn.q_proj(y)
                k = lyr.attn.k_proj(y)
                v = lyr.attn.v_proj(y)

                qh = _split_heads(q, H)  # [B,H,Nc,Dh]
                kh = _split_heads(k, lyr.attn.num_kv_heads)  # [B,Hkv,Nc,Dh]
                vh = _split_heads(v, lyr.attn.num_kv_heads)
                kh = expand_kv_heads(kh, H // lyr.attn.num_kv_heads)  # [B,H,Nc,Dh]
                vh = expand_kv_heads(vh, H // lyr.attn.num_kv_heads)

                # Dense self-attention over context rows
                attn = F.scaled_dot_product_attention(
                    qh, kh, vh, attn_mask=None, dropout_p=0.0, is_causal=False
                )  # [B,H,Nc,Dh]
                out = lyr.attn.o_proj(_combine_heads(attn))  # [B,Nc,D]
                h = h + lyr.drop_attn(out)

                y2 = lyr.norm2(h)
                h = h + lyr.ff2(lyr.drop_ff(F.gelu(lyr.ff1(y2))))  # [B,Nc,D]

                # Capture per-layer K/V from the same normalized state (y)
                # Save un-batched [1,H,Nc,Dh] so we can expand to any B later
                kc = _split_heads(lyr.attn.k_proj(y), lyr.attn.num_kv_heads)
                vc = _split_heads(lyr.attn.v_proj(y), lyr.attn.num_kv_heads)
                kc = expand_kv_heads(kc, H // lyr.attn.num_kv_heads)
                vc = expand_kv_heads(vc, H // lyr.attn.num_kv_heads)
                # Store per-batch context K/V for each layer: [B, H, Nc, Dh]
                Kc_list.append(kc.contiguous())
                Vc_list.append(vc.contiguous())

        self._ctx_kv = _CtxKV(Kc=Kc_list, Vc=Vc_list)

    # ------------------------------ Mode A (all) ------------------------------
    @torch.no_grad()
    def predict_all_targets(
        self,
        *,
        return_params: bool = False,
        num_samples: int = 1,
    ) -> LossAttr | DataAttr:
        """Predict all targets independently (no target-target attention).

        Returns:
            - If return_params=True: LossAttr with mixture params (means, sds, weights)
            - Else: DataAttr with sampled predictions at target positions (yc)
        """
        assert self._emb_tgt_rows is not None and self._ctx_kv is not None
        h = self._emb_tgt_rows  # [B,T,D]
        B, T, _ = h.shape

        with SDPAKernel(self.backend):
            for l, lyr in enumerate(self.backbone.layers):
                H = lyr.attn.num_heads
                y = lyr.norm1(h)
                q = lyr.attn.q_proj(y)
                qh = _split_heads(q, H)  # [B,H,T,Dh]

                # Per-batch context K/V (no targets in K/V)
                Kc = self._ctx_kv.Kc[l]
                Vc = self._ctx_kv.Vc[l]
                attn = F.scaled_dot_product_attention(
                    qh, Kc, Vc, attn_mask=None, dropout_p=0.0, is_causal=False
                )  # [B,H,T,Dh]
                out = lyr.attn.o_proj(_combine_heads(attn))  # [B,T,D]
                h = h + lyr.drop_attn(out)

                y2 = lyr.norm2(h)
                h = h + lyr.ff2(lyr.drop_ff(F.gelu(lyr.ff1(y2))))

        z = self.backbone.norm(h)
        if return_params:
            return self.head(z, yt=None, num_samples=0)
        else:
            samples = self.head.sample(z, num_samples=max(1, int(num_samples)))  # [B,T,S,1]
            if samples.dim() == 4:
                if samples.shape[2] > 1:
                    ymean = samples.mean(dim=2)
                else:
                    ymean = samples.squeeze(2)
            else:
                ymean = samples
            return DataAttr(xc=None, yc=ymean)

    @torch.no_grad()
    def predict_ll_independent(self, batch: DataAttr) -> torch.Tensor:
        """Exact per-target log-likelihood under context-only conditioning (Mode A).

        This computes z for all targets by attending to context K/V only (no target–
        target interactions), then evaluates the exact mixture log-likelihood via
        the head. Returns a tensor of shape [B, T, 1] in normalized space.

        Note: This method runs its own precompute and context K/V build to ensure
        a self-contained call. Callers that manage state can instead call
        precompute_rows() and build_context_kv() beforehand and reuse this path.
        """
        assert batch.xt is not None and batch.yt is not None and batch.xc is not None and batch.yc is not None

        # Prepare embeddings and context K/V
        self.precompute_rows(batch)
        self.build_context_kv()

        h = self._emb_tgt_rows  # [B,T,D]
        B, T, _ = h.shape

        with SDPAKernel(self.backend):
            for l, lyr in enumerate(self.backbone.layers):
                H = lyr.attn.num_heads
                y = lyr.norm1(h)
                q = lyr.attn.q_proj(y)
                qh = _split_heads(q, H)  # [B,H,T,Dh]

                Kc = self._ctx_kv.Kc[l]
                Vc = self._ctx_kv.Vc[l]
                attn = torch.nn.functional.scaled_dot_product_attention(
                    qh, Kc, Vc, attn_mask=None, dropout_p=0.0, is_causal=False
                )
                out = lyr.attn.o_proj(_combine_heads(attn))
                h = h + lyr.drop_attn(out)

                y2 = lyr.norm2(h)
                h = h + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2))))

        z = self.backbone.norm(h)
        loss_attr = self.head(z, batch.yt, num_samples=0)
        return loss_attr.log_likelihood  # [B,T,1]

    @torch.no_grad()
    def evaluate_ll_reencode_tf(
        self,
        batch: DataAttr,
        order: torch.Tensor,
    ) -> torch.Tensor:
        """Teacher-forcing LL with re-encode-after-each-step (Mode B).

        Args:
            batch: DataAttr with xc,yc,xt,yt (normalized). Batch size B.
            order: LongTensor [B, T] giving the target permutation per batch item.

        Returns:
            ll_per_step: Tensor [B, T, 1] of per-step exact log-likelihoods (in step order).
        """
        assert batch.xt is not None and batch.yt is not None and batch.xc is not None and batch.yc is not None

        # Precompute frozen row embeddings and store context encodings
        self.precompute_rows(batch)

        B, Nc, D = self._emb_ctx_rows.shape
        T = self._emb_tgt_rows.shape[1]
        assert order.shape[0] == B and order.shape[1] == T

        # Context label encodings (repeat across CLS if needed)
        yc_enc = self.embedder.target_encoder(batch.yc)  # [B,Nc,E]
        if self.concat_cls:
            yc_enc = yc_enc.repeat_interleave(self.num_cls_tokens, dim=-1)  # [B,Nc,D]

        ll = torch.zeros(B, T, 1, device=self._emb_tgt_rows.device, dtype=self._emb_tgt_rows.dtype)

        with SDPAKernel(self.backend):
            for s in range(T):
                # Build prefix embeddings for true previous targets in the given order
                if s > 0:
                    idx_prev = order[:, :s]  # [B,s]
                    idx_prev_exp = idx_prev.unsqueeze(-1).expand(-1, -1, D)  # [B,s,D]
                    prev_rows = torch.gather(self._emb_tgt_rows, 1, idx_prev_exp)  # [B,s,D]
                    y_prev = torch.gather(batch.yt, 1, idx_prev.unsqueeze(-1))  # [B,s,1]
                    y_prev_enc = self.embedder.target_encoder(y_prev)  # [B,s,E]
                    if self.concat_cls:
                        y_prev_enc = y_prev_enc.repeat_interleave(self.num_cls_tokens, dim=-1)  # [B,s,D]
                    gen_part = prev_rows + y_prev_enc  # [B,s,D]
                    h_prefix = torch.cat([self._emb_ctx_rows + yc_enc, gen_part], dim=1)  # [B,Nc+s,D]
                else:
                    h_prefix = self._emb_ctx_rows + yc_enc  # [B,Nc,D]

                # Run transformer over prefix; capture per-layer K/V
                Kc_list: list[torch.Tensor] = []
                Vc_list: list[torch.Tensor] = []
                h = h_prefix
                for lyr in self.backbone.layers:
                    H = lyr.attn.num_heads
                    y = lyr.norm1(h)
                    q = lyr.attn.q_proj(y)
                    k = lyr.attn.k_proj(y)
                    v = lyr.attn.v_proj(y)

                    qh = _split_heads(q, H)
                    kh = _split_heads(k, lyr.attn.num_kv_heads)
                    vh = _split_heads(v, lyr.attn.num_kv_heads)
                    kh = expand_kv_heads(kh, H // lyr.attn.num_kv_heads)
                    vh = expand_kv_heads(vh, H // lyr.attn.num_kv_heads)

                    attn = torch.nn.functional.scaled_dot_product_attention(
                        qh, kh, vh, attn_mask=None, dropout_p=0.0, is_causal=False
                    )
                    out = lyr.attn.o_proj(_combine_heads(attn))
                    h = h + lyr.drop_attn(out)
                    y2 = lyr.norm2(h)
                    h = h + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2))))

                    Kc_list.append(kh)
                    Vc_list.append(vh)

                # Decode current target (in order)
                idx_curr = order[:, s]  # [B]
                # Gather query embedding for current indices
                idx_curr_exp = idx_curr.view(B, 1, 1).expand(-1, 1, D)
                h_q = torch.gather(self._emb_tgt_rows, 1, idx_curr_exp)  # [B,1,D]

                for l, lyr in enumerate(self.backbone.layers):
                    H = lyr.attn.num_heads
                    yq = lyr.norm1(h_q)
                    q = lyr.attn.q_proj(yq)
                    qh = _split_heads(q, H)  # [B,H,1,Dh]
                    Kc = Kc_list[l]  # [B,H,Nc+s,Dh]
                    Vc = Vc_list[l]
                    attn_q = torch.nn.functional.scaled_dot_product_attention(
                        qh, Kc, Vc, attn_mask=None, dropout_p=0.0, is_causal=False
                    )
                    out_q = lyr.attn.o_proj(_combine_heads(attn_q))
                    h_q = h_q + lyr.drop_attn(out_q)
                    y2q = lyr.norm2(h_q)
                    h_q = h_q + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2q))))

                z = self.backbone.norm(h_q)
                # Gather true y for current indices
                y_curr = torch.gather(batch.yt, 1, idx_curr.view(B, 1, 1))  # [B,1,1]
                loss_attr = self.head(z, y_curr, num_samples=0)
                ll[:, s : s + 1, :] = loss_attr.log_likelihood  # [B,1,1]

        return ll

    @torch.no_grad()
    def evaluate_ll_buffer_tf_kchunk(
        self,
        batch: DataAttr,
        order: torch.Tensor,
        ar_indexing: str = "chunk",
    ) -> torch.Tensor:
        """Teacher-forcing LL with buffer-based AR and local chunk AR tokens (Mode C).

        For each step s, commits the TRUE previous targets as buffer tokens using
        row_embedding + target_encoder(y_true) + AR position tokens, then decodes
        the current target and evaluates exact LL via the head.

        Args:
            batch: DataAttr with xc,yc,xt,yt (normalized). Batch size B.
            order: LongTensor [B, T] giving the target permutation per batch item.
            ar_indexing: 'chunk' (default) uses local index s-1 for AR tokens; 'global'
                would use (global_step % max_buffer_size).

        Returns:
            ll_per_step: Tensor [B, T, 1] of per-step exact log-likelihoods (in step order).
        """
        assert batch.xt is not None and batch.yt is not None and batch.xc is not None and batch.yc is not None

        # Precompute and build context K/V once
        self.precompute_rows(batch)
        self.build_context_kv()

        B, T, D = self._emb_tgt_rows.shape
        assert order.shape[0] == B and order.shape[1] == T
        max_buf = self._ar_tokens_expanded.shape[0]
        assert T <= max_buf, f"T={T} exceeds max_buffer_size={max_buf} for Mode C teacher-forcing"

        # Precompute true target encodings for all targets
        y_enc_all = self.embedder.target_encoder(batch.yt)  # [B,T,E]
        if self.concat_cls:
            y_enc_all = y_enc_all.repeat_interleave(self.num_cls_tokens, dim=-1)  # [B,T,D]

        # Per-layer lists of committed buffer K/V (grow with steps)
        Kb_lists = [[] for _ in self.backbone.layers]
        Vb_lists = [[] for _ in self.backbone.layers]

        ll = torch.zeros(B, T, 1, device=self._emb_tgt_rows.device, dtype=self._emb_tgt_rows.dtype)

        with SDPAKernel(self.backend):
            for s in range(T):
                # Commit buffer for previous true target (teacher-forcing)
                if s > 0:
                    idx_prev = order[:, s - 1]  # [B]
                    idx_prev_rows = idx_prev.view(B, 1, 1).expand(-1, 1, D)
                    prev_rows = torch.gather(self._emb_tgt_rows, 1, idx_prev_rows)  # [B,1,D]
                    y_prev = torch.gather(y_enc_all, 1, idx_prev.view(B, 1, 1).expand(-1, 1, D))  # [B,1,D]
                    ar_idx = (s - 1) if ar_indexing == "chunk" else ((s - 1) % max_buf)
                    ar_tok = self._ar_tokens_expanded[ar_idx].unsqueeze(0).unsqueeze(0).expand(B, 1, -1)
                    h_b = prev_rows + y_prev + ar_tok  # [B,1,D]

                    # Compute per-layer buffer K/V and append
                    for l, lyr in enumerate(self.backbone.layers):
                        H = lyr.attn.num_heads
                        yb = lyr.norm1(h_b)
                        kb = _split_heads(lyr.attn.k_proj(yb), lyr.attn.num_kv_heads)
                        vb = _split_heads(lyr.attn.v_proj(yb), lyr.attn.num_kv_heads)
                        kb = expand_kv_heads(kb, H // lyr.attn.num_kv_heads)  # [B,H,1,Dh]
                        vb = expand_kv_heads(vb, H // lyr.attn.num_kv_heads)
                        Kb_lists[l].append(kb)
                        Vb_lists[l].append(vb)

                        # Propagate buffer hidden state for consistency
                        qh_b = _split_heads(lyr.attn.q_proj(yb), H)
                        K_full_b = torch.cat([self._ctx_kv.Kc[l]] + Kb_lists[l], dim=2)
                        V_full_b = torch.cat([self._ctx_kv.Vc[l]] + Vb_lists[l], dim=2)
                        attn_b = torch.nn.functional.scaled_dot_product_attention(
                            qh_b, K_full_b, V_full_b, attn_mask=None, dropout_p=0.0, is_causal=False
                        )
                        out_b = lyr.attn.o_proj(_combine_heads(attn_b))
                        h_b = h_b + lyr.drop_attn(out_b)
                        y2b = lyr.norm2(h_b)
                        h_b = h_b + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2b))))

                # Decode current target in the given order against context + committed buffers
                idx_curr = order[:, s]
                idx_curr_rows = idx_curr.view(B, 1, 1).expand(-1, 1, D)
                h_q = torch.gather(self._emb_tgt_rows, 1, idx_curr_rows)  # [B,1,D]

                for l, lyr in enumerate(self.backbone.layers):
                    H = lyr.attn.num_heads
                    yq = lyr.norm1(h_q)
                    q = lyr.attn.q_proj(yq)
                    qh_q = _split_heads(q, H)
                    if len(Kb_lists[l]) > 0:
                        K_full_q = torch.cat([self._ctx_kv.Kc[l]] + Kb_lists[l], dim=2)
                        V_full_q = torch.cat([self._ctx_kv.Vc[l]] + Vb_lists[l], dim=2)
                    else:
                        K_full_q = self._ctx_kv.Kc[l]
                        V_full_q = self._ctx_kv.Vc[l]
                    attn_q = torch.nn.functional.scaled_dot_product_attention(
                        qh_q, K_full_q, V_full_q, attn_mask=None, dropout_p=0.0, is_causal=False
                    )
                    out_q = lyr.attn.o_proj(_combine_heads(attn_q))
                    h_q = h_q + lyr.drop_attn(out_q)
                    y2q = lyr.norm2(h_q)
                    h_q = h_q + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2q))))

                z = self.backbone.norm(h_q)
                y_curr = torch.gather(batch.yt, 1, idx_curr.view(B, 1, 1))  # [B,1,1]
                loss_attr = self.head(z, y_curr, num_samples=0)
                ll[:, s : s + 1, :] = loss_attr.log_likelihood

        return ll

    @torch.no_grad()
    def predict_mean_buffer_tf_kchunk(
        self,
        batch: DataAttr,
        order: torch.Tensor,
        y_path: torch.Tensor | None = None,
        ar_indexing: str = "chunk",
    ) -> torch.Tensor:
        """Teacher-forcing mean prediction with buffer-based AR and local chunk AR tokens (Mode C).

        For each step s, commits the PREVIOUS targets as buffer tokens using
        row_embedding + target_encoder(y_prev) + AR position tokens, then decodes
        the current target and computes the exact mixture mean via the head.

        Args:
            batch: DataAttr with xc,yc,xt (normalized). Batch size B.
            order: LongTensor [B, T] giving the target permutation per batch item.
            y_path: Optional Tensor [B, T, 1] of targets used for teacher forcing;
                if None, uses batch.yt (true) for TF. Pass predicted paths to smooth.
            ar_indexing: 'chunk' (default) uses local index s-1 for AR tokens; 'global'
                would use (s-1) % max_buffer.

        Returns:
            mu_per_step: Tensor [B, T, 1] of per-step mixture means (in step order).
        """
        assert batch.xt is not None and batch.xc is not None and batch.yc is not None

        # Precompute and build context K/V once
        self.precompute_rows(batch)
        self.build_context_kv()

        B = batch.xc.shape[0]
        T = batch.xt.shape[1]
        D = self._emb_tgt_rows.shape[-1]
        assert order.shape[0] == B and order.shape[1] == T
        max_buf = self._ar_tokens_expanded.shape[0]
        assert T <= max_buf, f"T={T} exceeds max_buffer_size={max_buf} for Mode C teacher-forcing"

        # Encode teacher-forcing path
        if y_path is None:
            assert batch.yt is not None, "Need y_path or batch.yt for teacher-forcing"
            y_enc_all = self.embedder.target_encoder(batch.yt)  # [B,T,E]
        else:
            y_enc_all = self.embedder.target_encoder(y_path)  # [B,T,E]
        if self.concat_cls:
            y_enc_all = y_enc_all.repeat_interleave(self.num_cls_tokens, dim=-1)  # [B,T,D]

        # Per-layer lists of committed buffer K/V
        Kb_lists = [[] for _ in self.backbone.layers]
        Vb_lists = [[] for _ in self.backbone.layers]

        mu = torch.zeros(B, T, 1, device=self._emb_tgt_rows.device, dtype=self._emb_tgt_rows.dtype)

        with SDPAKernel(self.backend):
            for s in range(T):
                # Commit buffer for previous (teacher-forced) target
                if s > 0:
                    idx_prev = order[:, s - 1]  # [B]
                    idx_prev_rows = idx_prev.view(B, 1, 1).expand(-1, 1, D)
                    prev_rows = torch.gather(self._emb_tgt_rows, 1, idx_prev_rows)  # [B,1,D]
                    y_prev = torch.gather(y_enc_all, 1, idx_prev.view(B, 1, 1).expand(-1, 1, D))  # [B,1,D]
                    ar_idx = (s - 1) if ar_indexing == "chunk" else ((s - 1) % max_buf)
                    ar_tok = self._ar_tokens_expanded[ar_idx % max_buf].unsqueeze(0).unsqueeze(0).expand(B, 1, -1)
                    h_b = prev_rows + y_prev + ar_tok  # [B,1,D]

                    # Compute per-layer buffer K/V and append
                    for l, lyr in enumerate(self.backbone.layers):
                        H = lyr.attn.num_heads
                        yb = lyr.norm1(h_b)
                        kb = _split_heads(lyr.attn.k_proj(yb), lyr.attn.num_kv_heads)
                        vb = _split_heads(lyr.attn.v_proj(yb), lyr.attn.num_kv_heads)
                        kb = expand_kv_heads(kb, H // lyr.attn.num_kv_heads)  # [B,H,1,Dh]
                        vb = expand_kv_heads(vb, H // lyr.attn.num_kv_heads)
                        Kb_lists[l].append(kb)
                        Vb_lists[l].append(vb)

                        # Propagate buffer hidden state for consistency
                        qh_b = _split_heads(lyr.attn.q_proj(yb), H)
                        K_full_b = torch.cat([self._ctx_kv.Kc[l]] + Kb_lists[l], dim=2)
                        V_full_b = torch.cat([self._ctx_kv.Vc[l]] + Vb_lists[l], dim=2)
                        attn_b = torch.nn.functional.scaled_dot_product_attention(
                            qh_b, K_full_b, V_full_b, attn_mask=None, dropout_p=0.0, is_causal=False
                        )
                        out_b = lyr.attn.o_proj(_combine_heads(attn_b))
                        h_b = h_b + lyr.drop_attn(out_b)
                        y2b = lyr.norm2(h_b)
                        h_b = h_b + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2b))))

                # Decode current target in the given order against context + committed buffers
                idx_curr = order[:, s]
                idx_curr_rows = idx_curr.view(B, 1, 1).expand(-1, 1, D)
                h_q = torch.gather(self._emb_tgt_rows, 1, idx_curr_rows)  # [B,1,D]

                for l, lyr in enumerate(self.backbone.layers):
                    H = lyr.attn.num_heads
                    yq = lyr.norm1(h_q)
                    q = lyr.attn.q_proj(yq)
                    qh_q = _split_heads(q, H)
                    if len(Kb_lists[l]) > 0:
                        K_full_q = torch.cat([self._ctx_kv.Kc[l]] + Kb_lists[l], dim=2)
                        V_full_q = torch.cat([self._ctx_kv.Vc[l]] + Vb_lists[l], dim=2)
                    else:
                        K_full_q = self._ctx_kv.Kc[l]
                        V_full_q = self._ctx_kv.Vc[l]
                    attn_q = torch.nn.functional.scaled_dot_product_attention(
                        qh_q, K_full_q, V_full_q, attn_mask=None, dropout_p=0.0, is_causal=False
                    )
                    out_q = lyr.attn.o_proj(_combine_heads(attn_q))
                    h_q = h_q + lyr.drop_attn(out_q)
                    y2q = lyr.norm2(h_q)
                    h_q = h_q + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2q))))

                z = self.backbone.norm(h_q)
                # Deterministic mixture mean from head
                la = self.head(z, yt=None, num_samples=0)
                mu_step = (la.weights * la.means).sum(dim=2)  # [B,1,1]
                mu[:, s : s + 1, :] = mu_step

        return mu

    # --------------------------- Modes B/C (stubs) ----------------------------
    @torch.no_grad()
    def decode_ar_kbatch(self, batch: DataAttr, K: int = 4) -> DataAttr:
        """Autoregressive decode across targets without current-query self-attend.

        Implements Mode B semantics: use precomputed row embeddings and per-layer
        context K/V; commit only buffer tokens (previous predictions) into per-layer
        K/V memory; never write the current query's K/V. Within each step, propagate
        both buffer and query through all layers to produce the next prediction.
        """
        assert self._emb_tgt_rows is not None and self._ctx_kv is not None
        B, T, D = self._emb_tgt_rows.shape

        # Allocate per-layer buffer KV up to T
        K_buf: list[torch.Tensor] = []
        V_buf: list[torch.Tensor] = []
        for lyr in self.backbone.layers:
            H = lyr.attn.num_heads
            Dh = lyr.attn.head_dim
            K_buf.append(torch.zeros(B, H, T, Dh, device=self._emb_tgt_rows.device, dtype=self._emb_tgt_rows.dtype))
            V_buf.append(torch.zeros_like(K_buf[-1]))

        # Storage for predictions
        ypred = torch.zeros(B, T, 1, device=self._emb_tgt_rows.device, dtype=self._emb_tgt_rows.dtype)

        with SDPAKernel(self.backend):
            for t in range(T):
                # Query embedding for current target
                h_q = self._emb_tgt_rows[:, t : t + 1, :]

                if t == 0:
                    # No previous buffer; attend to context only
                    for l, lyr in enumerate(self.backbone.layers):
                        H = lyr.attn.num_heads
                        yq = lyr.norm1(h_q)
                        q = lyr.attn.q_proj(yq)
                        qh = _split_heads(q, H)

                        Kc = self._ctx_kv.Kc[l]
                        Vc = self._ctx_kv.Vc[l]
                        attn = torch.nn.functional.scaled_dot_product_attention(
                            qh, Kc, Vc, attn_mask=None, dropout_p=0.0, is_causal=False
                        )
                        out = lyr.attn.o_proj(_combine_heads(attn))
                        h_q = h_q + lyr.drop_attn(out)
                        y2 = lyr.norm2(h_q)
                        h_q = h_q + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2))))
                else:
                    # Build buffer embedding from previous prediction
                    y_prev = ypred[:, t - 1 : t, :]  # [B,1,1]
                    y_enc = self.embedder.target_encoder(y_prev)  # [B,1,E]
                    if self.concat_cls:
                        y_enc = y_enc.repeat_interleave(self.num_cls_tokens, dim=-1)  # [B,1,D]
                    ar_idx = (t - 1) % self._ar_tokens_expanded.shape[0]
                    ar_tok = self._ar_tokens_expanded[ar_idx].unsqueeze(0).unsqueeze(0).expand(B, 1, -1)
                    h_b = self._emb_tgt_rows[:, t - 1 : t, :] + y_enc + ar_tok  # [B,1,D]

                    # Layer-wise propagation for buffer (commit K/V) and query (no commit)
                    for l, lyr in enumerate(self.backbone.layers):
                        H = lyr.attn.num_heads
                        # Compute buffer K/V for this layer from pre-attn state
                        yb = lyr.norm1(h_b)
                        kb = _split_heads(lyr.attn.k_proj(yb), lyr.attn.num_kv_heads)
                        vb = _split_heads(lyr.attn.v_proj(yb), lyr.attn.num_kv_heads)
                        kb = expand_kv_heads(kb, H // lyr.attn.num_kv_heads)  # [B,H,1,Dh]
                        vb = expand_kv_heads(vb, H // lyr.attn.num_kv_heads)
                        # Commit buffer K/V at position t-1
                        K_buf[l][:, :, t - 1 : t, :] = kb
                        V_buf[l][:, :, t - 1 : t, :] = vb

                        # Attention for buffer token (optional; needed to propagate h_b to next layer)
                        qh_b = _split_heads(lyr.attn.q_proj(yb), H)
                        K_full_b = torch.cat([self._ctx_kv.Kc[l], K_buf[l][:, :, :t, :]], dim=2)
                        V_full_b = torch.cat([self._ctx_kv.Vc[l], V_buf[l][:, :, :t, :]], dim=2)
                        attn_b = torch.nn.functional.scaled_dot_product_attention(
                            qh_b, K_full_b, V_full_b, attn_mask=None, dropout_p=0.0, is_causal=False
                        )
                        out_b = lyr.attn.o_proj(_combine_heads(attn_b))
                        h_b = h_b + lyr.drop_attn(out_b)
                        y2b = lyr.norm2(h_b)
                        h_b = h_b + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2b))))

                        # Attention for current query (exclude its own K/V by never writing it)
                        yq = lyr.norm1(h_q)
                        qh_q = _split_heads(lyr.attn.q_proj(yq), H)
                        K_full_q = torch.cat([self._ctx_kv.Kc[l], K_buf[l][:, :, :t, :]], dim=2)
                        V_full_q = torch.cat([self._ctx_kv.Vc[l], V_buf[l][:, :, :t, :]], dim=2)
                        attn_q = torch.nn.functional.scaled_dot_product_attention(
                            qh_q, K_full_q, V_full_q, attn_mask=None, dropout_p=0.0, is_causal=False
                        )
                        out_q = lyr.attn.o_proj(_combine_heads(attn_q))
                        h_q = h_q + lyr.drop_attn(out_q)
                        y2q = lyr.norm2(h_q)
                        h_q = h_q + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2q))))

                # Sample prediction from final query representation
                z = self.backbone.norm(h_q)
                y_samp = self.head.sample(z, num_samples=1).squeeze(2)  # [B,1,1]
                ypred[:, t : t + 1, :] = y_samp

        return DataAttr(xc=None, yc=ypred)

    @torch.no_grad()
    def decode_ar_buffer_kchunk(
        self,
        batch: DataAttr,
        K: int = 32,
        ar_indexing: str = "chunk",
        order: torch.Tensor | None = None,
        refresh_mode: str = "none",
    ) -> DataAttr:
        """Autoregressive decode with buffer tokens using K-chunking.

        - Precompute row embeddings and context K/V beforehand.
        - Process targets in chunks of size K. Within each chunk, AR position tokens
          are indexed locally (s-1) by default (ar_indexing='chunk'). Optionally,
          a 'global' policy can be supported via modulo.
        - Commit only buffer tokens; never write the current query's K/V.
        - Returns predictions in normalized space.
        """
        assert self._emb_tgt_rows is not None and self._ctx_kv is not None
        B, T, D = self._emb_tgt_rows.shape
        max_buf = self._ar_tokens_expanded.shape[0]
        assert K <= max_buf, f"K={K} exceeds max_buffer_size={max_buf}"

        ypred = torch.zeros(B, T, 1, device=self._emb_tgt_rows.device, dtype=self._emb_tgt_rows.dtype)

        # If a permutation order is provided, process targets in that order and
        # build buffer K/V as an ordered list (independent of global positions).
        if order is not None:
            assert order.shape[0] == B and order.shape[1] == T, "order must be [B,T]"
            # Per-layer lists of committed buffer K/V (grow with steps)
            Kb_lists = [[] for _ in self.backbone.layers]
            Vb_lists = [[] for _ in self.backbone.layers]

            with SDPAKernel(self.backend):
                for s in range(T):
                    # Optional periodic chunk refresh of last K committed buffers
                    if refresh_mode == "boundary" and s > 0 and (s % K == 0):
                        start = max(0, s - K)
                        for l, lyr in enumerate(self.backbone.layers):
                            H = lyr.attn.num_heads
                            new_K_list = []
                            new_V_list = []
                            for jj in range(start, s):
                                idx_j = order[:, jj]
                                idx_rows = idx_j.view(B, 1, 1).expand(-1, 1, D)
                                prev_rows = torch.gather(self._emb_tgt_rows, 1, idx_rows)
                                y_prev = torch.gather(ypred, 1, idx_j.view(B, 1, 1))
                                y_enc = self.embedder.target_encoder(y_prev)
                                if self.concat_cls:
                                    y_enc = y_enc.repeat_interleave(self.num_cls_tokens, dim=-1)
                                ar_idx_local = (jj - start) % self._ar_tokens_expanded.shape[0]
                                ar_tok = self._ar_tokens_expanded[ar_idx_local].unsqueeze(0).unsqueeze(0).expand(B, 1, -1)
                                h_bj = prev_rows + y_enc + ar_tok
                                ybj = lyr.norm1(h_bj)
                                kbj = _split_heads(lyr.attn.k_proj(ybj), lyr.attn.num_kv_heads)
                                vbj = _split_heads(lyr.attn.v_proj(ybj), lyr.attn.num_kv_heads)
                                kbj = expand_kv_heads(kbj, H // lyr.attn.num_kv_heads)
                                vbj = expand_kv_heads(vbj, H // lyr.attn.num_kv_heads)
                                new_K_list.append(kbj)
                                new_V_list.append(vbj)
                                # Propagate for consistency within chunk
                                qh_bj = _split_heads(lyr.attn.q_proj(ybj), H)
                                K_full_b = torch.cat([self._ctx_kv.Kc[l]] + Kb_lists[l][:start] + new_K_list, dim=2)
                                V_full_b = torch.cat([self._ctx_kv.Vc[l]] + Vb_lists[l][:start] + new_V_list, dim=2)
                                attn_b = torch.nn.functional.scaled_dot_product_attention(
                                    qh_bj, K_full_b, V_full_b, attn_mask=None, dropout_p=0.0, is_causal=False
                                )
                                out_b = lyr.attn.o_proj(_combine_heads(attn_b))
                                h_bj = h_bj + lyr.drop_attn(out_b)
                                y2bj = lyr.norm2(h_bj)
                                h_bj = h_bj + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2bj))))
                            if new_K_list:
                                Kb_lists[l][start:s] = new_K_list
                                Vb_lists[l][start:s] = new_V_list
                    # Commit buffer for previous predicted target (if any)
                    if s > 0:
                        idx_prev = order[:, s - 1]  # [B]
                        idx_prev_rows = idx_prev.view(B, 1, 1).expand(-1, 1, D)
                        prev_rows = torch.gather(self._emb_tgt_rows, 1, idx_prev_rows)  # [B,1,D]
                        y_prev = torch.gather(ypred, 1, idx_prev.view(B, 1, 1))  # [B,1,1]
                        y_enc = self.embedder.target_encoder(y_prev)
                        if self.concat_cls:
                            y_enc = y_enc.repeat_interleave(self.num_cls_tokens, dim=-1)
                        ar_idx = (s - 1) % max_buf if ar_indexing == "chunk" else ((s - 1) % max_buf)
                        ar_tok = self._ar_tokens_expanded[ar_idx].unsqueeze(0).unsqueeze(0).expand(B, 1, -1)
                        h_b = prev_rows + y_enc + ar_tok  # [B,1,D]

                        # Compute per-layer buffer K/V and append
                        for l, lyr in enumerate(self.backbone.layers):
                            H = lyr.attn.num_heads
                            yb = lyr.norm1(h_b)
                            kb = _split_heads(lyr.attn.k_proj(yb), lyr.attn.num_kv_heads)
                            vb = _split_heads(lyr.attn.v_proj(yb), lyr.attn.num_kv_heads)
                            kb = expand_kv_heads(kb, H // lyr.attn.num_kv_heads)
                            vb = expand_kv_heads(vb, H // lyr.attn.num_kv_heads)
                            Kb_lists[l].append(kb)
                            Vb_lists[l].append(vb)

                            # Propagate buffer hidden state for consistency
                            qh_b = _split_heads(lyr.attn.q_proj(yb), H)
                            K_full_b = torch.cat([self._ctx_kv.Kc[l]] + Kb_lists[l], dim=2)
                            V_full_b = torch.cat([self._ctx_kv.Vc[l]] + Vb_lists[l], dim=2)
                            attn_b = torch.nn.functional.scaled_dot_product_attention(
                                qh_b, K_full_b, V_full_b, attn_mask=None, dropout_p=0.0, is_causal=False
                            )
                            out_b = lyr.attn.o_proj(_combine_heads(attn_b))
                            h_b = h_b + lyr.drop_attn(out_b)
                            y2b = lyr.norm2(h_b)
                            h_b = h_b + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2b))))

                    # Decode current target in the given order against context + committed buffers
                    idx_curr = order[:, s]
                    idx_curr_rows = idx_curr.view(B, 1, 1).expand(-1, 1, D)
                    h_q = torch.gather(self._emb_tgt_rows, 1, idx_curr_rows)  # [B,1,D]

                    for l, lyr in enumerate(self.backbone.layers):
                        H = lyr.attn.num_heads
                        yq = lyr.norm1(h_q)
                        q = lyr.attn.q_proj(yq)
                        qh_q = _split_heads(q, H)
                        if len(Kb_lists[l]) > 0:
                            K_full_q = torch.cat([self._ctx_kv.Kc[l]] + Kb_lists[l], dim=2)
                            V_full_q = torch.cat([self._ctx_kv.Vc[l]] + Vb_lists[l], dim=2)
                        else:
                            K_full_q = self._ctx_kv.Kc[l]
                            V_full_q = self._ctx_kv.Vc[l]
                        attn_q = torch.nn.functional.scaled_dot_product_attention(
                            qh_q, K_full_q, V_full_q, attn_mask=None, dropout_p=0.0, is_causal=False
                        )
                        out_q = lyr.attn.o_proj(_combine_heads(attn_q))
                        h_q = h_q + lyr.drop_attn(out_q)
                        y2q = lyr.norm2(h_q)
                        h_q = h_q + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2q))))

                    z = self.backbone.norm(h_q)
                    y_samp = self.head.sample(z, num_samples=1).squeeze(2)  # [B,1,1]
                    # Scatter predicted y back to global index position
                    for b in range(B):
                        ypred[b, idx_curr[b].item(), :] = y_samp[b, 0, :]

            return DataAttr(xc=None, yc=ypred)

        # Default path: process in natural order with K-chunking (original behavior)
        # Allocate per-layer buffer KV up to T
        K_buf: list[torch.Tensor] = []
        V_buf: list[torch.Tensor] = []
        for lyr in self.backbone.layers:
            H = lyr.attn.num_heads
            Dh = lyr.attn.head_dim
            K_buf.append(torch.zeros(B, H, T, Dh, device=self._emb_tgt_rows.device, dtype=self._emb_tgt_rows.dtype))
            V_buf.append(torch.zeros_like(K_buf[-1]))

        with SDPAKernel(self.backend):
            for chunk_start in range(0, T, K):
                chunk_len = min(K, T - chunk_start)
                for s in range(chunk_len):
                    t = chunk_start + s

                    # Build buffer token for previous step, if any
                    if s > 0 or t > 0:
                        t_prev = t - 1
                        y_prev = ypred[:, t_prev : t_prev + 1, :]
                        y_enc = self.embedder.target_encoder(y_prev)
                        if self.concat_cls:
                            y_enc = y_enc.repeat_interleave(self.num_cls_tokens, dim=-1)
                        if ar_indexing == "chunk":
                            ar_idx = (s - 1) % max_buf if s > 0 else 0
                        else:
                            ar_idx = (t_prev) % max_buf
                        ar_tok = self._ar_tokens_expanded[ar_idx].unsqueeze(0).unsqueeze(0).expand(B, 1, -1)
                        h_b = self._emb_tgt_rows[:, t_prev : t_prev + 1, :] + y_enc + ar_tok
                    else:
                        h_b = None

                    # Current query embedding
                    h_q = self._emb_tgt_rows[:, t : t + 1, :]

                    for l, lyr in enumerate(self.backbone.layers):
                        H = lyr.attn.num_heads

                        # Commit buffer K/V at global position t_prev
                        if h_b is not None:
                            yb = lyr.norm1(h_b)
                            kb = _split_heads(lyr.attn.k_proj(yb), lyr.attn.num_kv_heads)
                            vb = _split_heads(lyr.attn.v_proj(yb), lyr.attn.num_kv_heads)
                            kb = expand_kv_heads(kb, H // lyr.attn.num_kv_heads)
                            vb = expand_kv_heads(vb, H // lyr.attn.num_kv_heads)
                            K_buf[l][:, :, t_prev : t_prev + 1, :] = kb
                            V_buf[l][:, :, t_prev : t_prev + 1, :] = vb

                            # Propagate buffer hidden state for consistency
                            qh_b = _split_heads(lyr.attn.q_proj(yb), H)
                            K_full_b = torch.cat([self._ctx_kv.Kc[l], K_buf[l][:, :, : t_prev + 1, :]], dim=2)
                            V_full_b = torch.cat([self._ctx_kv.Vc[l], V_buf[l][:, :, : t_prev + 1, :]], dim=2)
                            attn_b = torch.nn.functional.scaled_dot_product_attention(
                                qh_b, K_full_b, V_full_b, attn_mask=None, dropout_p=0.0, is_causal=False
                            )
                            out_b = lyr.attn.o_proj(_combine_heads(attn_b))
                            h_b = h_b + lyr.drop_attn(out_b)
                            y2b = lyr.norm2(h_b)
                            h_b = h_b + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2b))))

                        # Query attends to context + committed buffers (no self write)
                        yq = lyr.norm1(h_q)
                        qh_q = _split_heads(lyr.attn.q_proj(yq), H)
                        up_to = t if h_b is not None else 0
                        K_full_q = torch.cat([self._ctx_kv.Kc[l], K_buf[l][:, :, :up_to, :]], dim=2)
                        V_full_q = torch.cat([self._ctx_kv.Vc[l], V_buf[l][:, :, :up_to, :]], dim=2)
                        attn_q = torch.nn.functional.scaled_dot_product_attention(
                            qh_q, K_full_q, V_full_q, attn_mask=None, dropout_p=0.0, is_causal=False
                        )
                        out_q = lyr.attn.o_proj(_combine_heads(attn_q))
                        h_q = h_q + lyr.drop_attn(out_q)
                        y2q = lyr.norm2(h_q)
                        h_q = h_q + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2q))))

                    z = self.backbone.norm(h_q)
                    y_samp = self.head.sample(z, num_samples=1).squeeze(2)
                    ypred[:, t : t + 1, :] = y_samp

        return DataAttr(xc=None, yc=ypred)

    @torch.no_grad()
    def decode_ar_reencode(self, batch: DataAttr, K: int = 4) -> DataAttr:
        """Autoregressive re-encode per step (stub)."""
        raise NotImplementedError("AR re-encode (Mode C) not implemented yet")

    # ------------------------------ Internals ---------------------------------
    def _get_context_y(self) -> torch.Tensor:
        """Internal hook to retrieve the context labels (yc) used for precompute.

        The engine requires that precompute_rows() was called with the same batch
        that will be decoded. For now, we store context(labels) as a side-channel
        tensor on the cached embeddings for simplicity.
        """
        # We piggy-back by inferring the tensor from the cached embeddings' device/shape
        # and keep a private attribute set by precompute_rows. This avoids threading
        # through additional state in public APIs.
        if not hasattr(self, "_cached_yc"):
            raise RuntimeError("Context labels not cached; call precompute_rows with batch")
        return getattr(self, "_cached_yc")

    @torch.no_grad()
    def cache_context_labels(self, yc: torch.Tensor) -> None:
        """Cache the context labels (yc) for use during context K/V build."""
        setattr(self, "_cached_yc", yc)

    @torch.no_grad()
    def decode_reencode(self, batch: DataAttr) -> DataAttr:
        """Re-encode-after-each-prediction (no buffer) autoregressive decode.

        Procedure per step t:
        - Build a prefix of size Nc + t: original context rows with labels, plus
          the first t target rows with their generated labels (no AR tokens).
        - Run the transformer over this prefix with dense self-attention and
          capture per-layer K/V for the updated prefix.
        - Decode the current query (target row t) by attending to the updated
          prefix K/V only (no targets in K/V beyond the prefix), sample y_t.
        - Append y_t to the generated list and proceed.

        Returns DataAttr with yc=[B,T,1] in normalized space.
        """
        assert self._emb_ctx_rows is not None and self._emb_tgt_rows is not None
        B, Nc, D = self._emb_ctx_rows.shape
        T = self._emb_tgt_rows.shape[1]

        yc_ctx = self._get_context_y()  # [B,Nc,1]
        yc_enc = self.embedder.target_encoder(yc_ctx)
        if self.concat_cls:
            yc_enc = yc_enc.repeat_interleave(self.num_cls_tokens, dim=-1)

        ypred = torch.zeros(B, T, 1, device=self._emb_tgt_rows.device, dtype=self._emb_tgt_rows.dtype)

        with SDPAKernel(self.backend):
            for t in range(T):
                # Build prefix embeddings: [ctx + generated targets]
                if t > 0:
                    y_prev = ypred[:, :t, :]  # [B,t,1]
                    y_prev_enc = self.embedder.target_encoder(y_prev)  # [B,t,E]
                    if self.concat_cls:
                        y_prev_enc = y_prev_enc.repeat_interleave(self.num_cls_tokens, dim=-1)  # [B,t,D]
                    h_prefix = torch.cat([
                        self._emb_ctx_rows + yc_enc,
                        self._emb_tgt_rows[:, :t, :] + y_prev_enc,
                    ], dim=1)  # [B, Nc+t, D]
                else:
                    h_prefix = self._emb_ctx_rows + yc_enc  # [B,Nc,D]

                # Run transformer over prefix and capture per-layer K/V
                Kc_list: list[torch.Tensor] = []
                Vc_list: list[torch.Tensor] = []
                h = h_prefix
                for lyr in self.backbone.layers:
                    H = lyr.attn.num_heads
                    y = lyr.norm1(h)
                    q = lyr.attn.q_proj(y)
                    k = lyr.attn.k_proj(y)
                    v = lyr.attn.v_proj(y)

                    qh = _split_heads(q, H)
                    kh = _split_heads(k, lyr.attn.num_kv_heads)
                    vh = _split_heads(v, lyr.attn.num_kv_heads)
                    kh = expand_kv_heads(kh, H // lyr.attn.num_kv_heads)
                    vh = expand_kv_heads(vh, H // lyr.attn.num_kv_heads)

                    attn = torch.nn.functional.scaled_dot_product_attention(
                        qh, kh, vh, attn_mask=None, dropout_p=0.0, is_causal=False
                    )
                    out = lyr.attn.o_proj(_combine_heads(attn))
                    h = h + lyr.drop_attn(out)
                    y2 = lyr.norm2(h)
                    h = h + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2))))

                    Kc_list.append(kh)
                    Vc_list.append(vh)

                # Decode current query (target t) against updated prefix K/V
                h_q = self._emb_tgt_rows[:, t : t + 1, :]
                for l, lyr in enumerate(self.backbone.layers):
                    H = lyr.attn.num_heads
                    yq = lyr.norm1(h_q)
                    q = lyr.attn.q_proj(yq)
                    qh = _split_heads(q, H)
                    Kc = Kc_list[l]
                    Vc = Vc_list[l]
                    attn_q = torch.nn.functional.scaled_dot_product_attention(
                        qh, Kc, Vc, attn_mask=None, dropout_p=0.0, is_causal=False
                    )
                    out_q = lyr.attn.o_proj(_combine_heads(attn_q))
                    h_q = h_q + lyr.drop_attn(out_q)
                    y2q = lyr.norm2(h_q)
                    h_q = h_q + lyr.ff2(lyr.drop_ff(torch.nn.functional.gelu(lyr.ff1(y2q))))

                z = self.backbone.norm(h_q)
                y_samp = self.head.sample(z, num_samples=1).squeeze(2)
                ypred[:, t : t + 1, :] = y_samp

        return DataAttr(xc=None, yc=ypred)
