from __future__ import annotations

"""
Numerical parity check for M3 (AR buffer) between SDPA and Triton CA paths,
using the compiled_core model with MLP embedders and per-layer K/V precompute.

Two checks:
- Deterministic AR (uses mu as y): compares mu sequences exactly via unified_ar_driver.
- Stochastic AR sampling (num_samples=1): seeds RNG and compares sampled y with tolerance.
"""

import os
import sys
import copy
from pathlib import Path

import torch

if __package__ is None or __package__ == "":
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    ROOT = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            ROOT = cand
            break
    ROOT = ROOT or base
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

from scripts.fast_times.config import DEFAULT, BenchConfig
from scripts.fast_times.common import configure_torch_env, data_gen
from scripts.fast_times.compiled_core import (
    ContextEmbedder, TargetEmbedder, BufferEmbedder,
    ContextEncoder, Decoder, GaussianHead,
    CompiledFiveMethodAdapter, ARBufferProvider,
    unified_ar_driver, patch_decoder_with_triton,
)


def build_modules(cfg: BenchConfig):
    device = cfg.runtime.device
    dtype = torch.float16 if (device == "cuda" and torch.cuda.is_available()) else torch.float32
    D = cfg.dims
    emb_ctx = ContextEmbedder(D.dx + D.dy, D.d_model).to(device).to(dtype)
    emb_tgt = TargetEmbedder(D.dx, D.d_model).to(device).to(dtype)
    emb_buf = BufferEmbedder(D.dx + D.dy, D.d_model).to(device).to(dtype)
    head = GaussianHead(D.d_model, D.dy).to(device).to(dtype)
    ctx_enc = ContextEncoder(D.d_model, D.n_heads, D.n_layers_enc, D.d_ff).to(device).to(dtype)
    dec = Decoder(D.d_model, D.n_heads, D.n_layers_dec, D.d_ff).to(device).to(dtype)
    return emb_ctx, emb_tgt, emb_buf, head, ctx_enc, dec


@torch.no_grad()
def check_deterministic_mu(cfg: BenchConfig, Nc: int = 16, Nt: int = 16) -> None:
    print("Deterministic AR (mu-as-y) parity check...")
    device = cfg.runtime.device
    emb_ctx, emb_tgt, emb_buf, head, ctx_enc, dec_sdpa = build_modules(cfg)
    dec_triton = copy.deepcopy(dec_sdpa)
    patch_decoder_with_triton(dec_triton, verbose=False, force_triton=True)

    # Fixed synthetic data
    xc, yc, xt, _ = data_gen(Nc, Nt, B=1, dx=cfg.dims.dx, dy=cfg.dims.dy, device=device,
                             dtype=next(ctx_enc.parameters()).dtype)

    prov_sdpa = ARBufferProvider(ctx_enc, emb_ctx, emb_buf, dec_sdpa)
    prov_trit = ARBufferProvider(ctx_enc, emb_ctx, emb_buf, dec_triton)

    mu_sdpa, _ = unified_ar_driver(prov_sdpa, dec_sdpa, emb_tgt, head, emb_ctx, xc, yc, xt, dy=cfg.dims.dy)
    mu_trit, _ = unified_ar_driver(prov_trit, dec_triton, emb_tgt, head, emb_ctx, xc, yc, xt, dy=cfg.dims.dy)

    max_abs = (mu_sdpa - mu_trit).abs().max().item()
    max_rel = (mu_sdpa - mu_trit).abs().max().item() / (mu_sdpa.abs().max().item() + 1e-6)
    print(f"max_abs_diff: {max_abs:.6e}, max_rel_diff: {max_rel:.6e}")
    tol = 1e-4 if mu_sdpa.dtype == torch.float32 else 5e-3
    if max_abs <= tol or max_rel <= 5e-3:
        print("PASS: Deterministic mu parity within tolerance.")
    else:
        print("WARN: Differences exceed tolerance. Kernel or dtype drift present.")


@torch.no_grad()
def check_sampling(cfg: BenchConfig, Nc: int = 16, Nt: int = 16) -> None:
    print("Stochastic AR sampling parity check (num_samples=1)...")
    device = cfg.runtime.device
    emb_ctx, emb_tgt, emb_buf, head, ctx_enc, dec_sdpa = build_modules(cfg)
    dec_triton = copy.deepcopy(dec_sdpa)
    patch_decoder_with_triton(dec_triton, verbose=False, force_triton=True)

    m3_sdpa = CompiledFiveMethodAdapter("m3", ctx_enc, dec_sdpa, emb_ctx, emb_tgt, emb_buf, head)
    m3_trit = CompiledFiveMethodAdapter("m3", ctx_enc, dec_triton, emb_ctx, emb_tgt, emb_buf, head)

    # Fixed data and RNG
    dtype = next(ctx_enc.parameters()).dtype
    xc, yc, xt, _ = data_gen(Nc, Nt, B=1, dx=cfg.dims.dx, dy=cfg.dims.dy, device=device, dtype=dtype)

    torch.manual_seed(1234)
    y_sdpa = m3_sdpa.sample_joint_predictive(xc, yc, xt, num_samples=1)
    # Reset RNG so both draw identical eps
    torch.manual_seed(1234)
    y_trit = m3_trit.sample_joint_predictive(xc, yc, xt, num_samples=1)

    max_abs = (y_sdpa - y_trit).abs().max().item()
    max_rel = (y_sdpa - y_trit).abs().max().item() / (y_sdpa.abs().max().item() + 1e-6)
    print(f"max_abs_diff: {max_abs:.6e}, max_rel_diff: {max_rel:.6e}")
    tol = 2e-4 if y_sdpa.dtype == torch.float32 else 1e-2
    if max_abs <= tol or max_rel <= 1e-2:
        print("PASS: Sampled outputs within tolerance.")
    else:
        print("WARN: Differences exceed tolerance. Consider fp32 runs for tighter checks.")


def main(cfg: BenchConfig = DEFAULT):
    configure_torch_env()
    Nc = 16
    Nt = 16
    check_deterministic_mu(cfg, Nc=Nc, Nt=Nt)
    check_sampling(cfg, Nc=Nc, Nt=Nt)


if __name__ == "__main__":
    main()
