from __future__ import annotations

"""
Toy forward+backward timings for baseline models (TNPD, TNPA, TNPND).

Measures the time spent in a combined forward + loss.backward() across a small
grid of batch sizes (B), context sizes (Nc), and target sizes (Nt).

Notes
- CUDA required (like other fast_times scripts).
- Models run in float32 for robustness.
- We reuse the fast_times helpers for environment setup and data generation.

Run as a module for clean imports:
  uv run python -m scripts.fast_times.run_backward_baselines
"""

import os
import sys
import time
from dataclasses import dataclass, asdict, field
from pathlib import Path
from typing import Dict, List, Tuple, Any

import numpy as np
import torch

# Allow running as a file or pasted into a notebook: add repo root to sys.path
if __package__ is None or __package__ == "":
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    ROOT = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            ROOT = cand
            break
    ROOT = ROOT or base
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

from scripts.fast_times.common import (
    configure_torch_env,
    _ensure_dir,
    _atomic_write_json,
    check_gpu_memory,
    clear_gpu_memory,
    data_gen,
)
from scripts.fast_times.config import BenchConfig, ModelDims, Runtime
from src.utils import DataAttr
from src.models.benchmarks import TNPA, TNPD, TNPND
from scripts.fast_times.ace_encoder import patch_tnp_encoder_with_ace


@dataclass
class BackwardGrid:
    B_values: Tuple[int, ...] = (1, 2, 4, 8)
    Nc_values: Tuple[int, ...] = (32, 128, 512)
    Nt_values: Tuple[int, ...] = (8, 16, 32)
    num_runs: int = 10


@dataclass
class BackwardConfig:
    grid: BackwardGrid = field(default_factory=BackwardGrid)
    dims: ModelDims = field(default_factory=ModelDims)
    runtime: Runtime = field(default_factory=Runtime)
    out_dir: str = "outputs/fast_times"
    name: str = "backward_baselines"

    def to_dict(self) -> Dict[str, Any]:
        return {
            "grid": asdict(self.grid),
            "dims": asdict(self.dims),
            "runtime": asdict(self.runtime),
            "out_dir": self.out_dir,
            "name": self.name,
        }


def build_baseline_models(cfg: BackwardConfig):
    device = cfg.runtime.device
    dtype = torch.float32
    D = cfg.dims
    m_tnpd = TNPD(
        dim_x=D.dx, dim_y=D.dy,
        d_model=D.d_model, emb_depth=1, dim_feedforward=D.d_ff,
        nhead=D.n_heads, dropout=0.0, num_layers=D.n_layers_enc, bound_std=False,
        pos_emb_init=False,
    ).to(device).to(dtype)
    m_tnpa = TNPA(
        dim_x=D.dx, dim_y=D.dy,
        d_model=D.d_model, emb_depth=1, dim_feedforward=D.d_ff,
        nhead=D.n_heads, dropout=0.0, num_layers=D.n_layers_enc, bound_std=False,
        permute=False, pos_emb_init=False,
    ).to(device).to(dtype)
    m_tnpnd = TNPND(
        dim_x=D.dx, dim_y=D.dy,
        d_model=D.d_model, emb_depth=1, dim_feedforward=D.d_ff,
        nhead=D.n_heads, dropout=0.0, num_layers=D.n_layers_enc,
        num_std_layers=2, bound_std=False,
        pos_emb_init=False,
    ).to(device).to(dtype)
    # Use ACE encoder optimization for non-AR baselines (TNPD, TNP-ND)
    m_tnpd = patch_tnp_encoder_with_ace(m_tnpd)
    m_tnpnd = patch_tnp_encoder_with_ace(m_tnpnd)
    for m in (m_tnpd, m_tnpa, m_tnpnd):
        m.train()
    return m_tnpd, m_tnpa, m_tnpnd


def _to_batch(xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor, yt: torch.Tensor) -> DataAttr:
    return DataAttr(xc=xc, yc=yc, xt=xt, yt=yt)


def _time_fwd_bwd_once(model: torch.nn.Module, batch: DataAttr) -> float:
    # Use CUDA events only when the model is on CUDA
    p = next(model.parameters(), None)
    use_cuda = bool(p is not None and p.is_cuda)
    if use_cuda:
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        model.zero_grad(set_to_none=True)
        start_event.record()
        loss = model.forward(batch).loss
        loss.backward()
        end_event.record()
        torch.cuda.synchronize()
        t_sec = start_event.elapsed_time(end_event) / 1000.0
    else:
        model.zero_grad(set_to_none=True)
        t0 = time.perf_counter()
        loss = model.forward(batch).loss
        loss.backward()
        t_sec = time.perf_counter() - t0
    return float(t_sec)


def benchmark_fwd_bwd(model: torch.nn.Module,
                      B: int, Nc: int, Nt: int,
                      *, num_runs: int, dx: int, dy: int,
                      device: str, dtype: torch.dtype) -> Tuple[np.ndarray, float, float]:
    # Pre-generate fixed synthetic batch so we time compute deterministically
    xc, yc, xt, yt = data_gen(Nc, Nt, B=B, dx=dx, dy=dy, device=device, dtype=dtype)
    batch = _to_batch(xc, yc, xt, yt)

    # Warm-up (build caches/graphs; exclude from timing)
    with torch.enable_grad():
        for _ in range(2):
            model.zero_grad(set_to_none=True)
            loss = model.forward(batch).loss
            loss.backward()
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            model.zero_grad(set_to_none=True)

    timings: List[float] = []
    with torch.enable_grad():
        for i in range(num_runs):
            if torch.cuda.is_available() and (i % 5 == 0):
                torch.cuda.empty_cache()
            t = _time_fwd_bwd_once(model, batch)
            timings.append(t)
    t_np = np.array(timings, dtype=np.float64)
    return t_np, float(t_np.mean()), float(t_np.std())


def run_grid(models: Dict[str, torch.nn.Module], cfg: BackwardConfig, *, ours_k: int | None = 16) -> Dict[str, Dict[str, List[Any]]]:
    results: Dict[str, Dict[str, List[Any]]] = {}
    B_vals = cfg.grid.B_values
    Nc_vals = cfg.grid.Nc_values
    Nt_vals = cfg.grid.Nt_values
    runs = cfg.grid.num_runs
    dx, dy = cfg.dims.dx, cfg.dims.dy
    device = cfg.runtime.device
    dtype = torch.float32

    print(f"Running forward+backward timings over |B|={len(B_vals)}, |Nc|={len(Nc_vals)}, |Nt|={len(Nt_vals)}")
    print("Methods:", ", ".join(models.keys()))
    total = len(B_vals) * len(Nc_vals) * len(Nt_vals)
    idx = 0

    for label, model in models.items():
        results[label] = {"B": [], "Nc": [], "Nt": [], "mean_fwd_bwd": [], "std_fwd_bwd": [], "all_times": []}
        for B in B_vals:
            for Nc in Nc_vals:
                for Nt in Nt_vals:
                    idx += 1
                    print(f"[{label}] {idx}/{total}: B={B}, Nc={Nc}, Nt={Nt}")
                    if not check_gpu_memory(1.0):
                        clear_gpu_memory()
                    try:
                        times, mean_t, std_t = benchmark_fwd_bwd(
                            model, B=B, Nc=Nc, Nt=Nt, num_runs=runs, dx=dx, dy=dy, device=device, dtype=dtype
                        )
                        results[label]["B"].append(B)
                        results[label]["Nc"].append(Nc)
                        results[label]["Nt"].append(Nt)
                        results[label]["mean_fwd_bwd"].append(mean_t)
                        results[label]["std_fwd_bwd"].append(std_t)
                        results[label]["all_times"].append(times.tolist())
                        print(f"   fwd+bwd: {mean_t:.6f}s ± {std_t:.6f}s")
                    except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
                        print(f"   skip due to error: {str(e)[:120]}")
                        results[label]["B"].append(B)
                        results[label]["Nc"].append(Nc)
                        results[label]["Nt"].append(Nt)
                        # Use -1 sentinel so plotting can drop negatives
                        results[label]["mean_fwd_bwd"].append(-1.0)
                        results[label]["std_fwd_bwd"].append(-1.0)
                        results[label]["all_times"].append([])
                        clear_gpu_memory()
                    finally:
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()
    return results


def add_ours_tnpd(results: Dict[str, Dict[str, List[Any]]], tnqd_model: torch.nn.Module,
                  cfg: BackwardConfig, *, K: int) -> None:
    """Add an 'Ours (TNPD +Kctx)' entry by evaluating TNPD at Nc+K for each grid point.

    This treats our method as equivalent to augmenting the context with K extra points.
    """
    label = "Ours"
    results[label] = {"B": [], "Nc": [], "Nt": [], "mean_fwd_bwd": [], "std_fwd_bwd": [], "all_times": []}
    B_vals = cfg.grid.B_values
    Nc_vals = cfg.grid.Nc_values
    Nt_vals = cfg.grid.Nt_values
    runs = cfg.grid.num_runs
    dx, dy = cfg.dims.dx, cfg.dims.dy
    device = cfg.runtime.device
    dtype = torch.float32
    for B in B_vals:
        for Nc in Nc_vals:
            for Nt in Nt_vals:
                Nc_eff = Nc + K
                try:
                    times, mean_t, std_t = benchmark_fwd_bwd(
                        tnqd_model, B=B, Nc=Nc_eff, Nt=Nt, num_runs=runs, dx=dx, dy=dy, device=device, dtype=dtype
                    )
                    results[label]["B"].append(B)
                    results[label]["Nc"].append(Nc_eff)
                    results[label]["Nt"].append(Nt)
                    results[label]["mean_fwd_bwd"].append(mean_t)
                    results[label]["std_fwd_bwd"].append(std_t)
                    results[label]["all_times"].append(times.tolist())
                except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
                    results[label]["B"].append(B)
                    results[label]["Nc"].append(Nc_eff)
                    results[label]["Nt"].append(Nt)
                    results[label]["mean_fwd_bwd"].append(-1.0)
                    results[label]["std_fwd_bwd"].append(-1.0)
                    results[label]["all_times"].append([])
                    clear_gpu_memory()


def main(cfg: BackwardConfig | None = None) -> None:
    cfg = cfg or BackwardConfig()
    # Only require CUDA when targeting CUDA; otherwise allow CPU toy runs
    dev = str(cfg.runtime.device)
    if dev.startswith("cuda"):
        configure_torch_env()
    else:
        try:
            torch.set_float32_matmul_precision("high")
        except Exception:
            pass
    m_tnpd, m_tnpa, m_tnpnd = build_baseline_models(cfg)
    models = {
        "TNPD": m_tnpd,
        "TNPA": m_tnpa,
        "TNP-ND": m_tnpnd,
    }

    methods = run_grid(models, cfg, ours_k=16)
    # Add our method: TNPD evaluated with K extra context points
    try:
        add_ours_tnpd(methods, models["TNPD"], cfg, K=16)
    except Exception:
        pass
    meta = {"script": "backward_baselines", "config": cfg.to_dict(), "timing": "fwd_bwd"}
    out_obj = {"metadata": meta, "methods": methods}

    out_dir = cfg.out_dir
    _ensure_dir(os.path.join(out_dir, "_"))
    out_path = os.path.join(out_dir, f"{cfg.name}.json")
    _atomic_write_json(out_path, out_obj)
    print(f"Saved: {out_path}")


if __name__ == "__main__":
    main()
