from __future__ import annotations

"""
Exact Triton M3 (AR buffer) path ported from the notebook snippet.

It provides:
- Triton kernels: ca_shared_kv_fwd, ca_shared_kv_fwd_lq1
- Autotuned wrapper: ca_shared_kv_forward_autotuned
- Minimal model blocks used by M3: embedders, encoder, decoder, head
- CrossAttentionConcat that routes to the Triton fast path when possible
- ARBufferProvider and a compiled adapter matching the notebook behavior

This file is intentionally self-contained to mirror the notebook’s M3 logic
without perturbing the existing compiled_core module.
"""

import os
import json
import gc
import warnings
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

# --------------------
# Triton detection
# --------------------
COMPILE_ENABLED_DEFAULT = os.environ.get("FAST_TIMES_COMPILE", "1").lower() not in {"0", "false", "no"}
COMPILE_MODE_DEFAULT = os.environ.get("FAST_TIMES_COMPILE_MODE", "reduce-overhead")
FULLGRAPH_DEFAULT = os.environ.get("FAST_TIMES_FULLGRAPH", "1").lower() not in {"0", "false", "no"}
FP32_LN_DEFAULT = os.environ.get("FAST_TIMES_FP32_LN", "1").lower() not in {"0", "false", "no"}

try:
    import triton  # type: ignore
    import triton.language as tl  # type: ignore
    HAS_TRITON = True
except Exception:  # pragma: no cover - runtime check
    HAS_TRITON = False


# --------------------
# Helpers
# --------------------

def split_heads(x: torch.Tensor, n_heads: int) -> torch.Tensor:
    # x: [B, L, D] -> [B, H, L, Dh]
    B, L, D = x.shape
    assert D % n_heads == 0, f"D={D} not divisible by n_heads={n_heads}"
    Dh = D // n_heads
    return x.view(B, L, n_heads, Dh).permute(0, 2, 1, 3).contiguous()


def combine_heads(y: torch.Tensor) -> torch.Tensor:
    # y: [B, H, L, Dh] -> [B, L, D]
    B, H, L, Dh = y.shape
    return y.permute(0, 2, 1, 3).contiguous().view(B, L, H * Dh)


class FFN(nn.Module):
    def __init__(self, d_in: int, d_hid: int, d_out: Optional[int] = None, act=nn.GELU, bias: bool = True):
        super().__init__()
        d_out = d_out or d_in
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hid, bias=bias),
            act(),
            nn.Linear(d_hid, d_out, bias=bias),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class FP32LayerNorm(nn.Module):
    """LayerNorm implemented explicitly in fp32 to avoid Dynamo/FakeTensor LN path issues.

    Keeps parameters in fp32; casts activations to fp32 for stats and affine; returns original dtype.
    """

    def __init__(self, normalized_shape: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(normalized_shape, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(normalized_shape, dtype=torch.float32))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        xf = x.float()
        mean = xf.mean(dim=-1, keepdim=True)
        var = xf.var(dim=-1, keepdim=True, unbiased=False)
        inv_std = torch.rsqrt(var + self.eps)
        y = (xf - mean) * inv_std
        y = y * self.weight + self.bias
        return y.to(dtype=x.dtype)


# --------------------
# Triton kernels (notebook)
# --------------------


@triton.autotune(
    configs=[
        triton.Config({'GROUP_B': 16, 'BLOCK_N_CTX': 64,  'BLOCK_N_BUF': 32}, num_warps=4, num_stages=2),
        triton.Config({'GROUP_B': 16, 'BLOCK_N_CTX': 128, 'BLOCK_N_BUF': 32}, num_warps=4, num_stages=3),
        triton.Config({'GROUP_B': 32, 'BLOCK_N_CTX': 64,  'BLOCK_N_BUF': 32}, num_warps=8, num_stages=3),
        triton.Config({'GROUP_B': 32, 'BLOCK_N_CTX': 128, 'BLOCK_N_BUF': 32}, num_warps=8, num_stages=3),
        # Wider context / buffer tiles for large Nc/Tbuf
        triton.Config({'GROUP_B': 16, 'BLOCK_N_CTX': 256, 'BLOCK_N_BUF': 32}, num_warps=8, num_stages=3),
        triton.Config({'GROUP_B': 32, 'BLOCK_N_CTX': 256, 'BLOCK_N_BUF': 32}, num_warps=8, num_stages=3),
        triton.Config({'GROUP_B': 16, 'BLOCK_N_CTX': 128, 'BLOCK_N_BUF': 64}, num_warps=8, num_stages=3),
        triton.Config({'GROUP_B': 32, 'BLOCK_N_CTX': 128, 'BLOCK_N_BUF': 64}, num_warps=8, num_stages=3),
    ],
    key=['B', 'H', 'Lq', 'D', 'Nctx', 'Tbuf'],
)
@triton.jit
def ca_shared_kv_fwd(
    Q_ptr,             # [B, H, Lq, D]
    Kc_ptr, Vc_ptr,    # [H, Nctx, D]     (shared across batch)
    Kb_ptr, Vb_ptr,    # [B, H, Tbuf, D]  (per-batch; Tbuf can be 0)
    Out_ptr,           # [B, H, Lq, D]

    B, H, Lq, D, Nctx, Tbuf,
    # strides
    stride_q_b, stride_q_h, stride_q_l, stride_q_d,
    stride_kc_h, stride_kc_n, stride_kc_d,
    stride_vc_h, stride_vc_n, stride_vc_d,
    stride_kb_b, stride_kb_h, stride_kb_n, stride_kb_d,
    stride_vb_b, stride_vb_h, stride_vb_n, stride_vb_d,
    stride_o_b, stride_o_h, stride_o_l, stride_o_d,

    scale,  # 1/sqrt(D)
    GROUP_B: tl.constexpr,
    BLOCK_N_CTX: tl.constexpr,
    BLOCK_N_BUF: tl.constexpr,
    BLOCK_D: tl.constexpr,   # == D
):
    pid_b = tl.program_id(0)  # batch-group id
    pid_h = tl.program_id(1)  # head id
    pid_l = tl.program_id(2)  # query index (0..Lq-1)

    offs_b = pid_b * GROUP_B + tl.arange(0, GROUP_B)           # [G]
    mask_b = offs_b < B

    offs_d = tl.arange(0, BLOCK_D)                              # [D]
    mask_d = offs_d < D
    if BLOCK_D % 16 == 0:
        tl.multiple_of(offs_d, 16)

    # ---- load Q [G,D] ----
    q_ptrs = (Q_ptr
              + offs_b[:, None] * stride_q_b
              + pid_h * stride_q_h
              + pid_l * stride_q_l
              + offs_d[None, :] * stride_q_d)
    mask_q = mask_b[:, None] & mask_d[None, :]
    q = tl.load(q_ptrs, mask=mask_q, other=0.0).to(tl.float32)  # [G,D]

    # online softmax accumulators
    m = tl.full((GROUP_B,), -float("inf"), tl.float32)
    l = tl.zeros((GROUP_B,), tl.float32)
    acc = tl.zeros((GROUP_B, BLOCK_D), tl.float32)              # [G,D]

    # ---- context tiles (shared) ----
    for n0 in range(0, Nctx, BLOCK_N_CTX):
        offs_n = n0 + tl.arange(0, BLOCK_N_CTX)
        mask_n = offs_n < Nctx

        kc_ptrs = (Kc_ptr + pid_h * stride_kc_h
                   + offs_n[:, None] * stride_kc_n
                   + offs_d[None, :] * stride_kc_d)
        vc_ptrs = (Vc_ptr + pid_h * stride_vc_h
                   + offs_n[:, None] * stride_vc_n
                   + offs_d[None, :] * stride_vc_d)

        # stream context through L2
        k_tile = tl.load(kc_ptrs, mask=mask_n[:, None] & mask_d[None, :],
                         other=0.0, cache_modifier=".cg").to(tl.float32)  # [Nc_blk,D]
        v_tile = tl.load(vc_ptrs, mask=mask_n[:, None] & mask_d[None, :],
                         other=0.0, cache_modifier=".cg").to(tl.float32)

        s = tl.dot(q, tl.trans(k_tile)) * scale          # [G, Nc_blk]
        s = tl.where(mask_n[None, :], s, -float("inf"))

        smax = tl.max(s, 1)
        m_new = tl.maximum(m, smax)
        alpha = tl.exp(m - m_new)
        s_exp = tl.exp(s - m_new[:, None])

        l = l * alpha + tl.sum(s_exp, 1)
        acc = acc * alpha[:, None] + tl.dot(s_exp, v_tile)      # [G,D]
        m = m_new

    # ---- buffer tiles (vectorized) ----
    for n0 in range(0, Tbuf, BLOCK_N_BUF):
        offs_n = n0 + tl.arange(0, BLOCK_N_BUF)
        mask_n = offs_n < Tbuf

        kb_ptrs = (Kb_ptr
                   + offs_b[:, None, None] * stride_kb_b
                   + pid_h * stride_kb_h
                   + offs_n[None, :, None] * stride_kb_n
                   + offs_d[None, None, :] * stride_kb_d)
        vb_ptrs = (Vb_ptr
                   + offs_b[:, None, None] * stride_vb_b
                   + pid_h * stride_vb_h
                   + offs_n[None, :, None] * stride_vb_n
                   + offs_d[None, None, :] * stride_vb_d)

        mask_3d = mask_b[:, None, None] & mask_n[None, :, None] & mask_d[None, None, :]
        kbt = tl.load(kb_ptrs, mask=mask_3d, other=0.0).to(tl.float32)  # [G,Nb_blk,D]
        vbt = tl.load(vb_ptrs, mask=mask_3d, other=0.0).to(tl.float32)

        # scores via broadcast-sum (vectorized across GROUP_B)
        s_buf = tl.sum(kbt * q[:, None, :], axis=2) * scale              # [G,Nb_blk]
        s_buf = tl.where(mask_b[:, None] & mask_n[None, :], s_buf, -float("inf"))

        smax = tl.max(s_buf, 1)
        m_new = tl.maximum(m, smax)
        alpha = tl.exp(m - m_new)
        s_exp = tl.exp(s_buf - m_new[:, None])                           # [G,Nb_blk]

        l = l * alpha + tl.sum(s_exp, 1)
        acc = acc * alpha[:, None] + tl.sum(s_exp[:, :, None] * vbt, axis=1)
        m = m_new

    out = acc / l[:, None]                                              # [G,D]
    out_ptrs = (Out_ptr
                + offs_b[:, None] * stride_o_b
                + pid_h * stride_o_h
                + pid_l * stride_o_l
                + offs_d[None, :] * stride_o_d)
    tl.store(out_ptrs, out.to(out_ptrs.dtype.element_ty), mask=mask_q)


@triton.autotune(
    configs=[
        triton.Config({'GROUP_B': 16, 'BLOCK_N_CTX': 64,  'BLOCK_N_BUF': 32}, num_warps=4, num_stages=2),
        triton.Config({'GROUP_B': 32, 'BLOCK_N_CTX': 64,  'BLOCK_N_BUF': 32}, num_warps=8, num_stages=2),
        triton.Config({'GROUP_B': 32, 'BLOCK_N_CTX': 128, 'BLOCK_N_BUF': 32}, num_warps=8, num_stages=3),
    ],
    key=['B', 'H', 'D', 'Nctx', 'Tbuf'],
)
@triton.jit
def ca_shared_kv_fwd_lq1(
    Q_ptr,             # [B, H, 1, D]
    Kc_ptr, Vc_ptr,    # [H, Nctx, D]
    Kb_ptr, Vb_ptr,    # [B, H, Tbuf, D]
    Out_ptr,           # [B, H, 1, D]

    B, H, D, Nctx, Tbuf,
    # strides
    stride_q_b, stride_q_h, stride_q_l, stride_q_d,
    stride_kc_h, stride_kc_n, stride_kc_d,
    stride_vc_h, stride_vc_n, stride_vc_d,
    stride_kb_b, stride_kb_h, stride_kb_n, stride_kb_d,
    stride_vb_b, stride_vb_h, stride_vb_n, stride_vb_d,
    stride_o_b, stride_o_h, stride_o_l, stride_o_d,

    scale,
    GROUP_B: tl.constexpr,
    BLOCK_N_CTX: tl.constexpr,
    BLOCK_N_BUF: tl.constexpr,
    BLOCK_D: tl.constexpr,     # == D
):
    pid_b = tl.program_id(0)
    pid_h = tl.program_id(1)

    offs_b = pid_b * GROUP_B + tl.arange(0, GROUP_B)           # [G]
    mask_b = offs_b < B

    offs_d = tl.arange(0, BLOCK_D)                              # [D]
    mask_d = offs_d < D
    if BLOCK_D % 16 == 0:
        tl.multiple_of(offs_d, 16)

    # ---- load Q at l=0 : [G,D] ----
    q_ptrs = (Q_ptr
              + offs_b[:, None] * stride_q_b
              + pid_h * stride_q_h
              # + 0 * stride_q_l
              + offs_d[None, :] * stride_q_d)
    mask_q = mask_b[:, None] & mask_d[None, :]
    q = tl.load(q_ptrs, mask=mask_q, other=0.0).to(tl.float32)  # [G,D]

    # online softmax accumulators
    m = tl.full((GROUP_B,), -float("inf"), tl.float32)
    l = tl.zeros((GROUP_B,), tl.float32)
    acc = tl.zeros((GROUP_B, BLOCK_D), tl.float32)              # [G,D]

    # ---- context tiles (shared across rows) ----
    for n0 in range(0, Nctx, BLOCK_N_CTX):
        offs_n = n0 + tl.arange(0, BLOCK_N_CTX)
        mask_n = offs_n < Nctx

        kc_ptrs = (Kc_ptr + pid_h * stride_kc_h
                   + offs_n[:, None] * stride_kc_n
                   + offs_d[None, :] * stride_kc_d)
        vc_ptrs = (Vc_ptr + pid_h * stride_vc_h
                   + offs_n[:, None] * stride_vc_n
                   + offs_d[None, :] * stride_vc_d)

        k_tile = tl.load(kc_ptrs, mask=mask_n[:, None] & mask_d[None, :],
                         other=0.0, cache_modifier=".cg").to(tl.float32)  # [Nc_blk,D]
        v_tile = tl.load(vc_ptrs, mask=mask_n[:, None] & mask_d[None, :],
                         other=0.0, cache_modifier=".cg").to(tl.float32)

        # use tl.dot for context
        s = tl.dot(q, tl.trans(k_tile)) * scale                # [G, Nc_blk]
        s = tl.where(mask_n[None, :], s, -float("inf"))

        smax = tl.max(s, 1)
        m_new = tl.maximum(m, smax)
        alpha = tl.exp(m - m_new)
        s_exp = tl.exp(s - m_new[:, None])

        l = l * alpha + tl.sum(s_exp, 1)
        acc = acc * alpha[:, None] + tl.dot(s_exp, v_tile)     # [G,D]
        m = m_new

    # ---- buffer tiles (vectorized across GROUP_B) ----
    for n0 in range(0, Tbuf, BLOCK_N_BUF):
        offs_n = n0 + tl.arange(0, BLOCK_N_BUF)
        mask_n = offs_n < Tbuf

        kb_ptrs = (Kb_ptr
                   + offs_b[:, None, None] * stride_kb_b
                   + pid_h * stride_kb_h
                   + offs_n[None, :, None] * stride_kb_n
                   + offs_d[None, None, :] * stride_kb_d)
        vb_ptrs = (Vb_ptr
                   + offs_b[:, None, None] * stride_vb_b
                   + pid_h * stride_vb_h
                   + offs_n[None, :, None] * stride_vb_n
                   + offs_d[None, None, :] * stride_vb_d)

        mask_3d = mask_b[:, None, None] & mask_n[None, :, None] & mask_d[None, None, :]
        kbt = tl.load(kb_ptrs, mask=mask_3d, other=0.0).to(tl.float32)  # [G,Nb_blk,D]
        vbt = tl.load(vb_ptrs, mask=mask_3d, other=0.0).to(tl.float32)

        # vectorized per-group score: sum over D
        s_buf = tl.sum(kbt * q[:, None, :], axis=2) * scale              # [G,Nb_blk]
        s_buf = tl.where(mask_b[:, None] & mask_n[None, :], s_buf, -float("inf"))

        smax = tl.max(s_buf, 1)
        m_new = tl.maximum(m, smax)
        alpha = tl.exp(m - m_new)
        s_exp = tl.exp(s_buf - m_new[:, None])                           # [G,Nb_blk]

        l = l * alpha + tl.sum(s_exp, 1)
        acc = acc * alpha[:, None] + tl.sum(s_exp[:, :, None] * vbt, axis=1)
        m = m_new

    out = acc / l[:, None]                                               # [G,D]
    out_ptrs = (Out_ptr
                + offs_b[:, None] * stride_o_b
                + pid_h * stride_o_h
                # + 0 * stride_o_l
                + offs_d[None, :] * stride_o_d)
    tl.store(out_ptrs, out.to(out_ptrs.dtype.element_ty), mask=mask_q)


def ca_shared_kv_forward_autotuned(Q: torch.Tensor,
                                   K_ctx: torch.Tensor,
                                   V_ctx: torch.Tensor,
                                   K_buf: Optional[torch.Tensor] = None,
                                   V_buf: Optional[torch.Tensor] = None) -> torch.Tensor:
    """
    Q:      [B, H, Lq, Dh]
    K_ctx:  [H, Nctx, Dh]  (shared across batch)
    V_ctx:  [H, Nctx, Dh]
    K_buf:  [B, H, Tbuf, Dh] or None
    V_buf:  [B, H, Tbuf, Dh] or None
    -> Out: [B, H, Lq, Dh]
    """
    assert Q.is_cuda and K_ctx.is_cuda and V_ctx.is_cuda
    B, H, Lq, D = Q.shape
    Nctx = K_ctx.shape[1]

    # Avoid hasattr/complex guards inside compiled region.
    if K_buf is None or (K_buf is not None and K_buf.numel() == 0):
        Tbuf = 0
        K_buf = Q.new_empty((1, 1, 1, D))
        V_buf = Q.new_empty((1, 1, 1, D))
    else:
        Tbuf = K_buf.shape[2]

    Out = torch.empty_like(Q)
    scale = 1.0 / (D ** 0.5)

    def strides(x: torch.Tensor):
        return x.stride()

    def grid_general(META):
        return (triton.cdiv(B, META['GROUP_B']), H, Lq)

    def grid_lq1(META):
        return (triton.cdiv(B, META['GROUP_B']), H)

    if Lq == 1:
        ca_shared_kv_fwd_lq1[grid_lq1](
            Q, K_ctx, V_ctx, K_buf, V_buf, Out,
            B, H, D, Nctx, Tbuf,
            *strides(Q),
            *strides(K_ctx), *strides(V_ctx),
            *strides(K_buf), *strides(V_buf),
            *strides(Out),
            scale,
            BLOCK_D=D,
        )
    else:
        ca_shared_kv_fwd[grid_general](
            Q, K_ctx, V_ctx, K_buf, V_buf, Out,
            B, H, Lq, D, Nctx, Tbuf,
            *strides(Q),
            *strides(K_ctx), *strides(V_ctx),
            *strides(K_buf), *strides(V_buf),
            *strides(Out),
            scale,
            BLOCK_D=D,
        )
    return Out


# --------------------
# Blocks (notebook-style)
# --------------------


class SelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        self.n_heads = n_heads
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False):
        q = split_heads(self.q_proj(x), self.n_heads)
        k = split_heads(self.k_proj(x), self.n_heads)
        v = split_heads(self.v_proj(x), self.n_heads)
        y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal)
        y = combine_heads(y)
        return self.o_proj(y)


class CrossAttentionConcat(nn.Module):
    """
    Q from targets; K/V provided as head-split tensors (context shared across batch),
    plus optional per-batch buffer K/V. Use Triton fast path if conditions met.
    """

    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        self.n_heads = n_heads
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)

    def forward(
        self,
        x_q: torch.Tensor,
        Kc: torch.Tensor,
        Vc: torch.Tensor,
        Kb: Optional[torch.Tensor] = None,
        Vb: Optional[torch.Tensor] = None,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # x_q: [B, Lq, D];  Kc/Vc: [1, H, Nc, Dh] (shared) or [B, H, Nc, Dh]
        B, Lq, D = x_q.shape
        H = self.n_heads
        Dh = D // H
        Q = split_heads(self.q_proj(x_q), H).contiguous()  # [B,H,Lq,Dh]

        use_triton = (
            HAS_TRITON
            and x_q.is_cuda
            and (attn_mask is None)
            and (Kc is not None) and (Vc is not None)
            and (Kc.shape[0] == 1) and (Vc.shape[0] == 1)
            and (Dh % 16 == 0)
        )

        if use_triton:
            K_ctx = Kc[0].contiguous()  # [H,Nc,Dh]
            V_ctx = Vc[0].contiguous()
            K_buf = Kb.contiguous() if (Kb is not None and Kb.numel() > 0) else None
            V_buf = Vb.contiguous() if (Vb is not None and Vb.numel() > 0) else None
            Y = ca_shared_kv_forward_autotuned(Q, K_ctx, V_ctx, K_buf, V_buf)  # [B,H,Lq,Dh]
            y = combine_heads(Y)                                               # [B,Lq,D]
            return self.o_proj(y)

        # SDPA fallback with concat if Triton path disabled
        if Kc.shape[0] == 1 and B > 1:
            Kc = Kc.expand(B, -1, -1, -1).contiguous()
            Vc = Vc.expand(B, -1, -1, -1).contiguous()
        if (Kb is not None) and (Vb is not None) and (Kb.numel() > 0):
            K = torch.cat([Kc, Kb], dim=2)
            V = torch.cat([Vc, Vb], dim=2)
        else:
            K, V = Kc, Vc
        y = F.scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
        y = combine_heads(y)
        return self.o_proj(y)


class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        self.ln1 = FP32LayerNorm(d_model) if FP32_LN_DEFAULT else nn.LayerNorm(d_model)
        self.sa = SelfAttention(d_model, n_heads)
        self.ln2 = FP32LayerNorm(d_model) if FP32_LN_DEFAULT else nn.LayerNorm(d_model)
        self.ffn = FFN(d_model, d_ff)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        x = x + self.sa(self.ln1(x), attn_mask=attn_mask, is_causal=False)
        x = x + self.ffn(self.ln2(x))
        return x


class ContextEncoder(nn.Module):
    def __init__(self, d_model: int, n_heads: int, n_layers: int, d_ff: int):
        super().__init__()
        self.blocks = nn.ModuleList([TransformerEncoderBlock(d_model, n_heads, d_ff) for _ in range(n_layers)])
        # For reference parity: keep k/v proj hooks but we won’t depend on them.
        self.k_proj = nn.ModuleList([nn.Linear(d_model, d_model, bias=False) for _ in range(n_layers)])
        self.v_proj = nn.ModuleList([nn.Linear(d_model, d_model, bias=False) for _ in range(n_layers)])
        self.n_heads = n_heads

    def encode(self, x_mem: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        h = x_mem
        for blk in self.blocks:
            h = blk(h, attn_mask=attn_mask)
        return h

    def kv_from_encoded(self, E: torch.Tensor):  # not used by default here
        H = self.n_heads
        Ks, Vs = [], []
        for kp, vp in zip(self.k_proj, self.v_proj):
            K = split_heads(kp(E), H).contiguous()
            V = split_heads(vp(E), H).contiguous()
            Ks.append(K)
            Vs.append(V)
        return tuple(Ks), tuple(Vs)


class DecoderBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        self.ln_q = FP32LayerNorm(d_model) if FP32_LN_DEFAULT else nn.LayerNorm(d_model)
        self.ca = CrossAttentionConcat(d_model, n_heads)
        self.ln_f = FP32LayerNorm(d_model) if FP32_LN_DEFAULT else nn.LayerNorm(d_model)
        self.ffn = FFN(d_model, d_ff)

    def forward(self, h_in: torch.Tensor, Kc: torch.Tensor, Vc: torch.Tensor,
                Kb: Optional[torch.Tensor] = None, Vb: Optional[torch.Tensor] = None) -> torch.Tensor:
        y = self.ca(self.ln_q(h_in), Kc, Vc, Kb, Vb)
        h = h_in + y
        h = h + self.ffn(self.ln_f(h))
        return h


class Decoder(nn.Module):
    def __init__(self, d_model: int, n_heads: int, n_layers: int, d_ff: int):
        super().__init__()
        self.blocks = nn.ModuleList([DecoderBlock(d_model, n_heads, d_ff) for _ in range(n_layers)])
        self.n_heads = n_heads
        self.buf_k_proj = nn.ModuleList([nn.Linear(d_model, d_model, bias=False) for _ in range(n_layers)])
        self.buf_v_proj = nn.ModuleList([nn.Linear(d_model, d_model, bias=False) for _ in range(n_layers)])

    def buf_kv_from_token(self, buf_tok: torch.Tensor):
        H = self.n_heads
        Ks, Vs = [], []
        for kp, vp in zip(self.buf_k_proj, self.buf_v_proj):
            K = split_heads(kp(buf_tok), H).contiguous()
            V = split_heads(vp(buf_tok), H).contiguous()
            Ks.append(K)
            Vs.append(V)
        return tuple(Ks), tuple(Vs)


@torch.no_grad()
def build_decoder_ctx_kv(E: torch.Tensor, dec: Decoder):
    """
    Split encoded memory E to per-decoder-layer K/V by splitting heads only.
    Ensures Kc/Vc length matches len(dec.blocks).
    """
    H = dec.n_heads
    Ks, Vs = [], []
    K = split_heads(E, H).contiguous()  # [B,H,L,Dh]
    V = split_heads(E, H).contiguous()
    for _ in dec.blocks:
        Ks.append(K)
        Vs.append(V)
    return tuple(Ks), tuple(Vs)


# --------------------
# Embedders & Head
# --------------------


class ContextEmbedder(nn.Module):
    def __init__(self, d_in: int, d_model: int):
        super().__init__()
        self.mlp = FFN(d_in, 2 * d_model, d_model)

    def forward(self, x_ctx: torch.Tensor) -> torch.Tensor:
        return self.mlp(x_ctx)


class TargetEmbedder(nn.Module):
    def __init__(self, d_in: int, d_model: int):
        super().__init__()
        self.mlp = FFN(d_in, 2 * d_model, d_model)

    def forward(self, x_tgt: torch.Tensor) -> torch.Tensor:
        if x_tgt.dim() == 2:
            x_tgt = x_tgt.unsqueeze(0)
        return self.mlp(x_tgt)


class BufferEmbedder(nn.Module):
    def __init__(self, d_in: int, d_model: int):
        super().__init__()
        self.mlp = FFN(d_in, 2 * d_model, d_model)

    def forward(self, x_buf: torch.Tensor) -> torch.Tensor:
        return self.mlp(x_buf)


class GaussianHead(nn.Module):
    def __init__(self, d_model: int, d_out: int):
        super().__init__()
        self.net = FFN(d_model, 2 * d_model, 2 * d_out)

    def forward(self, h: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        out = self.net(h)
        mu, log_sigma = out.chunk(2, dim=-1)
        log_sigma = log_sigma.clamp(-7.0, 3.0)
        return mu, log_sigma


# --------------------
# Provider (AR buffer)
# --------------------


class ARBufferProvider:
    def __init__(self, ctx_enc: ContextEncoder, emb_ctx: nn.Module, emb_buf: nn.Module, dec: Decoder):
        self.ctx_enc = ctx_enc
        self.emb_ctx = emb_ctx
        self.emb_buf = emb_buf
        self.dec = dec
        self.Kc_list: Optional[List[torch.Tensor]] = None
        self.Vc_list: Optional[List[torch.Tensor]] = None
        self.K_buf: Optional[List[torch.Tensor]] = None
        self.V_buf: Optional[List[torch.Tensor]] = None
        self.B: Optional[int] = None
        self.T: Optional[int] = None
        self.xt_feats: Optional[torch.Tensor] = None
        self.device: Optional[torch.device] = None

    @torch.no_grad()
    def prepare_context(self, xc_feats: torch.Tensor) -> None:
        self.device = xc_feats.device
        E_ctx = self.ctx_enc.encode(xc_feats)
        self.Kc_list, self.Vc_list = build_decoder_ctx_kv(E_ctx, self.dec)

    @torch.no_grad()
    def start_sequence(self, B: int, T: int, xt_feats: torch.Tensor) -> None:
        self.B, self.T = B, T
        self.xt_feats = xt_feats
        H = self.Kc_list[0].shape[1]
        Dh = self.Kc_list[0].shape[-1]
        L = len(self.Kc_list)
        self.K_buf = [torch.empty(B, H, T, Dh, device=self.device, dtype=self.Kc_list[0].dtype) for _ in range(L)]
        self.V_buf = [torch.empty(B, H, T, Dh, device=self.device, dtype=self.Vc_list[0].dtype) for _ in range(L)]

    @torch.no_grad()
    def memory_at(self, t: int):
        if t == 0:
            return list(self.Kc_list), list(self.Vc_list), [None] * len(self.Kc_list), [None] * len(self.Vc_list)
        Kb_list = [Kb[:, :, :t, :].contiguous() for Kb in self.K_buf]
        Vb_list = [Vb[:, :, :t, :].contiguous() for Vb in self.V_buf]
        return list(self.Kc_list), list(self.Vc_list), Kb_list, Vb_list

    @torch.no_grad()
    def observe(self, t: int, x_t_feat_B: torch.Tensor, y_t: torch.Tensor, h_t: torch.Tensor, dec: Decoder) -> None:
        buf_in = torch.cat([x_t_feat_B, y_t], dim=-1)
        buf_tok = self.emb_buf(buf_in)
        K_new_list, V_new_list = dec.buf_kv_from_token(buf_tok)
        for l in range(len(self.K_buf)):
            self.K_buf[l][:, :, t:t + 1, :].copy_(K_new_list[l])
            self.V_buf[l][:, :, t:t + 1, :].copy_(V_new_list[l])


# --------------------
# AR decode driver + compiled adapter
# --------------------


@torch.no_grad()
def unified_ar_decode(provider: ARBufferProvider,
                      dec: Decoder,
                      emb_tgt: nn.Module,
                      head: nn.Module,
                      xc_feats: torch.Tensor,      # [1,Nc,dx]
                      xt_feats: torch.Tensor,      # [T,dx]
                      B: int) -> torch.Tensor:     # -> [B,T,dy]
    T = xt_feats.shape[0]
    provider.prepare_context(xc_feats)
    provider.start_sequence(B, T, xt_feats)

    ys: List[torch.Tensor] = []
    for t in range(T):
        q_t = emb_tgt(xt_feats[t:t + 1, :]).squeeze(0).expand(B, -1, -1).contiguous()   # [B,1,D]
        Kc_list, Vc_list, Kb_list, Vb_list = provider.memory_at(t)

        h = q_t
        for l, blk in enumerate(dec.blocks):
            Kc1, Vc1 = Kc_list[l], Vc_list[l]
            Kb = None if (Kb_list is None) else Kb_list[l]
            Vb = None if (Vb_list is None) else Vb_list[l]
            h = blk(h, Kc1, Vc1, Kb, Vb)

        mu, log_sigma = head(h)                    # [B,1,dy]
        eps = torch.randn_like(mu)
        y_t = mu + torch.exp(log_sigma) * eps
        ys.append(y_t)

        x_t_feat_B = xt_feats[t:t + 1, :].unsqueeze(0).expand(B, -1, -1).contiguous()
        provider.observe(t, x_t_feat_B, y_t, h, dec)

    return torch.cat(ys, dim=1)                    # [B,T,dy]


def build_m3_exec(
    ctx_enc: ContextEncoder,
    dec: Decoder,
    emb_ctx: nn.Module,
    emb_tgt: nn.Module,
    emb_buf: nn.Module,
    head: nn.Module,
    x_ctx_feats: torch.Tensor,
    x_tgt_feats: torch.Tensor,
    B: int,
    compile: bool = True,
    mode: str = COMPILE_MODE_DEFAULT,
    fullgraph: bool = FULLGRAPH_DEFAULT,
):
    provider = ARBufferProvider(ctx_enc, emb_ctx, emb_buf, dec)

    def runner():
        return unified_ar_decode(provider, dec, emb_tgt, head, x_ctx_feats, x_tgt_feats, B)

    if not compile:
        return runner
    try:
        return torch.compile(runner, fullgraph=fullgraph, dynamic=False, mode=mode)
    except Exception as e:  # pragma: no cover
        warnings.warn(f"torch.compile failed, falling back to eager: {e}")
        return runner


def _match(x: torch.Tensor, device: torch.device | str, dtype: torch.dtype) -> torch.Tensor:
    return x.to(device=device, dtype=dtype, non_blocking=True)


class CompiledM3Adapter(nn.Module):
    """
    Compiled M3-only adapter exposing predict()/sample_joint_predictive(), matching
    the notebook’s compiled wrapper semantics for benchmarking.
    """

    def __init__(self,
                 ctx_enc: nn.Module,
                 dec: nn.Module,
                 emb_ctx: nn.Module,
                 emb_tgt: nn.Module,
                 emb_buf: nn.Module,
                 head: nn.Module):
        super().__init__()
        self.ctx_enc, self.dec = ctx_enc, dec
        self.emb_ctx, self.emb_tgt, self.emb_buf = emb_ctx, emb_tgt, emb_buf
        self.head = head
        p = next(ctx_enc.parameters())
        self._device, self._dtype = p.device, p.dtype
        self._cache = {}  # (Nc, Nt, B) -> compiled zero-arg fn
        # For dx check (optional)
        lin0 = getattr(getattr(self.emb_tgt, "mlp", None), "0", None)
        self._dx_expected = getattr(lin0, "in_features", None)
        # compile controls
        self._compile_enabled = COMPILE_ENABLED_DEFAULT
        self._compile_mode = COMPILE_MODE_DEFAULT
        self._fullgraph = FULLGRAPH_DEFAULT

    def _key(self, Nc: int, Nt: int, B: int):
        return (int(Nc), int(Nt), int(B), str(self._dtype), str(self._device))

    @torch.no_grad()
    def _ensure_runner(self, Nc: int, Nt: int, B: int, dx: int):
        k = self._key(Nc, Nt, B)
        if k in self._cache:
            return self._cache[k]
        device, dtype = self._device, self._dtype
        x_ctx_feats = torch.randn(1, Nc, dx, device=device, dtype=dtype)
        x_tgt_feats = torch.randn(Nt, dx, device=device, dtype=dtype)
        fn = build_m3_exec(
            self.ctx_enc,
            self.dec,
            self.emb_ctx,
            self.emb_tgt,
            self.emb_buf,
            self.head,
            x_ctx_feats,
            x_tgt_feats,
            B,
            compile=self._compile_enabled,
            mode=self._compile_mode,
            fullgraph=self._fullgraph,
        )
        self._cache[k] = fn
        return fn

    @torch.no_grad()
    def predict(self, xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor, num_samples: int = 50):
        # For M3, unify with sample_joint_predictive
        return self.sample_joint_predictive(xc, yc, xt, num_samples=num_samples)

    @torch.no_grad()
    def sample_joint_predictive(self, xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor, num_samples: int = 50):
        xc = _match(xc, self._device, self._dtype)
        xt = _match(xt, self._device, self._dtype)
        Nc, Nt, dx = xc.shape[1], xt.shape[1], xt.shape[-1]
        if self._dx_expected is not None and dx != self._dx_expected:
            raise ValueError(f"[m3] dx mismatch: data dx={dx} vs embedder expects {self._dx_expected}")
        B = int(num_samples)
        runner = self._ensure_runner(Nc, Nt, B, dx=dx)
        y = runner()  # [B,Nt,dy]
        return y.permute(1, 0, 2).unsqueeze(0).contiguous()  # [1,Nt,B,dy]


# Convenience builders for the script


def build_modules_for_m3(dx: int, dy: int, d_model: int, n_heads: int, n_layers_enc: int, n_layers_dec: int, d_ff: int,
                         device: str) -> tuple[nn.Module, nn.Module, nn.Module, nn.Module, nn.Module, nn.Module]:
    dtype = torch.float16 if (device == "cuda" and torch.cuda.is_available()) else torch.float32
    emb_ctx = ContextEmbedder(dx, d_model).to(device).to(dtype)
    emb_tgt = TargetEmbedder(dx, d_model).to(device).to(dtype)
    emb_buf = BufferEmbedder(dx + dy, d_model).to(device).to(dtype)
    head = GaussianHead(d_model, dy).to(device).to(dtype)
    ctx_enc = ContextEncoder(d_model, n_heads, n_layers_enc, d_ff).to(device).to(dtype)
    dec = Decoder(d_model, n_heads, n_layers_dec, d_ff).to(device).to(dtype)
    return emb_ctx, emb_tgt, emb_buf, head, ctx_enc, dec
