'''
This file defines functions to compute an optimal threshold knowing the runtime monitor predictions scores for a set of data and its ground truth 
(maximize F1, g-mean, MCC, Kappa). 
Also, some functions to evaluate metrics on a set knowing the ground truth are given (TNR@95TPR, TPR@TNR, F1, g-mean, precision, recall, specificity, accuracy, ...)
'''


import numpy as np 
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve, precision_recall_curve
from sklearn.metrics import PrecisionRecallDisplay, RocCurveDisplay
from sklearn.metrics import precision_score, recall_score, f1_score, fbeta_score, accuracy_score
from sklearn.metrics import confusion_matrix, matthews_corrcoef, cohen_kappa_score


### definition of used metrics
def get_tnr_frac_tpr_oms_new(scores, y_true, frac=0.95):
    """
    Compute TNR (True Negative Rate) when TPR (True Positive Rate) reaches a score of frac (usually 0.95).
    """
    if scores.dtype == "bool":
        raise ValueError("Scores must be continuous values, not booleans")
        
    scores_correct = scores[y_true]
    scores_wrong = scores[y_true == 0]

    scores_correct.sort()
    limit = scores_correct[int((1 - frac) * len(scores_correct))]
    
    excluded = np.count_nonzero(scores_wrong >= limit)
    total = scores_wrong.shape[0]

    tnr = 1 - (excluded / total)
    return tnr

def get_tpr_frac_tnr_oms_new(scores, y_true, frac=0.95):
    """
    Compute TPR (True Positive Rate) when TNR (True Negative Rate) reaches a score of frac (usually 0.95).
    """
    if scores.dtype == "bool":
        raise ValueError("Scores must be continuous values, not booleans")
        
    scores_posi = scores[y_true]
    scores_nega = scores[y_true == 0]
    
    scores_nega.sort()
    limit = scores_nega[int(frac * len(scores_nega))]
    
    true_posi = np.count_nonzero(scores_posi >= limit)
    total = scores_posi.shape[0]

    tpr = (true_posi / total)
    return tpr

def get_specificity_score(y_true, y_pred): 
    
    #not handle for now zero division
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    return tn / (tn+fp)

def get_optimal_threshold_f1_new(scores, y_true, beta=1, constraint=False, display=True):
    """
    Compute threshold that maximizes f1 score of the vector input corresponding to its ground truth.
    
    The thresholds chosen are taken from sklearn precision_recall_curve. 
    Also a boolean indicating if there exist constraint or not on Recall > Specificity is added.
    """
    fpr, tpr, thresholds = roc_curve(y_true, scores)

    precision, recall, thresholds = precision_recall_curve(y_true, scores)
    
    # Computing fbeta, handle case where tp = 0, leading to a null precision and null recall
    f1_scores = np.zeros_like(precision)
    nom = (1+beta**2)*recall*precision
    denom = (recall+beta**2*precision)
    np.divide(nom, denom, out=f1_scores, where=(denom != 0))
    
#     if constraint:
#         idx_constraint = np.where(tpr > (1-fpr))
#         m = np.ones(f1_scores.size, dtype=bool)
#         m[idx_constraint] = False
#         f1_scores = np.ma.array(f1_scores, mask=m)

    index_opt = np.argmax(f1_scores)
    threshold_opt = thresholds[index_opt]
    precision_opt = precision[index_opt]
    recall_opt = recall[index_opt]
    f1_opt = f1_scores[index_opt]
    
    if display:
        print("\n ...Optimization set fitting")
        print('Optimal threshold: ', threshold_opt)
        print('Optimal F{} Score: '.format(beta), f1_opt)
        print(f"Recall score {recall_opt}, Precision score {precision_opt}")
    return threshold_opt, precision_opt, recall_opt, f1_opt

def get_optimal_threshold_Gmean(scores, y_true, constraint=False, display=True):
    """
    Compute threshold that maximizes Gmean(Recall, Specificity) score of the vector input corresponding to its ground truth.
    
    The thresholds chosen are taken from sklearn roc_curve. 
    Also a boolean indicating if there exist constraint or not on Recall > Specificity is added.
    """
    fpr, tpr, thresholds = roc_curve(y_true, scores)
    gmean = np.sqrt(tpr * (1 - fpr))
    
#     if constraint:
#         idx_constraint = np.where(tpr > (1-fpr))
#         m = np.ones(gmean.size, dtype=bool)
#         m[idx_constraint] = False
#         gmean = np.ma.array(gmean, mask=m)
        
    index_opt = np.argmax(gmean)
    threshold_opt = thresholds[index_opt]
    fpr_opt = fpr[index_opt]
    tpr_opt = tpr[index_opt]
    gmean_opt = gmean[index_opt]
    
    if display:
        print("\n ...Optimization set fitting")
        print('Optimal threshold: ', threshold_opt)
        print('Optimal G-mean: ', gmean_opt)
        print(f"FPR score: {fpr_opt}, TPR score: {tpr_opt}")
    return threshold_opt, fpr_opt, tpr_opt, gmean_opt

def get_optimal_threshold_YoudenJstat(scores, y_true, constraint=False, display=True):
    """
    Compute threshold that maximizes YoudenJ's Statistic of the vector input corresponding to its ground truth.
    
    The thresholds chosen are taken from sklearn roc_curve. 
    Also a boolean indicating if there exist constraint or not on Recall > Specificity is added.
    """
    fpr, tpr, thresholds = roc_curve(y_true, scores)
    # Youden's J statistic calculation
    youdenJ = tpr - fpr
    
#     if constraint:
#         idx_constraint = np.where(tpr > (1-fpr))
#         m = np.ones(youdenJ.size, dtype=bool)
#         m[idx_constraint] = False
#         youdenJ = np.ma.array(youdenJ, mask=m)

    index_opt = np.argmax(youdenJ)
    threshold_opt = thresholds[index_opt]
    fpr_opt = fpr[index_opt]
    tpr_opt = tpr[index_opt]
    youdenJ_opt = youdenJ[index_opt]
    
    # Display
    if display:
        print("\n ...Optimization set fitting")
        print('Optimal threshold: ', threshold_opt)
        print('Optimal Youden J statistic: ', youdenJ_opt)
        print(f"FPR score: {fpr_opt}, TPR score: {tpr_opt}")
    return threshold_opt, fpr_opt, tpr_opt, youdenJ_opt


def get_optimal_threshold_MCC(scores, y_true, display=True):
    """
    Compute threshold that maximizes Matthews correlation coefficient (MCC) score of the vector input corresponding to its ground truth.
    """
    all_mcc = np.ones(len(scores)) * (-1) # init all values at -1
    
    scores.sort()
    
    distinct_thresh_idx = np.r_[np.where(np.diff(scores))[0], scores.size-1]
    for i in distinct_thresh_idx:
        y_pred = scores <= scores[i]
        all_mcc[i] = matthews_corrcoef(y_true,y_pred)

    argmax = np.argmax(all_mcc)
    threshold_opt = scores[argmax]
    mcc_opt = all_mcc[argmax]
    
    # Display
    if display:
        print("\n ...Optimization set fitting")
        print('Optimal threshold: ', threshold_opt)
        print('Optimal MCC: ', mcc_opt)

    return threshold_opt, mcc_opt

def get_optimal_threshold_kappa(scores, y_true, display=True):
    """
    Compute threshold that maximizes Cohen’s kappa statistic of the vector input corresponding to its ground truth.
    """
    all_kappa = np.ones(len(scores)) * (-1) # init all values at -1
    
    scores.sort()
    
    distinct_thresh_idx = np.r_[np.where(np.diff(scores))[0], scores.size-1]
    for i in distinct_thresh_idx:
        y_pred = scores <= scores[i]
        all_kappa[i] = cohen_kappa_score(y_true,y_pred)

    argmax = np.argmax(all_kappa)
    threshold_opt = scores[argmax]
    kappa_opt = all_kappa[argmax]
    
    # Display
    if display:
        print("\n ...Optimization set fitting")
        print('Optimal threshold: ', threshold_opt)
        print('Optimal Kappa: ', kappa_opt)

    return threshold_opt, kappa_opt

def get_threshold_fracTNR(scores, y_true, frac=0.95, display=True):
    """
    Compute threshold which gives a TNR score at frac (usually 0.95) of the vector input corresponding to its ground truth.
    """
    if scores.dtype == "bool":
        raise ValueError("Scores must be continuous values, not booleans")

    scores_posi = scores[y_true == 1]
    scores_nega = scores[y_true == 0]

    scores_nega.sort()

    threshold = scores_nega[int(frac * len(scores_nega))]
    
    return threshold

def compute_metrics_thresholdOpt_evaluationset(y_true_evaluation, y_pred_evaluation, thresholdOpt, beta=1, display=True):
    """
    Compute different metrics useful for the evaluation of the experiences. 
    Evaluation metrics include Fbeta, precision, recall, accuracy, specificity, gmean, youdenJ stat.
    """
    # Compute different metrics threshold-based
    fbeta = fbeta_score(y_true_evaluation, y_pred_evaluation, beta=beta)
    precision = precision_score(y_true_evaluation, y_pred_evaluation)
    recall = recall_score (y_true_evaluation, y_pred_evaluation)
    accuracy = accuracy_score(y_true_evaluation, y_pred_evaluation) #binary accuracy or jaccard score
    specificity = get_specificity_score(y_true_evaluation, y_pred_evaluation)
        
    gmean = np.sqrt(recall * specificity)
    youden = recall - (1-specificity) 

    # Display
    if display: 
        print("\n ...evaluation set evaluation")
        print('Optimal threshold = {} from training set is used to compute following metrics:'.format(np.round(thresholdOpt, 5)))

        print(f"F{beta}-score: {fbeta}")
        print(f"Recall score: {recall}")
        print(f"Precision score: {precision}")
        print(f"Accuracy score: {accuracy}")
        print(f"Specificity score: {specificity}")
        print('\n...Additional metrics:')
        print(f"gmean-score: {gmean}")
        print(f"youden-score: {youden}")
    return fbeta, precision, recall, accuracy, specificity, gmean, youden