import sklearn 
import numpy as np

def roc_curve(id_scores,ood_scores):
    id_labels = np.ones(len(id_scores))
    ood_labels = np.zeros(len(ood_scores))
    
    _scores = np.concatenate((id_scores, ood_scores))
    _labels = np.concatenate((id_labels, ood_labels))

    fpr, tpr, thr = sklearn.metrics.roc_curve(_labels, _scores)

    return fpr, tpr, thr

def auc(fpr, tpr):
    return sklearn.metrics.auc(fpr, tpr)

def fpr95(fpr, tpr, thr):
    if all(tpr < 0.95):
        raise ValueError(f"No threshold allows for TPR 95%.")
        
    idxs = [i for i, x in enumerate(tpr) if x >= 0.95]
    if len(idxs) == 0:
        idx = 0
    else:
        idx = min(idxs)
    return float(fpr[idx]), float(tpr[idx]), float(thr[idx])

def tnr05(fpr, tpr, thr):
    idxs = [i for i, x in enumerate(tpr) if x >= 0.05]
    if len(idxs) == 0:
        idx = 0
    else:
        idx = min(idxs)
    return float(1. - fpr[idx]), float(1. - tpr[idx]), float(thr[idx])
    
    
def fnr95(fpr, tpr, thr):
    tnr = 1. - fpr
    fnr = 1. - tpr
    
    if all(tnr < 0.95):
        raise ValueError(f"No threshold allows for TNR 95%.")
    idxs = [i for i, x in enumerate(tnr) if x >= 0.95]
    idx = min(idxs)
    return float(fnr[idx]), float(tnr[idx]), float(thr[idx])
    