from uncertainty.response_generator import LLM_RESULTS, StandardGenerator
from uncertainty.utils import LLM
from uncertainty.uncertainty_evaluation import Uncertainty_Evaluator
from uncertainty.uncertainty_estimation.utils import entropy, expand_truncated_logits, kl_div, hellinger_distance, Bhatacharyya_distance
from uncertainty.uncertainty_estimation.esi.estimation_func import DistributionDistance
from uncertainty.uncertainty_estimation.methods import average_entropy,average_prob
from tqdm import tqdm
import numpy as np
import pandas as pd
import math
import torch
import os
import json
import random
from itertools import combinations
import math
from collections import defaultdict
import argparse
import psutil
import gc




def key(prompt, response, model_name):
    return prompt+"[sep]"+response+"[sep]" + model_name

def read_from_cache(prompts, responses, cached_result):
    logits ={
            "input_ids": [],
            "logits": {
             "scores": [],
                "ids": []
                },
            "transition_scores": []
    }
    for p, r in zip(prompts, responses):
        d = cached_result["data"][key(p,r, cached_result["model_name"])]
        for k,v in d.items():
            if isinstance(v, dict):
                for k_v, v_v in v.items():
                    logits[k][k_v].append(v_v)
            else:
                logits[k].append(v)
    return logits

def expanded_logits_generator(logits,indexes, vocab_size):
    for logit, index in zip(logits,indexes):
        yield expand_truncated_logits(logit, indexes=index, vocab_size=vocab_size, expand_to_max_acceptable_size=True)


def pop_random_k_indexes(k, total_elements_num, max_sample_num=10):
    """
    k - int, the number of elements to sample for each example
    results List of List, sample random k results (include the first one) from second dimension 
    max_sample_num - int, the times to sample
    """
    
    assert k <= total_elements_num, f"the given elements num k {k} is larger than the total elements num {total_elements_num}"
    total_comb_num = math.comb(total_elements_num, k)
    sample_num = min(max_sample_num, total_comb_num)
    indexs_list = list(range(1,total_elements_num+1,1))
    if total_comb_num/sample_num <= 20:
        
        all_k_combs = list(combinations(indexs_list, k))
    
        sampled_indexes = [[0] + list(i) for i in random.sample(all_k_combs, sample_num)]
    else:
        sampled_indexes =[]
        while len(sampled_indexes) < sample_num:
            temp_index = [0] + list(random.sample(indexs_list, k))
            if temp_index not in sampled_indexes:
                sampled_indexes.append(temp_index)

    return sampled_indexes
    
    
        
        

def compute_scores_for_different_k(logits, indexes, labels, max_sample_num=10, max_search_n = 40, measure = "hellinger", k = None):
    print("start to do k search.")
    # expanded_logits = [expand_truncated_logits(logit, indexes=index, vocab_size=tokenizer.vocab_size, expand_to_max_acceptable_size=True) for logit, index in tqdm(zip(logits,indexes))]
    entropy_weight = []
    for l, i in zip(logits, indexes):
        entropy_weight.append(entropy(l[0], index = i[0]))
    
    total_elements_num = len(logits[0]) - 1
    evaluator = Uncertainty_Evaluator(["auroc"])
    result_scores = defaultdict(list)
    if max_search_n > total_elements_num:
        max_search_n = total_elements_num
    if k is None:
        if max_search_n <= 30:
            k_indexes = list(range(2, max_search_n + 1 ,2))
        else:
            k_indexes=  list(range(2, 30 +1 ,2)) + [30 + i for i in list(range(5, max_search_n-29, 5))]
    else:
        if isinstance(k, int):
            k_indexes = [k]
        else:
            k_indexes = k

    sample_indexes_list = [pop_random_k_indexes(k, total_elements_num,max_sample_num=max_sample_num) for k in k_indexes]
    score_dict = dict()
    for i, k in enumerate(k_indexes):
        score_dict[k] = [[] for _ in range(len(sample_indexes_list[i]))]
    
    for ex_id, (logit, index) in tqdm(enumerate(zip(logits, indexes))):
        expand_logit = expand_truncated_logits(logit, indexes=index, expand_to_max_acceptable_size=True)
        for k, sample_indexes in zip(k_indexes,sample_indexes_list):
            for sample_id, s_index in enumerate(sample_indexes):
                sample_logit = torch.cat([expand_logit[i].unsqueeze(0) for i in s_index])
                
                score = DistributionDistance(sample_logit, entropy_weight[ex_id], distance_measure=measure)["mean"]
                score_dict[k][sample_id].append(score)
                
    
    for k, scores_list in score_dict.items():
        for scores in scores_list:
            score = evaluator.evaluate(scores, labels, verbose=False)["auroc"].loc["method"]
            result_scores["score"].append(score)
            result_scores["element_num"].append(k)
    
    return result_scores

DISTRIBUTION_DISTANCE_FUNC_MAPPING = {
    "kl": kl_div,
    "hellinger": hellinger_distance,
    "Bhatacharyya":Bhatacharyya_distance,
}

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='search different sample num')
    parser.add_argument('-d', '--dataset', type=str, help='dataset name')
    parser.add_argument('-m', '--model', type=str, help='model name')
    parser.add_argument("-p", "--percent", type=float, default=0.5, help='percentage of tokens to be intervened')
    parser.add_argument('-a', '--augfunc', default="skip", type=str, help='texta ugmentation method used')

    parser.add_argument("-n", "--cached_sample_num", type=int, default=40, help='saved sample num in cached file')

    parser.add_argument("-sn", "--search_num", type=int, default=10, help='max sample times for each k searching')
    
    parser.add_argument("-k", "--sample_num", type=int, help='the specific k to search')
    parser.add_argument("--max_sample_num", type=int, help='the max k to search')
    
    args, _ = parser.parse_known_args()
    dataset = args.dataset
    model = args.model
    if args.augfunc == "paraphrasellama":
        aug_func = "paraphrase"
        augfunc_name = "paraphrase_llama"
    else:
        

        aug_func = args.augfunc
        augfunc_name = args.augfunc

    
    percent = args.percent
    sample_num = args.cached_sample_num
    if args.sample_num is not None:
        k=  args.sample_num
    else:
        k = None
    
    if args.max_sample_num is None:
        max_k=  sample_num
    else:
        max_k = args.max_sample_num
    
    data_path = f"./output/experimental_results/{dataset}/{model}"
    result_path = f"./output/experimental_results/sampling_num/esi/{dataset}-{model}"
    print(f"start to load results for model {model}, dataset {dataset} and percent {percent}")
    if not os.path.exists(result_path):
        os.makedirs(result_path)
    cached_generation_path = os.path.join(data_path,"results.json")
    results = LLM_RESULTS.load(cached_generation_path)
    # results.evaluate_correctness("rougeL")
    tokenizer = LLM.initial_tokenizer(results.model_name)
    if augfunc_name == "paraphrase_llama":
        test_aug_file_name = "llama3_paraphrase/esi" + "_" + aug_func + "_" + str(sample_num) + "_" + str(percent) + ".json"
    else:
        test_aug_file_name = "esi" + "_" + aug_func + "_" + str(sample_num) + "_" + str(percent) + ".json"
    cached_test_aug_path = os.path.join(data_path, test_aug_file_name)
    with open(cached_test_aug_path, "r", encoding="utf-8") as f:
        ta_cached_result = json.load(f)
    results.evaluate_correctness("rougeL")
    
    print("results loaded")
    prompts = [r for r in results.prompts]
    responses = [r for r in results.responses]

   
    ta_logits = read_from_cache(prompts, responses, ta_cached_result)
    logits = ta_logits["logits"]["scores"]
    indexes = ta_logits["logits"]["ids"]
    transition_scores = ta_logits["transition_scores"]

    metric = "bem"
    # metric = "rougeL"
    threshold = 0.7
    truthlabel = (np.array(results.scores[metric]) < threshold).astype(int)
    k_scores = compute_scores_for_different_k(logits, indexes, truthlabel, max_search_n = max_k, k=k)
    result_score = pd.DataFrame.from_dict(k_scores).groupby("element_num").agg({'score':['mean','std']})
    save_path = os.path.join(result_path, augfunc_name + "_" + str(sample_num) + "_" + str(percent) + ".xlsx")
    print(f"results saved to {save_path}")
    result_score.to_excel(save_path)
    
