import os, sys, inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)
import json
import numpy as np
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, roc_curve

def is_ambig(label_list):
    if 'singleAnswer' not in label_list:
        return True
    else:
        return False


filepath = 'logs/clarification/ask_conf_ambigqa.json'
with open(filepath, 'r', encoding='utf-8') as f:
    content = json.load(f)

print(len(content))

roc_list = []

refined_labels = []
ambig_askconf = []
unambig_askconf= []
askconf=[]
for i in range(len(content)):
    labels = content[i]['label']
    ambig_flag = is_ambig(labels)
    if ambig_flag:
        refined_labels.append(True)
        ambig_askconf.append(np.array(content[i]['score']).mean())
    else:
        refined_labels.append(False)
        unambig_askconf.append(np.array(content[i]['score']).mean())
    askconf.append(np.array(content[i]['score']).mean())

refined_labels = np.array(refined_labels)
ys_array = refined_labels
xs_askconf = np.array(askconf)


print("============================")
print("askconf :")


auroc_askconf = roc_auc_score(refined_labels, xs_askconf)
print("auroc askconf:", auroc_askconf)

print("askconf ambig: ", np.mean(ambig_askconf))
print("askconf unambig: ", np.mean(unambig_askconf))

all_f1s = []
all_precisions = []
all_recalls = []
thres_cdts = np.arange(-1,np.max(askconf)*100) / 100
for thres in thres_cdts:
    pred_correctness_labels = np.array([True if x > thres else False for x in xs_askconf])
    tgt_correctness_labels = ys_array
    corr_f1 = f1_score(tgt_correctness_labels, pred_correctness_labels)
    precision = precision_score(tgt_correctness_labels, pred_correctness_labels)
    recall = recall_score(tgt_correctness_labels, pred_correctness_labels)
    
    all_precisions.append(precision)
    all_recalls.append(recall)

    all_f1s.append(corr_f1)
best_idx = np.argmax(all_f1s)
print("best f1: ", np.max(all_f1s))
print('best precision: ', all_precisions[best_idx])
print("best recall: ", all_recalls[best_idx])
print("best thres: ", thres_cdts[best_idx])

best_thres = thres_cdts[best_idx]
ambig_preds = np.array([x > best_thres for x in ambig_askconf])
unambig_preds = np.array([x <= best_thres for x in unambig_askconf])


ambig_pred_acc = np.sum(ambig_preds) / len(ambig_askconf)
print("ambig acc: ", ambig_pred_acc)

unambig_pred_acc = np.sum(unambig_preds) / len(unambig_askconf)
print("unambig acc: ", unambig_pred_acc)


