from __future__ import annotations

"""
GPU forward+backward timing for baseline models + our method (TNPD with K extra context).

Grid (defaults to the requested set):
  B  = {32, 64, 128, 256}
  Nc = {128, 256, 512, 1024}
  Nt = {64, 128, 256, 512}
  runs = 10

Model dims (defaults): d_model=128, n_heads=4, n_layers_enc/dec=6, d_ff=256, dx=1, dy=1.

Saves JSON to outputs/fast_times/fwd_bwd_gpu.json with schema:
{
  "metadata": { ... },
  "methods": {
    "TNPD":     {"B":[], "Nc":[], "Nt":[], "mean_time":[], "std_time":[]},
    "TNPA":     { ... },
    "TNP-ND":   { ... },
    "OURS (TNPD +16 ctx)": { ... }
  }
}
"""

import os
import sys
import time
import argparse
from pathlib import Path
from typing import Dict, List, Tuple, Any

import numpy as np
import torch

# Allow running as a file or pasted into a notebook: add repo root to sys.path
if __package__ is None or __package__ == "":
    try:
        base = Path(__file__).resolve()
    except NameError:
        base = Path.cwd()
    ROOT = None
    for cand in [base] + list(base.parents):
        if (cand / "scripts").exists() or (cand / "pyproject.toml").exists() or (cand / ".git").exists():
            ROOT = cand
            break
    ROOT = ROOT or base
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))

from scripts.fast_times.config import BenchConfig, BenchGrid, ModelDims, Runtime
from scripts.fast_times.common import (
    configure_torch_env,
    _ensure_dir,
    _atomic_write_json,
    clear_gpu_memory,
    data_gen,
)
from scripts.fast_times.ace_encoder import patch_tnp_encoder_with_ace
from src.utils import DataAttr
from src.models.benchmarks import TNPA, TNPD, TNPND


def _to_batch(xc: torch.Tensor, yc: torch.Tensor, xt: torch.Tensor, yt: torch.Tensor) -> DataAttr:
    return DataAttr(xc=xc, yc=yc, xt=xt, yt=yt)


def _time_fwd_bwd_once(model: torch.nn.Module, batch: DataAttr, *, use_amp: bool, scaler: torch.cuda.amp.GradScaler | None) -> float:
    # CUDA timing with events
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    model.zero_grad(set_to_none=True)
    start_event.record()
    if use_amp:
        assert scaler is not None
        with torch.cuda.amp.autocast(dtype=torch.float16):
            loss = model.forward(batch).loss
        scaler.scale(loss).backward()
    else:
        loss = model.forward(batch).loss
        loss.backward()
    end_event.record()
    torch.cuda.synchronize()
    t_sec = start_event.elapsed_time(end_event) / 1000.0
    return float(t_sec)


def benchmark_fwd_bwd(model: torch.nn.Module,
                      B: int, Nc: int, Nt: int,
                      *, num_runs: int, dx: int, dy: int,
                      device: str, dtype: torch.dtype, use_amp: bool = False) -> Tuple[np.ndarray, float, float]:
    xc, yc, xt, yt = data_gen(Nc, Nt, B=B, dx=dx, dy=dy, device=device, dtype=dtype)
    batch = _to_batch(xc, yc, xt, yt)
    # Warm up
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    with torch.enable_grad():
        for _ in range(2):
            model.zero_grad(set_to_none=True)
            if use_amp:
                with torch.cuda.amp.autocast(dtype=torch.float16):
                    loss = model.forward(batch).loss
                scaler.scale(loss).backward()
            else:
                loss = model.forward(batch).loss
                loss.backward()
            torch.cuda.synchronize()
            model.zero_grad(set_to_none=True)
    torch.cuda.empty_cache()
    timings: List[float] = []
    with torch.enable_grad():
        for i in range(num_runs):
            t = _time_fwd_bwd_once(model, batch, use_amp=use_amp, scaler=scaler)
            timings.append(t)
            # Clear CUDA allocator/cache between timed iterations
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
    t_np = np.array(timings, dtype=np.float64)
    torch.cuda.empty_cache()
    return t_np, float(t_np.mean()), float(t_np.std())


def build_models(device: str, dims: ModelDims):
    dtype = torch.float32  # keep training in fp32 for stability
    D = dims
    tnpd = TNPD(
        dim_x=D.dx, dim_y=D.dy,
        d_model=D.d_model, emb_depth=1, dim_feedforward=D.d_ff,
        nhead=D.n_heads, dropout=0.0, num_layers=D.n_layers_enc, bound_std=False,
        pos_emb_init=False,
    ).to(device).to(dtype)
    tnpa = TNPA(
        dim_x=D.dx, dim_y=D.dy,
        d_model=D.d_model, emb_depth=1, dim_feedforward=D.d_ff,
        nhead=D.n_heads, dropout=0.0, num_layers=D.n_layers_enc, bound_std=False,
        permute=False, pos_emb_init=False,
    ).to(device).to(dtype)
    tnpnd = TNPND(
        dim_x=D.dx, dim_y=D.dy,
        d_model=D.d_model, emb_depth=1, dim_feedforward=D.d_ff,
        nhead=D.n_heads, dropout=0.0, num_layers=D.n_layers_enc,
        num_std_layers=2, bound_std=False,
        pos_emb_init=False,
    ).to(device).to(dtype)
    # Apply ACE encoder optimization for non-AR models
    tnpd = patch_tnp_encoder_with_ace(tnpd)
    tnpnd = patch_tnp_encoder_with_ace(tnpnd)
    for m in (tnpd, tnpa, tnpnd):
        m.train()
    return tnpd, tnpa, tnpnd


def run_grid(device: str, dims: ModelDims, *,
             B_vals: Tuple[int, ...], Nc_vals: Tuple[int, ...], Nt_vals: Tuple[int, ...],
             runs: int, ours_k: int = 16, use_amp: bool = False) -> Dict[str, Dict[str, List[Any]]]:
    tnpd, tnpa, tnpnd = build_models(device, dims)
    models: Dict[str, torch.nn.Module] = {
        "TNPD": tnpd,
        "TNPA": tnpa,
        "TNP-ND": tnpnd,
    }
    results: Dict[str, Dict[str, List[Any]]] = {}
    dx, dy = dims.dx, dims.dy
    dtype = torch.float32

    total = len(B_vals) * len(Nc_vals) * len(Nt_vals)
    print(f"Total experiments per method: {total}")

    for label, model in models.items():
        results[label] = {"B": [], "Nc": [], "Nt": [], "mean_time": [], "std_time": [], "all_times": []}
        idx = 0
        for B in B_vals:
            for Nc in Nc_vals:
                for Nt in Nt_vals:
                    idx += 1
                    print(f"[{label}] {idx}/{total}: B={B}, Nc={Nc}, Nt={Nt}")
                    try:
                        times, mean_t, std_t = benchmark_fwd_bwd(
                            model, B=B, Nc=Nc, Nt=Nt, num_runs=runs, dx=dx, dy=dy, device=device, dtype=dtype, use_amp=use_amp
                        )
                        results[label]["B"].append(B)
                        results[label]["Nc"].append(Nc)
                        results[label]["Nt"].append(Nt)
                        results[label]["mean_time"].append(mean_t)
                        results[label]["std_time"].append(std_t)
                        results[label]["all_times"].append(times.tolist())
                    except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
                        print(f"   skip due to error: {str(e)[:120]}")
                        results[label]["B"].append(B)
                        results[label]["Nc"].append(Nc)
                        results[label]["Nt"].append(Nt)
                        # Use -1 sentinel so plotting code can drop negatives consistently
                        results[label]["mean_time"].append(-1.0)
                        results[label]["std_time"].append(-1.0)
                        results[label]["all_times"].append([])
                    finally:
                        clear_gpu_memory()
        clear_gpu_memory()

    # Ours: TNPD with K extra context
    ours_label = "Ours"
    results[ours_label] = {"B": [], "Nc": [], "Nt": [], "mean_time": [], "std_time": [], "all_times": []}
    idx = 0
    for B in B_vals:
        for Nc in Nc_vals:
            for Nt in Nt_vals:
                idx += 1
                Nc_eff = Nc + ours_k
                print(f"[{ours_label}] {idx}/{total}: B={B}, Nc={Nc_eff}, Nt={Nt}")
                try:
                    times, mean_t, std_t = benchmark_fwd_bwd(
                        tnpd, B=B, Nc=Nc_eff, Nt=Nt, num_runs=runs, dx=dx, dy=dy, device=device, dtype=dtype, use_amp=use_amp
                    )
                    results[ours_label]["B"].append(B)
                    results[ours_label]["Nc"].append(Nc_eff)
                    results[ours_label]["Nt"].append(Nt)
                    results[ours_label]["mean_time"].append(mean_t)
                    results[ours_label]["std_time"].append(std_t)
                    results[ours_label]["all_times"].append(times.tolist())
                except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
                    print(f"   skip due to error: {str(e)[:120]}")
                    results[ours_label]["B"].append(B)
                    results[ours_label]["Nc"].append(Nc_eff)
                    results[ours_label]["Nt"].append(Nt)
                    results[ours_label]["mean_time"].append(-1.0)
                    results[ours_label]["std_time"].append(-1.0)
                    results[ours_label]["all_times"].append([])
                finally:
                    clear_gpu_memory()
    clear_gpu_memory()
    return results


def main() -> None:
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is required for run_fwd_bwd_gpu; no GPU detected.")
    configure_torch_env()
    # SDPA backend controls
    # Default: flash off, math on, mem_efficient off (baseline behavior)
    # We'll parse args below to optionally enable mem_efficient.

    ap = argparse.ArgumentParser(description="GPU forward+backward timing for baselines + ours (TNPD +K ctx)")
    ap.add_argument("--B", type=str, default="32,64,128,256")
    ap.add_argument("--Nc", type=str, default="128,256,512,1024")
    ap.add_argument("--Nt", type=str, default="64,128,256,512")
    ap.add_argument("--runs", type=int, default=10)
    ap.add_argument("--d_model", type=int, default=128)
    ap.add_argument("--n_heads", type=int, default=4)
    ap.add_argument("--n_layers", type=int, default=6)
    ap.add_argument("--d_ff", type=int, default=256)
    ap.add_argument("--ours_k", type=int, default=16)
    ap.add_argument("--out", type=str, default="outputs/fast_times/fwd_bwd_gpu.json")
    # Precision / SDPA controls
    ap.add_argument("--amp", dest="amp", action="store_true", help="Use autocast fp16 for forward+backward")
    ap.add_argument("--no-amp", dest="amp", action="store_false", help="Disable autocast (fp32)")
    ap.add_argument("--flash", dest="flash", action="store_true", help="Enable SDPA flash kernel when possible")
    ap.add_argument("--no-flash", dest="flash", action="store_false", help="Disable SDPA flash kernel")
    ap.add_argument("--mem_efficient_sdp", action="store_true", help="Enable memory-efficient SDPA (flash stays on if supported)")
    # Defaults: fp16 on, flash off for comparability, mem-efficient off
    ap.set_defaults(amp=True, flash=False)
    # Accept unknown args (e.g., Jupyter passes "-f <kernel.json>")
    args, _unknown = ap.parse_known_args()

    # Apply SDPA config now that args are known
    if torch.cuda.is_available():
        try:
            torch.backends.cuda.enable_flash_sdp(bool(args.flash))
            torch.backends.cuda.enable_math_sdp(True)
            torch.backends.cuda.enable_mem_efficient_sdp(bool(args.mem_efficient_sdp))
        except Exception:
            pass

    def _parse_list(s: str) -> Tuple[int, ...]:
        return tuple(int(x) for x in s.split(",") if x)

    B_vals = _parse_list(args.B)
    Nc_vals = _parse_list(args.Nc)
    Nt_vals = _parse_list(args.Nt)

    dims = ModelDims(dx=1, dy=1, d_model=args.d_model, n_heads=args.n_heads,
                     n_layers_enc=args.n_layers, n_layers_dec=args.n_layers, d_ff=args.d_ff)
    cfg = BenchConfig(
        grid=BenchGrid(Nc_values=Nc_vals, num_samples_values=(), Nt=max(Nt_vals), num_runs=args.runs),
        dims=dims,
        runtime=Runtime(device="cuda", dtype="auto"),
        out_dir="outputs/fast_times",
        compile_mode="reduce-overhead",
    )

    methods = run_grid("cuda", dims, B_vals=B_vals, Nc_vals=Nc_vals, Nt_vals=Nt_vals, runs=args.runs, ours_k=args.ours_k, use_amp=args.amp)

    meta: Dict[str, Any] = {
        "script": "run_fwd_bwd_gpu",
        "config": cfg.to_dict(),
        "grid": {"B": list(B_vals), "Nc": list(Nc_vals), "Nt": list(Nt_vals)},
        "timing": "forward+backward",
        "ours_k": int(args.ours_k),
        "amp": bool(args.amp),
        "mem_efficient_sdp": bool(args.mem_efficient_sdp),
    }
    out = {"metadata": meta, "methods": methods}
    _ensure_dir(os.path.join(cfg.out_dir, "_"))
    _atomic_write_json(args.out, out)
    print(f"Saved: {args.out}")


if __name__ == "__main__":
    main()
