from __future__ import annotations

"""
Compiled implementations (sampling and LL) for five methods (M1..M5), plus
an optional Triton fast-path cross-attention used by the AR buffer method (M3).

This file ports the core architecture and adapters from the notebooks, keeping
the compile configuration and call signatures aligned so we can time them from
short runner scripts.
"""

import math
from typing import Optional, Tuple, List, Dict, Any
import os

import torch
import torch.nn as nn
import torch.nn.functional as F


# -------------------------
# Helpers: heads and attn
# -------------------------

COMPILE_MODE = os.environ.get("FAST_TIMES_COMPILE_MODE", "reduce-overhead")

def split_heads(x: torch.Tensor, n_heads: int) -> torch.Tensor:
    B, L, D = x.shape
    Dh = D // n_heads
    return x.view(B, L, n_heads, Dh).permute(0, 2, 1, 3).contiguous()


def combine_heads(y: torch.Tensor) -> torch.Tensor:
    B, H, L, Dh = y.shape
    return y.permute(0, 2, 1, 3).contiguous().view(B, L, H * Dh)


class FFN(nn.Module):
    def __init__(self, d_in: int, d_hid: int, d_out: Optional[int] = None, act=nn.GELU, bias: bool = True):
        super().__init__()
        d_out = d_out or d_in
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hid, bias=bias),
            act(),
            nn.Linear(d_hid, d_out, bias=bias),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class SelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        self.n_heads = n_heads
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False):
        q = split_heads(self.q_proj(x), self.n_heads)
        k = split_heads(self.k_proj(x), self.n_heads)
        v = split_heads(self.v_proj(x), self.n_heads)
        y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal)
        return self.out(combine_heads(y))


class CrossAttentionConcat(nn.Module):
    """Q from targets; K/V is concat of context and optional buffer K/V."""

    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        self.n_heads = n_heads
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x_q: torch.Tensor, Kc: Any, Vc: Any,
                Kb: Optional[Any] = None, Vb: Optional[Any] = None,
                attn_mask: Optional[torch.Tensor] = None):
        # Q projection + split
        q = split_heads(self.q_proj(x_q), self.n_heads)  # [B,H,Lq,Dh]

        def as_heads(K: Any, V: Any) -> tuple[torch.Tensor, torch.Tensor]:
            # If already per-head tensors [B or 1, H, L, Dh], return (expanded to B if needed)
            if isinstance(K, torch.Tensor) and K.dim() == 4 and K.shape[1] == self.n_heads:
                assert isinstance(V, torch.Tensor) and V.dim() == 4 and V.shape[1] == self.n_heads
                k, v = K, V
                if k.shape[0] == 1 and q.shape[0] > 1:
                    k = k.expand(q.shape[0], -1, -1, -1).contiguous()
                    v = v.expand(q.shape[0], -1, -1, -1).contiguous()
                return k, v
            # Otherwise assume D-space [B,L,D] and apply k/v projections then split to heads
            k = split_heads(self.k_proj(K), self.n_heads)
            v = split_heads(self.v_proj(V), self.n_heads)
            return k, v

        kc_h, vc_h = as_heads(Kc, Vc)
        if Kb is not None and Vb is not None:
            kb_h, vb_h = as_heads(Kb, Vb)
            k_h = torch.cat([kc_h, kb_h], dim=2)
            v_h = torch.cat([vc_h, vb_h], dim=2)
        else:
            k_h, v_h = kc_h, vc_h

        y = F.scaled_dot_product_attention(q, k_h, v_h, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
        return self.out(combine_heads(y))


class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        self.sa = SelfAttention(d_model, n_heads)
        self.ff = FFN(d_model, d_ff)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
        x = x + self.sa(self.ln1(x), attn_mask=attn_mask, is_causal=False)
        x = x + self.ff(self.ln2(x))
        return x


class ContextEncoder(nn.Module):
    def __init__(self, d_model: int, n_heads: int, n_layers: int, d_ff: int):
        super().__init__()
        self.blocks = nn.ModuleList([TransformerEncoderBlock(d_model, n_heads, d_ff) for _ in range(n_layers)])
        self.ln = nn.LayerNorm(d_model)

    def encode(self, x_mem: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        h = x_mem
        for blk in self.blocks:
            h = blk(h, attn_mask=attn_mask)
        return self.ln(h)

    def kv_from_encoded(self, E: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # E: [B,L,D] -> K/V as linear projections (shared per-head split inside CA)
        return E, E


class DecoderBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        self.ca = CrossAttentionConcat(d_model, n_heads)
        self.ff = FFN(d_model, d_ff)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, h_in: torch.Tensor, Kc: torch.Tensor, Vc: torch.Tensor,
                Kb: Optional[torch.Tensor] = None, Vb: Optional[torch.Tensor] = None,
                attn_mask: Optional[torch.Tensor] = None):
        h = h_in + self.ca(self.ln1(h_in), Kc, Vc, Kb, Vb, attn_mask=attn_mask)
        h = h + self.ff(self.ln2(h))
        return h


class Decoder(nn.Module):
    def __init__(self, d_model: int, n_heads: int, n_layers: int, d_ff: int):
        super().__init__()
        self.blocks = nn.ModuleList([DecoderBlock(d_model, n_heads, d_ff) for _ in range(n_layers)])
        self.ln = nn.LayerNorm(d_model)
        self.n_heads = n_heads

    def forward(self, h_in: torch.Tensor,
                Kc: Any,
                Vc: Any,
                Kb: Optional[Any] = None,
                Vb: Optional[Any] = None,
                attn_mask: Optional[torch.Tensor] = None):
        """Decoder forward supporting either D-space K/V tensors or per-layer per-head K/V lists.

        - If Kc/Vc are tensors, they are treated as D-space inputs and each layer's CA
          applies its own k/v projections on the fly (legacy behavior).
        - If Kc/Vc are lists/tuples (len == n_layers), each item is a per-head tensor
          [B or 1, H, L, Dh] giving precomputed K/V for that layer.
        Similarly for Kb/Vb when provided.
        """
        h = h_in
        use_per_layer = isinstance(Kc, (list, tuple)) and isinstance(Vc, (list, tuple))
        if use_per_layer:
            assert len(Kc) == len(self.blocks) and len(Vc) == len(self.blocks)
        for li, blk in enumerate(self.blocks):
            if use_per_layer:
                Kc_l = Kc[li]
                Vc_l = Vc[li]
                Kb_l = None if (Kb is None) else Kb[li]
                Vb_l = None if (Vb is None) else Vb[li]
                h = blk(h, Kc_l, Vc_l, Kb_l, Vb_l, attn_mask=attn_mask)
            else:
                h = blk(h, Kc, Vc, Kb, Vb, attn_mask=attn_mask)
        return self.ln(h)

    def buf_kv_from_token(self, buf_tok: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """Project a single buffer token to per-layer, per-head K/V once.

        Returns lists of tensors with shape [B, H, 1, Dh] for each layer.
        """
        B, L, D = buf_tok.shape
        assert L == 1, "buf_kv_from_token expects a single-step token [B,1,D]"
        H = self.n_heads
        K_list: List[torch.Tensor] = []
        V_list: List[torch.Tensor] = []
        for blk in self.blocks:
            k = split_heads(blk.ca.k_proj(buf_tok), H).contiguous()
            v = split_heads(blk.ca.v_proj(buf_tok), H).contiguous()
            K_list.append(k)
            V_list.append(v)
        return K_list, V_list


class ContextEmbedder(nn.Module):
    def __init__(self, d_in: int, d_model: int):
        super().__init__()
        # Use 2-layer MLP for parity with Triton path
        self.net = FFN(d_in, 2 * d_model, d_model)

    def forward(self, x_ctx: torch.Tensor) -> torch.Tensor:
        return self.net(x_ctx)


class TargetEmbedder(nn.Module):
    def __init__(self, d_in: int, d_model: int):
        super().__init__()
        self.net = FFN(d_in, 2 * d_model, d_model)

    def forward(self, x_tgt: torch.Tensor) -> torch.Tensor:
        return self.net(x_tgt)


class BufferEmbedder(nn.Module):
    def __init__(self, d_in: int, d_model: int):
        super().__init__()
        self.net = FFN(d_in, 2 * d_model, d_model)

    def forward(self, x_buf: torch.Tensor) -> torch.Tensor:
        return self.net(x_buf)


class GaussianHead(nn.Module):
    def __init__(self, d_model: int, d_out: int):
        super().__init__()
        self.proj = nn.Linear(d_model, 2 * d_out)

    def forward(self, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        mu, log_sigma = torch.chunk(self.proj(h), 2, dim=-1)
        log_sigma = log_sigma.clamp(-7.0, 3.0)
        sigma = torch.exp(log_sigma)
        return mu, sigma


# ---- Multivariate head (LL only) ----

class MvGaussianHead(nn.Module):
    def __init__(self, d_model: int, dy: int, n_heads: int, d_ff: int,
                 n_std_layers: int = 2, prj_dim: int = 8, bound_diag: bool = True, min_diag: float = 0.05):
        super().__init__()
        self.mean = nn.Linear(d_model, dy)
        std_encoder_layer = nn.TransformerEncoderLayer(d_model, n_heads, d_ff, batch_first=True)
        self.std_encoder = nn.TransformerEncoder(std_encoder_layer, n_std_layers)
        self.projector = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, prj_dim * dy)
        )
        self.bound_diag = bound_diag
        self.min_diag = min_diag

    def forward(self, h_seq: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # h_seq: [B,T,D]
        B, T, D = h_seq.shape
        dy = self.mean.out_features
        mu = self.mean(h_seq).view(B, -1)  # [B,T*dy]
        E = self.std_encoder(h_seq)
        P = self.projector(E).view(B, T * dy, -1)
        L_full = torch.bmm(P, P.transpose(1, 2))
        L = torch.tril(L_full)
        if self.bound_diag:
            idx = torch.arange(T * dy, device=h_seq.device)
            L[:, idx, idx] = self.min_diag + torch.relu(L[:, idx, idx])
        return mu, L


# ------------------------------
# Providers and AR drivers
# ------------------------------

@torch.no_grad()
def build_decoder_ctx_kv(E: torch.Tensor, dec: Decoder) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    """Compute per-layer, per-head context K/V once using each decoder layer's CA projections.

    Returns lists of tensors of shape [B, H, L, Dh] (no batch expansion; callers may broadcast).
    """
    H = dec.n_heads
    K_list: List[torch.Tensor] = []
    V_list: List[torch.Tensor] = []
    for blk in dec.blocks:
        k = split_heads(blk.ca.k_proj(E), H).contiguous()
        v = split_heads(blk.ca.v_proj(E), H).contiguous()
        K_list.append(k)
        V_list.append(v)
    return K_list, V_list

class MemoryProvider:
    def prepare_context(self, xc_feats: torch.Tensor): ...
    def start_sequence(self, B: int, T: int, xt_feats: torch.Tensor): ...
    def memory_at(self, t: int): ...
    def observe(self, t: int, x_t_feat_B: torch.Tensor, y_t: torch.Tensor, h_t: torch.Tensor, dec: Decoder): ...


class ARBufferProvider(MemoryProvider):
    """Encode C once; keep per-layer K/V buffer grown over steps."""

    def __init__(self, ctx_enc: ContextEncoder, emb_ctx: nn.Module, emb_buf: nn.Module, dec: Decoder):
        self.ctx_enc, self.emb_ctx, self.emb_buf, self.dec = ctx_enc, emb_ctx, emb_buf, dec
        self.E = None
        # Per-layer per-head context K/V
        self.Kc: Optional[List[torch.Tensor]] = None  # each [B or 1,H,Nc,Dh]
        self.Vc: Optional[List[torch.Tensor]] = None
        # Preallocated per-layer per-head buffers for AR
        self.Kb_buf: Optional[List[torch.Tensor]] = None  # each [B,H,T,Dh]
        self.Vb_buf: Optional[List[torch.Tensor]] = None
        self.dx: Optional[int] = None
        self.dy: Optional[int] = None

    def prepare_context(self, xc_feats):
        E = self.ctx_enc.encode(xc_feats)
        self.E = E
        # Build per-layer, per-head K/V using decoder CA projections
        self.Kc, self.Vc = build_decoder_ctx_kv(E, self.dec)

    def start_sequence(self, B, T, xt_feats):
        device = xt_feats.device
        dtype = xt_feats.dtype
        H = self.dec.n_heads
        Dh = xt_feats.shape[-1] // H
        L = len(self.dec.blocks)
        self.Kb_buf = [torch.empty(B, H, T, Dh, device=device, dtype=dtype) for _ in range(L)]
        self.Vb_buf = [torch.empty(B, H, T, Dh, device=device, dtype=dtype) for _ in range(L)]

    def memory_at(self, t):
        if self.Kb_buf is not None and t > 0:
            Kb = [Kb_l[:, :, :t, :].contiguous() for Kb_l in self.Kb_buf]
            Vb = [Vb_l[:, :, :t, :].contiguous() for Vb_l in self.Vb_buf]
        else:
            Kb = Vb = None
        return self.Kc, self.Vc, Kb, Vb

    def observe(self, t, x_t_feat_B, y_t, h_t, dec: Decoder):
        if self.dx is None:
            self.dx = x_t_feat_B.shape[-1]
        if self.dy is None:
            self.dy = y_t.shape[-1]
        buf_tok = torch.cat([x_t_feat_B, y_t], dim=-1)  # [B,1,dx+dy]
        buf_feat = self.emb_buf(buf_tok)                # [B,1,D]
        # Project once to per-layer, per-head K/V
        Kb_t_list, Vb_t_list = dec.buf_kv_from_token(buf_feat)  # lists of [B,H,1,Dh]
        if self.Kb_buf is not None and self.Vb_buf is not None:
            for li in range(len(self.Kb_buf)):
                self.Kb_buf[li][:, :, t:t+1, :].copy_(Kb_t_list[li])
                self.Vb_buf[li][:, :, t:t+1, :].copy_(Vb_t_list[li])
        else:
            # Fallback (should not happen with start_sequence): no prealloc; ignore
            pass


class TNPDRencodeProvider(MemoryProvider):
    """Re-encode [C | TY<=t] each step; no KV buffer reuse."""

    def __init__(self, ctx_enc: ContextEncoder, emb_ctx: nn.Module, emb_buf: nn.Module, dec: Decoder):
        self.ctx_enc, self.emb_ctx, self.emb_buf, self.dec = ctx_enc, emb_ctx, emb_buf, dec
        self.xc_feats = None

    def prepare_context(self, xc_feats):
        self.xc_feats = xc_feats

    def start_sequence(self, B, T, xt_feats):
        self.Ybuf = []

    def memory_at(self, t):
        # Encode [C | TY<=t]
        if self.Ybuf:
            buf = torch.cat(self.Ybuf, dim=1)  # [B,t,dx+dy]
            buf_feat = self.emb_buf(buf)
            mem = torch.cat([self.xc_feats, buf_feat], dim=1)
        else:
            mem = self.xc_feats
        E = self.ctx_enc.encode(mem)
        Kc, Vc = build_decoder_ctx_kv(E, self.dec)
        return Kc, Vc, None, None

    def observe(self, t, x_t_feat_B, y_t, h_t, dec: Decoder):
        self.Ybuf.append(torch.cat([x_t_feat_B, y_t], dim=-1))


class TNPAReencodeProvider(MemoryProvider):
    """TNPA masked one-pass style provider for AR timing; we still iterate for parity."""

    def __init__(self, ctx_enc: ContextEncoder, emb_ctx: nn.Module, emb_buf: nn.Module, dec: Decoder):
        self.ctx_enc, self.emb_ctx, self.emb_buf, self.dec = ctx_enc, emb_ctx, emb_buf, dec
        self.xc_feats: Optional[torch.Tensor] = None  # [B,Nc,D]
        self.Ybuf: list[torch.Tensor] = []           # list of [B,1,dx+dy]
        self.dx: Optional[int] = None
        self.dy: Optional[int] = None

    def prepare_context(self, xc_feats):
        self.xc_feats = xc_feats

    def start_sequence(self, B, T, xt_feats):
        self.Ybuf = []

    def memory_at(self, t):
        """Build TNPA memory = [C | TY<=t' | T0<=t'] with mask; t' = len(Ybuf).

        - C: context tokens (already embedded to D)
        - TY: embedded buffer with [x|y]
        - T0: embedded buffer with [x|0]
        - Mask: TNPA square mask [L,L] applied to encoder
        """
        B = self.xc_feats.shape[0]
        Nc = self.xc_feats.shape[1]
        t1 = len(self.Ybuf)
        if t1 > 0:
            buf = torch.cat(self.Ybuf, dim=1)              # [B,t1,dx+dy]
            assert self.dx is not None and self.dy is not None
            x_pref = buf[..., : self.dx]
            y_pref = buf[..., self.dx : self.dx + self.dy]
            zeros = torch.zeros_like(y_pref)
            ty_tok = self.emb_buf(torch.cat([x_pref, y_pref], dim=-1))  # [B,t1,D]
            t0_tok = self.emb_buf(torch.cat([x_pref, zeros], dim=-1))   # [B,t1,D]
            mem = torch.cat([self.xc_feats, ty_tok, t0_tok], dim=1)     # [B,Nc+2*t1,D]
            mask = build_tnpa_mask_fast(Nc, t1, device=mem.device)      # [L,L] bool
            E = self.ctx_enc.encode(mem, attn_mask=mask)
        else:
            E = self.ctx_enc.encode(self.xc_feats)
        Kc, Vc = build_decoder_ctx_kv(E, self.dec)
        return Kc, Vc, None, None

    def observe(self, t, x_t_feat_B, y_t, h_t, dec: Decoder):
        # Record dx/dy on first observation
        if self.dx is None:
            self.dx = x_t_feat_B.shape[-1]
        if self.dy is None:
            self.dy = y_t.shape[-1]
        self.Ybuf.append(torch.cat([x_t_feat_B, y_t], dim=-1))

def build_tnpa_mask_fast(Nc: int, t1: int, device) -> torch.Tensor:
    """Vectorized TNPA AR mask, [L,L] bool, True = masked.

    Layout: [ctx | ty_1..ty_t1 | t0_1..t0_t1]
    Rules per TNPA: each ty_i and t0_i can see ctx and all columns within its time index i (<= i).
    """
    L = Nc + 2 * t1
    if t1 == 0:
        return torch.zeros((L, L), dtype=torch.bool, device=device)
    I = torch.arange(L, device=device)[:, None]
    J = torch.arange(L, device=device)[None, :]

    row_ctx = I < Nc
    row_ty = (I >= Nc) & (I < Nc + t1)
    row_t0 = (I >= Nc + t1) & (I < Nc + 2 * t1)

    col_ctx = J < Nc
    col_ty = (J >= Nc) & (J < Nc + t1)
    col_t0 = (J >= Nc + t1) & (J < Nc + 2 * t1)

    i_ty = (I - Nc).clamp(min=0)
    i_t0 = (I - (Nc + t1)).clamp(min=0)
    j_ty = (J - Nc).clamp(min=0)
    j_t0 = (J - (Nc + t1)).clamp(min=0)

    allow_ctx = row_ctx & col_ctx
    allow_ty = row_ty & (col_ctx | (col_ty & (j_ty <= i_ty)) | (col_t0 & (j_t0 <= i_ty)))
    allow_t0 = row_t0 & (col_ctx | (col_ty & (j_ty <= i_t0)) | (col_t0 & (j_t0 <= i_t0)))

    allow = allow_ctx | allow_ty | allow_t0
    return ~allow


# ------------------------------
# Methods: sampling
# ------------------------------

@torch.no_grad()
def method1_independent_with_embed_head(ctx_enc: ContextEncoder, dec: Decoder, emb_ctx: nn.Module, emb_tgt: nn.Module,
                                        head: GaussianHead, xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    B = xc.shape[0]
    xc_feats = emb_ctx(torch.cat([xc, yc], dim=-1))
    E = ctx_enc.encode(xc_feats)
    Kc, Vc = build_decoder_ctx_kv(E, dec)
    xt_feats = emb_tgt(xt)
    h0 = xt_feats
    h = dec(h0, Kc, Vc, None, None)
    mu, sigma = head(h)
    return mu, sigma


def unified_ar_driver(provider: MemoryProvider, dec: Decoder, emb_tgt: nn.Module, head: GaussianHead,
                      emb_ctx: nn.Module, xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor, *, dy: int) -> Tuple[torch.Tensor, torch.Tensor]:
    B, T = xt.shape[0], xt.shape[1]
    # Context embeddings [x|y]
    xc_feats = emb_ctx(torch.cat([xc, yc], dim=-1))
    provider.prepare_context(xc_feats)
    xt_feats = emb_tgt(xt)
    provider.start_sequence(B, T, xt_feats)
    y_prev = torch.zeros(B, 1, dy, device=xt.device, dtype=xt.dtype)
    h_t = xt_feats[:, :1, :]
    mu_out = []
    sigma_out = []
    for t in range(T):
        Kc, Vc, Kb, Vb = provider.memory_at(t)
        h_t = dec(h_t, Kc, Vc, Kb, Vb)
        mu_t, sigma_t = head(h_t)
        mu_out.append(mu_t)
        sigma_out.append(sigma_t)
        # Pass raw x_t (dx) with y_t (dy) to emb_buf inside provider
        provider.observe(t, xt[:, t:t + 1, :], mu_t, h_t, dec)
    mu = torch.cat(mu_out, dim=1)
    sigma = torch.cat(sigma_out, dim=1)
    return mu, sigma


def unified_ar_ll_driver(provider: MemoryProvider, dec: Decoder, emb_tgt: nn.Module, head: GaussianHead,
                         emb_ctx: nn.Module, xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor, yt: torch.Tensor,
                         *, dy: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """Teacher-forcing AR driver for LL: writes ground-truth y_t into buffer."""
    B, T = xt.shape[0], xt.shape[1]
    xc_feats = emb_ctx(torch.cat([xc, yc], dim=-1))
    provider.prepare_context(xc_feats)
    xt_feats = emb_tgt(xt)
    provider.start_sequence(B, T, xt_feats)
    h_t = xt_feats[:, :1, :]
    mu_out = []
    sigma_out = []
    for t in range(T):
        Kc, Vc, Kb, Vb = provider.memory_at(t)
        h_t = dec(h_t, Kc, Vc, Kb, Vb)
        mu_t, sigma_t = head(h_t)
        mu_out.append(mu_t)
        sigma_out.append(sigma_t)
        # Use raw x_t with ground-truth y_t for buffer embedding
        provider.observe(t, xt[:, t:t + 1, :], yt[:, t:t + 1, :], h_t, dec)
    mu = torch.cat(mu_out, dim=1)
    sigma = torch.cat(sigma_out, dim=1)
    return mu, sigma


@torch.no_grad()
def build_arbuffer_ca_mask(Nc: int, T: int, device: torch.device) -> torch.Tensor:
    """Cross-attention mask for AR buffer: allow all context, allow buffer positions <= t for query t.
    Returns [T, Nc+T] bool, True = masked.
    """
    Lk = Nc + T
    I = torch.arange(T, device=device)[:, None]
    J = torch.arange(Lk, device=device)[None, :]
    allow_ctx = J < Nc
    allow_buf = (J >= Nc) & ((J - Nc) <= I)
    allow = allow_ctx | allow_buf
    return ~allow


@torch.no_grad()
def method3_arbuffer_ll_parallel(ctx_enc: ContextEncoder, dec: Decoder,
                                 emb_ctx: nn.Module, emb_tgt: nn.Module, emb_buf: nn.Module, head: GaussianHead,
                                 xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor, yt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """Teacher-forcing LL for M3 in one pass over T using SDPA mask on buffer tokens.

    Builds buffer tokens for all t from ground-truth y, then decodes all T queries with a
    lower-triangular mask over the buffer part (context always visible).
    """
    B, T, dx = xt.shape
    dy = yt.shape[-1]
    device = xt.device
    # Encode context once
    xc_feats = emb_ctx(torch.cat([xc, yc], dim=-1))  # [B,Nc,D]
    E_ctx = ctx_enc.encode(xc_feats)                  # [B,Nc,D]
    Kc_list, Vc_list = build_decoder_ctx_kv(E_ctx, dec)
    # Buffer tokens for all t (teacher forcing)
    buf_tok = emb_buf(torch.cat([xt, yt], dim=-1))    # [B,T,dx+dy]->[B,T,D]
    # Precompute per-layer, per-head buffer K/V
    H = dec.n_heads
    Kb_list: List[torch.Tensor] = []
    Vb_list: List[torch.Tensor] = []
    for blk in dec.blocks:
        Kb_l = split_heads(blk.ca.k_proj(buf_tok), H).contiguous()  # [B,H,T,Dh]
        Vb_l = split_heads(blk.ca.v_proj(buf_tok), H).contiguous()
        Kb_list.append(Kb_l)
        Vb_list.append(Vb_l)
    # Queries
    Q = emb_tgt(xt)                                   # [B,T,D]
    # CA mask [T,Nc+T]
    Nc = xc_feats.shape[1]
    ca_mask = build_arbuffer_ca_mask(Nc, T, device)
    # Decode once with mask using per-layer KV
    h = dec(Q, Kc_list, Vc_list, Kb_list, Vb_list, attn_mask=ca_mask)
    mu, sigma = head(h)
    return mu, sigma


@torch.no_grad()
def method4_tnpa_ll_parallel(ctx_enc: ContextEncoder, dec: Decoder,
                             emb_ctx: nn.Module, emb_tgt: nn.Module, emb_buf: nn.Module, head: GaussianHead,
                             xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor, yt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """One-pass TNPA LL: build full [C|TY|T0] and apply TNPA mask once; decode all T queries together."""
    B, T, dx = xt.shape
    dy = yt.shape[-1]
    device = xt.device
    # Context and full TY/T0 streams
    xc_feats = emb_ctx(torch.cat([xc, yc], dim=-1))       # [B,Nc,D]
    ty_tok  = emb_buf(torch.cat([xt, yt], dim=-1))        # [B,T,D]
    zeros   = torch.zeros_like(yt)
    t0_tok  = emb_buf(torch.cat([xt, zeros], dim=-1))     # [B,T,D]
    mem = torch.cat([xc_feats, ty_tok, t0_tok], dim=1)    # [B,Nc+2T,D]
    Nc = xc_feats.shape[1]
    mask = build_tnpa_mask_fast(Nc, T, device=mem.device) # [L,L]
    E = ctx_enc.encode(mem, attn_mask=mask)               # [B,L,D]
    Kc, Vc = build_decoder_ctx_kv(E, dec)
    Q = emb_tgt(xt)                                       # [B,T,D]
    h = dec(Q, Kc, Vc, None, None)                        # no CA mask needed; encoder enforces TNPA
    mu, sigma = head(h)
    return mu, sigma


# ------------------------------
# Methods: log-likelihood
# ------------------------------

@torch.no_grad()
def method5_mvnd_ll(ctx_enc: ContextEncoder, dec: Decoder, emb_ctx: nn.Module, emb_tgt: nn.Module, mv_head: MvGaussianHead,
                    xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor, yt: torch.Tensor) -> torch.Tensor:
    B, T, dy = yt.shape
    xc_feats = emb_ctx(torch.cat([xc, yc], dim=-1))
    E = ctx_enc.encode(xc_feats)
    Kc, Vc = build_decoder_ctx_kv(E, dec)
    xt_feats = emb_tgt(xt)
    h = dec(xt_feats, Kc, Vc, None, None)
    mu_vec, L = mv_head(h)
    yt_vec = yt.view(B, -1)
    # log prob under MVN parameterized by L (tril)
    # Solve for Mahalanobis in fp32 on CUDA (half not supported for triangular solve)
    diff = yt_vec - mu_vec
    L32 = L.float()
    diff32 = diff.float()
    y = torch.linalg.solve_triangular(L32, diff32.unsqueeze(-1), upper=False).squeeze(-1)
    maha = (y ** 2).sum(-1)
    log_det = torch.log(torch.diagonal(L32, dim1=-2, dim2=-1)).sum(-1)
    const = T * dy * math.log(2 * math.pi)
    ll = -0.5 * (const + maha) - log_det
    return ll.to(yt.dtype)


# ------------------------------
# Compiled builders + adapters
# ------------------------------

class CompiledFiveMethodAdapter(nn.Module):
    """Sampling adapter for M1..M5 compiled paths.
    m1: TNP-D independent; m2: TNP-D AR (re-encode); m3: ours AR buffer;
    m4: TNPA AR masked re-encode; m5: TNP-ND (MVN) decode-once.
    """

    def __init__(self, method_id: str, ctx_enc: ContextEncoder, dec: Decoder,
                 emb_ctx: nn.Module, emb_tgt: nn.Module, emb_buf: nn.Module,
                 head: GaussianHead, mv_head: Optional[MvGaussianHead] = None):
        super().__init__()
        assert method_id in {"m1", "m2", "m3", "m4", "m5"}
        self.method_id = method_id
        self.ctx_enc, self.dec = ctx_enc, dec
        self.emb_ctx, self.emb_tgt, self.emb_buf = emb_ctx, emb_tgt, emb_buf
        self.head, self.mv_head = head, mv_head
        self._compiled: Dict[Tuple[int, int, int], Any] = {}

    def _key(self, Nc: int, T: int, B_eff: int) -> Tuple[int, int, int]:
        return (Nc, T, B_eff)

    def _ensure_runner(self, Nc: int, T: int, B_eff: int, dx: int, dy: int):
        k = self._key(Nc, T, B_eff)
        if k in self._compiled:
            return self._compiled[k]
        # Build compiled function that accepts real tensors (xc,yc,xt)
        if self.method_id == "m1":
            def fn(xc_in, yc_in, xt_in):
                mu, sigma = method1_independent_with_embed_head(self.ctx_enc, self.dec, self.emb_ctx, self.emb_tgt, self.head, xc_in, yc_in, xt_in)
                return mu, sigma
        elif self.method_id == "m5":
            assert self.mv_head is not None, "mv_head is required for m5"
            def fn(xc_in, yc_in, xt_in):
                xc, yc = xc_in, yc_in
                xc_feats = self.emb_ctx(torch.cat([xc, yc], dim=-1))
                E = self.ctx_enc.encode(xc_feats)
                Kc, Vc = build_decoder_ctx_kv(E, self.dec)
                xt_feats = self.emb_tgt(xt_in)
                h = self.dec(xt_feats, Kc, Vc, None, None)
                mu_vec, L = self.mv_head(h)  # [B, T*dy], [B, T*dy, T*dy]
                return mu_vec, L
        elif self.method_id == "m2":
            provider = TNPDRencodeProvider(self.ctx_enc, self.emb_ctx, self.emb_buf, self.dec)
            def fn(xc_in, yc_in, xt_in):
                mu, sigma = unified_ar_driver(provider, self.dec, self.emb_tgt, self.head, self.emb_ctx, xc_in, yc_in, xt_in, dy=dy)
                return mu, sigma
        elif self.method_id == "m3":
            provider = ARBufferProvider(self.ctx_enc, self.emb_ctx, self.emb_buf, self.dec)
            def fn(xc_in, yc_in, xt_in):
                mu, sigma = unified_ar_driver(provider, self.dec, self.emb_tgt, self.head, self.emb_ctx, xc_in, yc_in, xt_in, dy=dy)
                return mu, sigma
        else:  # m4
            provider = TNPAReencodeProvider(self.ctx_enc, self.emb_ctx, self.emb_buf, self.dec)
            def fn(xc_in, yc_in, xt_in):
                mu, sigma = unified_ar_driver(provider, self.dec, self.emb_tgt, self.head, self.emb_ctx, xc_in, yc_in, xt_in, dy=dy)
                return mu, sigma
        cfn = torch.compile(fn, fullgraph=True, dynamic=False, mode=COMPILE_MODE)
        self._compiled[k] = cfn
        return cfn

    # Entry points expected by the sampling benchmark harness
    def predict(self, xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor, num_samples: int):
        # independent path (m1: diagonal; m5: multivariate) — return S samples as in notebooks
        Nc = xc.shape[1]
        T = xt.shape[1]
        B_eff = 1
        dx = xt.shape[-1]
        dy = yc.shape[-1]
        runner = self._ensure_runner(Nc, T, B_eff, dx, dy)
        if self.method_id == "m5":
            mu_vec, L = runner(xc, yc, xt)
            # sample S from MVN with Cholesky L
            B = mu_vec.shape[0]
            Dtot = T * dy
            eps = torch.randn(num_samples, B, Dtot, device=mu_vec.device, dtype=mu_vec.dtype)
            # Compute y_delta = (L @ eps^T)^T with proper batching
            y_delta_BDS = torch.matmul(L, eps.permute(1, 2, 0))  # [B,Dtot,S]
            y_delta_SBD = y_delta_BDS.permute(2, 0, 1)           # [S,B,Dtot]
            y = (mu_vec.unsqueeze(0) + y_delta_SBD).view(num_samples, B, T, dy)
            return y.permute(1, 2, 0, 3).contiguous()
        else:
            mu, sigma = runner(xc, yc, xt)
            # Draw S samples to match sampling benchmark semantics
            y = torch.distributions.Normal(mu, sigma).sample([num_samples])  # [S,B,T,dy]
            return y.permute(1, 2, 0, 3).contiguous()

    def sample_joint_predictive(self, xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor, num_samples: int) -> torch.Tensor:
        # M2/M3/M4: run AR with per-step sampling. Optimize embeddings by computing once at B0 and broadcasting to S where safe.
        B0, T, dx = xt.shape
        dy = yc.shape[-1]
        device = xt.device
        dtype = xt.dtype

        if self.method_id == "m5":
            # MVN sampling (decode once) as before
            Nc = xc.shape[1]
            runner = self._ensure_runner(Nc, T, 1, dx, dy)
            mu_vec, L = runner(xc, yc, xt)
            B = mu_vec.shape[0]
            Dtot = T * dy
            eps = torch.randn(num_samples, B, Dtot, device=mu_vec.device, dtype=mu_vec.dtype)
            y_delta_BDS = torch.matmul(L, eps.permute(1, 2, 0))  # [B,Dtot,S]
            y_delta_SBD = y_delta_BDS.permute(2, 0, 1)
            y = (mu_vec.unsqueeze(0) + y_delta_SBD).view(num_samples, B, T, dy)
            return y.permute(1, 2, 0, 3).contiguous()

        S = num_samples
        # Choose provider per method
        if self.method_id == "m2":
            provider: MemoryProvider = TNPDRencodeProvider(self.ctx_enc, self.emb_ctx, self.emb_buf, self.dec)
        elif self.method_id == "m3":
            provider = ARBufferProvider(self.ctx_enc, self.emb_ctx, self.emb_buf, self.dec)
        else:  # m4
            provider = TNPAReencodeProvider(self.ctx_enc, self.emb_ctx, self.emb_buf, self.dec)

        out = torch.empty(S, B0, T, dy, device=device, dtype=dtype)

        # Optimized path when B0==1: embed context once and broadcast; embed targets once and broadcast each step
        if B0 == 1:
            # Context/target embeddings at base batch
            xc_feats_base = self.emb_ctx(torch.cat([xc, yc], dim=-1))   # [1,Nc,D]
            xt_feats_base = self.emb_tgt(xt)                            # [1,T,D]
            if self.method_id in {"m2", "m4"}:
                # For re-encode methods, memory is [S, ...]; broadcast context to S
                provider.prepare_context(xc_feats_base.expand(S, -1, -1).contiguous())
            else:
                provider.prepare_context(xc_feats_base)
            provider.start_sequence(S, T, xt_feats_base)

            for t in range(T):
                # Broadcast per-step target embedding to S rows
                h_in = xt_feats_base[:, t:t+1, :].expand(S, -1, -1).contiguous()  # [S,1,D]
                Kc, Vc, Kb, Vb = provider.memory_at(t)
                h = self.dec(h_in, Kc, Vc, Kb, Vb)                                  # [S,1,D]
                mu_t, sigma_t = self.head(h)                                        # [S,1,dy]
                y_t = torch.distributions.Normal(mu_t, sigma_t).sample()
                out[:, 0:1, t:t+1, :] = y_t.unsqueeze(1)
                # Grow buffer/memory with S samples
                x_t_S = xt[:, t:t+1, :].expand(S, -1, -1).contiguous()               # [S,1,dx]
                provider.observe(t, x_t_S, y_t, h, self.dec)
            return out.permute(1, 2, 0, 3).contiguous()

        # Fallback: expand to S and run (general B0>1 case)
        xcS = xc.expand(S, -1, -1).contiguous()
        ycS = yc.expand(S, -1, -1).contiguous()
        xtS = xt.expand(S, -1, -1).contiguous()
        xc_feats = self.emb_ctx(torch.cat([xcS, ycS], dim=-1))
        provider.prepare_context(xc_feats)
        xt_feats = self.emb_tgt(xtS)
        provider.start_sequence(S, T, xt_feats)
        for t in range(T):
            x_t = xtS[:, t:t+1, :]
            Kc, Vc, Kb, Vb = provider.memory_at(t)
            h_in = self.emb_tgt(x_t)
            h = self.dec(h_in, Kc, Vc, Kb, Vb)
            mu_t, sigma_t = self.head(h)
            y_t = torch.distributions.Normal(mu_t, sigma_t).sample()
            out[:, 0:1, t:t+1, :] = y_t.unsqueeze(1)
            provider.observe(t, x_t, y_t, h, self.dec)
        return out.permute(1, 2, 0, 3).contiguous()


class CompiledFiveMethodLLAdapter(nn.Module):
    """Compiled LL for M1..M5, with dynamic B (num_perms) for AR methods (m2/m3/m4).
    We compile per (Nc, T) and reuse graphs across runs.
    """

    def __init__(self, method_id: str, ctx_enc: ContextEncoder, dec: Decoder,
                 emb_ctx: nn.Module, emb_tgt: nn.Module, emb_buf: nn.Module,
                 head_diag: Optional[GaussianHead], mv_head: Optional[MvGaussianHead],
                 ):  # head_tnpa omitted for brevity in this consolidated port
        super().__init__()
        assert method_id in {"m1", "m2", "m3", "m4", "m5"}
        self.method_id = method_id
        self.ctx_enc, self.dec = ctx_enc, dec
        self.emb_ctx, self.emb_tgt, self.emb_buf = emb_ctx, emb_tgt, emb_buf
        self.head_diag, self.mv_head = head_diag, mv_head
        self._compiled: Dict[Tuple[int, int, int], Any] = {}

    def _key(self, Nc: int, T: int, B_eff: int) -> Tuple[int, int, int]:
        return (Nc, T, B_eff)

    def _ensure_runner(self, Nc: int, T: int, B_eff: int, dx: int, dy: int):
        k = self._key(Nc, T, B_eff)
        if k in self._compiled:
            return self._compiled[k]
        if self.method_id == "m1":
            def fn(xc_in, yc_in, xt_in, yt_in):
                mu, sigma = method1_independent_with_embed_head(self.ctx_enc, self.dec, self.emb_ctx, self.emb_tgt, self.head_diag, xc_in, yc_in, xt_in)
                ll = (-(0.5 * ((yt_in - mu) / sigma) ** 2 + torch.log(sigma * math.sqrt(2 * math.pi))).sum(dim=(-1, -2)))
                return ll  # [B]
        elif self.method_id == "m5":
            def fn(xc_in, yc_in, xt_in, yt_in):
                return method5_mvnd_ll(self.ctx_enc, self.dec, self.emb_ctx, self.emb_tgt, self.mv_head, xc_in, yc_in, xt_in, yt_in)
        else:
            # AR methods (m2/m3/m4)
            if self.method_id == "m2":
                provider = TNPDRencodeProvider(self.ctx_enc, self.emb_ctx, self.emb_buf, self.dec)
                def fn(xc_in, yc_in, xt_in, yt_in):
                    mu, sigma = unified_ar_ll_driver(provider, self.dec, self.emb_tgt, self.head_diag, self.emb_ctx,
                                                     xc_in, yc_in, xt_in, yt_in, dy=dy)
                    ll = (-(0.5 * ((yt_in - mu) / sigma) ** 2 + torch.log(sigma * math.sqrt(2 * math.pi))).sum(dim=(-1, -2)))
                    return ll  # [B]
            elif self.method_id == "m3":
                def fn(xc_in, yc_in, xt_in, yt_in):
                    mu, sigma = method3_arbuffer_ll_parallel(self.ctx_enc, self.dec, self.emb_ctx, self.emb_tgt, self.emb_buf,
                                                             self.head_diag, xc_in, yc_in, xt_in, yt_in)
                    ll = (-(0.5 * ((yt_in - mu) / sigma) ** 2 + torch.log(sigma * math.sqrt(2 * math.pi))).sum(dim=(-1, -2)))
                    return ll  # [B]
            else:  # m4 TNPA
                def fn(xc_in, yc_in, xt_in, yt_in):
                    mu, sigma = method4_tnpa_ll_parallel(self.ctx_enc, self.dec, self.emb_ctx, self.emb_tgt, self.emb_buf,
                                                          self.head_diag, xc_in, yc_in, xt_in, yt_in)
                    ll = (-(0.5 * ((yt_in - mu) / sigma) ** 2 + torch.log(sigma * math.sqrt(2 * math.pi))).sum(dim=(-1, -2)))
                    return ll  # [B]

        cfn = torch.compile(fn, fullgraph=True, dynamic=False, mode=COMPILE_MODE)
        self._compiled[k] = cfn
        return cfn

    def eval_log_joint_likelihood(self, xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor, yt: torch.Tensor, num_perms: int = 1) -> torch.Tensor:
        Nc = xc.shape[1]
        T = xt.shape[1]
        dx = xt.shape[-1]
        dy = yt.shape[-1]
        B_eff = num_perms if self.method_id in {"m2", "m3", "m4"} else 1
        runner = self._ensure_runner(Nc, T, B_eff, dx, dy)
        # Expand batch for AR methods so runtime scales with num_perms like the baselines
        if self.method_id in {"m2", "m3", "m4"} and num_perms > 1:
            xc_in = xc.expand(num_perms, -1, -1).contiguous()
            yc_in = yc.expand(num_perms, -1, -1).contiguous()
            xt_in = xt.expand(num_perms, -1, -1).contiguous()
            yt_in = yt.expand(num_perms, -1, -1).contiguous()
        else:
            xc_in, yc_in, xt_in, yt_in = xc, yc, xt, yt
        out = runner(xc_in, yc_in, xt_in, yt_in)
        # Standardize shape to [num_perms] vector (tile if needed)
        if self.method_id in {"m2", "m3", "m4"}:
            return out
        else:
            return out.repeat(num_perms)


# ------------------------------
# Triton fast-path CA (optional)
# ------------------------------

try:
    import triton  # type: ignore
    import triton.language as tl  # type: ignore
    TRITON_AVAILABLE = True
except Exception:
    TRITON_AVAILABLE = False


@triton.autotune(
    configs=[
        triton.Config({'BLOCK_K': 64},  num_warps=4, num_stages=2),
        triton.Config({'BLOCK_K': 128}, num_warps=4, num_stages=3),
        triton.Config({'BLOCK_K': 256}, num_warps=8, num_stages=3),
    ],
    key=['Dh']
)
@triton.jit
def _ca_lq1_fwd(Q, K, V, Y,
                stride_q_b, stride_q_h, stride_q_lq, stride_q_d,
                stride_k_b, stride_k_h, stride_k_l, stride_k_d,
                stride_v_b, stride_v_h, stride_v_l, stride_v_d,
                stride_y_b, stride_y_h, stride_y_lq, stride_y_d,
                B: tl.constexpr, H: tl.constexpr, Lq: tl.constexpr, Lk, Dh: tl.constexpr,
                BLOCK_K: tl.constexpr):
    pid_b = tl.program_id(0)
    pid_h = tl.program_id(1)

    off_d = tl.arange(0, Dh)

    # Load q vector (Lq must be 1)
    q = tl.load(Q + pid_b * stride_q_b + pid_h * stride_q_h + 0 * stride_q_lq + off_d * stride_q_d)
    qf = q.to(tl.float32)
    Dh_f32 = tl.full([1], Dh, dtype=tl.float32)
    scale = 1.0 / tl.sqrt(Dh_f32)

    # Pass 1: max logits (mask invalid rows)
    m = tl.full([1], -float('inf'), dtype=tl.float32)
    for start in range(0, Lk, BLOCK_K):
        idx_k = start + tl.arange(0, BLOCK_K)
        mask_k = idx_k < Lk
        k_block = tl.load(K + pid_b * stride_k_b + pid_h * stride_k_h + idx_k[:, None] * stride_k_l + off_d[None, :] * stride_k_d,
                          mask=mask_k[:, None], other=0.0)
        logits = tl.sum(k_block.to(tl.float32) * qf[None, :], axis=1) * scale
        logits = tl.where(mask_k, logits, -float('inf'))
        m = tl.maximum(m, tl.max(logits, axis=0))

    # Pass 2: denom and output (mask invalid rows)
    denom = tl.zeros([1], dtype=tl.float32)
    out = tl.zeros([Dh], dtype=tl.float32)
    for start in range(0, Lk, BLOCK_K):
        idx_k = start + tl.arange(0, BLOCK_K)
        mask_k = idx_k < Lk
        k_block = tl.load(K + pid_b * stride_k_b + pid_h * stride_k_h + idx_k[:, None] * stride_k_l + off_d[None, :] * stride_k_d,
                          mask=mask_k[:, None], other=0.0)
        v_block = tl.load(V + pid_b * stride_v_b + pid_h * stride_v_h + idx_k[:, None] * stride_v_l + off_d[None, :] * stride_v_d,
                          mask=mask_k[:, None], other=0.0)
        logits = tl.sum(k_block.to(tl.float32) * qf[None, :], axis=1) * scale
        logits = tl.where(mask_k, logits, -float('inf'))
        w = tl.exp(logits - m)
        denom += tl.sum(w, axis=0)
        out += tl.sum(v_block.to(tl.float32) * w[:, None], axis=0)

    y = (out / denom).to(q.dtype)
    tl.store(Y + pid_b * stride_y_b + pid_h * stride_y_h + 0 * stride_y_lq + off_d * stride_y_d, y)


def _triton_ca_lq1_forward(Qh: torch.Tensor, K_all_h: torch.Tensor, V_all_h: torch.Tensor) -> torch.Tensor:
    """Launch Triton kernel for Lq=1. Shapes [B,H,1,Dh], [B,H,Lk,Dh]."""
    assert TRITON_AVAILABLE, "Triton is required for this path"
    B, H, Lq, Dh = Qh.shape
    assert Lq == 1, f"Triton fast-path requires Lq=1, got {Lq}"
    Lk = K_all_h.shape[2]
    Yh = torch.empty((B, H, 1, Dh), device=Qh.device, dtype=Qh.dtype)
    grid = (B, H)
    if os.environ.get("FAST_TRITON_TUNE", "1") == "1":
        _ca_lq1_fwd[grid](
            Qh, K_all_h, V_all_h, Yh,
            Qh.stride(0), Qh.stride(1), Qh.stride(2), Qh.stride(3),
            K_all_h.stride(0), K_all_h.stride(1), K_all_h.stride(2), K_all_h.stride(3),
            V_all_h.stride(0), V_all_h.stride(1), V_all_h.stride(2), V_all_h.stride(3),
            Yh.stride(0), Yh.stride(1), Yh.stride(2), Yh.stride(3),
            B=B, H=H, Lq=Lq, Lk=Lk, Dh=Dh,
        )
    else:
        # Fixed BLOCK_K to avoid per-step autotune. Tune once if needed and set here.
        _ca_lq1_fwd_fixed[grid](
            Qh, K_all_h, V_all_h, Yh,
            Qh.stride(0), Qh.stride(1), Qh.stride(2), Qh.stride(3),
            K_all_h.stride(0), K_all_h.stride(1), K_all_h.stride(2), K_all_h.stride(3),
            V_all_h.stride(0), V_all_h.stride(1), V_all_h.stride(2), V_all_h.stride(3),
            Yh.stride(0), Yh.stride(1), Yh.stride(2), Yh.stride(3),
            B=B, H=H, Lq=Lq, Lk=Lk, Dh=Dh,
            BLOCK_K=128,
        )
    return Yh


@triton.autotune(
    configs=[
        triton.Config({'BLOCK_K': 64},  num_warps=4, num_stages=2),
        triton.Config({'BLOCK_K': 128}, num_warps=4, num_stages=3),
        triton.Config({'BLOCK_K': 256}, num_warps=8, num_stages=3),
    ],
    key=['Dh']
)
@triton.jit
def _ca_lq1_shared_kv(Q, Kc, Vc, Kb, Vb, Y,
                      stride_q_b, stride_q_h, stride_q_lq, stride_q_d,
                      stride_kc_b, stride_kc_h, stride_kc_l, stride_kc_d,
                      stride_vc_b, stride_vc_h, stride_vc_l, stride_vc_d,
                      stride_kb_b, stride_kb_h, stride_kb_l, stride_kb_d,
                      stride_vb_b, stride_vb_h, stride_vb_l, stride_vb_d,
                      stride_y_b, stride_y_h, stride_y_lq, stride_y_d,
                      B: tl.constexpr, H: tl.constexpr, Lq: tl.constexpr, Nc, Nb, Dh: tl.constexpr,
                      BLOCK_K: tl.constexpr):
    pid_b = tl.program_id(0)
    pid_h = tl.program_id(1)

    off_d = tl.arange(0, Dh)

    q = tl.load(Q + pid_b * stride_q_b + pid_h * stride_q_h + 0 * stride_q_lq + off_d * stride_q_d)
    qf = q.to(tl.float32)
    Dh_f32 = tl.full([1], Dh, dtype=tl.float32)
    scale = 1.0 / tl.sqrt(Dh_f32)

    # Pass 1: max over ctx (mask invalid rows)
    m = tl.full([1], -float('inf'), dtype=tl.float32)
    for start in range(0, Nc, BLOCK_K):
        idx = start + tl.arange(0, BLOCK_K)
        msk = idx < Nc
        kc_blk = tl.load(Kc + pid_b*stride_kc_b + pid_h*stride_kc_h + idx[:, None]*stride_kc_l + off_d[None,:]*stride_kc_d,
                         mask=msk[:, None], other=0.0)
        logits = tl.sum(kc_blk.to(tl.float32) * qf[None, :], axis=1) * scale
        logits = tl.where(msk, logits, -float('inf'))
        m = tl.maximum(m, tl.max(logits, axis=0))
    # Max over buf (mask invalid rows)
    for start in range(0, Nb, BLOCK_K):
        idx = start + tl.arange(0, BLOCK_K)
        msk = idx < Nb
        kb_blk = tl.load(Kb + pid_b*stride_kb_b + pid_h*stride_kb_h + idx[:, None]*stride_kb_l + off_d[None,:]*stride_kb_d,
                         mask=msk[:, None], other=0.0)
        logits = tl.sum(kb_blk.to(tl.float32) * qf[None, :], axis=1) * scale
        logits = tl.where(msk, logits, -float('inf'))
        m = tl.maximum(m, tl.max(logits, axis=0))

    denom = tl.zeros([1], dtype=tl.float32)
    out = tl.zeros([Dh], dtype=tl.float32)
    # Denom/out over ctx (mask invalid rows)
    for start in range(0, Nc, BLOCK_K):
        idx = start + tl.arange(0, BLOCK_K)
        msk = idx < Nc
        kc_blk = tl.load(Kc + pid_b*stride_kc_b + pid_h*stride_kc_h + idx[:, None]*stride_kc_l + off_d[None,:]*stride_kc_d,
                         mask=msk[:, None], other=0.0)
        vc_blk = tl.load(Vc + pid_b*stride_vc_b + pid_h*stride_vc_h + idx[:, None]*stride_vc_l + off_d[None,:]*stride_vc_d,
                         mask=msk[:, None], other=0.0)
        logits = tl.sum(kc_blk.to(tl.float32) * qf[None, :], axis=1) * scale
        logits = tl.where(msk, logits, -float('inf'))
        w = tl.exp(logits - m)
        denom += tl.sum(w, axis=0)
        out += tl.sum(vc_blk.to(tl.float32) * w[:, None], axis=0)
    # Denom/out over buf (mask invalid rows)
    for start in range(0, Nb, BLOCK_K):
        idx = start + tl.arange(0, BLOCK_K)
        msk = idx < Nb
        kb_blk = tl.load(Kb + pid_b*stride_kb_b + pid_h*stride_kb_h + idx[:, None]*stride_kb_l + off_d[None,:]*stride_kb_d,
                         mask=msk[:, None], other=0.0)
        vb_blk = tl.load(Vb + pid_b*stride_vb_b + pid_h*stride_vb_h + idx[:, None]*stride_vb_l + off_d[None,:]*stride_vb_d,
                         mask=msk[:, None], other=0.0)
        logits = tl.sum(kb_blk.to(tl.float32) * qf[None, :], axis=1) * scale
        logits = tl.where(msk, logits, -float('inf'))
        w = tl.exp(logits - m)
        denom += tl.sum(w, axis=0)
        out += tl.sum(vb_blk.to(tl.float32) * w[:, None], axis=0)

    y = (out / denom).to(q.dtype)
    tl.store(Y + pid_b * stride_y_b + pid_h * stride_y_h + 0 * stride_y_lq + off_d * stride_y_d, y)


def _triton_ca_lq1_shared_kv_forward(Qh: torch.Tensor, Kc_h: torch.Tensor, Vc_h: torch.Tensor,
                                     Kb_h: torch.Tensor, Vb_h: torch.Tensor) -> torch.Tensor:
    assert TRITON_AVAILABLE, "Triton is required for this path"
    B, H, Lq, Dh = Qh.shape
    Nc = Kc_h.shape[2]
    Nb = Kb_h.shape[2]
    Yh = torch.empty((B, H, 1, Dh), device=Qh.device, dtype=Qh.dtype)
    grid = (B, H)
    if os.environ.get("FAST_TRITON_TUNE", "1") == "1":
        _ca_lq1_shared_kv[grid](
            Qh, Kc_h, Vc_h, Kb_h, Vb_h, Yh,
            Qh.stride(0), Qh.stride(1), Qh.stride(2), Qh.stride(3),
            Kc_h.stride(0), Kc_h.stride(1), Kc_h.stride(2), Kc_h.stride(3),
            Vc_h.stride(0), Vc_h.stride(1), Vc_h.stride(2), Vc_h.stride(3),
            Kb_h.stride(0), Kb_h.stride(1), Kb_h.stride(2), Kb_h.stride(3),
            Vb_h.stride(0), Vb_h.stride(1), Vb_h.stride(2), Vb_h.stride(3),
            Yh.stride(0), Yh.stride(1), Yh.stride(2), Yh.stride(3),
            B=B, H=H, Lq=Lq, Nc=Nc, Nb=Nb, Dh=Dh,
        )
    else:
        _ca_lq1_shared_kv[grid](
            Qh, Kc_h, Vc_h, Kb_h, Vb_h, Yh,
            Qh.stride(0), Qh.stride(1), Qh.stride(2), Qh.stride(3),
            Kc_h.stride(0), Kc_h.stride(1), Kc_h.stride(2), Kc_h.stride(3),
            Vc_h.stride(0), Vc_h.stride(1), Vc_h.stride(2), Vc_h.stride(3),
            Kb_h.stride(0), Kb_h.stride(1), Kb_h.stride(2), Kb_h.stride(3),
            Vb_h.stride(0), Vb_h.stride(1), Vb_h.stride(2), Vb_h.stride(3),
            Yh.stride(0), Yh.stride(1), Yh.stride(2), Yh.stride(3),
            B=B, H=H, Lq=Lq, Nc=Nc, Nb=Nb, Dh=Dh,
            BLOCK_K=128,
        )
    return Yh
@triton.jit
def _ca_lq1_fwd_fixed(Q, K, V, Y,
                      stride_q_b, stride_q_h, stride_q_lq, stride_q_d,
                      stride_k_b, stride_k_h, stride_k_l, stride_k_d,
                      stride_v_b, stride_v_h, stride_v_l, stride_v_d,
                      stride_y_b, stride_y_h, stride_y_lq, stride_y_d,
                      B: tl.constexpr, H: tl.constexpr, Lq: tl.constexpr, Lk, Dh: tl.constexpr,
                      BLOCK_K: tl.constexpr):
    pid_b = tl.program_id(0)
    pid_h = tl.program_id(1)

    off_d = tl.arange(0, Dh)

    q = tl.load(Q + pid_b * stride_q_b + pid_h * stride_q_h + 0 * stride_q_lq + off_d * stride_q_d)
    qf = q.to(tl.float32)
    Dh_f32 = tl.full([1], Dh, dtype=tl.float32)
    scale = 1.0 / tl.sqrt(Dh_f32)

    m = tl.full([1], -float('inf'), dtype=tl.float32)
    for start in range(0, Lk, BLOCK_K):
        idx_k = start + tl.arange(0, BLOCK_K)
        mask_k = idx_k < Lk
        k_block = tl.load(K + pid_b * stride_k_b + pid_h * stride_k_h + idx_k[:, None] * stride_k_l + off_d[None, :] * stride_k_d,
                          mask=mask_k[:, None], other=0.0)
        logits = tl.sum(k_block.to(tl.float32) * qf[None, :], axis=1) * scale
        m = tl.maximum(m, tl.max(logits, axis=0))

    denom = tl.zeros([1], dtype=tl.float32)
    out = tl.zeros([Dh], dtype=tl.float32)
    for start in range(0, Lk, BLOCK_K):
        idx_k = start + tl.arange(0, BLOCK_K)
        mask_k = idx_k < Lk
        k_block = tl.load(K + pid_b * stride_k_b + pid_h * stride_k_h + idx_k[:, None] * stride_k_l + off_d[None, :] * stride_k_d,
                          mask=mask_k[:, None], other=0.0)
        v_block = tl.load(V + pid_b * stride_v_b + pid_h * stride_v_h + idx_k[:, None] * stride_v_l + off_d[None, :] * stride_v_d,
                          mask=mask_k[:, None], other=0.0)
        logits = tl.sum(k_block.to(tl.float32) * qf[None, :], axis=1) * scale
        w = tl.exp(logits - m)
        denom += tl.sum(w, axis=0)
        out += tl.sum(v_block.to(tl.float32) * w[:, None], axis=0)

    y = (out / denom).to(q.dtype)
    tl.store(Y + pid_b * stride_y_b + pid_h * stride_y_h + 0 * stride_y_lq + off_d * stride_y_d, y)


class CrossAttentionConcatTriton(CrossAttentionConcat):
    """Triton-only cross-attention. Raises if conditions not met."""
    def __init__(self, d_model: int, n_heads: int, *, verbose: bool = False, force_triton: bool = True):
        super().__init__(d_model, n_heads)
        if not TRITON_AVAILABLE:
            raise RuntimeError("Triton is required for CrossAttentionConcatTriton but is not available")
        self.verbose = verbose
        self.force_triton = force_triton

    def forward(self, x_q, Kc, Vc, Kb=None, Vb=None, attn_mask=None):
        if attn_mask is not None:
            raise RuntimeError("Triton CA does not support attn_mask; expected None")
        B, Lq, D = x_q.shape
        if Lq != 1:
            raise RuntimeError(f"Triton CA requires Lq==1, got {Lq}")
        # Projections and split to heads
        q = split_heads(self.q_proj(x_q), self.n_heads).contiguous()  # [B,H,1,Dh]

        # Convert context to per-head and collapse batch to [H,Nc,Dh]
        def to_head_ctx(K: Any, V: Any) -> tuple[torch.Tensor, torch.Tensor]:
            if isinstance(K, torch.Tensor) and K.dim() == 4 and K.shape[1] == self.n_heads:
                k_h = K
                v_h = V
            else:
                k_h = split_heads(self.k_proj(K), self.n_heads)
                v_h = split_heads(self.v_proj(V), self.n_heads)
            # collapse batch -> assume shared context across batch, take first row
            if k_h.shape[0] > 1:
                k_h = k_h[:1]
                v_h = v_h[:1]
            return k_h[0].contiguous(), v_h[0].contiguous()  # [H,Nc,Dh]

        # Convert buffer to per-head [B,H,Nb,Dh]
        def to_head_buf(K: Optional[Any], V: Optional[Any]) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
            if K is None or V is None:
                return None, None
            if isinstance(K, torch.Tensor) and K.dim() == 4 and K.shape[1] == self.n_heads:
                return K.contiguous(), V.contiguous()
            kb = split_heads(self.k_proj(K), self.n_heads)
            vb = split_heads(self.v_proj(V), self.n_heads)
            return kb.contiguous(), vb.contiguous()

        kc_h, vc_h = to_head_ctx(Kc, Vc)
        kb_h, vb_h = to_head_buf(Kb, Vb)

        # Delegate to notebook Triton kernel
        from scripts.fast_times import triton_m3_core as tm3
        Yh = tm3.ca_shared_kv_forward_autotuned(q, kc_h, vc_h, kb_h, vb_h)
        y = self.out(combine_heads(Yh))
        return y


def patch_decoder_with_triton(dec: Decoder, *, verbose: bool = False, force_triton: bool = False) -> Decoder:
    for blk in dec.blocks:
        old_ca = blk.ca
        # Create Triton CA on same device/dtype
        new_ca = CrossAttentionConcatTriton(old_ca.out.in_features, old_ca.n_heads, verbose=verbose, force_triton=True)
        # Copy projection weights for q,k,v,out to preserve numerics
        with torch.no_grad():
            new_ca.q_proj.weight.copy_(old_ca.q_proj.weight)
            if old_ca.q_proj.bias is not None and new_ca.q_proj.bias is not None:
                new_ca.q_proj.bias.copy_(old_ca.q_proj.bias)
            new_ca.k_proj.weight.copy_(old_ca.k_proj.weight)
            if old_ca.k_proj.bias is not None and new_ca.k_proj.bias is not None:
                new_ca.k_proj.bias.copy_(old_ca.k_proj.bias)
            new_ca.v_proj.weight.copy_(old_ca.v_proj.weight)
            if old_ca.v_proj.bias is not None and new_ca.v_proj.bias is not None:
                new_ca.v_proj.bias.copy_(old_ca.v_proj.bias)
            new_ca.out.weight.copy_(old_ca.out.weight)
            if old_ca.out.bias is not None and new_ca.out.bias is not None:
                new_ca.out.bias.copy_(old_ca.out.bias)
        # Match device/dtype
        ref_param = next(blk.ln1.parameters(), None)
        if ref_param is None:
            ref_param = next(dec.parameters())
        new_ca = new_ca.to(device=ref_param.device, dtype=ref_param.dtype)
        blk.ca = new_ca
    return dec


# ------------------------------
# M3 Triton Sampler (eager, AR buffer)
# ------------------------------

class M3TritonSamplerAdapter(nn.Module):
    """Our method (M3) sampler using Triton CA fast-path at each AR step.

    This mirrors the notebook: encode C once, then for t=1..T do
      h_t = Decoder(emb_tgt(x_t), Kc, Vc, Kb, Vb);
      (mu_t, sigma_t) = head(h_t);
      y_t ~ Normal(mu_t, sigma_t);
      Kb,Vb grow with emb_buf([x_t|y_t]) via dec.buf_kv_from_token.

    Returns samples of shape [B, T, S, dy].
    """

    def __init__(self, ctx_enc: ContextEncoder, dec: Decoder,
                 emb_ctx: nn.Module, emb_tgt: nn.Module, emb_buf: nn.Module,
                 head: GaussianHead):
        super().__init__()
        self.ctx_enc = ctx_enc
        self.dec = dec
        self.emb_ctx = emb_ctx
        self.emb_tgt = emb_tgt
        self.emb_buf = emb_buf
        self.head = head

    @torch.no_grad()
    def sample_joint_predictive(self, xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor, num_samples: int) -> torch.Tensor:
        B, T, dx = xt.shape
        dy = yc.shape[-1]
        device = xt.device
        dtype = xt.dtype

        # Context encode once
        xc_feats = self.emb_ctx(torch.cat([xc, yc], dim=-1))  # [B,Nc,dx+dy]->[B,Nc,D]
        E = self.ctx_enc.encode(xc_feats)
        Kc, Vc = self.ctx_enc.kv_from_encoded(E)

        # Output tensor
        out = torch.empty(B, T, num_samples, dy, device=device, dtype=dtype)

        for s in range(num_samples):
            Kb_list: List[torch.Tensor] = []
            Vb_list: List[torch.Tensor] = []
            for t in range(T):
                x_t = xt[:, t:t+1, :]                               # [B,1,dx]
                h_in = self.emb_tgt(x_t)                             # [B,1,D]
                if Kb_list:
                    Kb = torch.cat(Kb_list, dim=1)                  # [B,t,D]
                    Vb = torch.cat(Vb_list, dim=1)
                else:
                    Kb = Vb = None
                h = self.dec(h_in, Kc, Vc, Kb, Vb)                   # [B,1,D]
                mu_t, sigma_t = self.head(h)                         # [B,1,dy]
                y_t = torch.distributions.Normal(mu_t, sigma_t).sample()  # [B,1,dy]
                out[:, t:t+1, s:s+1, :] = y_t

                # Grow buffer with raw token [x_t|y_t]
                buf_tok = torch.cat([x_t, y_t], dim=-1)              # [B,1,dx+dy]
                buf_feat = self.emb_buf(buf_tok)                     # [B,1,D]
                Kb_t, Vb_t = self.dec.buf_kv_from_token(buf_feat)
                Kb_list.append(Kb_t)
                Vb_list.append(Vb_t)

        return out
