import warnings
from collections.abc import Sequence

import torch
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, \
    average_precision_score

warnings.filterwarnings("ignore")

METRIC_MAPPING = {
    "accuracy": accuracy_score,
    "f1": f1_score,
    "precision": precision_score,
    "recall": recall_score,
    "roc-auc": roc_auc_score,
    "ap": average_precision_score,
    "average_precision": average_precision_score
}


class Scorer:

    def __init__(self, metrics: dict, criterion: str, invert_crit:bool, class_names=None):
        """
        Initialize the Scorer with the metrics to be calculated.

        Args:
            metrics (dict): A dictionary where keys are metric names and values are dictionaries of parameters for each metric.
            criterion (str): The key metric to be used for early stopping or evaluation.
            invert_crit (bool): Whether to invert the key metric for early stopping (lower should always be better).
            class_names (list[str], optional): List of class names for multi-class metrics. If None, default names will be used.
        """
        if criterion == "loss" and invert_crit:
            raise ValueError("Criterion 'loss' should be inverted. Set invert_crit to False.")

        self.metric_params = metrics
        self.criterion = criterion
        self.invert = invert_crit
        self.class_names = class_names

    def __call__(self, targets: torch.tensor, predictions: torch.tensor) -> tuple[float, dict[str, dict[str, float]]]:
        return self.calc_test_scores(targets, predictions)

    def calc_score(self,
                   targets: torch.tensor,
                   predictions: torch.tensor,
                   params: dict,
                   class_names: list[str] = None) -> dict[str, float]:
        """
        Calculate the accuracy, F1, precision, or recall score for passed predictions and targets.

        Given the input of metric this function calculates the total score and the separate scores for each class.

        :param targets: True labels
        :param predictions: Predicted labels
        :param metric: The metric to calculate the score
        :param params: Parameters for the metric function
        :return: Dictionary of the total and separate scores for each class
        """
        score_prediction: bool = params['function'] in ["roc-auc", "ap", "average_precision"]

        # Case: Output is multi dimensional, but target is one dimensional (multi-class classification)
        if (len(predictions.shape) > 1) and (len(targets.shape) == 1) and not score_prediction:
            predictions = torch.argmax(predictions, dim=1)

        # Case: Output is two dimension, but target is one dimensional (binary classification)
        elif (len(predictions.shape) == 2) and (len(targets.shape) == 1) and score_prediction:
            predictions = predictions[:, 1]

        # Case: Output and targets are multi dimensional and metric does not take liklihood (multi-label classification)
        elif type(predictions) != torch.int32 and (len(targets.shape) > 1) and not score_prediction:
            predictions = torch.round(predictions, decimals=0)

        # Case: unsure, remove? # TODO: check if necessary
        elif len(predictions.shape) == 1 and type(predictions) != torch.int32 and not score_prediction:
            predictions = torch.round(predictions, decimals=0)

        metric = METRIC_MAPPING[params.pop("function")]

        #TODO: finde better solution for this hotfix
        #round predictions if single continuous score (binary classification case
        # if len(predictions.shape) == 1 and type(predictions) != torch.int32:
        #     predictions = torch.round(predictions, decimals=0)

        total = metric(targets,
                       predictions,
                       **params)
        score_dict = {
            "0_total": total
        }

        if "average" in params and params['average'] is not None:
            params["average"] = None  # Reset average to calculate separate scores
            separate = metric(targets,
                              predictions,
                              **params)

            if isinstance(separate, Sequence):
                if class_names is None:
                    class_names = [f"class_{i + 1}" for i in range(len(separate))]
                for score, name in zip(separate, class_names):
                    score_dict[name] = score

        return score_dict

    def calc_test_scores(self, targets: torch.tensor, predictions: torch.tensor) -> tuple[
        float, dict[str, dict[str, float]]]:
        """
        Calculate the accuracy, F1, precision, and recall scores for passed predictions and targets.

        Given the input of targets and predictions this function calculates the total score and the separate scores for each
        class for each of the metrics.

        :param targets: True labels
        :param predictions: Predicted labels
        :return: A dict with a dict for each score containing the total and separate scores for each class
        """
        #
        # # Prepare predictions and targets
        # if predictions.dtype == torch.float:
        #     predictions = torch.round(predictions)

        # Detach tensors from the computation graph and move them to CPU
        predictions = predictions.detach().cpu()
        targets = targets.detach().cpu()

        # Calculate scores for each metric
        scores = {
            score: self.calc_score(targets, predictions, params.copy()) for score, params in self.metric_params.items()
        }

        if self.criterion != "loss":
            early_stopping_score = scores[self.criterion]["0_total"] if not self.invert else -scores[self.criterion]["0_total"]
        else:
            early_stopping_score = None

        return early_stopping_score, scores
