from __future__ import annotations

"""
Compiled log-likelihood evaluation times for five methods (M1..M5), matching the
"Log Likelihood Evaluations Notebook" flow: torch.compile, dynamic shapes,
and per-(Nc,T) compiled runners that accept num_perms=B for AR methods.
"""

import os
import sys
import time
from pathlib import Path
from typing import Dict

import numpy as np
import torch

# Allow running as a file or pasted into a notebook: add repo root to sys.path
if __package__ is None or __package__ == "":
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    ROOT = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            ROOT = cand
            break
    ROOT = ROOT or base
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

from scripts.fast_times.config import BenchConfig, DEFAULT
from scripts.fast_times.common import (
    configure_torch_env,
    _ensure_dir,
    _atomic_write_json,
    check_gpu_memory,
    clear_gpu_memory,
    data_gen,
    package_for_plot,
)
from scripts.fast_times.compiled_core import (
    ContextEmbedder, TargetEmbedder, BufferEmbedder,
    ContextEncoder, Decoder, GaussianHead, MvGaussianHead,
    CompiledFiveMethodLLAdapter,
)


def build_modules(cfg: BenchConfig):
    device = cfg.runtime.device
    dtype = torch.float16 if (device == "cuda" and torch.cuda.is_available()) else torch.float32
    D = cfg.dims
    emb_ctx = ContextEmbedder(D.dx + D.dy, D.d_model).to(device).to(dtype)
    emb_tgt = TargetEmbedder(D.dx, D.d_model).to(device).to(dtype)
    emb_buf = BufferEmbedder(D.dx + D.dy, D.d_model).to(device).to(dtype)
    head = GaussianHead(D.d_model, D.dy).to(device).to(dtype)
    mv_head = MvGaussianHead(d_model=D.d_model, dy=D.dy, n_heads=D.n_heads, d_ff=D.d_ff,
                             n_std_layers=2, prj_dim=8, bound_diag=True, min_diag=0.05).to(device).to(dtype)
    ctx_enc = ContextEncoder(D.d_model, D.n_heads, D.n_layers_enc, D.d_ff).to(device).to(dtype)
    dec = Decoder(D.d_model, D.n_heads, D.n_layers_dec, D.d_ff).to(device).to(dtype)
    return emb_ctx, emb_tgt, emb_buf, head, mv_head, ctx_enc, dec


def benchmark_ll_compiled(adapter: CompiledFiveMethodLLAdapter, Nc: int, Nt: int, *, num_perms: int, num_runs: int,
                          dx: int, dy: int) -> tuple[np.ndarray, float, float]:
    device = next(adapter.parameters()).device
    dtype = next(adapter.parameters()).dtype
    xc, yc, xt, yt = data_gen(Nc, Nt, B=1, dx=dx, dy=dy, device=device, dtype=dtype)
    # Warm-up twice under no_grad to avoid building autograd graphs and allow cudagraph fast paths
    with torch.no_grad():
        for _ in range(2):
            _ = adapter.eval_log_joint_likelihood(xc, yc, xt, yt, num_perms=num_perms)
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

    timings: list[float] = []
    for i in range(num_runs):
        if torch.cuda.is_available():
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()
            with torch.no_grad():
                _ = adapter.eval_log_joint_likelihood(xc, yc, xt, yt, num_perms=num_perms)
            end_event.record()
            torch.cuda.synchronize()
            timings.append(start_event.elapsed_time(end_event) / 1000.0)
        else:
            t0 = time.perf_counter()
            with torch.no_grad():
                _ = adapter.eval_log_joint_likelihood(xc, yc, xt, yt, num_perms=num_perms)
            timings.append(time.perf_counter() - t0)
    # Drop first measurement to avoid counting any lingering setup
    t = np.array(timings[1:] if len(timings) > 1 else timings)
    return t, float(t.mean()), float(t.std())


def run_ll_grid_compiled(adapter: CompiledFiveMethodLLAdapter, *, Nt: int, num_runs: int,
                         Nc_values, num_samples_values, dx: int, dy: int) -> Dict[str, list]:
    results = {"Nc": [], "num_samples": [], "mean_time": [], "std_time": []}
    for Nc in Nc_values:
        for B in num_samples_values:
            if not check_gpu_memory(2.0):
                clear_gpu_memory()
            try:
                _, mean_t, std_t = benchmark_ll_compiled(adapter, Nc=Nc, Nt=Nt, num_perms=B, num_runs=num_runs, dx=dx, dy=dy)
                results['Nc'].append(Nc)
                results['num_samples'].append(B)
                results['mean_time'].append(mean_t)
                results['std_time'].append(std_t)
                print(f"Nc={Nc} B={B}: {mean_t:.4f} ± {std_t:.4f} s")
            except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
                print(f"Skip Nc={Nc} B={B} due to error: {str(e)[:100]}")
                results['Nc'].append(Nc)
                results['num_samples'].append(B)
                results['mean_time'].append(float('nan'))
                results['std_time'].append(float('nan'))
                clear_gpu_memory()
            finally:
                clear_gpu_memory()
    return results


def main(cfg: BenchConfig = DEFAULT) -> None:
    configure_torch_env()
    emb_ctx, emb_tgt, emb_buf, head, mv_head, ctx_enc, dec = build_modules(cfg)

    m1 = CompiledFiveMethodLLAdapter("m1", ctx_enc, dec, emb_ctx, emb_tgt, emb_buf, head_diag=head, mv_head=None)
    m2 = CompiledFiveMethodLLAdapter("m2", ctx_enc, dec, emb_ctx, emb_tgt, emb_buf, head_diag=head, mv_head=None)
    m3 = CompiledFiveMethodLLAdapter("m3", ctx_enc, dec, emb_ctx, emb_tgt, emb_buf, head_diag=head, mv_head=None)
    m4 = CompiledFiveMethodLLAdapter("m4", ctx_enc, dec, emb_ctx, emb_tgt, emb_buf, head_diag=head, mv_head=None)
    m5 = CompiledFiveMethodLLAdapter("m5", ctx_enc, dec, emb_ctx, emb_tgt, emb_buf, head_diag=None, mv_head=mv_head)

    Nc_vals = tuple(cfg.grid.Nc_values)
    ns_vals = tuple(cfg.grid.num_samples_values)
    Nt = cfg.grid.Nt
    runs = cfg.grid.num_runs
    dx, dy = cfg.dims.dx, cfg.dims.dy

    methods: Dict[str, dict] = {}
    # Run TNPA and TNP-ND first
    print("\n=== LL M4 (TNPA masked AR) ===")
    methods["LL M4 TNPA AR"] = run_ll_grid_compiled(m4, Nt=Nt, num_runs=runs, Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy)
    clear_gpu_memory()
    print("\n=== LL M5 (TNP-ND MVN) ===")
    methods["LL M5 TNP-ND"] = run_ll_grid_compiled(m5, Nt=Nt, num_runs=runs, Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy)
    clear_gpu_memory()

    # Then run the remaining three
    print("\n=== LL M1 (independent) ===")
    methods["LL M1 TNP-D Indep"] = run_ll_grid_compiled(m1, Nt=Nt, num_runs=runs, Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy)
    clear_gpu_memory()
    print("\n=== LL M2 (AR re-encode) ===")
    methods["LL M2 TNP-D AR"] = run_ll_grid_compiled(m2, Nt=Nt, num_runs=runs, Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy)
    clear_gpu_memory()
    print("\n=== LL M3 (AR buffer) ===")
    methods["LL M3 Ours AR Buffer"] = run_ll_grid_compiled(m3, Nt=Nt, num_runs=runs, Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy)
    clear_gpu_memory()

    meta = {
        "script": "compiled_ll",
        "config": cfg.to_dict(),
    }
    out = package_for_plot(methods, meta=meta)
    out_dir = cfg.out_dir
    _ensure_dir(os.path.join(out_dir, "_"))
    out_path = os.path.join(out_dir, "compiled_ll.json")
    _atomic_write_json(out_path, out)
    print(f"Saved: {out_path}")


if __name__ == "__main__":
    main()
