import os
import argparse
from matplotlib.pyplot import axis
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score, precision_recall_curve, auc, roc_curve
from terminaltables import AsciiTable
import torch

def parse_args():
    '''Command instruction:
        source activate mmaction
        python experiments/compare_openness.py
    '''
    parser = argparse.ArgumentParser(description='Compare the performance of openness')
    # model config
    parser.add_argument('--base_model', default='tsm', help='the backbone model name')
    parser.add_argument('--ood_data', default='HMDB', help='the name of OOD dataset.')
    parser.add_argument('--thresholds', nargs='+', type=float, default=[-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1])
    parser.add_argument('--baseline_results', nargs='+', help='the testing results files.')
    parser.add_argument('--UOSAR', default='True', help='whether UOSAR or not')
    args = parser.parse_args()
    return args

def parse_results(result_file, method='softmax'):
    # Softmax and OpenMax
    assert os.path.exists(result_file), "File not found! Run baseline_openmax.py first to get softmax testing results!\n%s"%(result_file)
    results = np.load(result_file, allow_pickle=True)
    # parse results
    ind_labels = results['ind_label']  # (N1,)
    ood_labels = results['ood_label']  # (N2,)
    if method == 'softmax':
        ind_softmax = results['ind_softmax']  # (N1, C)
        ood_softmax = results['ood_softmax']  # (N2, C)
        return ind_softmax, ood_softmax, ind_labels, ood_labels
    elif method == 'openmax':
        ind_openmax = results['ind_openmax']  # (N1, C+1)
        ood_openmax = results['ood_openmax']  # (N2, C+1)
        return ind_openmax, ood_openmax, ind_labels, ood_labels


def eval_osr(y_true, y_pred):
    # open-set auc-roc (binary class)
    auroc = roc_auc_score(y_true, y_pred)

    # open-set auc-pr (binary class)
    # as an alternative, you may also use `ap = average_precision_score(labels, uncertains)`, which is approximate to aupr.
    precision, recall, _ = precision_recall_curve(y_true, y_pred)
    aupr = auc(recall, precision)

    # open-set fpr@95 (binary class)
    fpr, tpr, _ = roc_curve(y_true, y_pred, pos_label=1)
    operation_idx = np.abs(tpr - 0.95).argmin()
    fpr95 = fpr[operation_idx]  # FPR when TPR at 95%

    return auroc, aupr, fpr95

def eval_confidence_methods(ind_probs, ood_probs, ind_labels, ood_labels, score='max_prob', ind_ncls=101, threshold=-1, idx=0):
    # close-set accuracy (multi-class)
    ind_results = np.argmax(ind_probs, axis=1)
    ood_results = np.argmax(ood_probs, axis=1)
    acc = accuracy_score(ind_labels, ind_results)

    repeated_clss = [35, 29, 15, 26, 30, 34, 43, 31]
    index_repeated = np.zeros_like(ood_results)
    for i in repeated_clss:
        index_repeated[ood_labels == i] = 1
    index_no_repeated = 1 - index_repeated
    ood_probs = ood_probs[index_no_repeated==1]
    ood_results = ood_results[index_no_repeated==1]
    ood_labels = ood_labels[index_no_repeated==1]

    # open-set evaluation (binary class)
    if score == 'binary':
        preds = np.concatenate((ind_results, ood_results), axis=0)
        idx_pos = preds == ind_ncls
        idx_neg = preds != ind_ncls
        preds[idx_pos] = 1  # unknown class
        preds[idx_neg] = 0  # known class
    elif score == 'max_prob':
        ind_conf = np.max(ind_probs, axis=1)
        ood_conf = np.max(ood_probs, axis=1)
        confs = np.concatenate((ind_conf, ood_conf), axis=0)
        if threshold > 0:
            preds = np.concatenate((ind_results, ood_results), axis=0)
            preds[confs < threshold] = 1  # unknown class
            preds[confs >= threshold] = 0  # known class
        else:
            preds = 1 - confs

    u_ind_gt = np.zeros_like(ind_labels)
    u_ood_gt = np.ones_like(ood_labels)

    u_ind_gt_uosr = u_ind_gt.copy()
    u_ind_gt_uosr[ind_results != ind_labels] = 1

    labels_uosr = np.concatenate((u_ind_gt_uosr, u_ood_gt))
    labels_osr = np.concatenate((u_ind_gt, u_ood_gt))
    
    auroc_uosr, aupr, fpr95 = eval_osr(labels_uosr, preds)
    auroc_osr, aupr, fpr95 = eval_osr(labels_osr, preds)

    aurc, eaurc = calc_aurc_eaurc(1-preds, 1-labels_uosr)

    preds_norm = (preds - np.min(preds)) / (np.max(preds) - np.min(preds))
    ind_conf = (ind_conf- np.min(ind_conf)) / (np.max(ind_conf) - np.min(ind_conf))
    accs, confs, num_Bm, conf_intervals = eval_calibration(ind_results, ind_conf, ind_labels, M=15)
    # compute ECE
    ece = np.sum(np.abs(accs - confs) * num_Bm / np.sum(num_Bm))

    return acc, aurc, auroc_uosr, auroc_osr



def eval_uncertainty_methods(result_file, threshold=-1):
    assert os.path.exists(result_file), "File not found! Run ood_detection first!\n%s"%(result_file)
    # load the testing results
    results = np.load(result_file, allow_pickle=True)
    if "bnn" in result_file or "dear" in result_file:
        ind_uncertainties = results['ind_unctt'][:,0]  # (N1,)
        ood_uncertainties = results['ood_unctt'][:,0]  # (N2,)
    else:
        ind_uncertainties = results['ind_unctt']  # (N1,)
        ood_uncertainties = results['ood_unctt']  # (N2,)
    ind_results = results['ind_pred']  # (N1,)
    ood_results = results['ood_pred']  # (N2,)
    ind_labels = results['ind_label']
    ood_labels = results['ood_label']

    # close-set accuracy (multi-class)
    acc = accuracy_score(ind_labels, ind_results)

    repeated_clss = [35, 29, 15, 26, 30, 34, 43, 31]
    index_repeated = np.zeros_like(ood_results)
    for i in repeated_clss:
        index_repeated[ood_labels == i] = 1
    index_no_repeated = 1 - index_repeated
    ood_uncertainties = ood_uncertainties[index_no_repeated==1]
    ood_results = ood_results[index_no_repeated==1]
    ood_labels = ood_labels[index_no_repeated==1]


    # open-set evaluation (binary class)
    if threshold > 0:
        uncertain_sort = np.sort(ind_uncertainties)[::-1]
        N = ind_uncertainties.shape[0]
        topK = N - int(N * 0.85)
        threshold = uncertain_sort[topK-1]
        preds = np.concatenate((ind_results, ood_results), axis=0)
        uncertains = np.concatenate((ind_uncertainties, ood_uncertainties), axis=0)
        preds[uncertains > threshold] = 1
        preds[uncertains <= threshold] = 0
    else:
        preds = np.concatenate((ind_uncertainties, ood_uncertainties), axis=0)
    
    u_ind_gt = np.zeros_like(ind_labels)
    u_ood_gt = np.ones_like(ood_labels)

    u_ind_gt_uosr = u_ind_gt.copy()
    u_ind_gt_uosr[ind_results != ind_labels] = 1

    labels_uosr = np.concatenate((u_ind_gt_uosr, u_ood_gt))
    labels_osr = np.concatenate((u_ind_gt, u_ood_gt))
    
    auroc_uosr, aupr, fpr95 = eval_osr(labels_uosr, preds)
    auroc_osr, aupr, fpr95 = eval_osr(labels_osr, preds)

    aurc, eaurc = calc_aurc_eaurc(1-preds, 1-labels_uosr)

    preds_norm = (preds - np.min(preds)) / (np.max(preds) - np.min(preds))
    ind_uncertainties_norm = (ind_uncertainties - np.min(ind_uncertainties)) / (np.max(ind_uncertainties) - np.min(ind_uncertainties))
    accs, confs, num_Bm, conf_intervals = eval_calibration(ind_results, 1-ind_uncertainties_norm, ind_labels, M=15)
    # compute ECE
    ece = np.sum(np.abs(accs - confs) * num_Bm / np.sum(num_Bm))


    return acc, aurc, auroc_uosr, auroc_osr

def eval_calibration(predictions, confidences, labels, M=15):
    """
    M: number of bins for confidence scores
    """
    num_Bm = np.zeros((M,), dtype=np.int32)
    accs = np.zeros((M,), dtype=np.float32)
    confs = np.zeros((M,), dtype=np.float32)
    for m in range(M):
        interval = [m / M, (m+1) / M]
        Bm = np.where((confidences > interval[0]) & (confidences <= interval[1]))[0]
        if len(Bm) > 0:
            acc_bin = np.sum(predictions[Bm] == labels[Bm]) / len(Bm)
            conf_bin = np.mean(confidences[Bm])
            # gather results
            num_Bm[m] = len(Bm)
            accs[m] = acc_bin
            confs[m] = conf_bin
    conf_intervals = np.arange(0, 1, 1/M)
    return accs, confs, num_Bm, conf_intervals

def calc_aurc_eaurc(softmax, correct):

    sort_values = sorted(zip(softmax[:], correct[:]), key=lambda x:x[0], reverse=True)
    sort_softmax_max, sort_correctness = zip(*sort_values)
    risk_li, coverage_li = coverage_risk(sort_softmax_max, sort_correctness)
    aurc, eaurc = aurc_eaurc(risk_li)

    return aurc, eaurc

def coverage_risk(confidence, correctness):
    risk_list = []
    coverage_list = []
    risk = 0
    for i in range(len(confidence)):
        coverage = (i + 1) / len(confidence)
        coverage_list.append(coverage)

        if correctness[i] == 0:
            risk += 1

        risk_list.append(risk / (i + 1))

    return risk_list, coverage_list

def aurc_eaurc(risk_list):
    r = risk_list[-1]
    risk_coverage_curve_area = 0
    optimal_risk_area = r + (1 - r) * np.log(1 - r)
    for risk_value in risk_list:
        risk_coverage_curve_area += risk_value * (1 / len(risk_list))

    aurc = risk_coverage_curve_area
    eaurc = risk_coverage_curve_area - optimal_risk_area

    return aurc, eaurc


def main():

    # print(f'\nResults by using all thresholds (open-set data: {args.ood_data}, backbone: {args.base_model})')
    display_data = [["Methods", "Closed-Set ACC", "AURC", "AUROC-UOSR", "AUROC-OSR"], 
                    ["OpenMax"], ["MC Dropout"], ["BNN SVI"], ["SoftMax"], ["RPL"], ["DEAR"], ["BCE"], ["CRL"], ["DOCTOR"],
                    ["SIRC-MSP-z"], ["SIRC-MSP-res"], ["SIRC-H-z"], ["SIRC-H-res"], ["OE"], ["EB"], ["ENERGY"], ["VOS"], ["MCD"],
                    ]  # table heads and rows
    exp_dir = './tsm_video'

    for i in range(18):
        result_path = os.path.join(exp_dir, args.baseline_results[i])
        if i not in [3,4]:
            acc, aurc, auroc_uosr, auroc_osr = eval_uncertainty_methods(result_path, threshold=-1)
        else:
            ind_softmax, ood_softmax, ind_labels, ood_labels = parse_results(result_path, method='softmax')
            acc, aurc, auroc_uosr, auroc_osr = eval_confidence_methods(ind_softmax, ood_softmax, ind_labels, ood_labels, threshold=args.thresholds[i], idx=i)
        display_data[i+1].extend(["%.2f"%(acc * 100), "%.2f"%(aurc * 1000), "%.2f"%(auroc_uosr * 100), "%.2f"%(auroc_osr * 100)])

    table = AsciiTable(display_data)
    table.inner_footing_row_border = True
    table.justify_columns = {0: 'left', 1: 'center', 2: 'center', 3: 'center', 4: 'center'}
    print(table.table)
    print("\n")

if __name__ == "__main__":

    np.random.seed(123)
    args = parse_args()

    main()