# tcav_scoring.py
from __future__ import annotations
from typing import Any, Sequence
import numpy as np
import pandas as pd
from typing import Literal, Tuple, Any, Iterable, List, Dict

from .activations import get_gradient_at_layer
from .cache import save_df_bundle, try_load_df_bundle

try:  # optional (only needed for fast-path TCAR on Nyström CARs)
    from .nystrom_car import NystromCARClassifier
except Exception:  # pragma: no cover
    NystromCARClassifier = None  # type: ignore

def tcar_scores(
    car: dict,
    H: np.ndarray,
    y: np.ndarray,
    classes: Iterable[Any],
    positive_label: int = 1,
) -> dict:
    clf = car["clf"]
    H = np.asarray(H)
    y = np.asarray(y)
    scores: dict[Any, float] = {}
    for k in classes:
        mask = (y == k)
        if not np.any(mask):
            scores[k] = float("nan")
        else:
            preds = clf.predict(H[mask])
            scores[k] = float((preds == positive_label).mean())
    return scores


def tcar_score(
    car: dict,
    H: np.ndarray,
    y: np.ndarray,
    class_label: Any,
    positive_label: int = 1,
) -> float:
    return tcar_scores(car, H, y, [class_label], positive_label)[class_label]


def _get_cav_vector(obj: Any) -> np.ndarray:
    if isinstance(obj, dict) and "vector" in obj:
        return np.asarray(obj["vector"], dtype=float)
    return np.asarray(obj, dtype=float)


def _get_car_classifier(obj: Any):
    if hasattr(obj, "predict"):
        return obj
    if isinstance(obj, dict):
        if "clf" in obj:
            return obj["clf"]
        if "classifier" in obj:
            return obj["classifier"]
    raise TypeError("Could not find classifier in CAR object")


def _is_car_method(method: str) -> bool:
    m = (method or "").lower()
    return m in {"car", "tcar", "car_nystrom", "tcar_nystrom", "nystrom_car", "car-nystrom"}


def compute_concept_score(
    method: str,
    *,
    gradients: np.ndarray | None = None,       # [M, D] (TCAV)
    representations: np.ndarray | None = None, # [M, D] (TCAR)
    concept_obj: Any,
    positive_label: int = 1,
) -> float:
    # TCAR: fraction of examples in CAR H_c (positive region of SVC). :contentReference[oaicite:1]{index=1}
    if _is_car_method(method):
        if representations is None:
            raise ValueError("representations is required for CAR / TCAR scoring")
        clf = _get_car_classifier(concept_obj)
        preds = clf.predict(representations)
        return float((preds == positive_label).mean())

    # default: TCAV (fraction of positive directional derivatives)
    if gradients is None:
        raise ValueError("gradients is required for CAV / TCAV scoring")
    v = _get_cav_vector(concept_obj)
    s = gradients @ v
    num_examples = gradients.shape[0]
    return float((s > 0).sum()) / max(1, num_examples)


def compute_concept_scores_for_run(
    method: str,
    *,
    gradients: np.ndarray | None = None,
    representations: np.ndarray | None = None,
    concepts_for_run: Sequence[Any],
    positive_label: int = 1,
) -> np.ndarray:
    # Fast path for TCAR when all CARs for this run share the same Nyström basis:
    # transform representations once, then evaluate many linear classifiers.
    if _is_car_method(method):
        if representations is None:
            raise ValueError("representations is required for CAR / TCAR scoring")
        if len(concepts_for_run) == 0:
            return np.asarray([], dtype=float)

        reps = np.asarray(representations)
        first_clf = _get_car_classifier(concepts_for_run[0])

        can_share_transform = (
            NystromCARClassifier is not None
            and isinstance(first_clf, NystromCARClassifier)
        )

        Phi = reps
        if can_share_transform:
            # Verify all classifiers are Nyström-based and share the same basis file.
            basis_path0 = getattr(first_clf, "basis_path", None)
            for obj in concepts_for_run[1:]:
                clf_i = _get_car_classifier(obj)
                if not isinstance(clf_i, NystromCARClassifier):
                    can_share_transform = False
                    break
                if getattr(clf_i, "basis_path", None) != basis_path0:
                    can_share_transform = False
                    break

        if can_share_transform and reps.ndim == 2 and reps.shape[1] != first_clf.feature_dim:
            Phi = first_clf.transform(reps)

        scores: list[float] = []
        for obj in concepts_for_run:
            clf = _get_car_classifier(obj)
            preds = clf.predict(Phi)
            scores.append(float((preds == positive_label).mean()))
        return np.asarray(scores, dtype=float)

    # Default path: TCAV
    scores: list[float] = []
    for obj in concepts_for_run:
        score = compute_concept_score(
            method,
            gradients=gradients,
            representations=representations,
            concept_obj=obj,
            positive_label=positive_label,
        )
        scores.append(score)
    return np.asarray(scores, dtype=float)


def compute_concept_score_variance(
    method: str,
    *,
    gradients: np.ndarray | None = None,
    representations: np.ndarray | None = None,
    concepts_for_run: Sequence[Any],
    positive_label: int = 1,
) -> float:
    scores = compute_concept_scores_for_run(
        method,
        gradients=gradients,
        representations=representations,
        concepts_for_run=concepts_for_run,
        positive_label=positive_label,
    )
    if len(scores) <= 1:
        return 0.0
    return float(np.var(scores, ddof=1))

