import numpy as np
from sklearn import metrics



def confidences_auc(confidences, datasets):

    confidences = np.array(confidences)
    id_confi = confidences[0]

    auc_aver = 0.
    fpr_aver = 0.
    for (ood_confi, dataset) in zip(confidences[1:], datasets[1:]):

        auroc, fpr_95 = auc(id_confi, ood_confi)
        auc_aver += auroc
        fpr_aver +=fpr_95
        print(f"For {dataset}, AUC: {auroc}, FPR95: {fpr_95}")

    print(f"Average AUC: {auc_aver / (len(datasets)-1)}, FPR: {fpr_aver/ (len(datasets)-1)}")

def search_k(confidences, datasets, K=10):

    confidences = np.array(confidences)
    id_confi = confidences[0]
    
    for i in range(id_confi.shape[1]):
        
        auc_aver = 0.
        fpr_aver = 0.
        print(f"-------------- K is {i+1} ----------------")
        for (ood_confi, dataset) in zip(confidences[1:], datasets[1:]):

            auroc, fpr_95 = auc(id_confi[:, i], ood_confi[:, i])
            auc_aver += auroc
            fpr_aver +=fpr_95
            print(f"For {dataset}, AUC: {auroc}, FPR95: {fpr_95}")

        print(f"Average AUC: {auc_aver / (len(datasets)-1)}, FPR: {fpr_aver/ (len(datasets)-1)}")

def auc(ind_conf, ood_conf):

    conf = np.concatenate((ind_conf, ood_conf))
    ind_indicator = np.concatenate((np.ones_like(ind_conf), np.zeros_like(ood_conf)))

    fpr, tpr, _ = metrics.roc_curve(ind_indicator, conf)
    recall = 0.95
    fpr_95 = fpr[np.argmax(tpr >= recall)]
    # precision_in, recall_in, _ = metrics.precision_recall_curve(
    #     ind_indicator, conf)
    # precision_out, recall_out, _ = metrics.precision_recall_curve(
    #     1 - ind_indicator, 1 - conf)

    # print(*fpr, sep=',')
    # print(*tpr, sep=',')

    auroc = metrics.auc(fpr, tpr)
    # aupr_in = metrics.auc(recall_in, precision_in)
    # aupr_out = metrics.auc(recall_out, precision_out)

    return auroc, fpr_95
