from __future__ import annotations

import os
import json
import time
import warnings
from typing import Any, Dict, List, Tuple

import numpy as np
import torch


def configure_torch_env() -> None:
    """Require CUDA and match notebook runtime toggles for compile/inductor/triton friendliness."""
    if not torch.cuda.is_available():
        raise RuntimeError(
            "CUDA is required for fast_times scripts. No GPU detected. "
            "Make sure you're on a GPU machine and CUDA is available to PyTorch."
        )
    # Prefer lower compile overhead for grid sweeps
    os.environ.setdefault("TORCH_COMPILE_ENABLE_REDUCE_OVERHEAD", "1")
    # Avoid CUDAGraphs capture issues across repeated runs
    os.environ.setdefault("TORCHINDUCTOR_CUDAGRAPHS", "0")
    torch.set_float32_matmul_precision("high")
    os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", os.path.expanduser("~/.cache/torchinductor"))
    if torch.cuda.is_available():
        try:
            torch.backends.cuda.enable_flash_sdp(True)
            torch.backends.cuda.enable_mem_efficient_sdp(False)
            torch.backends.cuda.enable_math_sdp(False)
        except Exception:
            pass
        try:
            # Lift recompile limits and expand cache like in the notebooks
            torch._dynamo.config.recompile_limit = 10_000
            torch._dynamo.config.cache_size_limit = 4096
            torch._dynamo.config.fail_on_recompile_limit_hit = False
        except Exception:
            pass
        # Disable Inductor CUDAGraphs to avoid tensor reuse invalidation across warmups/loops
        try:
            import torch._inductor.config as ic
            ic.cudagraphs = False
            ic.triton.cudagraphs = False
        except Exception:
            pass


def _ensure_dir(path: str) -> None:
    d = os.path.dirname(path)
    if d and not os.path.exists(d):
        os.makedirs(d, exist_ok=True)


def _atomic_write_json(path: str, obj: Any) -> None:
    tmp = path + ".tmp"
    with open(tmp, "w") as f:
        json.dump(obj, f, indent=2)
    os.replace(tmp, path)


def data_gen(Nc: int, Nt: int, *, B: int = 1, dx: int = 1, dy: int = 1,
             device: torch.device | str | None = None,
             dtype: torch.dtype | None = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Tiny synthetic batch generator used by the notebooks.
    Shapes: xc[B,Nc,dx], yc[B,Nc,dy], xt[B,Nt,dx], yt[B,Nt,dy].
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    if dtype is None:
        dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    g = torch.Generator(device=device).manual_seed(0)
    xc = torch.randn(B, Nc, dx, device=device, dtype=dtype, generator=g)
    yc = torch.randn(B, Nc, dy, device=device, dtype=dtype, generator=g)
    xt = torch.randn(B, Nt, dx, device=device, dtype=dtype, generator=g)
    yt = torch.randn(B, Nt, dy, device=device, dtype=dtype, generator=g)
    return xc, yc, xt, yt


# ---- Benchmark helpers copied to match notebook semantics ----

def clear_gpu_memory() -> None:
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    import gc
    gc.collect()


def check_gpu_memory(required_gb: float | None = None, *, verbose: bool = True) -> bool:
    if not torch.cuda.is_available():
        return True
    free_memory = torch.cuda.mem_get_info()[0] / 1e9
    used_memory = torch.cuda.memory_allocated() / 1e9
    if verbose:
        print(f"   GPU Memory: {used_memory:.2f}GB used, {free_memory:.2f}GB free")
    if required_gb and free_memory < required_gb:
        warnings.warn(
            f"Low GPU memory: {free_memory:.2f}GB available, {required_gb:.2f}GB recommended"
        )
        return False
    return True


def benchmark_predictions(model: torch.nn.Module, Nc: int, Nt: int, *, num_samples: int,
                          num_runs: int = 10, method: str = "independent",
                          dx: int = 1, dy: int = 1) -> Tuple[List[torch.Tensor], List[float], float, float]:
    """Match the notebook timing loop: CUDA events timing; warmup; method switch."""
    params = list(model.parameters())
    device = params[0].device if params else ("cuda" if torch.cuda.is_available() else "cpu")
    dtype = params[0].dtype if params else (torch.float16 if torch.cuda.is_available() else torch.float32)
    B = 1
    xc, yc, xt, _ = data_gen(Nc, Nt, B=B, dx=dx, dy=dy, device=device, dtype=dtype)

    predictions: List[torch.Tensor] = []
    timings: List[float] = []

    if method == "independent":
        _sample = getattr(model, "predict")
    elif method == "autoregressive":
        _sample = getattr(model, "sample_joint_predictive")
    else:
        raise ValueError("method must be 'independent' or 'autoregressive'")

    print("Warming up...")
    for _ in range(3):
        with torch.no_grad():
            _ = _sample(xc, yc, xt, num_samples=num_samples)

    print(f"Running {num_runs} timed iterations with CUDA events")
    for i in range(num_runs):
        if torch.cuda.is_available():
            if i % 5 == 0:
                torch.cuda.empty_cache()
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()
            with torch.no_grad():
                pred = _sample(xc, yc, xt, num_samples=num_samples)
            end_event.record()
            torch.cuda.synchronize()
            elapsed_time_ms = start_event.elapsed_time(end_event)
            t_sec = elapsed_time_ms / 1000.0
        else:
            t0 = time.perf_counter()
            with torch.no_grad():
                pred = _sample(xc, yc, xt, num_samples=num_samples)
            t_sec = time.perf_counter() - t0
        predictions.append(pred)
        timings.append(t_sec)

    timings_np = np.array(timings)
    print(f"  Mean time: {timings_np.mean():.4f}s ± {timings_np.std():.4f}s")
    print(f"  Median time: {np.median(timings_np):.4f}s")
    print(f"  Min time: {timings_np.min():.4f}s")
    print(f"  Max time: {timings_np.max():.4f}s")
    print(f"  Total time: {timings_np.sum():.4f}s")
    print(f"  Throughput: {num_runs/timings_np.sum():.2f} predictions/second")

    # Match notebook return: timings mean & std as the last two values
    return predictions, timings, float(timings_np.mean()), float(timings_np.std())


def run_benchmark_grid(model: torch.nn.Module, *, Nt: int = 16, num_runs: int = 30,
                       method: str = "independent", memory_threshold_gb: float = 2.0,
                       Nc_values: List[int] | Tuple[int, ...] = (32, 64, 128, 256, 512, 1024),
                       num_samples_values: List[int] | Tuple[int, ...] = (128, 256, 512, 1024),
                       dx: int = 1, dy: int = 1) -> Dict[str, Any]:
    """Port of the notebook grid runner for sampling benchmarks."""
    results: Dict[str, Any] = {
        "Nc": [],
        "num_samples": [],
        "mean_time": [],
        "std_time": [],
        "pred_mean": [],
        "pred_std": [],
        "throughput": [],
        "all_times": [],
        "status": [],
    }

    total_experiments = len(Nc_values) * len(num_samples_values)
    experiment_counter = 0
    oom_configs = set()
    print(f"Running {total_experiments} experiments ({len(Nc_values)} Nc x {len(num_samples_values)} num_samples)")
    print(f"Memory threshold: {memory_threshold_gb:.1f}GB")
    print("=" * 60)

    for Nc in Nc_values:
        for num_samples in num_samples_values:
            experiment_counter += 1
            print(f"\nExperiment {experiment_counter}/{total_experiments}")
            print(f"Nc={Nc}, num_samples={num_samples}")
            print("-" * 40)

            if (Nc, num_samples) in oom_configs:
                print("Skipping - previously caused OOM")
                for k, v in {
                    "Nc": Nc, "num_samples": num_samples, "mean_time": np.nan, "std_time": np.nan,
                    "pred_mean": np.nan, "pred_std": np.nan, "throughput": np.nan, "all_times": [],
                    "status": "skipped_oom",
                }.items():
                    results[k].append(v)
                continue

            if not check_gpu_memory(memory_threshold_gb):
                clear_gpu_memory()
                if not check_gpu_memory(memory_threshold_gb, verbose=False):
                    print("Insufficient memory - skipping")
                    for k, v in {
                        "Nc": Nc, "num_samples": num_samples, "mean_time": np.nan, "std_time": np.nan,
                        "pred_mean": np.nan, "pred_std": np.nan, "throughput": np.nan, "all_times": [],
                        "status": "insufficient_memory",
                    }.items():
                        results[k].append(v)
                    continue

            try:
                predictions, times, pred_mean, pred_std = benchmark_predictions(
                    model, Nc=Nc, Nt=Nt, num_samples=num_samples, num_runs=num_runs, method=method, dx=dx, dy=dy
                )
                times_np = np.array(times)
                mean_time = times_np.mean()
                std_time = times_np.std()
                throughput = num_runs / times_np.sum()
                pkg = {
                    "Nc": Nc,
                    "num_samples": num_samples,
                    "mean_time": mean_time,
                    "std_time": std_time,
                    "pred_mean": float(pred_mean) if torch.is_tensor(pred_mean) else pred_mean,
                    "pred_std": float(pred_std) if torch.is_tensor(pred_std) else pred_std,
                    "throughput": throughput,
                    "all_times": times,
                    "status": "success",
                }
                for k, v in pkg.items():
                    results[k].append(v)
                print(f"Time: {mean_time:.4f}s +/- {std_time:.4f}s")
                print(f"Throughput: {throughput:.2f} pred/sec")
            except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
                if isinstance(e, torch.cuda.OutOfMemoryError) or "out of memory" in str(e).lower():
                    print("GPU OOM Error - recording NaN")
                    oom_configs.add((Nc, num_samples))
                    error_type = "oom_error"
                else:
                    print(f"Runtime Error: {str(e)[:100]}")
                    error_type = "runtime_error"
                for k, v in {
                    "Nc": Nc, "num_samples": num_samples, "mean_time": np.nan, "std_time": np.nan,
                    "pred_mean": np.nan, "pred_std": np.nan, "throughput": np.nan, "all_times": [],
                    "status": error_type,
                }.items():
                    results[k].append(v)
                clear_gpu_memory()

    print("\n" + "=" * 60)
    print("Benchmark Summary:")
    print("-" * 60)
    status_counts: Dict[str, int] = {}
    for status in results["status"]:
        status_counts[status] = status_counts.get(status, 0) + 1
    total_experiments = len(Nc_values) * len(num_samples_values)
    print(f"Total experiments: {total_experiments}")
    for status, count in status_counts.items():
        print(f"  {status}: {count}")
    return results


def package_for_plot(methods: Dict[str, Dict[str, Any]], *, meta: Dict[str, Any]) -> Dict[str, Any]:
    """Standardize JSON: { metadata, methods: {label: {Nc,num_samples,mean_time,std_time}} }"""
    # Only keep the keys needed by the plotting helper
    cleaned: Dict[str, Dict[str, Any]] = {}
    for label, d in methods.items():
        cleaned[label] = {
            "Nc": d.get("Nc", []),
            "num_samples": d.get("num_samples", []),
            "mean_time": d.get("mean_time", []),
            "std_time": d.get("std_time", []),
        }
    return {"metadata": dict(meta), "methods": cleaned}
