from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Any, Sequence, Optional, Union

import numpy as np
import pandas as pd

try:
    from fun.toy_fun import BONormalizer
except Exception:
    BONormalizer = None

def _here(*paths: str) -> str:
    """Get path relative to this file."""
    return os.path.join(os.path.dirname(os.path.abspath(__file__)), *paths)

DEFAULT_CSV_PATH = _here("cof.csv")
PARAM_NAMES = [
    "pore_diameter",
    "void_fraction",
    "surface_area",
    "crystal_density",
    "B",
    "O",
    "C",
    "H",
    "Si",
    "N",
    "S",
    "P",
    "halogens",
    "metals",
 ]  # fixed order, matches CSV column names
D = 14
METRIC_KEY = "gcmc_y"  # CSV target column name

# ------------------------ Regular-grid N-dimensional linear interpolation ------------------------
class RegularGridNDInterpolator:
    """
    Generic N-dimensional regular-grid multilinear interpolation.

    - grids: list[np.ndarray], each strictly increasing, length >= 2
    - values: np.ndarray, shape = tuple(len(g) for g in grids)
    """
    def __init__(self, grids: List[np.ndarray], values: np.ndarray):
        self.grids = [np.asarray(g, dtype=float) for g in grids]
        self.values = np.asarray(values, dtype=float)
        self.ndim = len(self.grids)
        shape = tuple(len(g) for g in self.grids)
        if self.values.shape != shape:
            raise ValueError(f"values shape {self.values.shape} != grid shape {shape}")
        if any(len(g) < 2 for g in self.grids):
            raise ValueError("Each grid must have at least 2 points")
        if any(np.any(np.diff(g) <= 0) for g in self.grids):
            raise ValueError("Each grid must be strictly increasing")
        self.bounds = [(g[0], g[-1]) for g in self.grids]

    def _locate(self, g: np.ndarray, x: float) -> Tuple[int, int, float]:
        """Locate x within 1D grid g; return (i0, i1, t) where t in [0,1], clamped at boundaries."""
        if x <= g[0]:
            return 0, 1, 0.0
        if x >= g[-1]:
            return len(g) - 2, len(g) - 1, 1.0
        j = int(np.searchsorted(g, x, side="right"))  # g[j-1] <= x < g[j]
        i0, i1 = j - 1, j
        t = (x - g[i0]) / (g[i1] - g[i0])
        return i0, i1, float(t)

    def __call__(self, x: Sequence[float]) -> float:
        if len(x) != self.ndim:
            raise ValueError(f"Dim mismatch: got {len(x)} but ndim={self.ndim}")
        i0s, i1s, ts = [], [], []
        for xi, g in zip(x, self.grids):
            i0, i1, t = self._locate(g, float(xi))
            i0s.append(i0); i1s.append(i1); ts.append(t)

        acc = 0.0
        for mask in range(1 << self.ndim):
            w = 1.0
            idx = []
            for d in range(self.ndim):
                if (mask >> d) & 1:
                    w *= ts[d]; idx.append(i1s[d])
                else:
                    w *= (1.0 - ts[d]); idx.append(i0s[d])
            acc += w * self.values[tuple(idx)]
        return float(acc)

    def batch(self, X: np.ndarray) -> np.ndarray:
        X = np.asarray(X, dtype=float)
        if X.ndim != 2 or X.shape[1] != self.ndim:
            raise ValueError(f"X must be (N,{self.ndim})")
        return np.array([self(xi) for xi in X], dtype=float)


# ------------------------ Evaluator: interpolate gcmc_y from CSV ------------------------
@dataclass
class COFEvaluator:
    """
    Evaluator mapping continuous 14D inputs to scalar gcmc_y.

    Logic:
      1) Try to map the CSV data to a regular grid; if complete, use
         RegularGridNDInterpolator (multilinear interpolation).
      2) If the grid is incomplete or irregular, fall back to scatter
         interpolation (kNN + IDW) in normalized [0,1]^14 space.

    Public interfaces:
      - evaluate_from_dict(params)
      - evaluate_from_list(x)
      - evaluate(x)  # supports single point or batch
      - query(x)     # alias of evaluate_from_list
    """
    csv_path: str = DEFAULT_CSV_PATH
    knn_k: int = 12                 # number of neighbors for kNN (scatter mode)
    idw_power: float = 2.0          # inverse distance weighting power (scatter mode)
    eps: float = 1e-12              # avoid division by zero
    bounds: List[Tuple[float, float]] = field(init=False)

    # Regular-grid mode
    use_grid: bool = field(init=False, default=False)
    levels: List[np.ndarray] = field(init=False, default=None)
    values: np.ndarray = field(init=False, default=None)
    interp: RegularGridNDInterpolator = field(init=False, default=None)

    # Scatter-mode cache
    X_scaled: np.ndarray = field(init=False, default=None)
    y: np.ndarray = field(init=False, default=None)
    lo: np.ndarray = field(init=False, default=None)
    span: np.ndarray = field(init=False, default=None)

    def __post_init__(self):
        df = pd.read_csv(self.csv_path)

        req_all = PARAM_NAMES + [METRIC_KEY]
        miss = [c for c in req_all if c not in df.columns]
        if miss:
            raise ValueError(f"CSV missing required columns: {miss}")

        use_cols = PARAM_NAMES + [METRIC_KEY]
        for c in use_cols:
            df[c] = pd.to_numeric(df[c], errors="coerce")
        df = df.dropna(subset=use_cols).reset_index(drop=True)

        self.bounds = []
        for c in PARAM_NAMES:
            vals = df[c].to_numpy(dtype=float)
            lo, hi = float(vals.min()), float(vals.max())
            if hi == lo:
                hi = lo + 1e-12
            self.bounds.append((lo, hi))

        levels = [np.sort(df[c].unique().astype(float)) for c in PARAM_NAMES]
        shape = tuple(len(g) for g in levels)
        
        # Check grid size safely to avoid memory explosion (14D grid can be huge)
        total_size = 1.0
        for s in shape:
            total_size *= float(s)
            if total_size > 1e8:
                break
        
        MAX_GRID_SIZE = 1e8
        grid_build_ok = False
        
        if total_size > MAX_GRID_SIZE:
            pass
        else:
            try:
                indexers = [{v: i for i, v in enumerate(g)} for g in levels]
                arr = np.empty(shape, dtype=float)
                filled = np.zeros(shape, dtype=bool)

                grid_build_ok = True
                try:
                    for _, row in df.iterrows():
                        idx = tuple(indexers[d][float(row[PARAM_NAMES[d]])] for d in range(D))
                        arr[idx] = float(row[METRIC_KEY])
                        filled[idx] = True
                    if not filled.all():
                        grid_build_ok = False
                except KeyError:
                    grid_build_ok = False
            except (ValueError, MemoryError) as e:
                grid_build_ok = False

        if grid_build_ok:
            self.use_grid = True
            self.levels = levels
            self.values = arr
            self.interp = RegularGridNDInterpolator(self.levels, self.values)
            return
        
        self.use_grid = False
        X = df[PARAM_NAMES].to_numpy(dtype=float)
        y = df[METRIC_KEY].to_numpy(dtype=float)

        lo = np.array([b[0] for b in self.bounds], dtype=float)
        hi = np.array([b[1] for b in self.bounds], dtype=float)
        span = np.maximum(hi - lo, 1e-12)

        self.X_scaled = (X - lo) / span
        self.y = y
        self.lo = lo
        self.span = span

    def _idw_single(self, x: Sequence[float]) -> float:
        x = np.asarray(x, dtype=float)
        xs = (x - self.lo) / self.span
        diff = self.X_scaled - xs[None, :]
        dist = np.sqrt(np.sum(diff * diff, axis=1))  # (N,)
        j = np.argmin(dist)
        if dist[j] <= self.eps:
            return float(self.y[j])
        k = min(self.knn_k, dist.shape[0])
        idx = np.argsort(dist)[:k]  # use argsort to ensure stable ordering
        d = dist[idx]
        w = 1.0 / np.power(d + self.eps, self.idw_power)
        w = w / np.sum(w)
        return float(np.sum(w * self.y[idx]))

    def _idw_batch(self, X: np.ndarray) -> np.ndarray:
        return np.array([self._idw_single(x) for x in X], dtype=float)

    def _pack_x_from_dict(self, params: Dict[str, Any]) -> List[float]:
        try:
            return [float(params[k]) for k in PARAM_NAMES]
        except Exception as e:
            raise ValueError(f"params must have keys {PARAM_NAMES} (float-like).") from e

    def evaluate_from_dict(self, params: Dict[str, Any]) -> Dict[str, float]:
        x = self._pack_x_from_dict(params)
        return self.evaluate_from_list(x)

    def evaluate_from_list(self, x: List[Any]) -> Dict[str, float]:
        if len(x) != D:
            raise ValueError(f"x must have length {D} in order {PARAM_NAMES}")
        xx = [float(v) for v in x]
        if self.use_grid:
            y = self.interp(xx)
        else:
            y = self._idw_single(xx)
        return {METRIC_KEY: float(y)}

    def evaluate(self, x: Union[Sequence[float], Sequence[Sequence[float]]]) -> Union[Dict[str, float], List[Dict[str, float]]]:
        if isinstance(x, (list, tuple)) and len(x) > 0 and isinstance(x[0], (list, tuple)):
            X = np.asarray(x, dtype=float)
            if self.use_grid:
                ys = self.interp.batch(X)
            else:
                ys = self._idw_batch(X)
            return [{METRIC_KEY: float(v)} for v in ys]
        else:
            return self.evaluate_from_list([float(v) for v in x])

    def query(self, x: List[Any]) -> Dict[str, float]:
        """Alias of evaluate_from_list for compatibility."""
        return self.evaluate_from_list(x)


def make_bo_normalizer_cof(
    csv_path: Optional[str] = None,
    bounds_override: Optional[List[Tuple[float, float]]] = None
) -> BONormalizer:
    """Create a BONormalizer for COF."""
    if BONormalizer is None:
        raise RuntimeError("BONormalizer not found. Ensure fun/toy_fun.py is importable.")

    if bounds_override is not None:
        if len(bounds_override) != D:
            raise ValueError(f"bounds_override must have length {D}")
        return BONormalizer(bounds_override)

    csv_path = csv_path or DEFAULT_CSV_PATH
    df = pd.read_csv(csv_path)
    for c in PARAM_NAMES + [METRIC_KEY]:
        if c not in df.columns:
            raise ValueError(f"CSV missing column: {c}")

    bounds = []
    for c in PARAM_NAMES:
        vals = pd.to_numeric(df[c], errors="coerce").dropna()
        lo, hi = float(vals.min()), float(vals.max())
        if hi == lo:
            hi = lo + 1e-12
        bounds.append((lo, hi))
    return BONormalizer(bounds)


if __name__ == "__main__":
    pass

