from __future__ import annotations

"""
Triton fast-path LL times for our AR method (M3 only).
If Triton is not available, uses SDPA with identical call semantics.
"""

import os
import sys
import time
from pathlib import Path
from typing import Dict

import numpy as np
import torch

# Allow running as a file or pasted into a notebook: add repo root to sys.path
if __package__ is None or __package__ == "":
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    ROOT = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            ROOT = cand
            break
    ROOT = ROOT or base
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

from scripts.fast_times.config import BenchConfig, DEFAULT
from scripts.fast_times.common import (
    configure_torch_env,
    _ensure_dir,
    _atomic_write_json,
    check_gpu_memory,
    clear_gpu_memory,
    data_gen,
    package_for_plot,
)
from scripts.fast_times.compiled_core import (
    ContextEmbedder, TargetEmbedder, BufferEmbedder,
    ContextEncoder, Decoder, GaussianHead,
    CompiledFiveMethodLLAdapter, patch_decoder_with_triton,
)


def build_modules(cfg: BenchConfig):
    device = cfg.runtime.device
    dtype = torch.float16 if (device == "cuda" and torch.cuda.is_available()) else torch.float32
    D = cfg.dims
    emb_ctx = ContextEmbedder(D.dx + D.dy, D.d_model).to(device).to(dtype)
    emb_tgt = TargetEmbedder(D.dx, D.d_model).to(device).to(dtype)
    emb_buf = BufferEmbedder(D.dx + D.dy, D.d_model).to(device).to(dtype)
    head = GaussianHead(D.d_model, D.dy).to(device).to(dtype)
    ctx_enc = ContextEncoder(D.d_model, D.n_heads, D.n_layers_enc, D.d_ff).to(device).to(dtype)
    dec = Decoder(D.d_model, D.n_heads, D.n_layers_dec, D.d_ff).to(device).to(dtype)
    patch_decoder_with_triton(dec, verbose=False, force_triton=False)
    return emb_ctx, emb_tgt, emb_buf, head, ctx_enc, dec


def benchmark_ll(adapter: CompiledFiveMethodLLAdapter, Nc: int, Nt: int, *, num_perms: int, num_runs: int,
                 dx: int, dy: int) -> tuple[np.ndarray, float, float]:
    device = next(adapter.parameters()).device
    dtype = next(adapter.parameters()).dtype
    xc, yc, xt, yt = data_gen(Nc, Nt, B=1, dx=dx, dy=dy, device=device, dtype=dtype)
    # warm-up
    for _ in range(2):
        _ = adapter.eval_log_joint_likelihood(xc, yc, xt, yt, num_perms=num_perms)
    timings: list[float] = []
    for i in range(num_runs):
        if torch.cuda.is_available():
            if i % 5 == 0:
                torch.cuda.empty_cache()
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()
            with torch.no_grad():
                _ = adapter.eval_log_joint_likelihood(xc, yc, xt, yt, num_perms=num_perms)
            end_event.record()
            torch.cuda.synchronize()
            timings.append(start_event.elapsed_time(end_event) / 1000.0)
        else:
            t0 = time.perf_counter()
            with torch.no_grad():
                _ = adapter.eval_log_joint_likelihood(xc, yc, xt, yt, num_perms=num_perms)
            timings.append(time.perf_counter() - t0)
    t = np.array(timings)
    return t, float(t.mean()), float(t.std())


def main(cfg: BenchConfig = DEFAULT) -> None:
    configure_torch_env()
    emb_ctx, emb_tgt, emb_buf, head, ctx_enc, dec = build_modules(cfg)
    m3 = CompiledFiveMethodLLAdapter("m3", ctx_enc, dec, emb_ctx, emb_tgt, emb_buf, head_diag=head, mv_head=None)

    Nc_vals = tuple(cfg.grid.Nc_values)
    ns_vals = tuple(cfg.grid.num_samples_values)
    Nt = cfg.grid.Nt
    runs = cfg.grid.num_runs
    dx, dy = cfg.dims.dx, cfg.dims.dy

    results = {"Nc": [], "num_samples": [], "mean_time": [], "std_time": []}
    print("\n=== TRITON LL (M3 only) ===")
    for Nc in Nc_vals:
        for B in ns_vals:
            if not check_gpu_memory(2.0):
                clear_gpu_memory()
            try:
                _, mean_t, std_t = benchmark_ll(m3, Nc=Nc, Nt=Nt, num_perms=B, num_runs=runs, dx=dx, dy=dy)
                results['Nc'].append(Nc)
                results['num_samples'].append(B)
                results['mean_time'].append(mean_t)
                results['std_time'].append(std_t)
                print(f"Nc={Nc} B={B}: {mean_t:.4f} ± {std_t:.4f} s")
            except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
                print(f"Skip Nc={Nc} B={B} due to error: {str(e)[:100]}")
                results['Nc'].append(Nc)
                results['num_samples'].append(B)
                results['mean_time'].append(float('nan'))
                results['std_time'].append(float('nan'))
                clear_gpu_memory()

    meta = {
        "script": "triton_ll_ours",
        "config": cfg.to_dict(),
    }
    out = package_for_plot({"TR LL M3 Ours": results}, meta=meta)
    out_dir = cfg.out_dir
    _ensure_dir(os.path.join(out_dir, "_"))
    out_path = os.path.join(out_dir, "triton_ll_ours.json")
    _atomic_write_json(out_path, out)
    print(f"Saved: {out_path}")


if __name__ == "__main__":
    main()
