import numpy as np
import os
import pickle

def expand_question(q):
    return f"Answer the following question in a single brief but complete sentence.\nQuestion: {q}\nAnswer:"

def process_raw_gt(gt):
    gt = gt.replace(" ", "")
    gt = gt.replace(".", "")
    gt = gt.lower()

    if "yes" in gt and "no" not in gt:
        return 0
    elif "yes" not in gt and "no" in gt:
        return 1
    else:
        raise NotImplementedError

def get_gt_scores(gt_corrected):

    dict_gt_scores = {}
    for q, gts in gt_corrected.items():
        gt_scores = [process_raw_gt(g) for g in gts]
        dict_gt_scores[q] = gt_scores
    return dict_gt_scores

def get_avg_scores(dict_gt_scores):

    dict_avg = {}
    for q, gts in dict_gt_scores.items():
        dict_avg[q] = np.mean(gts)
    return dict_avg


def get_dict_metric(dirname, metric):
    score_path = os.path.join(dirname, f"{metric}.pickle")
    with open(score_path, 'rb') as f:
        dict_entropy = pickle.load(f)
    return dict_entropy


def check_nans(x):
    if np.any(np.isnan(np.array(x))):
        raise ValueError("The array contains NaN values.")


def get_metric(dataset, model, temp, samples, metric, cat=1, judge="gpt-4"):

    if cat == 0:
        raise NotImplementedError
    else:
        dirname = f'./answers_{dataset}/{model.replace("/", "_")}_p_0.9_temp_{temp}_samples_{samples}'
    if metric == "gt":
        gt_path = os.path.join(dirname, f"gt_corrected_{judge}.pickle")
        with open(gt_path, 'rb') as f:
            dict_gt_corrected = pickle.load(f)
        dict_gt_scores = get_gt_scores(dict_gt_corrected)
        dict_avg_gt = get_avg_scores(dict_gt_scores)
        gt_hallu = [s for _, s in dict_avg_gt.items()]
        return gt_hallu
    else:
        dict_entropy = get_dict_metric(dirname, metric)
        entropy_hallu = [s for _, s in dict_entropy.items()]
        check_nans(entropy_hallu)
        return entropy_hallu
    

