import torch
import pickle
import os
import numpy as np
from tqdm import tqdm
from sklearn.metrics import recall_score, precision_recall_curve
from scipy.special import softmax

'''
    This script generates the weighted results for the language task
'''

seed = 0
num_top_selection = 150
recall_step = 0.1
all_recalls = np.arange(0, 1, recall_step)
NUM_CAL = 15    # number of calibration/test splits
NUM_TEST = 500 #size of test set
NUM_TEST_ENV = 100
TEMP = 1

#all available scoring functions
target_scoring_functions = [
    'SemanticEntropy',
    'Confidence',
    'Entropy',
    'MatrixDegreeUncertainty',
    'EccentricityUncertainty',
    'SumEigenUncertainty',
    'SentSAR',
    'PTrue',
    'SelfDetection',
    'KernelLanguageEntropy',
    'LARS',
    'Inside',
    'MARS',
    'AttentionScore',
    'MatrixDegreeConfidence',
    'EccentricityConfidence',
    'VerbalizedConfidence',
    'SAR']
#target scoring functions to be used
# target_scoring_functions = [
#     'SemanticEntropy',
#     'Confidence',
#     'PTrue',
#     'KernelLanguageEntropy',
#     'LARS',
#     'Inside',
#     'MARS',
#     'EccentricityUncertainty',
#     'EccentricityConfidence',
#     'VerbalizedConfidence',]
target_scoring_functions = [
    'Confidence',
    'MARS',
    'MatrixDegreeUncertainty'
]

np.random.seed(seed)
torch.manual_seed(seed)

#Load Data
trivia = pickle.load(open("./data/trivia_qa_meta-llama_Meta-Llama-3-8B-Instruct_2500_0_with_id.pkl", 'rb'))     # 2496 samples
gsm8k = pickle.load(open("./data/gsm8k_meta-llama_Meta-Llama-3-8B-Instruct_1.0_0_with_id.pkl", 'rb'))           # 1319 samples
# complete_data = trivia[:len(gsm8k)] + gsm8k
# Load the similarities
similarities = pickle.load(open("./data/trivia_qa_gsm8k_meta-llama_Meta-Llama-3-8B-Instruct_similarities.pkl", 'rb'))

def find_threshold(recalls, thresholds, recall):
    from bisect import bisect_left
    #reverse the arrays
    recalls = recalls[::-1]
    thresholds = thresholds[::-1]
    index = bisect_left(recalls, recall)
    if recalls[index] == recall:
        return thresholds[index]
    elif index == len(thresholds):
        return thresholds[-1]
    else:
        try:
            left_index = index - 1
            right_index = index
            #do random pick
            random_index = np.random.choice([left_index, right_index])
            return thresholds[random_index]
        except:
            print(index)
            print(recall)
            print(recalls)
            print(len(recalls))
            print(len(thresholds))
            raise ValueError('Error')
        
def get_threshold_decision_for_sample(test_sample, calibration_data, target_recall, similarities):
    res = {}
    similar_ids = test_sample["most_sim_ids"]
    calibration_correctness = [s["label"] for s in calibration_data if s["id"] in similar_ids]

    for key in target_scoring_functions:
        #get the scores of the calibration data
        if not np.isnan(test_sample[key]):
            calibration_scores = np.array([s[key] for s in calibration_data if s["id"] in similar_ids])
            calibration_scores = np.nan_to_num(calibration_scores, nan=-1e5, posinf=1e5, neginf=-1e5)

            weights = np.array([similarities[test_sample["id"]][s["id"]] for s in calibration_data if s["id"] in similar_ids])
            weights = np.append(weights, similarities[test_sample["id"]][test_sample["id"]])
            weights = softmax(weights/TEMP)[:-1]

            cal_pi = calibration_scores.argsort(0)
            cal_srt = np.take_along_axis(calibration_scores, cal_pi, axis=0)
            w_k = np.take_along_axis(weights, cal_pi, axis=0).cumsum(axis=0)
            index = w_k < (target_recall)
            index = index.sum()
            threshold = cal_srt[index]

            decision = int(test_sample[key] >= threshold)
            print(decision)
            input()
        else:
            threshold = -1e5
            decision = 0
        res[key] = [threshold, decision, test_sample["label"]]
    return res


'''
    error_results --> dict (keys are truth method names)
        semantic_entropy --> dict 
            target_recalls --> list of target recalls
            recalls --> list of recalls corresponding to each target recall
            errors --> list of |recalls - target recall|
            ARE --> single floating number of ARE
        length normalized scoring --> dict 
            target_recalls --> list of target recalls
            recalls --> list of recalls corresponding to each target recall
            errors --> list of |recalls - target recall|
            ARE --> single floating number of ARE
        ...

    ************************************************************************
        
    results --> dict (keys are truth method names)
        semantic_entropy --> dict (keys are target recall values)
            target_recall 0.01 --> list (true recall, list)
                [recall, list of [threshold, decision, label]]
            target_recall 0.02 --> list
                [recall, list of [threshold, decision, label]]
            ...

        length normalized scoring --> dict of lists
            target_recall 0.01 --> list
                [recall, list of [threshold, decision, label]]
            target_recall 0.02 --> list
                [recall, list of [threshold, decision, label]]
            ...
        ...

'''
def get_results(calibration_data, test_data, similarities):
    #get similarities of test samples to calibration data and save
    # calibration_data_ids = [s["id"] for s in calibration_data]
    # for test_sample in test_data:
    #     test_to_calib_sim = similarities[test_sample["id"]][calibration_data_ids]
    #     test_sample["most_sim_ids"] = torch.argsort(test_to_calib_sim)[-(num_top_selection+1):-1]
    
    # calibration_data_ids = np.array([s["id"] for s in calibration_data])
    calibration_data_ids = np.array([s["id"] for s in calibration_data if s["label"] == 0]) #only get the ones with label 0, hallucinations
    for test_sample in test_data:
        test_to_calib_sim = similarities[test_sample["id"]][calibration_data_ids]
        test_sample["most_sim_ids"] = calibration_data_ids[torch.argsort(test_to_calib_sim)[-(num_top_selection):]]

    #prepare the results
    results = {}
    for key in target_scoring_functions:
        results[key] = {}
        for target_recall in all_recalls:
            results[key][target_recall] = []

    #get the decisions for each sample for each threshold
    for target_recall in tqdm(all_recalls):
        for test_sample in test_data:
            res = get_threshold_decision_for_sample(test_sample, calibration_data, target_recall, similarities)
            for key in target_scoring_functions:
                results[key][target_recall].append(res[key])

        for key in target_scoring_functions:
            matrix = np.array(results[key][target_recall])
            predictions = matrix[:, 1]
            labels = matrix[:, 2]
            recall = recall_score(labels, predictions, pos_label=0)

            results[key][target_recall] = [recall, results[key][target_recall]]

    #get final error results
    error_results = {}
    for key in target_scoring_functions:
        error_results[key] = {"recalls": [], "target_recalls": []}
        for target_recall in all_recalls:
            recall, _ = results[key][target_recall]
            error_results[key]["recalls"].append(recall)
            error_results[key]["target_recalls"].append(target_recall)
        error_results[key]["errors"] = np.abs(np.array(error_results[key]["recalls"]) - np.array(error_results[key]["target_recalls"]))
        error_results[key]["ARE"] =  error_results[key]["errors"]
    return error_results, results

'''
    final_results --> dict (keys are truth method names)
        semantic_entropy --> list of dicts
            [following dict for each calib-test pair]
                target_recalls --> list of target recalls
                recalls --> list of recalls corresponding to each target recall
                errors --> list of |recalls - target recall|
                ARE --> single floating number of ARE
        length normalized scoring --> list of dicts
            [following dict for each calib-test pair]
                target_recalls --> list of target recalls
                recalls --> list of recalls corresponding to each target recall
                errors --> list of |recalls - target recall|
                ARE --> single floating number of ARE
        ...

'''

d_alpha = 0.5
rng = np.random.default_rng(seed=seed)
distribution = rng.dirichlet(alpha=[d_alpha,d_alpha], size=NUM_TEST_ENV)

if not os.path.exists(f'results_shift_{d_alpha}_weighted_{TEMP}_top{num_top_selection}'):
    os.makedirs(f'results_shift_{d_alpha}_weighted_{TEMP}_top{num_top_selection}')
with open(f'./results_shift_{d_alpha}_weighted_{TEMP}_top{num_top_selection}/test_distributions.pkl', 'wb') as f:
    pickle.dump(distribution, f)

final_results = {}
for key in target_scoring_functions:
    final_results[key] = [[] for _ in range(NUM_TEST_ENV)]

for test_i in range(NUM_TEST_ENV):
    print(f"Test environment {test_i+1}")
    for cal_i in range(NUM_CAL):
        print(f"Calibration iteration {cal_i+1}/{NUM_CAL}")

        #Randomly split data into equally sized calibration and test
        np.random.shuffle(trivia)
        np.random.shuffle(gsm8k)
        shortened_trivia = trivia[:len(gsm8k)]

        calibration_data = shortened_trivia[:-NUM_TEST] + gsm8k[:-NUM_TEST]
        temp_test_data = shortened_trivia[-int(distribution[test_i][0]*NUM_TEST):] + gsm8k[-(NUM_TEST-int(distribution[test_i][0]*NUM_TEST)):]
        
        print(len(calibration_data))
        # calibration_data = complete_data[:len(complete_data)//2]
        # temp_test_data = complete_data[len(complete_data)//2:]

        #sample from test data randomly
        np.random.shuffle(temp_test_data)
        test_data = temp_test_data[:NUM_TEST]

        print(f"{len(calibration_data)} calibration samples")
        print(f"{len(test_data)} test samples")

        error_results, _ = get_results(calibration_data, test_data, similarities)

        for key in target_scoring_functions:
            final_results[key][test_i].append(error_results[key])

    #save the results with pickle
    with open(f'./results_shift_{d_alpha}_weighted_{TEMP}_top{num_top_selection}/trivia_qa_gsm8k_meta-llama_Meta-Llama-3-8B-Instruct-{num_top_selection}_{NUM_CAL}_{NUM_TEST}_{test_i}_{seed}.pkl', 'wb') as f:
        pickle.dump(final_results, f)