from __future__ import annotations

"""
Baseline sampling-time benchmarks for 4 baselines:
- TNPD (independent)
- TNPD (autoregressive)
- TNPA (autoregressive)
- TNPND (independent)

Saves a JSON with shape:
{ "metadata": {...},
  "methods": { label: {Nc, num_samples, mean_time, std_time} } }
"""

import os
import sys
from pathlib import Path
from typing import Dict

import torch

# Allow running as a file or pasted into a notebook: add repo root to sys.path
if __package__ is None or __package__ == "":
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    ROOT = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            ROOT = cand
            break
    ROOT = ROOT or base
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

from scripts.fast_times.config import BenchConfig, DEFAULT
from scripts.fast_times.common import (
    configure_torch_env,
    _ensure_dir,
    _atomic_write_json,
    run_benchmark_grid,
    package_for_plot,
)

from src.models.benchmarks import TNPA, TNPD, TNPND
from scripts.fast_times.ace_encoder import patch_tnp_encoder_with_ace


def build_baseline_models(cfg: BenchConfig):
    device = cfg.runtime.device
    # Baselines run in float32 for robustness
    dtype = torch.float32
    D = cfg.dims
    # Instantiate with repo defaults (kept close to notebooks' spirit: dx=dy=1, D=128, H=4, etc.)
    m_tnpd = TNPD(
        dim_x=D.dx, dim_y=D.dy,
        d_model=D.d_model, emb_depth=1, dim_feedforward=D.d_ff,
        nhead=D.n_heads, dropout=0.0, num_layers=D.n_layers_enc, bound_std=False,
        pos_emb_init=False,
    ).to(device).to(dtype)
    m_tnpa = TNPA(
        dim_x=D.dx, dim_y=D.dy,
        d_model=D.d_model, emb_depth=1, dim_feedforward=D.d_ff,
        nhead=D.n_heads, dropout=0.0, num_layers=D.n_layers_enc, bound_std=False,
        permute=False, pos_emb_init=False,
    ).to(device).to(dtype)
    m_tnpnd = TNPND(
        dim_x=D.dx, dim_y=D.dy,
        d_model=D.d_model, emb_depth=1, dim_feedforward=D.d_ff,
        nhead=D.n_heads, dropout=0.0, num_layers=D.n_layers_enc,
        num_std_layers=2, bound_std=False,
        pos_emb_init=False,
    ).to(device).to(dtype)
    # Use ACE encoder optimization for non-AR baselines (TNPD, TNP-ND)
    m_tnpd = patch_tnp_encoder_with_ace(m_tnpd)
    m_tnpnd = patch_tnp_encoder_with_ace(m_tnpnd)
    return m_tnpd, m_tnpa, m_tnpnd


def main(cfg: BenchConfig = DEFAULT) -> None:
    configure_torch_env()
    # Baselines use masked attention in encoders; allow math/mem-efficient SDPA
    if torch.cuda.is_available():
        try:
            torch.backends.cuda.enable_math_sdp(True)
            torch.backends.cuda.enable_mem_efficient_sdp(True)
        except Exception:
            pass
    m_tnpd, m_tnpa, m_tnpnd = build_baseline_models(cfg)

    Nc_vals = tuple(cfg.grid.Nc_values)
    ns_vals = tuple(cfg.grid.num_samples_values)
    Nt = cfg.grid.Nt
    runs = cfg.grid.num_runs
    dx, dy = cfg.dims.dx, cfg.dims.dy

    methods: Dict[str, dict] = {}

    print("\n=== TNPD (independent) ===")
    res_tnpd_ind = run_benchmark_grid(
        m_tnpd, Nt=Nt, num_runs=runs, method="independent",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )
    methods["TNPD-Independent"] = res_tnpd_ind

    print("\n=== TNPD (autoregressive) ===")
    res_tnpd_ar = run_benchmark_grid(
        m_tnpd, Nt=Nt, num_runs=runs, method="autoregressive",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )
    methods["TNPD-AR"] = res_tnpd_ar

    print("\n=== TNPA (autoregressive) ===")
    res_tnpa = run_benchmark_grid(
        m_tnpa, Nt=Nt, num_runs=runs, method="autoregressive",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )
    methods["TNPA"] = res_tnpa

    print("\n=== TNP-ND (independent MVN) ===")
    res_tnpnd = run_benchmark_grid(
        m_tnpnd, Nt=Nt, num_runs=runs, method="independent",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )
    methods["TNP-ND"] = res_tnpnd

    meta = {
        "script": "baseline_sampling",
        "config": cfg.to_dict(),
    }
    out = package_for_plot(methods, meta=meta)
    out_dir = cfg.out_dir
    _ensure_dir(os.path.join(out_dir, "_"))
    out_path = os.path.join(out_dir, "baseline_sampling.json")
    _atomic_write_json(out_path, out)
    print(f"Saved: {out_path}")


if __name__ == "__main__":
    main()
