# cav.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal, Tuple, Any, Iterable
import numpy as np
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.svm import LinearSVC, SVC
from .utils import l2_normalize
import torch

from .nystrom_car import (
    GammaSpec,
    NystromCARClassifier,
    NystromBasis,
    build_nystrom_basis,
    load_nystrom_basis,
)

MethodName = Literal["dom", "logistic", "hinge", "svm"]
method_names: Tuple[str, ...] = ("dom", "logistic", "hinge", "svm")
CarKernelName = Literal["linear", "poly", "rbf", "sigmoid"]


@dataclass
class Concept:
    name: str
    tensor: "torch.Tensor" = None


# ---- linear probes for CAVs -------------------------------------------------

def _train_logistic(
    X: np.ndarray,
    y: np.ndarray,
    C: float = 1.0,
    max_iter: int = 2000,
    random_state: int = 0,
):
    lr = LogisticRegression(
        C=C,
        solver="liblinear",
        max_iter=max_iter,
        random_state=random_state,
    )
    lr.fit(X, y)
    w = lr.coef_.ravel()
    b = float(lr.intercept_.ravel()[0])
    acc = float((lr.predict(X) == y).mean())
    return w, b, acc


def _train_hinge(
    X: np.ndarray,
    y: np.ndarray,
    alpha: float = 1e-4,
    max_iter: int = 3000,
    random_state: int = 0,
):
    svm = SGDClassifier(
        loss="hinge",
        alpha=alpha,
        max_iter=max_iter,
        tol=1e-3,
        random_state=random_state,
    )
    svm.fit(X, y)
    w = svm.coef_.ravel()
    b = float(svm.intercept_.ravel()[0]) if svm.fit_intercept else 0.0
    acc = float((svm.predict(X) == y).mean())
    return w, b, acc


def _train_dom(Xp: np.ndarray, Xn: np.ndarray):
    mu_p = Xp.mean(axis=0)
    mu_n = Xn.mean(axis=0)
    w = mu_p - mu_n
    b = -0.5 * float((mu_p + mu_n) @ w)
    X = np.vstack([Xp, Xn])
    y = np.concatenate([np.ones(len(Xp)), np.zeros(len(Xn))])
    acc = float((((X @ w) + b >= 0).astype(int) == y).mean())
    return w, b, acc


def _train_svm(
    X: np.ndarray,
    y: np.ndarray,
    C: float = 1.0,
    max_iter: int = 10_000,
    random_state: int = 0,
):
    svm = LinearSVC(C=C, max_iter=max_iter, random_state=random_state)
    svm.fit(X, y)
    w = svm.coef_.ravel()
    b = float(svm.intercept_.ravel()[0])
    acc = float((svm.predict(X) == y).mean())
    return w, b, acc

def _train_car(
    X: np.ndarray,
    y: np.ndarray,
    kernel: CarKernelName = "rbf",
    C: float = 1.0,
    gamma: float | str = "scale",
    class_weight: Any = "balanced",
    random_state: int = 0,
    **kwargs: Any,
):
    clf = SVC(
        kernel=kernel,
        C=C,
        gamma=gamma,
        class_weight=class_weight,
        probability=False,
        random_state=random_state,
        **kwargs,
    )
    clf.fit(X, y)
    acc = float((clf.predict(X) == y).mean())
    return clf, acc


def train_cav(
    X_pos: np.ndarray,
    X_neg: np.ndarray,
    method: MethodName = "dom",
    random_state: int = 0,
    **kwargs: Any,
) -> dict:
    X = np.vstack([X_pos, X_neg])
    y = np.concatenate([np.ones(len(X_pos)), np.zeros(len(X_neg))])

    if method == "dom":
        w, b, acc = _train_dom(X_pos, X_neg)
    elif method == "logistic":
        w, b, acc = _train_logistic(X, y, **kwargs, random_state=random_state)
    elif method == "hinge":
        w, b, acc = _train_hinge(X, y, **kwargs, random_state=random_state)
    elif method == "svm":
        w, b, acc = _train_svm(X, y, **kwargs, random_state=random_state)
    else:
        raise ValueError(f"Unknown method: {method}")

    w = l2_normalize(w)
    return {
        "vector": w,
        "bias": float(b),
        "acc": float(acc),
        "method": method,
        "meta": {"random_state": random_state},
    }


def sample_train_cav(
    X_pos_all: np.ndarray,
    X_neg_all: np.ndarray,
    n_examples: int,
    method: MethodName,
    random_state: int = 0,
    with_replacement: bool = False,
    **kwargs: Any,
) -> dict:
    rng = np.random.default_rng(random_state)

    def _sample(X: np.ndarray) -> np.ndarray:
        if with_replacement:
            size = n_examples
            replace = True
        else:
            size = min(n_examples, len(X))
            replace = False
        idx = rng.choice(len(X), size=size, replace=replace)
        return X[idx]

    return train_cav(
        _sample(X_pos_all),
        _sample(X_neg_all),
        method=method,
        random_state=random_state,
        **kwargs,
    )


# ---- CARs: kernel SVC concept activation regions ----------------------------

def train_car(
    X_pos: np.ndarray,
    X_neg: np.ndarray,
    kernel: CarKernelName = "rbf",
    C: float = 1.0,
    gamma: float | str = "scale",
    class_weight: Any = "balanced",
    random_state: int = 0,
    **kwargs: Any,
) -> dict:
    X = np.vstack([X_pos, X_neg])
    y = np.concatenate([np.ones(len(X_pos)), np.zeros(len(X_neg))])

    clf, acc = _train_car(
        X,
        y,
        kernel=kernel,
        C=C,
        gamma=gamma,
        class_weight=class_weight,
        random_state=random_state,
        **kwargs,
    )

    return {
        "clf": clf,
        "acc": acc,
        "kernel": kernel,
        "meta": {"random_state": random_state},
    }


def sample_train_car(
    X_pos_all: np.ndarray,
    X_neg_all: np.ndarray,
    n_examples: int,
    kernel: CarKernelName = "rbf",
    random_state: int = 0,
    with_replacement: bool = False,
    **kwargs: Any,
) -> dict:
    rng = np.random.default_rng(random_state)

    def _sample(X: np.ndarray) -> np.ndarray:
        if with_replacement:
            size = n_examples
            replace = True
        else:
            size = min(n_examples, len(X))
            replace = False
        idx = rng.choice(len(X), size=size, replace=replace)
        return X[idx]

    return train_car(
        _sample(X_pos_all),
        _sample(X_neg_all),
        kernel=kernel,
        random_state=random_state,
        **kwargs,
    )


# ---- CARs: Nyström approximation -------------------------------------------

def get_or_create_nystrom_basis(
    X_pos_all: np.ndarray,
    X_neg_all: np.ndarray,
    *,
    basis_path: str,
    n_components: int = 200,
    gamma: GammaSpec = "scale",
    basis_random_state: int = 0,
) -> NystromBasis:
    """Load a Nyström basis from disk, or create and save it.

    The basis is constructed on the concatenated pool ``[X_pos_all; X_neg_all]``.
    """
    if basis_path is None:
        raise ValueError("basis_path must be provided")

    # Fast path: already on disk.
    try:
        return load_nystrom_basis(basis_path)
    except Exception:
        pass

    X_pool = np.vstack([X_pos_all, X_neg_all])
    basis = build_nystrom_basis(
        X_pool,
        n_components=n_components,
        gamma=gamma,
        random_state=basis_random_state,
    )
    basis.save(basis_path)
    # Ensure the cache sees the on-disk version.
    return load_nystrom_basis(basis_path)


def train_car_nystrom_from_features(
    Phi_pos: np.ndarray,
    Phi_neg: np.ndarray,
    *,
    basis_path: str,
    C: float = 1.0,
    class_weight: Any = "balanced",
    max_iter: int = 10_000,
    random_state: int = 0,
) -> dict:
    """Train a CAR in *explicit* Nyström feature space.

    Parameters
    ----------
    Phi_pos, Phi_neg:
        Nyström features for the positive/negative pools.
    basis_path:
        Path to the basis file. Stored on the classifier so it can also accept
        raw representations later (it will transform them lazily).

    Returns
    -------
    A dict compatible with the rest of the codebase:
    ``{"clf": ..., "vector": beta, ...}``.
    """
    X = np.vstack([Phi_pos, Phi_neg]).astype(np.float32, copy=False)
    y = np.concatenate([np.ones(len(Phi_pos)), np.zeros(len(Phi_neg))]).astype(int)

    svm = LinearSVC(
        C=C,
        class_weight=class_weight,
        max_iter=max_iter,
        random_state=random_state,
        dual="auto",
    )
    svm.fit(X, y)

    w = svm.coef_.ravel().astype(np.float32, copy=False)
    b = float(svm.intercept_.ravel()[0])

    # Normalize (w, b) to make variance comparable across runs.
    # Scaling by a positive constant does not change predictions.
    w_norm = float(np.linalg.norm(w) + 1e-12)
    w = w / w_norm
    b = b / w_norm

    clf = NystromCARClassifier(w=w, b=b, basis_path=basis_path)
    acc = float((clf.predict(X) == y).mean())

    return {
        "clf": clf,
        "vector": w,  # explicit beta in R^m
        "bias": float(b),
        "acc": float(acc),
        "kernel": "rbf",
        "approx": "nystrom",
        "nystrom_m": int(w.shape[0]),
        "meta": {"random_state": random_state},
    }


def sample_train_car_nystrom_from_features(
    Phi_pos_all: np.ndarray,
    Phi_neg_all: np.ndarray,
    n_examples: int,
    *,
    basis_path: str,
    random_state: int = 0,
    with_replacement: bool = False,
    **kwargs: Any,
) -> dict:
    """Sample a CAR training set and train a Nyström-CAR in feature space."""
    rng = np.random.default_rng(random_state)

    def _sample(Phi: np.ndarray) -> np.ndarray:
        if with_replacement:
            size = n_examples
            replace = True
        else:
            size = min(n_examples, len(Phi))
            replace = False
        idx = rng.choice(len(Phi), size=size, replace=replace)
        return Phi[idx]

    return train_car_nystrom_from_features(
        _sample(Phi_pos_all),
        _sample(Phi_neg_all),
        basis_path=basis_path,
        random_state=random_state,
        **kwargs,
    )
