import pickle
from typing import Callable, Collection, Dict, Iterable, Literal, Union

import numpy as np
import scipy.special
import torch

__all__ = [
    "prepare_machine_interpretability_score",
    "get_available_similarity_functions",
]
__similarity_function_registry: Dict[
    str, Callable[[str], Callable[[list[str]], np.ndarray]]
] = {}


def get_available_similarity_functions() -> Iterable[str]:
    """Returns the names of all available similarity functions."""
    return __similarity_function_registry.keys()


def register_prepare_similarity_function(name: str):
    """Decorator to register a new image similarity function.

    Args:
        name: The name of the similarity function.
    """

    def decorator(fn: Callable[[str], Callable[[list[str]], np.ndarray]]):
        __similarity_function_registry[name] = fn
        return fn

    return decorator


@register_prepare_similarity_function("dreamsim")
def prepare_dreamsim_similarity(
    similarity_function_arg: str,
) -> Callable[[list[str]], np.ndarray]:
    """Prepares the dreamsim similarity function.

    Args:
        similarity_function_arg: The path to the dreamsim feature map.

    Returns:
        A function that computes the similarity between images.
    """

    def compute_similarities(image_fns: list[str]) -> np.ndarray:
        # Remove file extensions
        image_fns = [fn.split(".")[0] for fn in image_fns]

        features = [fn_feature_map[fn] for fn in image_fns]
        features = torch.tensor(np.array(features))
        scores = torch.nn.functional.cosine_similarity(
            features.unsqueeze(1), features.unsqueeze(0), dim=-1
        ).numpy()

        return scores

    with open(similarity_function_arg, "rb") as f:
        fn_feature_map: dict[str, np.ndarray] = pickle.load(f)

    # Remove file extensions
    fn_feature_map = {
        fn.split(".")[0]: features for fn, features in fn_feature_map.items()
    }

    if not list(fn_feature_map.keys())[0].startswith("train/"):
        # we have absolute paths, so we need to remove the prefix
        fn_feature_map = {
            "train/" + k.split("train/")[1]: v for k, v in fn_feature_map.items()
        }

    return compute_similarities


def prepare_machine_interpretability_score_inference(
    compute_similarities_fn, inference_model_arg: str
) -> Callable:
    """
    Prepare a function computing the machine interpretability score given similarities.

    Args:
        compute_similarities_fn: A function that takes a list of image filenames
            and returns a matrix of similarities between the images.
        inference_model_arg: A string that specifies the path to the inference
            model. The inference model is a tuple of the form

    Returns:
        A function that takes a list of image filenames and returns the
        machine interpretability score.
    """

    with open(inference_model_arg, "rb") as f:
        inference_model, preprocess_x = pickle.load(f)

    def compute_machine_interpretability_score(
        batch_image_filenames: list[str], include_individual_scores: bool = False
    ):
        batch_dreamsim_scores = []
        for image_filenames in batch_image_filenames:
            # Compute scores
            dreamsim_scores = compute_similarities_fn(image_filenames)
            batch_dreamsim_scores.append(dreamsim_scores)
        batch_dreamsim_scores = np.stack(batch_dreamsim_scores, 0)
        assert batch_dreamsim_scores.ndim == 3
        assert batch_dreamsim_scores.shape[1] == batch_dreamsim_scores.shape[2] == 20

        x = batch_dreamsim_scores
        x = np.concatenate(
            (x[:, :9, 9], x[:, :9, -1], x[:, 10:-1, 9], x[:, 10:-1, -1]), 1
        )

        x = preprocess_x.transform(x)

        y_pred = inference_model.predict(x)
        y_scores = inference_model.predict_proba(x)
        assert y_scores.ndim == 2, y_scores.shape
        mis_confidence = y_scores[:, 1].mean(0)
        mis = (y_pred == 1).mean(0)

        result = (mis, mis_confidence)

        if include_individual_scores:
            result = (result, (y_pred == 1, y_scores[:, 1]))

        return result

    return compute_machine_interpretability_score


def prepare_similarity_function(
    similarity_function_name: str, similarity_function_arg: str
) -> Callable[[list[str]], np.ndarray]:
    """Prepare a similarity function.

    Args:
        similarity_function_name: Name of the similarity function to use.
        similarity_function_arg: Argument for the similarity function.

    Returns:
        A function that takes a list of image filenames and returns a matrix of
        similarities between the images.
    """
    if similarity_function_name in __similarity_function_registry:
        return __similarity_function_registry[similarity_function_name](
            similarity_function_arg
        )
    else:
        raise ValueError(f"Unknown similarity function {similarity_function_name}")


def prepare_machine_interpretability_score(
    similarity_function_name: str, similarity_function_args: Collection[str]
):
    """Prepare a machine interpretability score function.

    Args:
        similarity_function_name: Name of the similarity function to use.
        similarity_function_args: Arguments for the similarity function.

    Returns:
        A function that takes a list of image filenames and returns the
        machine interpretability score.
    """
    similarity_function = prepare_similarity_function(
        similarity_function_name, similarity_function_args[0]
    )
    return prepare_machine_interpretability_score_inference(
        similarity_function, similarity_function_args[1]
    )


class SimpleBinaryClassifier:
    """Simple binary classifier that takes aggregates image similarity to solve a 2AFC task.

    Args:
        aggregation: Aggregation function to use. Either "max" or "mean".
        n_references: Number of reference images to consider.
        probability_mode: Probability mode to use. Either "sigmoid" or "softmax".
    """

    def __init__(
        self,
        aggregation: Union[Literal["max"], Literal["mean"]] = "mean",
        n_references: int = 9,
        probability_mode: Literal["sigmoid", "softmax"] = "sigmoid",
    ):
        self.aggregation_f = np.mean if aggregation == "mean" else np.max
        self.n_references = n_references
        self.probability_mode = probability_mode

    def predict_proba(self, x: np.ndarray) -> np.ndarray:
        if x.shape[1] != 4 * self.n_references:
            raise ValueError(
                f"Expected input of shape (N, {4*self.n_references}), got {x.shape}"
            )

        min_top_features = x[:, : self.n_references]
        min_bottom_features = x[:, self.n_references : 2 * self.n_references]
        max_top_features = x[:, 2 * self.n_references : 3 * self.n_references]
        max_bottom_features = x[:, 3 * self.n_references :]

        min_top_features = self.aggregation_f(min_top_features, axis=1)
        min_bottom_features = self.aggregation_f(min_bottom_features, axis=1)
        max_top_features = self.aggregation_f(max_top_features, axis=1)
        max_bottom_features = self.aggregation_f(max_bottom_features, axis=1)

        top_logits = max_top_features - min_top_features
        bottom_logits = max_bottom_features - min_bottom_features

        logit = top_logits - bottom_logits

        if self.probability_mode == "softmax":
            scores = scipy.special.softmax(np.stack((logit, -logit), -1), axis=-1)
        else:
            score = 1 / (1 + np.exp(-logit))
            scores = np.stack((score, 1 - score), -1)

        return scores

    def predict(self, x: np.ndarray) -> np.ndarray:
        return np.argmax(self.predict_proba(x), -1)
