import collections
import statistics
from abc import ABC, abstractmethod
from functools import reduce
from typing import Iterable

import numpy as np
import sklearn.metrics as sm
from sklearn.metrics import roc_auc_score


class PredictionFilter(ABC):
    """Filter predictions with a certain criteria
    """

    @abstractmethod
    def filter(self, predictions, return_mode):
        pass

    @property
    @abstractmethod
    def identifier(self):
        pass


class TopKPredictionFilter(PredictionFilter):

    def __init__(self, k: int, prediction_mode='prob'):
        """
        Args:
            k: k predictions with highest confidence
            prediction_mode: can be 'indices' or 'prob', indicating whether the predictions
            are a set of class indices or predicted probabilities.
        """
        assert k >= 0
        assert prediction_mode == 'prob' or prediction_mode == 'indices', \
            f"Prediction mode {prediction_mode} is not supported!"

        self.prediction_mode = prediction_mode
        self.k = k

    def filter(self, predictions, return_mode):

        k = min(predictions.shape[1], self.k)

        if self.prediction_mode == 'prob':
            if k == 0:
                top_k_pred_indices = np.array([[] for i in range(predictions.shape[1])], dtype=int)
            elif k == 1:
                top_k_pred_indices = np.argmax(predictions, axis=1)
                top_k_pred_indices = top_k_pred_indices.reshape((-1, 1))
            else:
                top_k_pred_indices = np.argpartition(predictions, -k, axis=1)[:, -k:]
        else:
            top_k_pred_indices = predictions[:, :k]

        if return_mode == 'indices':
            return list(top_k_pred_indices)
        else:
            preds = np.zeros_like(predictions, dtype=bool)
            row_index = np.repeat(range(len(predictions)), k)
            col_index = top_k_pred_indices.reshape((1, -1))
            preds[row_index, col_index] = True

            return preds

    @property
    def identifier(self):
        return f'top{self.k}'


class ThresholdPredictionFilter(PredictionFilter):
    def __init__(self, threshold: float):
        """
        Args:
            threshold: confidence threshold
        """

        self.threshold = threshold

    def filter(self, predictions, return_mode):
        """ Return predictions over confidence over threshold
        Args:
            predictions: the model output numpy array. Shape (N, num_class)
            return_mode: can be 'indices' or 'vec', indicating whether return value
                is a set of class indices or 0-1 vector
        Returns:
            labels with probabilities over threshold, for each sample
        """
        if return_mode == 'indices':
            preds_over_thres = [[] for _ in range(len(predictions))]
            for indices in np.argwhere(predictions >= self.threshold):
                preds_over_thres[indices[0]].append(indices[1])

            return preds_over_thres
        else:
            return predictions >= self.threshold

    @property
    def identifier(self):
        return f'thres={self.threshold}'


def _indices_to_vec(indices, length):
    target_vec = np.zeros(length, dtype=int)
    target_vec[indices] = 1

    return target_vec


def _targets_to_mat(targets, n_class):
    if len(targets.shape) == 1:
        target_mat = np.zeros((len(targets), n_class), dtype=int)
        for i, t in enumerate(targets):
            target_mat[i, t] = 1
    else:
        target_mat = targets

    return target_mat


class Evaluator(ABC):
    """Class to evaluate model outputs and report the result.
    """

    def __init__(self):
        self.custom_fields = {}
        self.reset()

    @abstractmethod
    def add_predictions(self, predictions, targets):
        raise NotImplementedError

    @abstractmethod
    def get_report(self, **kwargs):
        raise NotImplementedError

    def add_custom_field(self, name, value):
        self.custom_fields[name] = str(value)

    def reset(self):
        self.custom_fields = {}


class EvaluatorAggregator(Evaluator):
    def __init__(self, evaluators):
        self.evaluators = evaluators
        super(EvaluatorAggregator, self).__init__()

    def add_predictions(self, predictions, targets):
        for evaluator in self.evaluators:
            evaluator.add_predictions(predictions, targets)

    def get_report(self, **kwargs):
        return reduce(lambda x, y: x.update(y) or x,
                      [evalator.get_report(**kwargs) for evalator in self.evaluators])

    def reset(self):
        for evaluator in self.evaluators:
            evaluator.reset()


class MemorizingEverythingEvaluator(Evaluator, ABC):
    """
    Base evaluator that memorize all ground truth and predictions
    """

    def __init__(self, prediction_filter=None):
        self.all_targets = np.array([])
        self.all_predictions = np.array([])

        super(MemorizingEverythingEvaluator, self).__init__()
        self.prediction_filter = prediction_filter

    def reset(self):
        super(MemorizingEverythingEvaluator, self).reset()
        self.all_targets = np.array([])
        self.all_predictions = np.array([])

    def add_predictions(self, predictions, targets):
        """ Add a batch of predictions for evaluation.
        Args:
            predictions: the model output array. Shape (N, num_class)
            targets: the ground truths. Shape (N, num_class) for multi-label or (N,) for multi-class
        """

        assert len(predictions) == len(targets)

        predictions = self.prediction_filter.filter(predictions, 'vec') \
            if self.prediction_filter else predictions

        if self.all_predictions.size != 0:
            self.all_predictions = np.append(self.all_predictions, predictions, axis=0)
        else:
            self.all_predictions = predictions.copy()

        if self.all_targets.size != 0:
            self.all_targets = np.append(self.all_targets, targets, axis=0)
        else:
            self.all_targets = targets.copy()

    def calculate_score(self, average='macro', filter_out_zero_tgt=True):
        """
        average : string, [None, 'micro', 'macro' (default), 'samples', 'weighted']
        If ``None``, the scores for each class are returned. Otherwise,
        this determines the type of averaging performed on the data:
        ``'micro'``:
            Calculate metrics globally by considering each element of the label
            indicator matrix as a label.
        ``'macro'``:
            Calculate metrics for each label, and find their unweighted
            mean.  This does not take label imbalance into account.
        ``'weighted'``:
            Calculate metrics for each label, and find their average, weighted
            by support (the number of true instances for each label).
        ``'samples'``:
            Calculate metrics for each instance, and find their average.
        filter_out_zero_tgt : bool
        Removes target columns that are all zero. For precision calculations this
            needs to be set to False, otherwise we could be removing FP
        """
        if self.all_predictions.size == 0:
            return 0.0

        tar_mat = _targets_to_mat(self.all_targets, self.all_predictions.shape[1])
        assert tar_mat.size == self.all_predictions.size
        result = 0.0
        if tar_mat.size > 0:
            non_empty_idx = np.where(np.invert(np.all(tar_mat == 0, axis=0)))[0] \
                if filter_out_zero_tgt else np.arange(tar_mat.shape[1])
            if non_empty_idx.size != 0:
                result = self._calculate(tar_mat[:, non_empty_idx], self.all_predictions[:, non_empty_idx],
                                         average=average)

        return result

    @abstractmethod
    def _calculate(self, targets, predictions, average):
        pass

    @abstractmethod
    def _get_id(self):
        pass

    def get_report(self, **kwargs):
        average = kwargs.get('average', 'macro')
        return {self._get_id(): self.calculate_score(average)}


class TopKAccuracyEvaluator(Evaluator):
    """
    Top k accuracy evaluator for multiclass classification
    """

    def __init__(self, k: int):
        assert k > 0
        self.total_num = 0
        self.topk_correct_num = 0

        super(TopKAccuracyEvaluator, self).__init__()
        self.prediction_filter = TopKPredictionFilter(k)

    def reset(self):
        super(TopKAccuracyEvaluator, self).reset()
        self.total_num = 0
        self.topk_correct_num = 0

    def add_predictions(self, predictions, targets):
        """ Evaluate a batch of predictions.
        Args:
            predictions: the model output numpy array. Shape (N, num_class)
            targets: the golden truths. Shape (N,)
        """
        assert len(predictions) == len(targets)
        assert len(targets.shape) == 1

        n_sample = len(predictions)

        top_k_predictions = self.prediction_filter.filter(predictions, 'indices')
        self.topk_correct_num += len(
            [1 for sample_idx in range(n_sample) if targets[sample_idx] in top_k_predictions[sample_idx]])
        self.total_num += n_sample

    def get_report(self, **kwargs):
        return {f'accuracy_{self.prediction_filter.identifier}': float(
            self.topk_correct_num) / self.total_num if self.total_num else 0.0}


class F1ScoreEvaluator(EvaluatorAggregator):
    """
    F1 score evaluator for both multi-class and multi-label classification, which also reports
    precision and recall
    """

    def __init__(self, prediction_filter):
        super().__init__([RecallEvaluator(prediction_filter), PrecisionEvaluator(prediction_filter)])
        self._filter_id = prediction_filter.identifier

    def get_report(self, **kwargs):
        average = kwargs.get('average', 'macro')
        report = super(F1ScoreEvaluator, self).get_report(average=average)
        prec = report[f'precision_{self._filter_id}']
        recall = report[f'recall_{self._filter_id}']
        report[f'f1_score_{self._filter_id}'] = 2 * (prec * recall) / (prec + recall) \
            if prec + recall > 0 else 0.0

        return report


class PrecisionEvaluator(MemorizingEverythingEvaluator):
    """
    Precision evaluator for both multi-class and multi-label classification
    """

    def __init__(self, prediction_filter):
        super().__init__(prediction_filter)

    def _get_id(self):
        return f'precision_{self.prediction_filter.identifier}'

    def _calculate(self, targets, predictions, average):
        return sm.precision_score(targets, predictions, average=average)

    def get_report(self, **kwargs):
        average = kwargs.get('average', 'macro')
        return {self._get_id(): self.calculate_score(average=average, filter_out_zero_tgt=False)}


class RecallEvaluator(MemorizingEverythingEvaluator):
    """
    Recall evaluator for both multi-class and multi-label classification
    """

    def __init__(self, prediction_filter):
        super().__init__(prediction_filter)

    def _get_id(self):
        return f'recall_{self.prediction_filter.identifier}'

    def _calculate(self, targets, predictions, average):
        return sm.recall_score(targets, predictions, average=average)


class AveragePrecisionEvaluator(MemorizingEverythingEvaluator):
    """
    Average Precision evaluator for both multi-class and multi-label classification
    """

    def __init__(self):
        super().__init__()

    def _get_id(self):
        return 'average_precision'

    def calculate_score(self, average='macro'):
        if average != 'macro':
            return super().calculate_score(average=average)

        ap = 0.0
        if self.all_targets.size == 0:
            return ap

        n_class_with_gt = 0
        c_to_sample_indices = dict()
        for sample_idx, targets in enumerate(self.all_targets):
            if isinstance(targets, Iterable):
                for t_idx, t in enumerate(targets):
                    if t:
                        c_to_sample_indices.setdefault(t_idx, []).append(sample_idx)
            else:
                c_to_sample_indices.setdefault(targets, []).append(sample_idx)

        for c_idx in range(self.all_predictions.shape[1]):
            if c_idx not in c_to_sample_indices:
                continue
            class_target_vec = _indices_to_vec(c_to_sample_indices[c_idx], len(self.all_targets))
            ap += sm.average_precision_score(class_target_vec, self.all_predictions[:, c_idx])
            n_class_with_gt += 1

        return ap / n_class_with_gt if n_class_with_gt > 0 else 0.0

    def _calculate(self, targets, predictions, average):
        return sm.average_precision_score(targets, predictions, average=average)


class RocAucEvaluator(Evaluator):
    """
    Utilize sklearn.metrics.roc_auc_score to Compute Area Under the Receiver Operating
    Characteristic Curve (ROC AUC) from prediction scores.
    """

    def __init__(self):
        super(RocAucEvaluator, self).__init__()
        self.all_targets = None
        self.all_predictions = None

    def reset(self):
        super(RocAucEvaluator, self).reset()
        self.all_targets = None
        self.all_predictions = None

    def add_predictions(self, predictions, targets):
        """ add predictions and targets.
        Args:
            predictions: predictions of array-like of shape (n_samples,) or (n_samples, n_classes)
            targets: targets of array-like of shape (n_samples,) or (n_samples, n_classes)
        """
        self.all_targets = np.concatenate([self.all_targets, np.array(targets)]) \
            if self.all_targets else np.array(targets)
        self.all_predictions = np.concatenate([self.all_predictions, np.array(predictions)]) \
            if self.all_predictions else np.array(predictions)

    def get_report(self, **kwargs):
        average = kwargs.get('average', 'macro')
        sample_weight = kwargs.get('sample_weight')
        max_fpr = kwargs.get('max_fpr')
        multi_class = kwargs.get('multi_class', 'raise')
        labels = kwargs.get('labels')

        if len(self.all_targets.shape) == 1 and \
                len(self.all_predictions.shape) == 2 and self.all_predictions.shape[1] == 2:
            all_predictions = self.all_predictions[:, 1]
        else:
            all_predictions = self.all_predictions
        return {
            'roc_auc': sm.roc_auc_score(y_true=self.all_targets, y_score=all_predictions, average=average,
                                        sample_weight=sample_weight, max_fpr=max_fpr, multi_class=multi_class,
                                        labels=labels)
        }


class MeanAveragePrecisionEvaluatorForSingleIOU(Evaluator):
    def __init__(self, iou=0.5, report_tag_wise=False):
        """
        Args:
            iou: float, single IoU for matching
            report_tag_wise: if assigned True, also return the per class average precision
        """
        super(MeanAveragePrecisionEvaluatorForSingleIOU, self).__init__()
        self.iou = iou
        self.report_tag_wise = report_tag_wise

    def add_predictions(self, predictions, targets):
        """ Evaluate list of image with object detection results using single IOU evaluation.
        Args:
            predictions: list of predictions [[[label_idx, probability, L, T, R, B], ...], [...], ...]
            targets: list of image targets [[[label_idx, L, T, R, B], ...], ...]
        """

        assert len(predictions) == len(targets)

        eval_predictions = collections.defaultdict(list)
        eval_ground_truths = collections.defaultdict(dict)
        for img_idx, prediction in enumerate(predictions):
            for bbox in prediction:
                label = int(bbox[0])
                eval_predictions[label].append(
                    [img_idx, float(bbox[1]), float(bbox[2]), float(bbox[3]), float(bbox[4]), float(bbox[5])])

        for img_idx, target in enumerate(targets):
            for bbox in target:
                label = int(bbox[0])
                if img_idx not in eval_ground_truths[label]:
                    eval_ground_truths[label][img_idx] = []
                eval_ground_truths[label][img_idx].append(
                    [float(bbox[1]), float(bbox[2]), float(bbox[3]), float(bbox[4])])

        class_indices = set(list(eval_predictions.keys()) + list(eval_ground_truths.keys()))
        for class_index in class_indices:
            is_correct, probabilities = self._evaluate_predictions(eval_ground_truths[class_index],
                                                                   eval_predictions[class_index], self.iou)
            true_num = sum([len(t) for t in eval_ground_truths[class_index].values()])

            self.is_correct[class_index].extend(is_correct)
            self.probabilities[class_index].extend(probabilities)
            self.true_num[class_index] += true_num

    @staticmethod
    def _calculate_area(rect):
        w = rect[2] - rect[0] + 1e-5
        h = rect[3] - rect[1] + 1e-5
        return float(w * h) if w > 0 and h > 0 else 0.0

    @staticmethod
    def _calculate_iou(rect0, rect1):
        rect_intersect = [max(rect0[0], rect1[0]),
                          max(rect0[1], rect1[1]),
                          min(rect0[2], rect1[2]),
                          min(rect0[3], rect1[3])]
        calc_area = MeanAveragePrecisionEvaluatorForSingleIOU._calculate_area
        area_intersect = calc_area(rect_intersect)
        return area_intersect / (calc_area(rect0) + calc_area(rect1) - area_intersect)

    def _is_true_positive(self, prediction, ground_truth, already_detected, iou_threshold):
        image_id = prediction[0]
        prediction_rect = prediction[2:6]
        if image_id not in ground_truth:
            return False, already_detected

        ious = np.array([self._calculate_iou(prediction_rect, g) for g in ground_truth[image_id]])
        best_bb = np.argmax(ious)
        best_iou = ious[best_bb]

        if best_iou < iou_threshold or (image_id, best_bb) in already_detected:
            return False, already_detected

        already_detected.add((image_id, best_bb))
        return True, already_detected

    def _evaluate_predictions(self, ground_truths, predictions, iou_threshold):
        """ Evaluate the correctness of the given predictions.
        Args:
            ground_truths: List of ground truths for the class.
            predictions: List of predictions for the class.
            iou_threshold: Minimum IOU threshold to be considered as a same bounding box.
        """

        # Sort the predictions by the probability
        sorted_predictions = sorted(predictions, key=lambda x: -x[1])
        already_detected = set()
        is_correct = []
        for prediction in sorted_predictions:
            correct, already_detected = self._is_true_positive(prediction, ground_truths, already_detected,
                                                               iou_threshold)
            is_correct.append(correct)

        is_correct = np.array(is_correct)
        probabilities = np.array([p[1] for p in sorted_predictions])

        return is_correct, probabilities

    @staticmethod
    def _calculate_average_precision(is_correct, probabilities, true_num, average='macro'):
        if true_num == 0:
            return 0
        if not is_correct or not any(is_correct):
            return 0
        recall = float(np.sum(is_correct)) / true_num
        return sm.average_precision_score(is_correct, probabilities, average=average) * recall

    def get_report(self, **kwargs):
        average = kwargs.get('average', 'macro')
        for class_index in self.is_correct:
            ap = MeanAveragePrecisionEvaluatorForSingleIOU._calculate_average_precision(
                self.is_correct[class_index],
                self.probabilities[class_index],
                self.true_num[class_index],
                average
            )
            self.aps[class_index] = ap

        mean_ap = float(statistics.mean([self.aps[x] for x in self.aps])) if self.aps else 0.0
        key_name = f'mAP_{int(self.iou * 100)}'
        report = {key_name: mean_ap}
        if self.report_tag_wise:
            report[f'tag_wise_AP_{int(self.iou * 100)}'] = [self.aps[class_index] for class_index in self.aps]
        return report

    def reset(self):
        self.is_correct = collections.defaultdict(list)
        self.probabilities = collections.defaultdict(list)
        self.true_num = collections.defaultdict(int)
        self.aps = collections.defaultdict(float)
        super(MeanAveragePrecisionEvaluatorForSingleIOU, self).reset()


class MeanAveragePrecisionEvaluatorForMultipleIOUs(EvaluatorAggregator):
    DEFAULT_IOU_VALUES = (0.3, 0.5, 0.75, 0.9)

    def __init__(self, ious=DEFAULT_IOU_VALUES, report_tag_wise=None):
        if not report_tag_wise:
            report_tag_wise = len(ious) * [False]

        assert len(ious) == len(report_tag_wise)
        evaluators = [MeanAveragePrecisionEvaluatorForSingleIOU(ious[i], report_tag_wise[i])
                      for i in range(len(ious))]
        super(MeanAveragePrecisionEvaluatorForMultipleIOUs, self).__init__(evaluators)


class BalancedAccuracyScoreEvaluator(MemorizingEverythingEvaluator):
    """
    Average of recall obtained on each class, for multiclass classification problem
    """

    def _calculate(self, targets, predictions, average):
        single_targets = np.argmax(targets, axis=1)
        y_single_preds = np.argmax(predictions, axis=1)
        return sm.balanced_accuracy_score(single_targets, y_single_preds)

    def _get_id(self):
        return 'balanced_accuracy'


class PrecisionRecallCurveMixin():
    """
    N-point interpolated precision-recall curve, averaged over samples
    """

    def __init__(self, n_points=11):
        super().__init__()
        self.ap_n_points_eval = []
        self.n_points = n_points

    def _calc_precision_recall_interp(self, predictions, targets, recall_thresholds):
        """ Evaluate a batch of predictions.
        Args:
            predictions: the probability or score of the data to be 'positive'. Shape (N,)
            targets: the binary ground truths in {0, 1} or {-1, 1}. Shape (N,)
        """
        assert len(predictions) == len(targets)
        assert len(targets.shape) == 1

        precision, recall, _ = sm.precision_recall_curve(targets, predictions)
        precision_interp = np.empty(len(recall_thresholds))
        recall_idx = 0
        precision_tmp = 0
        for idx, threshold in enumerate(recall_thresholds):
            while recall_idx < len(recall) and threshold <= recall[recall_idx]:
                precision_tmp = max(precision_tmp, precision[recall_idx])
                recall_idx += 1
            precision_interp[idx] = precision_tmp
        return precision_interp


class MeanAveragePrecisionNPointsEvaluator(PrecisionRecallCurveMixin, MemorizingEverythingEvaluator):
    """
    N-point interpolated average precision, averaged over classes
    """

    def _calculate(self, targets, predictions, average):
        n_class = predictions.shape[1]
        recall_thresholds = np.linspace(1, 0, self.n_points, endpoint=True).tolist()
        return np.mean([np.mean(self._calc_precision_recall_interp(predictions[:, i],
                                                                   targets[:, i], recall_thresholds))
                        for i in range(n_class)])

    def _get_id(self):
        return f'mAP_{self.n_points}_points'


def accuracy(y_label, y_pred):
    """ Compute Top1 accuracy
    Args:
        y_label: the ground truth labels. Shape (N,)
        y_pred: the prediction of a model. Shape (N,)
    """
    evaluator = TopKAccuracyEvaluator(1)
    evaluator.add_predictions(predictions=y_pred, targets=y_label)
    return evaluator.get_report()['accuracy_top1']


def map_11_points(y_label, y_pred_proba):
    evaluator = MeanAveragePrecisionNPointsEvaluator(11)
    evaluator.add_predictions(predictions=y_pred_proba, targets=y_label)
    return evaluator.get_report()[evaluator._get_id()]


def balanced_accuracy_score(y_label, y_pred):
    evaluator = BalancedAccuracyScoreEvaluator()
    evaluator.add_predictions(y_pred, y_label)
    return evaluator.get_report()[evaluator._get_id()]


def roc_auc(y_true, y_score):
    if y_score.shape[1] == 2:
        return roc_auc_score(y_true, y_score[:, 1])
    return roc_auc_score(y_true, y_score)


def get_metric(metric_name):
    if metric_name == "accuracy":
        return accuracy
    elif metric_name == "mean-per-class":
        return balanced_accuracy_score
    elif metric_name == "11point_mAP":
        return map_11_points
    elif metric_name == "roc_auc":
        return roc_auc
    else:
        raise NotImplementedError
