"""Nyström-approximated CARs (Concept Activation Regions).

This module provides a lightweight representation of kernel CARs using a
Nyström feature map.

Motivation
----------
Full kernel SVM CARs (e.g. RBF `sklearn.svm.SVC`) are expensive to store and
slow to evaluate:

* **Storage**: the pickled estimator contains support vectors of shape
  ``[n_sv, d]``. For flattened conv activations, ``d`` can be 100k–400k, so a
  single CAR can take gigabytes on disk.
* **Variance in RKHS**: comparing CAR weights in the implicit RKHS requires
  kernel evaluations between support vectors.

With a Nyström approximation, we obtain an *explicit* feature map
``Phi(x) in R^m`` (e.g. ``m=200``). Then:

* We store only an explicit weight vector ``beta in R^m`` per CAR.
* Variance can be computed directly as ``tr(Cov(beta))``.
* Prediction is O(m) per sample (after the feature map).

The intended usage is:
1) Create (and cache) a shared Nyström basis per layer/concept.
2) Precompute Nyström features for the positive/negative pools once.
3) Train many linear SVMs in the m-dimensional feature space.
"""

from __future__ import annotations

from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Literal, Optional, Union

import numpy as np

from sklearn.kernel_approximation import Nystroem
from sklearn.metrics.pairwise import rbf_kernel


GammaSpec = Union[float, Literal["scale", "auto"]]


def _resolve_rbf_gamma(
    gamma: GammaSpec,
    X: np.ndarray,
    *,
    random_state: int = 0,
    max_samples: int = 2048,
) -> float:
    """Resolve sklearn-style gamma for RBF.

    * ``'auto'``: 1 / d
    * ``'scale'``: 1 / (d * var(X))

    For very large pools we estimate ``var(X)`` on a subsample to keep this
    step cheap and deterministic.
    """
    if isinstance(gamma, (int, float, np.floating)):
        return float(gamma)

    g = str(gamma).lower()
    d = int(X.shape[1])
    if d <= 0:
        raise ValueError("X must have at least 1 feature")

    if g == "auto":
        return 1.0 / d

    if g != "scale":
        raise ValueError(f"Unsupported gamma spec: {gamma!r}")

    # Estimate variance on a (deterministic) subset of rows.
    Xf = np.asarray(X, dtype=np.float32)
    n = int(Xf.shape[0])
    if n == 0:
        raise ValueError("Cannot resolve gamma on empty X")

    if n > max_samples:
        rng = np.random.default_rng(int(random_state))
        idx = rng.choice(n, size=int(max_samples), replace=False)
        Xf = Xf[idx]

    var = float(np.var(Xf))
    if not np.isfinite(var) or var <= 0:
        # Fallback to 'auto' if the sample variance is degenerate.
        return 1.0 / d
    return 1.0 / (d * var)


@dataclass
class NystromBasis:
    """A saved Nyström basis for an RBF kernel."""

    components: np.ndarray  # [m, d]
    normalization: np.ndarray  # [m, m]
    gamma: float
    kernel: str = "rbf"

    def transform(self, X: np.ndarray) -> np.ndarray:
        """Map raw representations ``X`` to Nyström features ``Phi(X)``."""
        if self.kernel != "rbf":
            raise NotImplementedError(f"Only RBF kernel supported, got {self.kernel!r}")

        Xf = np.asarray(X, dtype=np.float32, order="C")
        comps = np.asarray(self.components, dtype=np.float32, order="C")
        K_nm = rbf_kernel(Xf, comps, gamma=float(self.gamma))  # float32 if inputs are float32
        return (K_nm @ np.asarray(self.normalization, dtype=np.float32)).astype(np.float32, copy=False)

    def save(self, path: str | Path, *, components_dtype: np.dtype = np.float16) -> None:
        """Save basis to a compressed ``.npz``.

        `components_dtype` defaults to float16 to reduce disk usage. The
        computation always casts to float32 during transform.
        """
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        np.savez_compressed(
            path,
            components=np.asarray(self.components).astype(components_dtype, copy=False),
            normalization=np.asarray(self.normalization).astype(np.float32, copy=False),
            gamma=np.asarray(float(self.gamma), dtype=np.float32),
            kernel=np.asarray(str(self.kernel)),
        )

    @classmethod
    def load(cls, path: str | Path) -> "NystromBasis":
        path = Path(path)
        with np.load(path, allow_pickle=False) as data:
            components = data["components"]
            normalization = data["normalization"].astype(np.float32, copy=False)
            gamma = float(data["gamma"])
            kernel = str(data["kernel"].item() if hasattr(data["kernel"], "item") else data["kernel"])
        return cls(components=components, normalization=normalization, gamma=gamma, kernel=kernel)


def build_nystrom_basis(
    X_pool: np.ndarray,
    *,
    n_components: int = 200,
    gamma: GammaSpec = "scale",
    random_state: int = 0,
    max_gamma_samples: int = 2048,
    components_dtype: np.dtype = np.float16,
) -> NystromBasis:
    """Create a Nyström basis on a pool of raw representations."""
    Xf = np.asarray(X_pool, dtype=np.float32, order="C")
    n = int(Xf.shape[0])
    if n == 0:
        raise ValueError("X_pool is empty")

    m = int(min(n_components, n))
    gamma_val = _resolve_rbf_gamma(gamma, Xf, random_state=random_state, max_samples=max_gamma_samples)

    nys = Nystroem(kernel="rbf", gamma=gamma_val, n_components=m, random_state=int(random_state))
    nys.fit(Xf)
    basis = NystromBasis(
        components=np.asarray(nys.components_, dtype=components_dtype),
        normalization=np.asarray(nys.normalization_, dtype=np.float32),
        gamma=float(gamma_val),
        kernel="rbf",
    )
    return basis


@lru_cache(maxsize=64)
def load_nystrom_basis(path: str) -> NystromBasis:
    """Cached basis loader.

    The cache is process-local and keyed by the path string.
    """
    return NystromBasis.load(path)


@dataclass
class NystromCARClassifier:
    """A lightweight CAR classifier operating in Nyström feature space."""

    w: np.ndarray  # [m]
    b: float
    basis_path: str
    positive_label: int = 1
    negative_label: int = 0

    def __post_init__(self) -> None:
        self.w = np.asarray(self.w, dtype=np.float32).reshape(-1)
        self.b = float(self.b)
        self.basis_path = str(self.basis_path)

    @property
    def feature_dim(self) -> int:
        return int(self.w.shape[0])

    def transform(self, X: np.ndarray) -> np.ndarray:
        basis = load_nystrom_basis(self.basis_path)
        return basis.transform(X)

    def decision_function(self, X: np.ndarray) -> np.ndarray:
        X_arr = np.asarray(X)

        # If X already looks like Nyström features, skip the basis transform.
        if X_arr.ndim == 2 and X_arr.shape[1] == self.feature_dim:
            Phi = np.asarray(X_arr, dtype=np.float32, order="C")
        else:
            Phi = self.transform(X_arr)

        return (Phi @ self.w) + np.float32(self.b)

    def predict(self, X: np.ndarray) -> np.ndarray:
        scores = self.decision_function(X)
        return np.where(scores >= 0, self.positive_label, self.negative_label).astype(int)
