from typing import Callable, Tuple, Dict, Any

import numpy as np
import torch
from sklearn.metrics import roc_auc_score, f1_score, precision_recall_curve, auc

from .ts_precision_recall import constant_bias_fn, inverse_proportional_cardinality_fn, ts_precision_and_recall, \
    improved_cardinality_fn, compute_window_indices


class Evaluator:
    """
    A class that can compute several evaluation metrics for a dataset. Each method must return the score as a single float,
    but it can also return additional information in a dict.
    """
    def __init__(self, precision : int = 2):
        super(Evaluator, self).__init__()

        self.precision = precision

    def auc(self, labels: torch.Tensor, scores: torch.Tensor) -> Tuple[float, Dict[str, Any]]:
        """Compute the AUC score over the dataset.

        :return: AUC score.
        :rtype: float
        """

        return roc_auc_score(labels.numpy(), scores.numpy()), {}

    def f1_score(self, labels: torch.Tensor, scores: torch.Tensor, pos_label: int = 1) -> Tuple[float, Dict[str, Any]]:
        """Compute the F1 score for the predictions.

        :param pos_label: Class to report.
        :type pos_label: int
        :return: F1 score.
        :rtype: float
        """

        return f1_score(labels.numpy(), scores.numpy(), pos_label=pos_label).item(), {}

    def best_fbeta_score(self, labels: torch.Tensor, scores: torch.Tensor, beta: float = 1) -> Tuple[float, Dict[str, Any]]:
        precision, recall, thresholds = precision_recall_curve(labels.numpy(), scores.numpy())

        f_score = np.nan_to_num((1 + beta ** 2) * precision * recall / (beta ** 2 * precision + recall), nan=0)
        best_index = np.argmax(f_score)

        return f_score[best_index].item(), dict(threshold=thresholds[best_index].item())

    def best_f1_score(self, labels, scores) -> Tuple[float, Dict[str, Any]]:
        return self.best_fbeta_score(labels, scores, 1)

    def auprc(self, labels: torch.Tensor, scores: torch.Tensor, integration='trapezoid') -> Tuple[float, Dict[str, Any]]:
        precision, recall, thresholds = precision_recall_curve(labels.numpy(), scores.numpy())
        # recall is nan in the case where all ground-truth labels are 0. Simply set it to zero here
        # so that it does not contribute to the area
        recall = np.nan_to_num(recall, nan=0)

        if integration == 'riemann':
            area = -np.sum(np.diff(recall) * precision[:-1])
        else:
            area = auc(recall, precision)

        return area.item(), {}

    def average_precision(self, labels: torch.Tensor, scores: torch.Tensor) -> Tuple[float, Dict[str, Any]]:
        return self.auprc(labels, scores, integration='riemann')

    def ts_auprc(self, labels: torch.Tensor, scores: torch.Tensor, alpha: float = 0,
                 recall_bias_fn: Callable = constant_bias_fn,
                 recall_cardinality_fn: Callable = improved_cardinality_fn,
                 precision_bias_fn: Callable = None,
                 precision_cardinality_fn: Callable = None,
                 integration='trapezoid',
                 weighted_precision: bool = False) \
            -> Tuple[float, Dict[str, Any]]:
        """Compute the AUPRC score over the dataset of time series.

        The score function should return two tensors of shape (batch_size, length).

        :param recall_bias_fn: Function that computes the bias term for a batch of segments.
        :type recall_bias_fn: Callable
        :param recall_cardinality_fn: Function that compute the cardinality for a batch of segments.
        :type recall_cardinality_fn: Callable
        :param precision_bias_fn: Function that computes the bias term for a batch of segments. If None, then
        recall_bias_fn will be used
        :type precision_bias_fn: Callable
        :param precision_cardinality_fn: Function that compute the cardinality for a batch of segments. If None, then
        recall_cardinality_fn will be used
        :type precision_cardinality_fn: Callable
        :param weighted_precision: If True, the precision score of a predicted window will be weighted with the
            length of the window in the final score. Otherwise, each window will have the same weight.
        :return: AUPRC score.
        :rtype: float
        """
        thresholds = torch.unique(input=scores, sorted=True)

        precision = torch.empty(thresholds.shape[0] + 1, dtype=torch.float, device=thresholds.device)
        recall = torch.empty(thresholds.shape[0] + 1, dtype=torch.float, device=thresholds.device)
        predictions = torch.empty_like(scores, dtype=torch.long)

        # Set last values when threshold is at infinity so that no point is predicted as anomalous.
        # Precision is not defined in this case, we set it to 1 to stay consistent with scikit-learn
        precision[-1] = 1
        recall[-1] = 0

        label_ranges = compute_window_indices(labels)

        for i, t in enumerate(thresholds):
            torch.greater_equal(scores, t, out=predictions)
            prec, rec = ts_precision_and_recall(labels, predictions, alpha,
                                                recall_bias_fn, recall_cardinality_fn, precision_bias_fn,
                                                precision_cardinality_fn, anomaly_ranges=label_ranges,
                                                weighted_precision=weighted_precision)
            precision[i] = prec
            recall[i] = rec

        if integration == 'riemann':
            area = -torch.sum(torch.diff(recall) * precision[:-1])
        else:
            area = auc(recall.numpy(), precision.numpy())

        return area.item(), {}

    def ts_average_precision(self, labels: torch.Tensor, scores: torch.Tensor, alpha: float = 0,
                             recall_bias_fn: Callable = constant_bias_fn,
                             recall_cardinality_fn: Callable = improved_cardinality_fn,
                             precision_bias_fn: Callable = None,
                             precision_cardinality_fn: Callable = None,
                             weighted_precision: bool = False
                             ) -> Tuple[float, Dict[str, Any]]:
        """Compute the Average Precision score over the dataset of time series.

        The score function should return two tensors of shape (batch_size, length).

        :param recall_bias_fn: Function that computes the bias term for a batch of segments.
        :type recall_bias_fn: Callable
        :param recall_cardinality_fn: Function that compute the cardinality for a batch of segments.
        :type recall_cardinality_fn: Callable
        :param precision_bias_fn: Function that computes the bias term for a batch of segments. If None, then
        recall_bias_fn will be used
        :type precision_bias_fn: Callable
        :param precision_cardinality_fn: Function that compute the cardinality for a batch of segments. If None, then
        recall_cardinality_fn will be used
        :type precision_cardinality_fn: Callable
        :param weighted_precision: If True, the precision score of a predicted window will be weighted with the
            length of the window in the final score. Otherwise, each window will have the same weight.
        :return: Average Precision score.
        :rtype: float
        """

        return self.ts_auprc(labels, scores, alpha, recall_bias_fn, recall_cardinality_fn, precision_bias_fn,
                             precision_cardinality_fn, integration='riemann', weighted_precision=weighted_precision)

    def ts_auprc_v3(self, labels: torch.Tensor, scores: torch.Tensor, alpha: float = 0,
                    recall_bias_fn: Callable = constant_bias_fn,
                    recall_cardinality_fn: Callable = improved_cardinality_fn,
                    precision_bias_fn: Callable = None,
                    precision_cardinality_fn: Callable = None,
                    integration='trapezoid') \
            -> Tuple[float, Dict[str, Any]]:
        return self.ts_auprc(labels, scores, alpha=alpha,
                             recall_bias_fn=recall_bias_fn,
                             recall_cardinality_fn=recall_cardinality_fn,
                             precision_bias_fn=precision_bias_fn,
                             precision_cardinality_fn=precision_cardinality_fn,
                             integration=integration,
                             weighted_precision=True)

    def best_ts_f1_score(self, labels: torch.Tensor, scores: torch.Tensor, alpha: float = 0,
                         recall_bias_fn: Callable = constant_bias_fn,
                         recall_cardinality_fn: Callable = inverse_proportional_cardinality_fn,
                         precision_bias_fn: Callable = None,
                         precision_cardinality_fn: Callable = None
                         ) -> Tuple[float, Dict[str, Any]]:
        return self.best_ts_fbeta_score(labels, scores, alpha, 1, recall_bias_fn, recall_cardinality_fn, precision_bias_fn,
                                        precision_cardinality_fn)

    def best_ts_fbeta_score(self, labels: torch.Tensor, scores: torch.Tensor, alpha: float = 0, beta: float = 1,
                            recall_bias_fn: Callable = constant_bias_fn,
                            recall_cardinality_fn: Callable = inverse_proportional_cardinality_fn,
                            precision_bias_fn: Callable = None,
                            precision_cardinality_fn: Callable = None,
                            weighted_precision: bool = False
                            ) -> Tuple[float, Dict[str, Any]]:
        """Compute the best F_beta score over the dataset of time series.

        The score function should return two tensors of shape (batch_size, length).

        :param recall_bias_fn: Function that computes the bias term for a batch of segments.
        :type recall_bias_fn: Callable
        :param recall_cardinality_fn: Function that compute the cardinality for a batch of segments.
        :type recall_cardinality_fn: Callable
        :param precision_bias_fn: Function that computes the bias term for a batch of segments. If None, then
        recall_bias_fn will be used
        :type precision_bias_fn: Callable
        :param precision_cardinality_fn: Function that compute the cardinality for a batch of segments. If None, then
        recall_cardinality_fn will be used
        :type precision_cardinality_fn: Callable
        :param weighted_precision: If True, the precision score of a predicted window will be weighted with the
            length of the window in the final score. Otherwise, each window will have the same weight.
        :return: Best F1 score.
        :rtype: float
        """
        thresholds = torch.unique(input=scores, sorted=True)

        precision = torch.empty_like(thresholds, dtype=torch.float)
        recall = torch.empty_like(thresholds, dtype=torch.float)
        predictions = torch.empty_like(scores, dtype=torch.long)

        label_ranges = compute_window_indices(labels)
        # label_ranges = None

        for i, t in enumerate(thresholds):
            torch.greater(scores, t, out=predictions)
            prec, rec = ts_precision_and_recall(labels, predictions, alpha,
                                                recall_bias_fn, recall_cardinality_fn, precision_bias_fn,
                                                precision_cardinality_fn, anomaly_ranges=label_ranges,
                                                weighted_precision=weighted_precision)

            # We need to handle the case where precision and recall are both 0. This can either happen for an
            # extremely bad classifier or if all predictions are 0
            if prec == rec == 0:
                # We simply set rec = 1 to avoid dividing by zero. The F-score will still be 0
                rec = 1

            precision[i] = prec
            recall[i] = rec

        f_score = (1 + beta**2) * precision * recall / (beta**2 * precision + recall)
        max_score_index = torch.argmax(f_score)

        return f_score[max_score_index].item(), dict(threshold=thresholds[max_score_index].item(),
                                                     precision=precision[max_score_index].item(),
                                                     recall=recall[max_score_index].item())

    def best_ts_fbeta_score_v2(self, labels: torch.Tensor, scores: torch.Tensor, alpha: float = 0, beta: float = 1,
                               recall_bias_fn: Callable = constant_bias_fn,
                               precision_bias_fn: Callable = None,
                               precision_cardinality_fn: Callable = None,
                               weighted_precision: bool = False
                               ) -> Tuple[float, Dict[str, Any]]:
        return self.best_ts_fbeta_score(labels, scores, alpha, beta, recall_bias_fn, improved_cardinality_fn,
                                        precision_bias_fn, precision_cardinality_fn,
                                        weighted_precision=weighted_precision)

    def best_ts_f1_score_v2(self, labels: torch.Tensor, scores: torch.Tensor, alpha: float = 0,
                            recall_bias_fn: Callable = constant_bias_fn,
                            precision_bias_fn: Callable = None,
                            precision_cardinality_fn: Callable = None,
                            weighted_precision: bool = False
                            ) -> Tuple[float, Dict[str, Any]]:
        return self.best_ts_fbeta_score_v2(labels, scores, alpha, 1, recall_bias_fn, precision_bias_fn,
                                           precision_cardinality_fn, weighted_precision=weighted_precision)

    def best_ts_fbeta_score_v3(self, labels: torch.Tensor, scores: torch.Tensor, alpha: float = 0, beta: float = 1,
                               recall_bias_fn: Callable = constant_bias_fn,
                               precision_bias_fn: Callable = None,
                               precision_cardinality_fn: Callable = None
                               ) -> Tuple[float, Dict[str, Any]]:
        return self.best_ts_fbeta_score_v2(labels, scores, alpha, beta, recall_bias_fn, precision_bias_fn,
                                           precision_cardinality_fn, weighted_precision=True)

    def best_ts_f1_score_v3(self, labels: torch.Tensor, scores: torch.Tensor, alpha: float = 0,
                            recall_bias_fn: Callable = constant_bias_fn,
                            precision_bias_fn: Callable = None,
                            precision_cardinality_fn: Callable = None
                            ) -> Tuple[float, Dict[str, Any]]:
        return self.best_ts_fbeta_score_v3(labels, scores, alpha, 1, recall_bias_fn, precision_bias_fn,
                                           precision_cardinality_fn)
