import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve

from utils.common import to_numpy


def draw_roc(fpr, tpr, roc_auc):
    plt.figure()
    lw = 2
    plt.plot(fpr, tpr, color='darkorange', lw=lw, label='ROC curve (area = %0.4f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic example')
    plt.legend(loc="lower right")
    # plt.savefig('roc.png')
    plt.show()


def print_ood_output(res_tar1, res_tar2, res_big_tar):
    auroc1, fpr1, aupr_in1, aupr_out1 = res_tar1['auroc'], res_tar1['fpr_at_95_tpr'], res_tar1['aupr_in'], res_tar1['aupr_out']
    auroc2, fpr2, aupr_in2, aupr_out2 = res_tar2['auroc'], res_tar2['fpr_at_95_tpr'], res_tar2['aupr_in'], res_tar2['aupr_out']
    auroc3, fpr3, aupr_in3, aupr_out3 = res_big_tar['auroc'], res_big_tar['fpr_at_95_tpr'], res_big_tar['aupr_in'], res_big_tar['aupr_out']
    print(f"SRC->TAR1:      AUROC: {auroc1:.4f}, FPR95: {fpr1:.4f}, AUPR_IN: {aupr_in1:.4f}, AUPR_OUT: {aupr_out1:.4f}")
    print(f"SRC->TAR2:      AUROC: {auroc2:.4f}, FPR95: {fpr2:.4f}, AUPR_IN: {aupr_in2:.4f}, AUPR_OUT: {aupr_out2:.4f}")
    print(f"SRC->TAR1+TAR2: AUROC: {auroc3:.4f}, FPR95: {fpr3:.4f}, AUPR_IN: {aupr_in3:.4f}, AUPR_OUT: {aupr_out3:.4f}")


def get_ood_metrics(src_scores, tar_scores, src_label=0):
    """
    Computes ood metrics given src_scores and tar_scores
    Scores can be distances, confidences, ...
    For my experiments, score is chamfer distance so a low value means sample predicted "normal", belonging to the positive class.
    Therefore, the last parameter of cal_metrics() should be tar_label.
    """
    tar_label = int(not src_label)
    src_scores = to_numpy(src_scores)
    tar_scores = to_numpy(tar_scores)
    labels = np.concatenate([np.full(src_scores.shape[0], src_label, dtype=np.int64),
                             np.full(tar_scores.shape[0], tar_label, dtype=np.int64)], axis=0)
    scores = np.concatenate([src_scores, tar_scores], axis=0)
    return calc_metrics(scores, labels, tar_label)


def auroc(preds, labels, pos_label=1):
    """
    Calculate and return the area under the ROC curve using unthresholded predictions on the data and a binary true label.
    preds: array, shape = [n_samples]
           Target normality scores, can either be probability estimates of the positive class, confidence values, or non-thresholded measure of decisions.
           i.e.: a high value means sample predicted "normal", belonging to the positive class

    labels: array, shape = [n_samples]
            True binary labels in range {0, 1} or {-1, 1}.

    pos_label: label of the positive class (1 by default)
    """
    fpr, tpr, _ = roc_curve(labels, preds, pos_label=pos_label)
    return auc(fpr, tpr)


def aupr(preds, labels, pos_label=1):
    """
    Calculate and return the area under the Precision Recall curve using unthresholded predictions on the data and a
    binary true label.
    preds: array, shape = [n_samples]
           Target normality scores, can either be probability estimates of the positive class, confidence values, or
           non-thresholded measure of decisions.
           i.e.: a high value means sample predicted "normal", belonging to the positive class

    labels: array, shape = [n_samples]
            True binary labels in range {0, 1} or {-1, 1}.

    pos_label: label of the positive class (1 by default)
    """
    precision, recall, _ = precision_recall_curve(labels, preds, pos_label=pos_label)
    return auc(recall, precision)


def fpr_at_95_tpr(preds, labels, pos_label=1):
    """
    Return the FPR when TPR is at minimum 95%.
    preds: array, shape = [n_samples]
           Target normality scores, can either be probability estimates of the positive class, confidence values, or non-thresholded measure of decisions.
           i.e.: a high value means sample predicted "normal", belonging to the positive class

    labels: array, shape = [n_samples]
            True binary labels in range {0, 1} or {-1, 1}.

    pos_label: label of the positive class (1 by default)
    """
    fpr, tpr, _ = roc_curve(labels, preds, pos_label=pos_label)

    if all(tpr < 0.95):
        # No threshold allows TPR >= 0.95
        return 0
    elif all(tpr >= 0.95):
        # All thresholds allow TPR >= 0.95, so find the lowest possible FPR
        idxes = [i for i, x in enumerate(tpr) if x >= 0.95]
        return min(map(lambda idx: fpr[idx], idxes))
    else:
        # Linear interp between values to get FPR at TPR == 0.95
        return np.interp(0.95, tpr, fpr)


def detection_error(preds, labels, pos_label=1):
    """
    Return the misclassification probability when TPR is 95%.
    preds: array, shape = [n_samples]
           Target normality scores, can either be probability estimates of the positive class, confidence values, or non-thresholded measure of decisions.
           i.e.: a high value means sample predicted "normal", belonging to the positive class

    labels: array, shape = [n_samples]
            True binary labels in range {0, 1} or {-1, 1}.

    pos_label: label of the positive class (1 by default)
    """
    fpr, tpr, _ = roc_curve(labels, preds, pos_label=pos_label)

    # Get ratios of positives to negatives
    pos_ratio = np.sum(np.array(labels) == pos_label) / len(labels)
    neg_ratio = 1 - pos_ratio

    # Get indexes of all TPR >= 95%
    idxes = [i for i, x in enumerate(tpr) if x >= 0.95]

    # Calc error for a given threshold (i.e. idx)
    # Calc is the (# of negatives * FNR) + (# of positives * FPR)
    _detection_error = lambda idx: neg_ratio * (1 - tpr[idx]) + pos_ratio * fpr[idx]

    # Return the minimum detection error such that TPR >= 0.95
    return min(map(_detection_error, idxes))


def calc_metrics(predictions, labels, pos_label=1):
    """
    Using predictions and labels, return a dictionary containing all novelty detection performance statistics.
    These metrics conform to how results are reported in the paper 'Enhancing The
    Reliability Of Out-of-Distribution Image Detection In Neural Networks'.

        preds: array, shape = [n_samples]
           Target normality scores, can either be probability estimates of the positive class, confidence values, or non-thresholded measure of decisions.
           i.e.: a high value means sample predicted "normal", belonging to the positive class

        labels: array, shape = [n_samples]
            True binary labels in range {0, 1} or {-1, 1}.

        pos_label: label of the positive class (1 by default)
    """

    return {
        'fpr_at_95_tpr': fpr_at_95_tpr(predictions, labels, pos_label=pos_label),
        'detection_error': detection_error(predictions, labels, pos_label=pos_label),
        'auroc': auroc(predictions, labels, pos_label=pos_label),
        'aupr_in': aupr(predictions, labels, pos_label=pos_label),
        'aupr_out': aupr([-a for a in predictions], [1 - a for a in labels], pos_label=pos_label)
    }
