from __future__ import annotations

from dataclasses import dataclass, asdict, field
from typing import List, Tuple, Dict, Any
import os
import torch


@dataclass
class BenchGrid:
    Nc_values: List[int] = (32, 64, 128, 256, 512, 1024)
    # Default matches notebooks: batch sizes used in grids
    num_samples_values: List[int] = (128, 256, 512, 1024)
    Nt: int = 16
    num_runs: int = 10


@dataclass
class ModelDims:
    dx: int = 1
    dy: int = 1
    d_model: int = 128
    n_heads: int = 4
    n_layers_enc: int = 6
    n_layers_dec: int = 6
    d_ff: int = 256


def _require_cuda_default() -> str:
    """Return 'cuda' if available, otherwise raise hard error to avoid CPU runs."""
    if torch.cuda.is_available():
        return "cuda"
    # Allow explicit CPU runs for toy smoke tests when env is set
    allow_cpu = os.environ.get("FAST_TIMES_CPU_OK", "").lower() in {"1", "true", "yes"}
    if allow_cpu:
        return "cpu"
    raise RuntimeError(
        "CUDA is required for fast_times scripts. No GPU detected. "
        "Ensure you're on a GPU node and PyTorch sees CUDA."
    )


@dataclass
class Runtime:
    device: str = field(default_factory=_require_cuda_default)
    dtype: str = "auto"  # "auto" -> fp16 on CUDA else fp32


@dataclass
class BenchConfig:
    grid: BenchGrid = field(default_factory=BenchGrid)
    dims: ModelDims = field(default_factory=ModelDims)
    runtime: Runtime = field(default_factory=Runtime)
    # output directory for JSONs
    out_dir: str = "outputs/fast_times"
    # torch.compile mode
    compile_mode: str = "reduce-overhead"

    def to_dict(self) -> Dict[str, Any]:
        return {
            "grid": asdict(self.grid),
            "dims": asdict(self.dims),
            "runtime": asdict(self.runtime),
            "out_dir": self.out_dir,
            "compile_mode": self.compile_mode,
        }


DEFAULT = BenchConfig()
