from __future__ import annotations

"""
Compiled sampling-time benchmarks for five methods (M1..M5):
  m1: TNP-D independent
  m2: TNP-D AR (re-encode)
  m3: ours AR (buffer KV)
  m4: TNPA AR (masked re-encode)
  m5: TNP-ND (multivariate; decode once, we sample via diagonal head surrogate)

Saves a JSON compatible with the plotting helper.
"""

import os
import sys
from pathlib import Path
from typing import Dict

import torch

# Allow running as a file or pasted into a notebook: add repo root to sys.path
if __package__ is None or __package__ == "":
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    ROOT = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            ROOT = cand
            break
    ROOT = ROOT or base
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

from scripts.fast_times.config import BenchConfig, DEFAULT
from scripts.fast_times.common import (
    configure_torch_env,
    _ensure_dir,
    _atomic_write_json,
    run_benchmark_grid,
    package_for_plot,
)
from scripts.fast_times.compiled_core import (
    ContextEmbedder, TargetEmbedder, BufferEmbedder,
    ContextEncoder, Decoder, GaussianHead, MvGaussianHead,
    CompiledFiveMethodAdapter,
)


def build_modules(cfg: BenchConfig):
    device = cfg.runtime.device
    dtype = torch.float16 if (device == "cuda" and torch.cuda.is_available()) else torch.float32
    D = cfg.dims
    emb_ctx = ContextEmbedder(D.dx + D.dy, D.d_model).to(device).to(dtype)
    emb_tgt = TargetEmbedder(D.dx, D.d_model).to(device).to(dtype)
    emb_buf = BufferEmbedder(D.dx + D.dy, D.d_model).to(device).to(dtype)
    head = GaussianHead(D.d_model, D.dy).to(device).to(dtype)
    mv_head = MvGaussianHead(d_model=D.d_model, dy=D.dy, n_heads=D.n_heads, d_ff=D.d_ff,
                             n_std_layers=2, prj_dim=8, bound_diag=True, min_diag=0.05).to(device).to(dtype)
    ctx_enc = ContextEncoder(D.d_model, D.n_heads, D.n_layers_enc, D.d_ff).to(device).to(dtype)
    dec = Decoder(D.d_model, D.n_heads, D.n_layers_dec, D.d_ff).to(device).to(dtype)

    # Optional compilation of modules for fair full-pipeline timing
    def _compile_ok(mod: torch.nn.Module) -> torch.nn.Module:
        try:
            # Use a safe compile that avoids cudagraphs capture in reduce-overhead
            return torch.compile(mod, mode="default", fullgraph=False, dynamic=True)
        except Exception:
            return mod

    if os.environ.get("FAST_TIMES_COMPILE_MODULES", "1") not in {"0", "false", "False", "no", "NO"}:
        emb_ctx = _compile_ok(emb_ctx)
        emb_tgt = _compile_ok(emb_tgt)
        emb_buf = _compile_ok(emb_buf)
        head = _compile_ok(head)
    # Always compile decoder for fair full-pipeline timing
    dec = _compile_ok(dec)
    return emb_ctx, emb_tgt, emb_buf, head, mv_head, ctx_enc, dec


def main(cfg: BenchConfig = DEFAULT) -> None:
    configure_torch_env()
    emb_ctx, emb_tgt, emb_buf, head, mv_head, ctx_enc, dec = build_modules(cfg)

    # Build adapters
    m1 = CompiledFiveMethodAdapter("m1", ctx_enc, dec, emb_ctx, emb_tgt, emb_buf, head)
    m2 = CompiledFiveMethodAdapter("m2", ctx_enc, dec, emb_ctx, emb_tgt, emb_buf, head)
    m3 = CompiledFiveMethodAdapter("m3", ctx_enc, dec, emb_ctx, emb_tgt, emb_buf, head)
    m4 = CompiledFiveMethodAdapter("m4", ctx_enc, dec, emb_ctx, emb_tgt, emb_buf, head)
    m5 = CompiledFiveMethodAdapter("m5", ctx_enc, dec, emb_ctx, emb_tgt, emb_buf, head, mv_head=mv_head)

    Nc_vals = tuple(cfg.grid.Nc_values)
    ns_vals = tuple(cfg.grid.num_samples_values)
    Nt = cfg.grid.Nt
    runs = cfg.grid.num_runs
    dx, dy = cfg.dims.dx, cfg.dims.dy

    methods: Dict[str, dict] = {}
    print("\n=== M1 (independent) ===")
    methods["M1 TNP-D Indep"] = run_benchmark_grid(
        m1, Nt=Nt, num_runs=runs, method="independent",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )
    print("\n=== M2 (AR re-encode) ===")
    methods["M2 TNP-D AR"] = run_benchmark_grid(
        m2, Nt=Nt, num_runs=runs, method="autoregressive",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )
    print("\n=== M3 (AR buffer) ===")
    methods["M3 Ours AR Buffer"] = run_benchmark_grid(
        m3, Nt=Nt, num_runs=runs, method="autoregressive",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )
    print("\n=== M4 (TNPA masked AR) ===")
    methods["M4 TNPA AR"] = run_benchmark_grid(
        m4, Nt=Nt, num_runs=runs, method="autoregressive",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )
    print("\n=== M5 (TNP-ND MVN) ===")
    methods["M5 TNP-ND"] = run_benchmark_grid(
        m5, Nt=Nt, num_runs=runs, method="independent",
        Nc_values=Nc_vals, num_samples_values=ns_vals, dx=dx, dy=dy,
    )

    meta = {
        "script": "compiled_sampling",
        "config": cfg.to_dict(),
    }
    out = package_for_plot(methods, meta=meta)
    out_dir = cfg.out_dir
    _ensure_dir(os.path.join(out_dir, "_"))
    out_path = os.path.join(out_dir, "compiled_sampling.json")
    _atomic_write_json(out_path, out)
    print(f"Saved: {out_path}")


if __name__ == "__main__":
    main()
