from __future__ import annotations

"""
Engine-parity sampling-time benchmark for M3 (SDPA vs Triton).

What this measures
- M3 decoding that matches InferenceEngine semantics:
  * marker Embedder (x-MLP + y-MLP + marker) used to embed context/targets
  * simple Gaussian head shared by both paths
  * per-step decode with context prefill via engine backbone
  * per-layer K/V for context taken directly from engine caches
  * current-step K/V computed by engine k_proj/v_proj on pre-LN query

Outputs two series in one JSON:
- "Engine M3 (SDPA)"   → compiled_core.Decoder CA via PyTorch SDPA
- "Engine M3 (Triton)" → same decoder patched with Triton CA

Usage
  python scripts/fast_times/run_engine_sampling_m3_parity.py
"""

import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Tuple

import torch


def _bootstrap_path() -> None:
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    root = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            root = cand
            break
    root = root or base
    if str(root) not in sys.path:
        sys.path.insert(0, str(root))


_bootstrap_path()

from scripts.fast_times.config import BenchConfig, DEFAULT
from scripts.fast_times.common import (
    configure_torch_env,
    _ensure_dir,
    _atomic_write_json,
    run_benchmark_grid,
    package_for_plot,
    data_gen,
)
from scripts.fast_times.compiled_core import (
    Decoder,
    GaussianHead,
    patch_decoder_with_triton,
)
from src.models.modules.embedders import Embedder
from src.models.modules.backbones import Transformer
from src.models.masks import create_context_self_attention_block_mask
from src.models.utils import expand_kv_heads
from src.models.ace import InferenceEngine2
from src.utils import DataAttr


def _compile_ok(mod: torch.nn.Module) -> torch.nn.Module:
    try:
        # Match other runners: compile surrounding modules to reduce Python overhead
        return torch.compile(mod, mode="default", fullgraph=False, dynamic=True)
    except Exception:
        return mod


@torch.no_grad()
def copy_engine_to_decoder(backbone: Transformer, dec: Decoder) -> None:
    for blk, lyr in zip(dec.blocks, backbone.layers):
        blk.ln1.weight.copy_(lyr.norm1.weight)
        blk.ln1.bias.copy_(lyr.norm1.bias)
        blk.ln2.weight.copy_(lyr.norm2.weight)
        blk.ln2.bias.copy_(lyr.norm2.bias)

        lin1 = blk.ff.net[0]
        lin2 = blk.ff.net[2]
        lin1.weight.copy_(lyr.ff1.weight)
        lin1.bias.copy_(lyr.ff1.bias)
        lin2.weight.copy_(lyr.ff2.weight)
        lin2.bias.copy_(lyr.ff2.bias)

        blk.ca.q_proj.weight.copy_(lyr.attn.q_proj.weight)
        if blk.ca.q_proj.bias is not None:
            blk.ca.q_proj.bias.zero_()
        blk.ca.out.weight.copy_(lyr.attn.o_proj.weight)
        if blk.ca.out.bias is not None:
            blk.ca.out.bias.zero_()

    dec.ln.weight.copy_(backbone.norm.weight)
    dec.ln.bias.copy_(backbone.norm.bias)


@torch.no_grad()
def get_ctx_kv_from_engine(engine: InferenceEngine2) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    L0 = engine.backbone.seq_len.item()
    Kc_list: List[torch.Tensor] = []
    Vc_list: List[torch.Tensor] = []
    for lyr in engine.backbone.layers:
        Kc_list.append(lyr.k_cache[:, :, :L0, :].contiguous())
        Vc_list.append(lyr.v_cache[:, :, :L0, :].contiguous())
    return Kc_list, Vc_list


@torch.no_grad()
def decode_step(dec: Decoder, engine: InferenceEngine2,
                Kc_list: List[torch.Tensor], Vc_list: List[torch.Tensor],
                h_in: torch.Tensor) -> torch.Tensor:
    """One-step decode via SDPA CA (no Triton), using engine-style per-layer Kb/Vb."""
    h = h_in
    for li, (blk, lyr) in enumerate(zip(dec.blocks, engine.backbone.layers)):
        qn = blk.ln1(h)
        kv_heads = lyr.attn.num_kv_heads
        num_heads = lyr.attn.num_heads
        Dh = lyr.attn.head_dim
        kb = lyr.attn.k_proj(qn).view(h.shape[0], qn.shape[1], kv_heads, Dh).transpose(1, 2)
        vb = lyr.attn.v_proj(qn).view(h.shape[0], qn.shape[1], kv_heads, Dh).transpose(1, 2)
        if kv_heads != num_heads:
            kb = expand_kv_heads(kb, num_heads // kv_heads)
            vb = expand_kv_heads(vb, num_heads // kv_heads)
        y = blk.ca(qn, Kc_list[li], Vc_list[li], kb.contiguous(), vb.contiguous(), attn_mask=None)
        h = h + y
        y2 = blk.ln2(h)
        h = h + blk.ff(y2)
    return dec.ln(h)


@torch.no_grad()
def decode_step_triton(dec: Decoder, engine: InferenceEngine2,
                       Kc_list: List[torch.Tensor], Vc_list: List[torch.Tensor],
                       h_in: torch.Tensor) -> torch.Tensor:
    """One-step decode via Triton CA, using engine-style per-layer Kb/Vb."""
    h = h_in
    for li, (blk, lyr) in enumerate(zip(dec.blocks, engine.backbone.layers)):
        qn = blk.ln1(h)
        kv_heads = lyr.attn.num_kv_heads
        num_heads = lyr.attn.num_heads
        Dh = lyr.attn.head_dim
        kb = lyr.attn.k_proj(qn).view(h.shape[0], qn.shape[1], kv_heads, Dh).transpose(1, 2)
        vb = lyr.attn.v_proj(qn).view(h.shape[0], qn.shape[1], kv_heads, Dh).transpose(1, 2)
        if kv_heads != num_heads:
            kb = expand_kv_heads(kb, num_heads // kv_heads)
            vb = expand_kv_heads(vb, num_heads // kv_heads)
        y = blk.ca(qn, Kc_list[li], Vc_list[li], kb.contiguous(), vb.contiguous(), attn_mask=None)
        h = h + y
        y2 = blk.ln2(h)
        h = h + blk.ff(y2)
    return dec.ln(h)


class EngineParityM3Adapter(torch.nn.Module):
    """Adapter exposing sample_joint_predictive for engine-parity M3 timing."""

    def __init__(self, engine: InferenceEngine2, dec: Decoder, head: GaussianHead, *, use_triton: bool):
        super().__init__()
        self.engine = engine
        self.dec = dec
        self.head = head
        self.use_triton = use_triton
        # A param to make device/dtype detection easy
        self._dummy = torch.nn.Parameter(torch.zeros((), device=next(dec.parameters()).device,
                                                     dtype=next(dec.parameters()).dtype))

    @torch.no_grad()
    def sample_joint_predictive(self, xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor, num_samples: int) -> torch.Tensor:
        B0, T, dx = xt.shape
        assert B0 == 1, "Adapter expects B=1 (benchmark harness uses B=1)."
        dy = yc.shape[-1]
        device = xt.device
        dtype = xt.dtype

        out = torch.empty(B0, T, num_samples, dy, device=device, dtype=dtype)

        # Build base context embedding
        # Here we rely on engine.embedder (marker-based) to match runtime
        base_ctx = DataAttr(xc=xc, yc=yc)
        ctx_emb = self.engine.embedder.embed_context(base_ctx)  # [1,Nc,D]

        for s in range(num_samples):
            # Reset context in engine for each sample sequence
            self.engine.context_embeddings = ctx_emb.clone()
            self.engine.init_kv_cache(B=1, max_seq=ctx_emb.shape[1] + T, device=device, dtype=dtype)

            seq_out = torch.empty(B0, T, dy, device=device, dtype=dtype)
            for t in range(T):
                # Prefill caches from current context
                L0 = self.engine.context_embeddings.shape[1]
                mask = create_context_self_attention_block_mask(
                    current_num_context=L0,
                    q_block_size=self.engine.q_block_size,
                    kv_block_size=self.engine.kv_block_size,
                    device=device,
                )
                self.engine.prefill_kv_cache(mask)
                Kc_list, Vc_list = get_ctx_kv_from_engine(self.engine)

                # One-step target embed (marker target)
                x_t = xt[:, t : t + 1, :]
                h_in = self.engine.embedder.embed_target(DataAttr(xt=x_t))

                # Decode and sample
                if self.use_triton:
                    z = decode_step_triton(self.dec, self.engine, Kc_list, Vc_list, h_in)
                else:
                    z = decode_step(self.dec, self.engine, Kc_list, Vc_list, h_in)
                mu_t, sigma_t = self.head(z)
                y_t = torch.distributions.Normal(mu_t, sigma_t).sample()  # [1,1,dy]
                seq_out[:, t : t + 1, :] = y_t.squeeze(1)

                # Update engine context with predicted y
                self.engine.update_context_embeddings(DataAttr(xc=x_t, yc=y_t))

            out[:, :, s : s + 1, :] = seq_out.unsqueeze(2)

        return out


def build_modules(cfg: BenchConfig):
    D = cfg.dims
    device = cfg.runtime.device
    dtype = torch.float16 if (device == "cuda" and torch.cuda.is_available()) else torch.float32

    # Marker embedder
    embedder = Embedder(
        dim_x=D.dx,
        dim_y=D.dy,
        hidden_dim=2 * D.d_model,
        out_dim=D.d_model,
        depth=2,
    ).to(device).to(dtype)
    # Compile embedder for parity with other runners
    embedder = _compile_ok(embedder)

    # Engine backbone
    backbone = Transformer(
        num_layers=D.n_layers_dec,
        dim_model=D.d_model,
        num_head=D.n_heads,
        dim_feedforward=D.d_ff,
        dropout=0.0,
    ).to(device).to(dtype)

    # Gaussian head
    head = GaussianHead(D.d_model, D.dy).to(device).to(dtype)
    head = _compile_ok(head)

    # Engine wrapper (for KV utils and embedder)
    ar_tokens = torch.zeros(32, D.d_model, device=device, dtype=dtype)
    engine = InferenceEngine2(
        embedder=embedder,
        backbone=backbone,
        head=head,
        ar_tokens=ar_tokens,
        max_buffer_size=ar_tokens.shape[0],
    ).to(device)

    # SDPA decoder copy
    dec_sdpa = Decoder(D.d_model, D.n_heads, D.n_layers_dec, D.d_ff).to(device).to(dtype)
    copy_engine_to_decoder(backbone, dec_sdpa)
    dec_sdpa = _compile_ok(dec_sdpa)

    # Triton decoder copy
    dec_triton = Decoder(D.d_model, D.n_heads, D.n_layers_dec, D.d_ff).to(device).to(dtype)
    copy_engine_to_decoder(backbone, dec_triton)
    patch_decoder_with_triton(dec_triton, verbose=False, force_triton=True)
    dec_triton = _compile_ok(dec_triton)

    # Adapters
    m3_sdpa = EngineParityM3Adapter(engine, dec_sdpa, head, use_triton=False)
    m3_trit = EngineParityM3Adapter(engine, dec_triton, head, use_triton=True)
    return m3_sdpa, m3_trit


def main(cfg: BenchConfig = DEFAULT) -> None:
    configure_torch_env()

    m3_sdpa, m3_trit = build_modules(cfg)

    Nc_vals = tuple(cfg.grid.Nc_values)
    ns_vals = tuple(cfg.grid.num_samples_values)
    Nt = cfg.grid.Nt
    runs = cfg.grid.num_runs
    dx, dy = cfg.dims.dx, cfg.dims.dy

    methods: Dict[str, Dict[str, Any]] = {}

    print("\n=== Engine-parity M3 (SDPA) ===")
    methods["Engine M3 (SDPA)"] = run_benchmark_grid(
        m3_sdpa, Nt=Nt, num_runs=runs, method="autoregressive",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )

    print("\n=== Engine-parity M3 (Triton) ===")
    methods["Engine M3 (Triton)"] = run_benchmark_grid(
        m3_trit, Nt=Nt, num_runs=runs, method="autoregressive",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )

    meta = {
        "script": "engine_sampling_m3_parity",
        "config": cfg.to_dict(),
    }
    out = package_for_plot(methods, meta=meta)
    out_dir = cfg.out_dir
    _ensure_dir(os.path.join(out_dir, "_"))
    out_path = os.path.join(out_dir, "engine_sampling_m3_parity.json")
    _atomic_write_json(out_path, out)
    print(f"Saved: {out_path}")


if __name__ == "__main__":
    main()
