from __future__ import annotations

"""
Step 2 Isolation: First-step (t=0) parity check for M3 between SDPA and Triton.

What it does
- Forces fp32 and disables Flash/MemEfficient SDPA to remove kernel/dtype variance.
- Builds modules from compiled_core with MLP embedders.
- Precomputes per-layer, per-head K/V for context once.
- Runs a single decode step (Lq=1) on SDPA and Triton decoders using identical K/V.
- Prints max absolute/relative differences for:
  * Context K/V equality across decoders (should be ~0)
  * First-block CA output (should be tiny)
  * Full-stack output after all decoder blocks (should be tiny)

Usage
  python scripts/fast_times/debug_m3_first_step.py --Nc 16 --Nt 16
"""

import argparse
import copy
import sys
from pathlib import Path

import torch


def _bootstrap_import_path() -> None:
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    root = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            root = cand
            break
    root = root or base
    if str(root) not in sys.path:
        sys.path.insert(0, str(root))


_bootstrap_import_path()

from scripts.fast_times.config import DEFAULT, BenchConfig
from scripts.fast_times.common import configure_torch_env, data_gen
from scripts.fast_times.compiled_core import (
    ContextEmbedder,
    TargetEmbedder,
    BufferEmbedder,
    ContextEncoder,
    Decoder,
    GaussianHead,
    patch_decoder_with_triton,
    build_decoder_ctx_kv,
)
import scripts.fast_times.compiled_core as cc


def build_modules(cfg: BenchConfig):
    device = cfg.runtime.device
    dtype = torch.float32  # force fp32 for isolation
    D = cfg.dims
    emb_ctx = ContextEmbedder(D.dx + D.dy, D.d_model).to(device).to(dtype)
    emb_tgt = TargetEmbedder(D.dx, D.d_model).to(device).to(dtype)
    emb_buf = BufferEmbedder(D.dx + D.dy, D.d_model).to(device).to(dtype)
    head = GaussianHead(D.d_model, D.dy).to(device).to(dtype)
    ctx_enc = ContextEncoder(D.d_model, D.n_heads, D.n_layers_enc, D.d_ff).to(device).to(dtype)
    dec = Decoder(D.d_model, D.n_heads, D.n_layers_dec, D.d_ff).to(device).to(dtype)
    return emb_ctx, emb_tgt, emb_buf, head, ctx_enc, dec


@torch.no_grad()
def run_first_step_debug(cfg: BenchConfig, *, Nc: int, Nt: int) -> None:
    print("Step2: first-step (t=0) isolation with precomputed per-layer K/V (fp32)")
    configure_torch_env()
    try:
        torch.backends.cuda.enable_flash_sdp(False)
        torch.backends.cuda.enable_mem_efficient_sdp(False)
        torch.backends.cuda.enable_math_sdp(True)
    except Exception:
        pass

    device = cfg.runtime.device
    emb_ctx, emb_tgt, emb_buf, head, ctx_enc, dec_sdpa = build_modules(cfg)
    dec_triton = copy.deepcopy(dec_sdpa)
    patch_decoder_with_triton(dec_triton, verbose=False, force_triton=True)

    # Fixed synthetic data (B=1)
    xc, yc, xt, _ = data_gen(Nc, Nt, B=1, dx=cfg.dims.dx, dy=cfg.dims.dy, device=device, dtype=torch.float32)

    # Context encode once
    xc_feats = emb_ctx(torch.cat([xc, yc], dim=-1))   # [1,Nc,D]
    E = ctx_enc.encode(xc_feats)                      # [1,Nc,D]

    # Precompute per-layer, per-head K/V once using SDPA decoder's CA weights
    Kc_list_sdpa, Vc_list_sdpa = build_decoder_ctx_kv(E, dec_sdpa)

    # Sanity-check: SDPA vs Triton per-layer context K/V equality
    Kc_list_tr, Vc_list_tr = build_decoder_ctx_kv(E, dec_triton)
    max_k = max((Kc_list_sdpa[i] - Kc_list_tr[i]).abs().max().item() for i in range(len(Kc_list_sdpa)))
    max_v = max((Vc_list_sdpa[i] - Vc_list_tr[i]).abs().max().item() for i in range(len(Vc_list_sdpa)))
    print(f"Context K/V per-layer equality: max|ΔK|={max_k:.3e}, max|ΔV|={max_v:.3e}")

    # Single-step query embedding (Lq=1)
    q0 = xt[:, :1, :]                 # [1,1,dx]
    h_in = emb_tgt(q0)                # [1,1,D]

    # Full-stack decode on both with SAME K/V lists; no buffer, no mask
    h0_sdpa = dec_sdpa(h_in, Kc_list_sdpa, Vc_list_sdpa, None, None)   # [1,1,D]
    h0_trit = dec_triton(h_in, Kc_list_sdpa, Vc_list_sdpa, None, None) # [1,1,D]

    max_abs_h = (h0_sdpa - h0_trit).abs().max().item()
    denom = h0_sdpa.abs().max().item() + 1e-6
    max_rel_h = max_abs_h / denom
    print(f"Full-stack: max|Δh|={max_abs_h:.3e}, maxRel={max_rel_h:.3e}")

    # Isolate first decoder block CA only (q_proj + attention + out_proj), no FFN
    blk_s = dec_sdpa.blocks[0]
    blk_t = dec_triton.blocks[0]
    q1 = blk_s.ln1(h_in)                                # same LN in both copies
    K0 = Kc_list_sdpa[0]                                 # [1,H,Nc,Dh]
    V0 = Vc_list_sdpa[0]
    y_ca_s = blk_s.ca(q1, K0, V0, None, None)           # [1,1,D]
    y_ca_t = blk_t.ca(q1, K0, V0, None, None)           # [1,1,D]
    max_abs_ca = (y_ca_s - y_ca_t).abs().max().item()
    denom_ca = y_ca_s.abs().max().item() + 1e-6
    max_rel_ca = max_abs_ca / denom_ca
    print(f"First-block CA: max|Δy|={max_abs_ca:.3e}, maxRel={max_rel_ca:.3e}")

    # ---- Kernel vs SDPA micro-check on per-head tensors ----
    H = dec_sdpa.n_heads
    Dh = y_ca_s.shape[-1] // H
    qh = cc.split_heads(blk_s.ca.q_proj(q1), H).contiguous()  # [1,H,1,Dh]
    kh = K0.contiguous()                                      # [1,H,Nc,Dh]
    vh = V0.contiguous()
    # Expand context across B if needed (here B==1)
    kh_b = kh.expand(qh.shape[0], -1, -1, -1).contiguous()
    vh_b = vh.expand(qh.shape[0], -1, -1, -1).contiguous()

    # PyTorch SDPA reference (per-head)
    yh_ref = torch.nn.functional.scaled_dot_product_attention(
        qh, kh_b, vh_b, attn_mask=None, dropout_p=0.0, is_causal=False
    )  # [1,H,1,Dh]
    # Triton kernel output (per-head)
    yh_tri = cc._triton_ca_lq1_forward(qh, kh_b, vh_b)       # [1,H,1,Dh]
    # Manual fp32 reference
    qf = qh.float()
    kf = kh_b.float()
    vf = vh_b.float()
    scores = (kf * qf.expand_as(kf)).sum(dim=-1) / (Dh ** 0.5)  # [1,H,Nc]
    scores = scores - scores.amax(dim=2, keepdim=True)
    w = scores.exp()
    denom = w.sum(dim=2, keepdim=True)
    outf = (w.unsqueeze(-1) * vf).sum(dim=2) / denom            # [1,H,Dh]
    yh_man = outf.unsqueeze(2).to(yh_ref.dtype)

    def _max_diff(a: torch.Tensor, b: torch.Tensor) -> tuple[float, float]:
        ma = (a - b).abs().max().item()
        mr = ma / (a.abs().max().item() + 1e-6)
        return ma, mr

    d_ref_mod = _max_diff(yh_ref, cc.split_heads(blk_s.ca(q1, K0, V0, None, None), H))
    d_ref_tri = _max_diff(yh_ref, yh_tri)
    d_man_ref = _max_diff(yh_man, yh_ref)
    d_man_tri = _max_diff(yh_man, yh_tri)
    print("Kernel micro-check (per-head):")
    print(f"  SDPA module vs SDPA per-head:   max|Δ|={d_ref_mod[0]:.3e}, maxRel={d_ref_mod[1]:.3e}")
    print(f"  Triton kernel vs SDPA per-head: max|Δ|={d_ref_tri[0]:.3e}, maxRel={d_ref_tri[1]:.3e}")
    print(f"  Manual vs SDPA per-head:        max|Δ|={d_man_ref[0]:.3e}, maxRel={d_man_ref[1]:.3e}")
    print(f"  Manual vs Triton per-head:      max|Δ|={d_man_tri[0]:.3e}, maxRel={d_man_tri[1]:.3e}")


def main() -> None:
    parser = argparse.ArgumentParser(description="M3 first-step parity (SDPA vs Triton)")
    parser.add_argument("--Nc", type=int, default=16, help="Number of context points")
    parser.add_argument("--Nt", type=int, default=16, help="Number of target points (only first used)")
    # Use parse_known_args to play nice with Jupyter, which injects a "-f <kernel.json>" arg
    args, _ = parser.parse_known_args()
    run_first_step_debug(DEFAULT, Nc=args.Nc, Nt=args.Nt)


if __name__ == "__main__":
    main()
