import os, sys, inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)
import json
import numpy as np
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, roc_curve

def is_ambig(label_list):
    if 'singleAnswer' not in label_list:
        return True
    else:
        return False

filepath = 'logs/CLARA/ambigqa/ambigqa_clara.json'
with open(filepath, 'r', encoding='utf-8') as f:
    content = json.load(f)

print(len(content))

roc_list = []

refined_labels = []
ambig_CLARA = []
ambig_LAPLACE=[]
unambig_CLARA= []
unambig_LAPLACE=[]
ambig_CLARAoq = []
unambig_CLARAoq= []
ambig_CLARAn = []
unambig_CLARAn= []
LAPLACE=[]
CLARA=[]
CLARAoq=[]
CLARAn=[]
for i in range(len(content)):
    labels = content[i]['label']
    ambig_flag = is_ambig(labels)
    if ambig_flag:
        refined_labels.append(True)
        ambig_LAPLACE.append(np.array(content[i]['laplace']).mean())
        ambig_CLARA.append(np.log(1 + np.array(content[i]['CLARA'])).mean())
        ambig_CLARAoq.append(np.log(1 + np.array(content[i]['CLARAoq'])).mean())
        ambig_CLARAn.append(np.log(1 + np.array(content[i]['CLARAn'])).mean())
    else:
        refined_labels.append(False)
        unambig_LAPLACE.append(np.array(content[i]['laplace']).mean())
        unambig_CLARA.append(np.log(1 + np.array(content[i]['CLARA'])).mean())
        unambig_CLARAoq.append(np.log(1 + np.array(content[i]['CLARAoq'])).mean())
        unambig_CLARAn.append(np.log(1 + np.array(content[i]['CLARAn'])).mean())
    CLARA.append(np.log(1 + np.array(content[i]['CLARA'])).mean())
    CLARAoq.append(np.log(1 + np.array(content[i]['CLARAoq'])).mean())
    CLARAn.append(np.log(1 + np.array(content[i]['CLARAn'])).mean())
    LAPLACE.append(np.array(content[i]['laplace']).mean())

refined_labels = np.array(refined_labels)
ys_array = refined_labels
xs_LAPLACE=np.array(LAPLACE)
xs_CLARA = np.array(CLARA)
xs_CLARAoq = np.array(CLARAoq)
xs_CLARAn = np.array(CLARAn)

print("============================")
print("LAPLACE :")

auroc_LAPLACE = roc_auc_score(refined_labels, xs_LAPLACE)
print("auroc LAPLACE:", auroc_LAPLACE)

print("LAPLACE ambig: ", np.mean(ambig_LAPLACE))
print("LAPLACE unambig: ", np.mean(unambig_LAPLACE))

all_f1s = []
all_precisions = []
all_recalls = []
thres_cdts = np.arange(-1,np.max(LAPLACE)*100) / 100
for thres in thres_cdts:
    pred_correctness_labels = np.array([True if x > thres else False for x in xs_LAPLACE])
    tgt_correctness_labels = ys_array
    corr_f1 = f1_score(tgt_correctness_labels, pred_correctness_labels)
    precision = precision_score(tgt_correctness_labels, pred_correctness_labels)
    recall = recall_score(tgt_correctness_labels, pred_correctness_labels)
    
    all_precisions.append(precision)
    all_recalls.append(recall)

    all_f1s.append(corr_f1)
best_idx = np.argmax(all_f1s)
print("best f1: ", np.max(all_f1s))
print('best precision: ', all_precisions[best_idx])
print("best recall: ", all_recalls[best_idx])
print("best thres: ", thres_cdts[best_idx])

best_thres = thres_cdts[best_idx]
ambig_preds = np.array([x > best_thres for x in ambig_LAPLACE])
unambig_preds = np.array([x <= best_thres for x in unambig_LAPLACE])


ambig_pred_acc = np.sum(ambig_preds) / len(ambig_LAPLACE)
print("ambig acc: ", ambig_pred_acc)

unambig_pred_acc = np.sum(unambig_preds) / len(unambig_LAPLACE)
print("unambig acc: ", unambig_pred_acc)


print("============================")
print("CLARA :")


auroc_CLARA = roc_auc_score(refined_labels, xs_CLARA)
print("auroc CLARA:", auroc_CLARA)

print("CLARA ambig: ", np.mean(ambig_CLARA))
print("CLARA unambig: ", np.mean(unambig_CLARA))

all_f1s = []
all_precisions = []
all_recalls = []
thres_cdts = np.arange(-1,np.max(CLARA)*100) / 100
for thres in thres_cdts:
    pred_correctness_labels = np.array([True if x > thres else False for x in xs_CLARA])
    tgt_correctness_labels = ys_array
    corr_f1 = f1_score(tgt_correctness_labels, pred_correctness_labels)
    precision = precision_score(tgt_correctness_labels, pred_correctness_labels)
    recall = recall_score(tgt_correctness_labels, pred_correctness_labels)
    
    all_precisions.append(precision)
    all_recalls.append(recall)

    all_f1s.append(corr_f1)
best_idx = np.argmax(all_f1s)
print("best f1: ", np.max(all_f1s))
print('best precision: ', all_precisions[best_idx])
print("best recall: ", all_recalls[best_idx])
print("best thres: ", thres_cdts[best_idx])

best_thres = thres_cdts[best_idx]
ambig_preds = np.array([x > best_thres for x in ambig_CLARA])
unambig_preds = np.array([x <= best_thres for x in unambig_CLARA])


ambig_pred_acc = np.sum(ambig_preds) / len(ambig_CLARA)
print("ambig acc: ", ambig_pred_acc)

unambig_pred_acc = np.sum(unambig_preds) / len(unambig_CLARA)
print("unambig acc: ", unambig_pred_acc)



print("============================")
print("CLARAoq :")


auroc_CLARAoq = roc_auc_score(refined_labels, xs_CLARAoq)
print("auroc CLARAoq:", auroc_CLARAoq)

print("CLARAoq ambig: ", np.mean(ambig_CLARAoq))
print("CLARAoq unambig: ", np.mean(unambig_CLARAoq))

all_f1s = []
all_precisions = []
all_recalls = []
thres_cdts = np.arange(-1,np.max(CLARAoq)*100) / 100
for thres in thres_cdts:
    pred_correctness_labels = np.array([True if x > thres else False for x in xs_CLARAoq])
    tgt_correctness_labels = ys_array
    corr_f1 = f1_score(tgt_correctness_labels, pred_correctness_labels)
    precision = precision_score(tgt_correctness_labels, pred_correctness_labels)
    recall = recall_score(tgt_correctness_labels, pred_correctness_labels)
    
    all_precisions.append(precision)
    all_recalls.append(recall)

    all_f1s.append(corr_f1)

best_idx = np.argmax(all_f1s)
print("best f1: ", np.max(all_f1s))
print('best precision: ', all_precisions[best_idx])
print("best recall: ", all_recalls[best_idx])
print("best thres: ", thres_cdts[best_idx])

best_thres = thres_cdts[best_idx]
ambig_preds = np.array([x > best_thres for x in ambig_CLARAoq])
unambig_preds = np.array([x <= best_thres for x in unambig_CLARAoq])


ambig_pred_acc = np.sum(ambig_preds) / len(ambig_CLARAoq)
print("ambig acc: ", ambig_pred_acc)

unambig_pred_acc = np.sum(unambig_preds) / len(unambig_CLARAoq)
print("unambig acc: ", unambig_pred_acc)




