from __future__ import annotations

"""
Engine ↔ Triton Sampling Parity (markers embedder + Gaussian head)

Compares InferenceEngine2 sampling step-by-step against a Triton-backed
compiled_core.Decoder using the same:
  - Embedder with markers (src.models.modules.embedders.Embedder)
  - Simple Gaussian head (mu/sigma) for both paths
  - Transformer weights (copied from engine backbone into compiled_core.Decoder)

Key points:
  - We feed the Triton CA with per-layer, per-head context K/V taken directly
    from the engine’s prefill KV caches, ensuring K/V identical for context.
  - For the current query’s self position, we compute per-layer K/V via the
    engine’s k_proj/v_proj on the pre-LN query at each layer, matching engine
    semantics, and pass as Kb/Vb to Triton CA.
  - We only check sampling: at each step t we generate a single eps ~ N(0,1)
    and produce y_t = mu + sigma * eps for both branches using the same eps.

Usage:
  python scripts/fast_times/check_engine_vs_triton_sampling.py --Nc 16 --Nt 16
"""

import argparse
from typing import List, Tuple
from pathlib import Path
import sys

import torch


def _bootstrap_path() -> None:
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    root = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            root = cand
            break
    root = root or base
    if str(root) not in sys.path:
        sys.path.insert(0, str(root))


_bootstrap_path()

from scripts.fast_times.config import DEFAULT, BenchConfig
from scripts.fast_times.common import data_gen
from scripts.fast_times.compiled_core import (
    Decoder,
    GaussianHead,
    patch_decoder_with_triton,
)
from src.models.modules.embedders import Embedder
from src.models.modules.backbones import Transformer
from src.models.masks import create_context_self_attention_block_mask
from src.models.utils import expand_kv_heads
from src.models.ace import InferenceEngine2
from src.utils import DataAttr


def configure_env_for_parity() -> None:
    """Prefer fp32 and disable Flash/ME SDPA for tighter parity."""
    torch.set_float32_matmul_precision("high")
    try:
        torch.backends.cuda.enable_flash_sdp(False)
        torch.backends.cuda.enable_mem_efficient_sdp(False)
        torch.backends.cuda.enable_math_sdp(True)
    except Exception:
        pass


def build_components(cfg: BenchConfig):
    D = cfg.dims
    device = cfg.runtime.device
    # Force fp32 for parity; switch to fp16/bf16 if desired
    dtype = torch.float32

    # Embedder with markers (engine’s)
    embedder = Embedder(
        dim_x=D.dx,
        dim_y=D.dy,
        hidden_dim=2 * D.d_model,
        out_dim=D.d_model,
        depth=2,
    ).to(device).to(dtype)

    # Engine backbone (Transformer)
    backbone = Transformer(
        num_layers=D.n_layers_dec,
        dim_model=D.d_model,
        num_head=D.n_heads,
        dim_feedforward=D.d_ff,
        dropout=0.0,
    ).to(device).to(dtype)

    # Dummy AR tokens (unused in this script but required by the engine API)
    ar_tokens = torch.zeros(D.DEFAULT_TARGETS_BLOCK_SIZE if hasattr(D, 'DEFAULT_TARGETS_BLOCK_SIZE') else 32,
                            D.d_model, device=device, dtype=dtype)

    # Simple Gaussian head (shared by both paths)
    head = GaussianHead(D.d_model, D.dy).to(device).to(dtype)

    # Engine wrapper (for KV cache utils and transformer_decode)
    engine = InferenceEngine2(
        embedder=embedder,
        backbone=backbone,
        head=head,  # not used directly; we compute mu/sigma via head ourselves
        ar_tokens=ar_tokens,
        max_buffer_size=ar_tokens.shape[0],
    ).to(device)

    # Triton-backed decoder with identical dims
    dec = Decoder(D.d_model, D.n_heads, D.n_layers_dec, D.d_ff).to(device).to(dtype)

    return embedder, backbone, head, engine, dec


@torch.no_grad()
def copy_engine_to_decoder(backbone: Transformer, dec: Decoder) -> None:
    """Copy Transformer weights from engine backbone into compiled_core.Decoder.

    Mapping:
      - per-layer norm1/norm2 -> ln1/ln2
      - ff1/ff2 -> ff.net[0]/ff.net[2]
      - attn q_proj/o_proj -> ca.q_proj/ca.out (biases zeroed)
      - final norm -> dec.ln
    """
    for blk, lyr in zip(dec.blocks, backbone.layers):
        # LayerNorms
        blk.ln1.weight.copy_(lyr.norm1.weight)
        blk.ln1.bias.copy_(lyr.norm1.bias)
        blk.ln2.weight.copy_(lyr.norm2.weight)
        blk.ln2.bias.copy_(lyr.norm2.bias)

        # Feed-forward (Linear -> GELU -> Linear)
        lin1 = blk.ff.net[0]
        lin2 = blk.ff.net[2]
        lin1.weight.copy_(lyr.ff1.weight)
        lin1.bias.copy_(lyr.ff1.bias)
        lin2.weight.copy_(lyr.ff2.weight)
        lin2.bias.copy_(lyr.ff2.bias)

        # Attention projections (q and out only; k/v are bypassed via provided K/V)
        blk.ca.q_proj.weight.copy_(lyr.attn.q_proj.weight)
        if blk.ca.q_proj.bias is not None:
            blk.ca.q_proj.bias.zero_()
        blk.ca.out.weight.copy_(lyr.attn.o_proj.weight)
        if blk.ca.out.bias is not None:
            blk.ca.out.bias.zero_()

    # Final norm
    dec.ln.weight.copy_(backbone.norm.weight)
    dec.ln.bias.copy_(backbone.norm.bias)


@torch.no_grad()
def get_ctx_kv_from_engine(engine: InferenceEngine2) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    """Extract per-layer, per-head [B,H,Nc,Dh] K/V from engine caches after prefill."""
    L0 = engine.backbone.seq_len.item()
    Kc_list: List[torch.Tensor] = []
    Vc_list: List[torch.Tensor] = []
    for lyr in engine.backbone.layers:
        Kc_list.append(lyr.k_cache[:, :, :L0, :].contiguous())
        Vc_list.append(lyr.v_cache[:, :, :L0, :].contiguous())
    return Kc_list, Vc_list


@torch.no_grad()
def decode_step_triton(
    engine: InferenceEngine2,
    dec: Decoder,
    Kc_list: List[torch.Tensor],
    Vc_list: List[torch.Tensor],
    h_in: torch.Tensor,
) -> torch.Tensor:
    """Run one decode step through Triton CA using engine-aligned per-layer K/V.

    For each layer l:
      - qn = ln1(h)
      - Kb_l, Vb_l = engine.attn.k_proj/v_proj(qn) expanded to num_heads
      - y = CA(qn, Kc_l, Vc_l, Kb_l, Vb_l)
      - h = h + y; h = h + FF(ln2(h))
    Finally apply dec.ln.
    """
    h = h_in
    for li, (blk, lyr) in enumerate(zip(dec.blocks, engine.backbone.layers)):
        qn = blk.ln1(h)
        # Current-query K/V via engine projections
        kv_heads = lyr.attn.num_kv_heads
        num_heads = lyr.attn.num_heads
        Dh = lyr.attn.head_dim
        kb = (
            lyr.attn.k_proj(qn)
            .view(h.shape[0], qn.shape[1], kv_heads, Dh)
            .transpose(1, 2)
        )  # [B, kvH, 1, Dh]
        vb = (
            lyr.attn.v_proj(qn)
            .view(h.shape[0], qn.shape[1], kv_heads, Dh)
            .transpose(1, 2)
        )
        if kv_heads != num_heads:
            kb = expand_kv_heads(kb, num_heads // kv_heads)
            vb = expand_kv_heads(vb, num_heads // kv_heads)

        y = blk.ca(qn, Kc_list[li], Vc_list[li], kb.contiguous(), vb.contiguous(), attn_mask=None)
        h = h + y
        y2 = blk.ln2(h)
        h = h + blk.ff(y2)
    return dec.ln(h)


@torch.no_grad()
def run_sampling_parity(cfg: BenchConfig, Nc: int, Nt: int) -> None:
    device = cfg.runtime.device
    dtype = torch.float32
    B = 1

    # Build components
    embedder, backbone, head, engine, dec = build_components(cfg)
    copy_engine_to_decoder(backbone, dec)
    patch_decoder_with_triton(dec, verbose=False, force_triton=True)

    # Data
    xc, yc, xt, _ = data_gen(Nc, Nt, B=B, dx=cfg.dims.dx, dy=cfg.dims.dy, device=device, dtype=dtype)

    # Engine context store and cache init
    engine.eval()
    dec.eval()
    engine.store_context_embeddings(DataAttr(xc=xc, yc=yc))
    engine.init_kv_cache(B=B, max_seq=Nc + Nt + 8, device=xc.device, dtype=dtype)

    # Outputs
    y_engine = torch.empty(B, Nt, cfg.dims.dy, device=device, dtype=dtype)
    y_triton = torch.empty_like(y_engine)

    # Step loop
    for t in range(Nt):
        # Prefill with current context
        L0 = engine.context_embeddings.shape[1]
        mask = create_context_self_attention_block_mask(
            current_num_context=L0,
            q_block_size=engine.q_block_size,
            kv_block_size=engine.kv_block_size,
            device=xc.device,
        )
        engine.prefill_kv_cache(mask)
        Kc_list, Vc_list = get_ctx_kv_from_engine(engine)

        # Target embedding (markers embedder)
        x_t = xt[:, t : t + 1, :]
        h_in = embedder.embed_target(DataAttr(xt=x_t))  # [B,1,D]

        # Engine decode (no head here)
        z_eng = engine.transformer_decode(h_in)  # [B,1,D]
        # Triton decode with engine-aligned K/V
        z_tri = decode_step_triton(engine, dec, Kc_list, Vc_list, h_in)

        # Same epsilon for both
        mu_e, sigma_e = head(z_eng)
        mu_t, sigma_t = head(z_tri)
        eps = torch.randn_like(mu_e)
        y_e = mu_e + sigma_e * eps
        y_ti = mu_t + sigma_t * eps
        y_engine[:, t : t + 1, :] = y_e.squeeze(1)
        y_triton[:, t : t + 1, :] = y_ti.squeeze(1)

        # Grow context with predicted sample (engine embedder semantics)
        engine.update_context_embeddings(DataAttr(xc=x_t, yc=y_e))

    # Report deltas
    max_abs = (y_engine - y_triton).abs().max().item()
    denom = y_engine.abs().max().item() + 1e-6
    max_rel = max_abs / denom
    print("Sampling parity (markers + Gaussian head):")
    print(f"  Nc={Nc}, Nt={Nt}, B=1, dtype={dtype}")
    print(f"  max|Δy|={max_abs:.6e}, maxRel={max_rel:.6e}")


def main() -> None:
    parser = argparse.ArgumentParser(description="Engine↔Triton sampling parity (markers + Gaussian head)")
    parser.add_argument("--Nc", type=int, default=16, help="Number of context points")
    parser.add_argument("--Nt", type=int, default=16, help="Number of target points")
    args, _ = parser.parse_known_args()

    configure_env_for_parity()
    cfg = DEFAULT
    # Default to CUDA if available
    if torch.cuda.is_available():
        cfg.runtime.device = "cuda"
    else:
        cfg.runtime.device = "cpu"

    run_sampling_parity(cfg, Nc=args.Nc, Nt=args.Nt)


if __name__ == "__main__":
    main()

