from __future__ import annotations

"""
Triton fast-path LL times for our AR method (M3 only).
Mandatory Triton: raises if Triton is unavailable.
"""

import os
import sys
import time
from pathlib import Path

import numpy as np
import torch

# Allow running as a file or pasted into a notebook: add repo root to sys.path
if __package__ is None or __package__ == "":
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    ROOT = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            ROOT = cand
            break
    ROOT = ROOT or base
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

from scripts.fast_times.config import BenchConfig, DEFAULT
from scripts.fast_times.common import (
    configure_torch_env,
    _ensure_dir,
    _atomic_write_json,
    check_gpu_memory,
    clear_gpu_memory,
    data_gen,
    package_for_plot,
)
from scripts.fast_times.compiled_core import (
    ContextEmbedder, TargetEmbedder, BufferEmbedder,
    ContextEncoder, Decoder, GaussianHead,
    CompiledFiveMethodLLAdapter, patch_decoder_with_triton, TRITON_AVAILABLE,
)


def build_modules(cfg: BenchConfig):
    device = cfg.runtime.device
    dtype = torch.float16 if (device == "cuda" and torch.cuda.is_available()) else torch.float32
    D = cfg.dims
    emb_ctx = ContextEmbedder(D.dx + D.dy, D.d_model).to(device).to(dtype)
    emb_tgt = TargetEmbedder(D.dx, D.d_model).to(device).to(dtype)
    emb_buf = BufferEmbedder(D.dx + D.dy, D.d_model).to(device).to(dtype)
    head = GaussianHead(D.d_model, D.dy).to(device).to(dtype)
    ctx_enc = ContextEncoder(D.d_model, D.n_heads, D.n_layers_enc, D.d_ff).to(device).to(dtype)
    dec = Decoder(D.d_model, D.n_heads, D.n_layers_dec, D.d_ff).to(device).to(dtype)
    patch_decoder_with_triton(dec, verbose=False, force_triton=True)
    return emb_ctx, emb_tgt, emb_buf, head, ctx_enc, dec


def benchmark_ll(adapter: CompiledFiveMethodLLAdapter, Nc: int, Nt: int, *, num_perms: int, num_runs: int,
                 dx: int, dy: int) -> tuple[np.ndarray, float, float]:
    device = next(adapter.parameters()).device
    dtype = next(adapter.parameters()).dtype
    xc, yc, xt, yt = data_gen(Nc, Nt, B=1, dx=dx, dy=dy, device=device, dtype=dtype)
    for _ in range(2):
        _ = adapter.eval_log_joint_likelihood(xc, yc, xt, yt, num_perms=num_perms)
    timings: list[float] = []
    for i in range(num_runs):
        if torch.cuda.is_available():
            if i % 5 == 0:
                torch.cuda.empty_cache()
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()
            with torch.no_grad():
                _ = adapter.eval_log_joint_likelihood(xc, yc, xt, yt, num_perms=num_perms)
            end_event.record()
            torch.cuda.synchronize()
            timings.append(start_event.elapsed_time(end_event) / 1000.0)
        else:
            t0 = time.perf_counter()
            with torch.no_grad():
                _ = adapter.eval_log_joint_likelihood(xc, yc, xt, yt, num_perms=num_perms)
            timings.append(time.perf_counter() - t0)
    t = np.array(timings)
    return t, float(t.mean()), float(t.std())


def main(cfg: BenchConfig = DEFAULT) -> None:
    if not TRITON_AVAILABLE:
        raise RuntimeError("Triton is required for run_triton_ll; please install Triton on a CUDA system.")
    configure_torch_env()
    emb_ctx, emb_tgt, emb_buf, head, ctx_enc, dec = build_modules(cfg)
    m3 = CompiledFiveMethodLLAdapter("m3", ctx_enc, dec, emb_ctx, emb_tgt, emb_buf, head_diag=head, mv_head=None)

    Nc_vals = tuple(cfg.grid.Nc_values)
    ns_vals = tuple(cfg.grid.num_samples_values)
    Nt = cfg.grid.Nt
    runs = cfg.grid.num_runs
    dx, dy = cfg.dims.dx, cfg.dims.dy

    results = {"Nc": [], "num_samples": [], "mean_time": [], "std_time": []}
    print("\n=== TRITON LL (M3 only) ===")
    for Nc in Nc_vals:
        for B in ns_vals:
            if not check_gpu_memory(2.0):
                clear_gpu_memory()
            try:
                _, mean_t, std_t = benchmark_ll(m3, Nc=Nc, Nt=Nt, num_perms=B, num_runs=runs, dx=dx, dy=dy)
                results['Nc'].append(Nc)
                results['num_samples'].append(B)
                results['mean_time'].append(mean_t)
                results['std_time'].append(std_t)
                print(f"Nc={Nc} B={B}: {mean_t:.4f} ± {std_t:.4f} s")
            except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
                print(f"Skip Nc={Nc} B={B} due to error: {str(e)[:100]}")
                results['Nc'].append(Nc)
                results['num_samples'].append(B)
                results['mean_time'].append(float('nan'))
                results['std_time'].append(float('nan'))
                clear_gpu_memory()

    meta = {
        "script": "triton_ll",
        "config": cfg.to_dict(),
    }
    out = package_for_plot({"TR LL M3 Ours": results}, meta=meta)
    out_dir = cfg.out_dir
    _ensure_dir(os.path.join(out_dir, "_"))
    out_path = os.path.join(out_dir, "triton_ll.json")
    _atomic_write_json(out_path, out)
    print(f"Saved: {out_path}")


if __name__ == "__main__":
    main()
