from collections import OrderedDict

import numpy as np
from sklearn.metrics import (
    roc_curve as sklearn_roc_curve,
    auc as sklearn_auc
)


def intersect_size(yhat, y, axis):
    size = np.logical_and(yhat, y).sum(axis=axis).astype(float)
    return size


def union_size(yhat, y, axis):
    size = np.logical_or(yhat, y).sum(axis=axis).astype(float)
    return size


class AccuracyPrecisionRecallF1:
    EPS = 1e-10

    @classmethod
    def calculate_accuracy(cls, y_hat, y):
        num = intersect_size(y_hat, y, 0) / (union_size(y_hat, y, 0) + cls.EPS)
        return np.mean(num)

    @classmethod
    def calculate_precision(cls, y_hat, y):
        num = intersect_size(y_hat, y, 0) / (y_hat.sum(axis=0) + cls.EPS)
        return np.mean(num)

    @classmethod
    def calculate_recall(cls, y_hat, y):
        num = intersect_size(y_hat, y, 0) / (y.sum(axis=0) + cls.EPS)
        return np.mean(num)

    @classmethod
    def calculate_precision_recall_f1(cls, y_hat, y):
        precision = cls.calculate_precision(y_hat, y)
        recall = cls.calculate_recall(y_hat, y)
        if precision + recall == 0:
            f1 = 0.
        else:
            f1 = 2 * (precision * recall) / (precision + recall)
        return precision, recall, f1

    @classmethod
    def calculate(cls, y_hat, y, average='micro'):
        if average not in ['micro', 'macro']:
            raise KeyError('average should be "micro" or "macro"')
        if average == 'micro':
            y_hat = y_hat.ravel()
            y = y.ravel()
        accuracy = cls.calculate_accuracy(y_hat, y)
        precision, recall, f1 = cls.calculate_precision_recall_f1(y_hat, y)
        metrics = OrderedDict({
            f'acc_{average}': accuracy,
            f'prec_{average}': precision,
            f'rec_{average}': recall,
            f'f1_{average}': f1,
        })
        return metrics


class AUC:
    @classmethod
    def calculate(cls, y_hat_raw, y):
        fpr = {}
        tpr = {}
        relevant_labels = []
        auc_labels = {}
        for i in range(y.shape[1]):
            if y[:, i].sum() > 0:
                fpr[i], tpr[i], _ = sklearn_roc_curve(y[:, i], y_hat_raw[:, i])
                if len(fpr[i]) > 1 and len(tpr[i]) > 1:
                    auc_score = sklearn_auc(fpr[i], tpr[i])
                    if not np.isnan(auc_score):
                        auc_labels["auc_%d" % i] = auc_score
                        relevant_labels.append(i)

        auc_macro = np.mean([auc_labels['auc_%d' % i] for i in relevant_labels])

        y = y.ravel()
        y_hat_raw = y_hat_raw.ravel()
        fpr_micro, tpr_micro, _ = sklearn_roc_curve(y, y_hat_raw)
        auc_micro = sklearn_auc(fpr_micro, tpr_micro)

        aucs = OrderedDict({
            'auc_macro': auc_macro,
            'auc_micro': auc_micro
        })
        return aucs


class RecallPrecisionF1K:
    @classmethod
    def calculate_recall_k(cls, y_hat_raw, y, k):
        sorted_index = np.argsort(y_hat_raw)[:, ::-1]
        top_k = sorted_index[:, :k]

        vals = []
        for i, tk in enumerate(top_k):
            num_true_in_top_k = y[i, tk].sum()
            denom = y[i, :].sum()
            vals.append(num_true_in_top_k / float(denom))

        vals = np.array(vals)
        vals[np.isnan(vals)] = 0.

        return np.mean(vals)

    @classmethod
    def calculate_precision_k(cls, y_hat_raw, y, k):
        sorted_index = np.argsort(y_hat_raw)[:, ::-1]
        top_k = sorted_index[:, :k]

        vals = []
        for i, tk in enumerate(top_k):
            if len(tk) > 0:
                num_true_in_top_k = y[i, tk].sum()
                denom = len(tk)
                vals.append(num_true_in_top_k / float(denom))

        return np.mean(vals)

    @classmethod
    def calculate_precision_recall_f1_k(cls, y_hat_raw, y, k):
        precision_k = cls.calculate_precision_k(y_hat_raw, y, k)
        recall_k = cls.calculate_recall_k(y_hat_raw, y, k)
        if precision_k + recall_k == 0:
            f1_k = 0.
        else:
            f1_k = 2 * (precision_k * recall_k) / (precision_k + recall_k)
        return precision_k, recall_k, f1_k

    @classmethod
    def calculate(cls, y_hat_raw, y, ks):
        metrics = {}
        for k in ks:
            precision_k, recall_k, f1_k = cls.calculate_precision_recall_f1_k(y_hat_raw, y, k)
            metrics[f'p@{k}'] = precision_k
            metrics[f'r@{k}'] = recall_k
            metrics[f'f1@{k}'] = f1_k
        return metrics


class ICDCodingMetrics:
    def __init__(self, ks):
        self.y_hat_raw = []
        self.y = []
        self.ks = ks

        self.selected_metrics = [
            'auc_macro', 'auc_micro',
            'f1_macro', 'f1_micro',
            'p@5', 'p@8', 'p@15',
            'r@5', 'r@8', 'r@15'
        ]

    def __call__(self, preds, targets):
        self.y_hat_raw.append(preds)
        self.y.append(targets)

    def reset(self):
        self.y_hat_raw = []
        self.y = []

    def _compute(self, y_hat_raw: np.ndarray, y: np.ndarray, threshold):
        y_hat = (y_hat_raw >= threshold).astype(int)
        results = [
            AccuracyPrecisionRecallF1.calculate(y_hat, y, average='macro'),
            AccuracyPrecisionRecallF1.calculate(y_hat, y, average='micro'),
            AUC.calculate(y_hat_raw, y),
            RecallPrecisionF1K.calculate(y_hat_raw, y, self.ks)
        ]
        metrics = {}
        for result in results:
            metrics.update(result)
        metrics = {key: metrics[key] for key in self.selected_metrics}

        return metrics

    def compute(self, printing=True):
        y_hat_raw = np.concatenate(self.y_hat_raw, axis=0)
        y = np.concatenate(self.y, axis=0)

        metrics = self._compute(y_hat_raw, y, 0)

        if printing:
            self.print_metrics(metrics)

        self.reset()

        return metrics

    def print_metrics(self, metrics):
        keys = list(metrics.keys())
        values = [f'{value:.4f}' for value in list(metrics.values())]

        col_widths = [max(len(key), len(value)) + 1 for key, value in zip(keys, values)]

        keys = [' ' + key + ' ' * (col_widths[i] - len(key)) for i, key in enumerate(keys)]
        values = [' ' + value + ' ' * (col_widths[i] - len(value)) for i, value in enumerate(values)]
        key_line = '|' + '|'.join(keys) + '|'
        value_line = '|' + '|'.join(values) + '|'
        print(key_line)
        print(value_line)


class ICDCodingDevTestMetrics(ICDCodingMetrics):
    def __init__(self, ks, split):
        super().__init__(ks)
        self.split = split

    def find_threshold(self, y_hat_raw, y):
        y_hat_raw = y_hat_raw.reshape(-1)
        y = y.reshape(-1)
        sort_arg = np.argsort(y_hat_raw)
        sort_label = np.take_along_axis(y, sort_arg, axis=0)
        label_count = np.sum(sort_label)
        correct = label_count - np.cumsum(sort_label)
        predict = y.shape[0] + 1 - np.cumsum(np.ones_like(sort_label))
        f1 = 2 * correct / (predict + label_count)
        sort_yhat_raw = np.take_along_axis(y_hat_raw, sort_arg, axis=0)
        f1_argmax = np.argmax(f1)
        threshold = sort_yhat_raw[f1_argmax]
        return threshold

    def compute(self, printing=True):
        y_hat_raw = np.concatenate(self.y_hat_raw, axis=0)
        y = np.concatenate(self.y, axis=0)

        y_hat_raw_dev = y_hat_raw[:self.split]
        y_dev = y[:self.split]
        dev_metrics = self._compute(y_hat_raw_dev, y_dev, 0)

        threshold = self.find_threshold(y_hat_raw_dev, y_dev)

        y_hat_raw_test = y_hat_raw[self.split:]
        y_test = y[self.split:]
        test_metrics = self._compute(y_hat_raw_test, y_test, threshold)

        self.reset()

        if printing:
            self.print_metrics(dev_metrics)
            print(f'find threshold: {threshold}')
            self.print_metrics(test_metrics)

        return dev_metrics, test_metrics
