from __future__ import annotations

"""
Baseline log-likelihood evaluation times for 4 baselines:
- TNPD (independent)
- TNPD (autoregressive)
- TNPA (autoregressive)
- TNPND (independent MVN)

We interpret the grid's `num_samples_values` as the number of AR orders (num_perms=B)
for AR methods; for order-independent methods (M1/M5), cost scales linearly if we repeat,
but semantics are order-independent. We preserve the same JSON structure used by plotting.
"""

import os
import sys
import time
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import torch

# Allow running as a file or pasted into a notebook: add repo root to sys.path
if __package__ is None or __package__ == "":
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    ROOT = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            ROOT = cand
            break
    ROOT = ROOT or base
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

from scripts.fast_times.config import BenchConfig, DEFAULT
from scripts.fast_times.common import (
    configure_torch_env,
    _ensure_dir,
    _atomic_write_json,
    check_gpu_memory,
    clear_gpu_memory,
    data_gen,
    package_for_plot,
)
from src.models.benchmarks import TNPA, TNPD, TNPND
from scripts.fast_times.ace_encoder import patch_tnp_encoder_with_ace


def _permute_targets_once(xt: torch.Tensor, yt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    B, T = xt.shape[0], xt.shape[1]
    device = xt.device
    r = torch.rand(B, T, device=device)
    perm = torch.argsort(r, dim=1)
    b_idx = torch.arange(B, device=device)[:, None]
    return xt[b_idx, perm], yt[b_idx, perm]


def benchmark_ll_baseline(
    model: torch.nn.Module,
    Nc: int,
    Nt: int,
    *,
    num_perms: int,
    num_runs: int,
    method_label: str,
    dx: int = 1,
    dy: int = 1,
) -> Tuple[np.ndarray, float, float]:
    """Expand batch to B=num_perms and call once per timed run (no Python loops)."""
    params = list(model.parameters())
    device = params[0].device if params else ("cuda" if torch.cuda.is_available() else "cpu")
    dtype = params[0].dtype if params else torch.float32

    # Ensure eval mode and no grad for warmup and timing
    model.eval()

    # Base batch (B0=1)
    xc0, yc0, xt0, yt0 = data_gen(Nc, Nt, B=1, dx=dx, dy=dy, device=device, dtype=dtype)

    # Build B=num_perms batch by tiling context
    B = num_perms
    xcB = xc0.expand(B, -1, -1).contiguous()
    ycB = yc0.expand(B, -1, -1).contiguous()

    if method_label in {"TNPD-AR", "TNPA"}:
        # AR: make independent per-row permutations
        xtB = xt0.expand(B, -1, -1).contiguous()
        ytB = yt0.expand(B, -1, -1).contiguous()
        xtB, ytB = _permute_targets_once(xtB, ytB)
    else:
        # Order-independent: no permutation, just tile
        xtB = xt0.expand(B, -1, -1).contiguous()
        ytB = yt0.expand(B, -1, -1).contiguous()

    # Warm-up (no grad to avoid activation storage + leaks)
    with torch.no_grad():
        for _ in range(2):
            if method_label == "TNPD-Independent":
                _ = model.eval_log_likelihood(xcB, ycB, xtB, ytB)
            else:
                _ = model.eval_log_joint_likelihood(xcB, ycB, xtB, ytB)
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

    timings: list[float] = []
    for i in range(num_runs):
        if torch.cuda.is_available():
            if i % 5 == 0:
                torch.cuda.empty_cache()
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()
            with torch.no_grad():
                if method_label == "TNPD-Independent":
                    _ = model.eval_log_likelihood(xcB, ycB, xtB, ytB)
                else:
                    _ = model.eval_log_joint_likelihood(xcB, ycB, xtB, ytB)
            end_event.record()
            torch.cuda.synchronize()
            timings.append(start_event.elapsed_time(end_event) / 1000.0)
        else:
            t0 = time.perf_counter()
            with torch.no_grad():
                if method_label == "TNPD-Independent":
                    _ = model.eval_log_likelihood(xcB, ycB, xtB, ytB)
                else:
                    _ = model.eval_log_joint_likelihood(xcB, ycB, xtB, ytB)
            timings.append(time.perf_counter() - t0)

    t = np.array(timings)
    # Proactively drop large local tensors before returning
    del xc0, yc0, xt0, yt0, xcB, ycB, xtB, ytB
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    return t, float(t.mean()), float(t.std())


def run_ll_grid_baselines(model: torch.nn.Module, *, Nt: int, num_runs: int, method_label: str,
                          Nc_values, num_samples_values, dx: int, dy: int) -> Dict[str, list]:
    results = {
        'Nc': [],
        'num_samples': [],
        'mean_time': [],
        'std_time': [],
    }
    for Nc in Nc_values:
        for B in num_samples_values:
            if not check_gpu_memory(2.0):
                clear_gpu_memory()
            try:
                # For order-independent baselines, eval cost doesn't depend on B; avoid huge tiling
                B_eff = 1 if method_label in {"TNPD-Independent", "TNP-ND"} else B
                _, mean_t, std_t = benchmark_ll_baseline(
                    model, Nc=Nc, Nt=Nt, num_perms=B_eff, num_runs=num_runs, method_label=method_label, dx=dx, dy=dy
                )
                results['Nc'].append(Nc)
                results['num_samples'].append(B)
                results['mean_time'].append(mean_t)
                results['std_time'].append(std_t)
                print(f"Nc={Nc} B={B}: {mean_t:.4f} ± {std_t:.4f} s")
                clear_gpu_memory()
            except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
                print(f"Skip Nc={Nc} B={B} due to error: {str(e)[:100]}")
                results['Nc'].append(Nc)
                results['num_samples'].append(B)
                # Use -1 sentinel instead of NaN to keep JSON/plot pipelines simple
                results['mean_time'].append(-1.0)
                results['std_time'].append(-1.0)
                clear_gpu_memory()
    return results


def build_baseline_models(cfg: BenchConfig):
    device = cfg.runtime.device
    # Baselines run in float32 for robustness (match notebook defaults)
    dtype = torch.float32
    D = cfg.dims
    m_tnpd = TNPD(
        dim_x=D.dx, dim_y=D.dy,
        d_model=D.d_model, emb_depth=1, dim_feedforward=D.d_ff,
        nhead=D.n_heads, dropout=0.0, num_layers=D.n_layers_enc, bound_std=False,
        pos_emb_init=False,
    ).to(device).to(dtype)
    m_tnpa = TNPA(
        dim_x=D.dx, dim_y=D.dy,
        d_model=D.d_model, emb_depth=1, dim_feedforward=D.d_ff,
        nhead=D.n_heads, dropout=0.0, num_layers=D.n_layers_enc, bound_std=False,
        permute=False, pos_emb_init=False,
    ).to(device).to(dtype)
    m_tnpnd = TNPND(
        dim_x=D.dx, dim_y=D.dy,
        d_model=D.d_model, emb_depth=1, dim_feedforward=D.d_ff,
        nhead=D.n_heads, dropout=0.0, num_layers=D.n_layers_enc,
        num_std_layers=2, bound_std=False,
        pos_emb_init=False,
    ).to(device).to(dtype)
    # Optimize encoder for non-AR baselines (TNPD, TNP-ND)
    m_tnpd = patch_tnp_encoder_with_ace(m_tnpd)
    m_tnpnd = patch_tnp_encoder_with_ace(m_tnpnd)
    return m_tnpd, m_tnpa, m_tnpnd


def main(cfg: BenchConfig = DEFAULT) -> None:
    configure_torch_env()
    # Baselines use masked attention in encoders; allow math/mem-efficient SDPA
    if torch.cuda.is_available():
        try:
            torch.backends.cuda.enable_math_sdp(True)
            torch.backends.cuda.enable_mem_efficient_sdp(True)
        except Exception:
            pass
    m_tnpd, m_tnpa, m_tnpnd = build_baseline_models(cfg)

    Nc_vals = tuple(cfg.grid.Nc_values)
    ns_vals = tuple(cfg.grid.num_samples_values)
    Nt = cfg.grid.Nt
    runs = cfg.grid.num_runs
    dx, dy = cfg.dims.dx, cfg.dims.dy

    methods: Dict[str, dict] = {}

    print("\n=== LL: TNPD (independent) ===")
    methods["TNPD-Independent"] = run_ll_grid_baselines(
        m_tnpd, Nt=Nt, num_runs=runs, method_label="TNPD-Independent",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )

    print("\n=== LL: TNPD (autoregressive) ===")
    methods["TNPD-AR"] = run_ll_grid_baselines(
        m_tnpd, Nt=Nt, num_runs=runs, method_label="TNPD-AR",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )

    print("\n=== LL: TNPA (autoregressive) ===")
    methods["TNPA"] = run_ll_grid_baselines(
        m_tnpa, Nt=Nt, num_runs=runs, method_label="TNPA",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )

    print("\n=== LL: TNP-ND (independent MVN) ===")
    methods["TNP-ND"] = run_ll_grid_baselines(
        m_tnpnd, Nt=Nt, num_runs=runs, method_label="TNP-ND",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )

    meta = {
        "script": "baseline_ll",
        "config": cfg.to_dict(),
    }
    out = package_for_plot(methods, meta=meta)
    out_dir = cfg.out_dir
    _ensure_dir(os.path.join(out_dir, "_"))
    out_path = os.path.join(out_dir, "baseline_ll.json")
    _atomic_write_json(out_path, out)
    print(f"Saved: {out_path}")


if __name__ == "__main__":
    main()
