import numpy as np
import pandas as pd
import os
import json
import copy
from tqdm import tqdm
from uncertainty.uncertainty_estimation.utils import load_sampling_results
from uncertainty.uncertainty_estimation.sar import sar, DEFAULT_SAR_CONFIG
from uncertainty.uncertainty_estimation.mi import DEFAULT_MI_CONFIG, read_mu_1, read_mu_2, calculate_mi_score 
from uncertainty.response_generator import LLM_RESULTS
from uncertainty.uncertainty_evaluation import Uncertainty_Evaluator
from uncertainty.uncertainty_estimation.semantic_entropy import semantic_entropy, DEFAULT_SEMANTIC_ENTROPY_CONFIG
from uncertainty.uncertainty_estimation.esi.estimation_func import DistributionDistance
from uncertainty.uncertainty_estimation.inside import calculate_EigenScore  


import matplotlib.pyplot as plt
from itertools import combinations
import math
from collections import defaultdict
import random
import torch

MI_DEFAULT_CONFIG = {
        "do_sample": True,
        "num_responses_per_prompt": 10,
        "temperature": 0.9,
        "top_p": 1.0,
        "output_scores": False,
        "return_normalized_transition_scores": True,
        "batch_size": 5
    }


def predictive_entropy(log_probs):
    """Compute MC estimate of entropy.

    `E[-log p(x)] ~= -1/N sum_i log p(x_i)`, i.e. the average token likelihood.
    """

    entropy = -np.sum(log_probs) / len(log_probs)

    return entropy

def extract_random_k_indexes(k, total_elements_num, max_sample_num=10):
    """
    k - int, the number of elements to sample for each example
    results List of List, sample random k results (include the first one) from second dimension 
    max_sample_num - int, the times to sample
    """
    
    assert k <= total_elements_num, f"the given elements num k {k} is larger than the total elements num {total_elements_num}"
    total_comb_num = math.comb(total_elements_num, k)
    sample_num = min(max_sample_num, total_comb_num)
    indexs_list = list(range(0,total_elements_num,1))
    if total_comb_num/sample_num <= 20:
        
        all_k_combs = list(combinations(indexs_list, k))
    
        sampled_indexes = [list(i) for i in random.sample(all_k_combs, sample_num)]
    else:
        sampled_indexes =[]
        while len(sampled_indexes) < sample_num:
            temp_index = list(random.sample(indexs_list, k))
            if temp_index not in sampled_indexes:
                sampled_indexes.append(temp_index)

    return sampled_indexes

def read_sar_score(dataset, model, result, cached_num, sample_num, sample_index=None):
    sar_config = copy.deepcopy(DEFAULT_SAR_CONFIG["generation_config"])
    sar_config["num_responses_per_prompt"] = sample_num
    cached_path = f"output/experimental_results/{dataset}/{model}"
    sar_cached_path = f"sar_sampling_results_{cached_num}_temp_1.0.json"
    generation_kwargs = copy.deepcopy(result.config["generation_config"])
    generation_kwargs.update(sar_config)

    sampling_outputs = load_sampling_results(os.path.join(cached_path, sar_cached_path), generation_kwargs, sample_index=sample_index)
    token_importance = sampling_outputs.token_importance["cross-encoder/stsb-roberta-large"]
    sim_matrix = sampling_outputs.sim_matrix["cross-encoder/stsb-roberta-large"]
    sar_score = sar(sampling_outputs.transition_scores, token_importance, sim_matrix)
    logprobs_list = [[np.mean(transition_score) for transition_score in transition_scores] for transition_scores in sampling_outputs.transition_scores]
    predictive_entropy_score = {"predictive_entropy": [predictive_entropy(logprobs).item() for logprobs in logprobs_list]}
    

    metric = "bem"
    # metric = "rougeL"
    threshold = 0.7
    labels = (np.array(result.scores[metric]) < threshold).astype(int)
    evaluator = Uncertainty_Evaluator(["auroc"])
    score = evaluator.evaluate(sar_score, labels, verbose=False)["auroc"].loc["sar"]
    pe_score = evaluator.evaluate(predictive_entropy_score, labels, verbose=False)["auroc"].loc["predictive_entropy"]
    return score, pe_score

def read_se_score(dataset, model, result, cached_num,sample_num, sample_index=None):
    se_config = copy.deepcopy(DEFAULT_SEMANTIC_ENTROPY_CONFIG["generation_config"])
    se_config["num_responses_per_prompt"] = sample_num
    cached_path = f"output/experimental_results/{dataset}/{model}"
    se_cached_path = f"se_sampling_results_{cached_num}_temp_1.0.json"
    generation_kwargs = copy.deepcopy(result.config["generation_config"])
    generation_kwargs.update(se_config)
    sampling_outputs = load_sampling_results(os.path.join(cached_path, se_cached_path), generation_kwargs, sample_index=sample_index)
    transition_scores = sampling_outputs.transition_scores
    
    semantic_ids = sampling_outputs.semantic_cluster_ids
    se_scores = semantic_entropy(transition_scores, semantic_ids)
    metric = "bem"
    # metric = "rougeL"
    threshold = 0.7
    labels = (np.array(result.scores[metric]) < threshold).astype(int)
    evaluator = Uncertainty_Evaluator(["auroc"])
    score = evaluator.evaluate(se_scores, labels, verbose=False)["auroc"].loc["semantic_entropy"]
    return score


def read_mi_score(dataset, model, result, cached_num, sample_num, sample_index=None):
    mi_config = copy.deepcopy(MI_DEFAULT_CONFIG)
    mi_config["num_responses_per_prompt"] = sample_num
    cached_path = f"output/experimental_results/mi/{dataset}/{model}"
    sampling_cached_path = f"mi_sampling_results_{cached_num}_temp_0.9.json"
    mu2_cached_path = f"mi_mu2_results_{cached_num}_temp_0.9.json"
    generation_kwargs = copy.deepcopy(result.config["generation_config"])
    generation_kwargs.update(mi_config)
    sampling_outputs = load_sampling_results(os.path.join(cached_path, sampling_cached_path), generation_kwargs, sample_index=sample_index)
    with open(os.path.join(cached_path, mu2_cached_path), "r", encoding="utf-8") as f:
        cached_mu2_result = json.load(f)
    raw_prob_list =[np.exp([np.sum(ts) for ts in ts_scores]) for ts_scores in sampling_outputs.transition_scores]
    mu_1_probs_dict_list = read_mu_1(sampling_outputs.responses, raw_prob_list)

    cached_mu2_result = read_mu_2(mu_1_probs_dict_list, sampling_outputs.queries, sampling_outputs.model_name, cached_mu_2_result=cached_mu2_result,batch_size=1)

    with open(os.path.join(cached_path, mu2_cached_path), "w", encoding="utf-8") as f:
        json.dump(cached_mu2_result, f, indent=4)
    
    mi_scores = {"mi": calculate_mi_score(sampling_outputs.queries, mu_1_probs_dict_list, cached_mu2_result)}
    metric = "bem"
    # metric = "rougeL"
    threshold = 0.7
    labels = (np.array(result.scores[metric]) < threshold).astype(int)
    evaluator = Uncertainty_Evaluator(["auroc"])
    score = evaluator.evaluate(mi_scores, labels, verbose=False)["auroc"].loc["mi"]
    return score

def read_inside_score(cached_hidden_states, result,  sample_index=None):

    sample_num = len(sample_index)
    
    hidden_states = [[h[i] for i in sample_index] for h in cached_hidden_states]
    score = {"inside": [calculate_EigenScore(hs, sample_num).item() for hs in hidden_states]}
    metric = "bem"
    # metric = "rougeL"
    threshold = 0.7
    labels = (np.array(result.scores[metric]) < threshold).astype(int)
    evaluator = Uncertainty_Evaluator(["auroc"])
    score = evaluator.evaluate(score, labels, verbose=False)["auroc"].loc["inside"]
    return score.item()

if __name__ == "__main__":
    # datasets = ["sciq", "truthfulqa", "ambigqa", "triviaqa", "coqa"]
    datasets = ["triviaqa"]
    # datasets = ["sciq", "truthfulqa", "ambigqa"]
    # datasets = ["triviaqa"]
    # models = ["llama2-chat-7b", "llama3-8b-instruct", "mistral-nemo-instruct", "llama3-70b-instruct"]
    # models = [ "llama3-8b-instruct", "mistral-nemo-instruct"]
    models = ["qwen2.5-14b-instruct"]
    # models= ["llama3.1-8b-instruct", "qwen2.5-14b-instruct"]
    # methods = ["se", "pe", "sar"]
    methods = ["mi"]
    # methods = ["inside"]
    # datasets = ["ambigqa"]
    # models = ["llama3-8b-instruct", "mistral-nemo-instruct"]
    # models = ["llama2-chat-7b"]
    sample_times = 10
    cached_num = 20
    # sample_nums = list(range(2,11,1))
    sample_nums = [10]
    result_path = "./output/experimental_results/sampling_num/"
    if "sar" in methods:
        if not os.path.exists(os.path.join(result_path, "sar")):
            os.makedirs(os.path.join(result_path, "sar"))
    
    if "pe" in methods:
        if not os.path.exists(os.path.join(result_path, "pe")):
            os.makedirs(os.path.join(result_path, "pe"))
    
    if "se" in methods:
        if not os.path.exists(os.path.join(result_path, "se")):
            os.makedirs(os.path.join(result_path, "se"))
    if "inside" in methods:
        if not os.path.exists(os.path.join(result_path, "inside")):
            os.makedirs(os.path.join(result_path, "inside"))
    if "mi" in methods:
        if not os.path.exists(os.path.join(result_path, "mi")):
            os.makedirs(os.path.join(result_path, "mi"))
    for dataset in tqdm(datasets):
        for model in models:
            cached_path = f"output/experimental_results/{dataset}/{model}"
            file_path = os.path.join(cached_path, "results.json")
            result = LLM_RESULTS.load(file_path)
            print(f"start to estimate model {model} on dataset {dataset}")
            
            
            se_scores = []
            pe_scores = []
            sar_scores = []
            mi_scores = []
            inside_scores = []
            sample_num_list = []
            if "inside" in methods:
                inside_cached_path = f"inside_sampling_results_{cached_num}_temp_0.5.json"
                with open(os.path.join(cached_path, inside_cached_path), "r") as f:
                    cached_inside_hidden_states = json.load(f)
                cached_inside_hidden_states = [cached_inside_hidden_states[p] for p in result.prompts]

            for sn in tqdm(sample_nums, desc="[sample num]"):
                sample_indexes = extract_random_k_indexes(sn, cached_num, sample_times)
                for index in sample_indexes:
                    if "sar" in methods or "pe" in methods:
                        sar_s, pe_s = read_sar_score(dataset, model, result, cached_num, sn, sample_index=index)
                        sar_scores.append(sar_s)
                        pe_scores.append(pe_s)
                        
                    if "se" in methods:
                        se_scores.append(read_se_score(dataset, model, result, cached_num,sn, sample_index=index))
                    if "inside" in methods:
                        inside_scores.append(read_inside_score(cached_inside_hidden_states, result, sample_index=index))
                    if "mi" in methods:
                        mi_scores.append(read_mi_score(dataset, model, result, cached_num,sn, sample_index=index))
                    
                    sample_num_list.append(sn)
            
            if "se" in methods:
                se_scores_dict = {
                    "scores": se_scores,
                    "sample_num":sample_num_list
                }
                se_mean_score = pd.DataFrame.from_dict(se_scores_dict).groupby("sample_num").agg({'scores':['mean','std']})
            if "sar" in methods:
                sar_scores_dict = {
                    "scores": sar_scores,
                    "sample_num":sample_num_list
                }
                sar_mean_score = pd.DataFrame.from_dict(sar_scores_dict).groupby("sample_num").agg({'scores':['mean','std']})
            
            if "pe" in methods:
                pe_scores_dict = {
                    "scores": pe_scores,
                    "sample_num":sample_num_list
                }
                pe_mean_score = pd.DataFrame.from_dict(pe_scores_dict).groupby("sample_num").agg({'scores':['mean','std']})
            
            if "mi" in methods:
                mi_scores_dict = {
                    "scores": mi_scores,
                    "sample_num":sample_num_list
                }
                mi_mean_score = pd.DataFrame.from_dict(mi_scores_dict).groupby("sample_num").agg({'scores':['mean','std']})

            if "inside" in methods:
                inside_scores_dict = {
                    "scores": inside_scores,
                    "sample_num":sample_num_list
                }
                inside_mean_score = pd.DataFrame.from_dict(inside_scores_dict).groupby("sample_num").agg({'scores':['mean','std']})

            
            if "sar" in methods:
                sar_save_path = os.path.join(result_path, f"sar/{dataset}-{model}-{cached_num}.xlsx")
                sar_mean_score.to_excel(sar_save_path)
            if "pe" in methods:
                pe_save_path = os.path.join(result_path, f"pe/{dataset}-{model}-{cached_num}.xlsx")
                pe_mean_score.to_excel(pe_save_path)
            if "se" in methods:
                se_save_path = os.path.join(result_path, f"se/{dataset}-{model}-{cached_num}.xlsx")
                se_mean_score.to_excel(se_save_path)
            if "inside" in methods:
                inside_save_path = os.path.join(result_path, f"inside/{dataset}-{model}-{cached_num}.xlsx")
                inside_mean_score.to_excel(inside_save_path)
            if "mi" in methods:
                mi_save_path = os.path.join(result_path, f"mi/{dataset}-{model}-{cached_num}.xlsx")
                mi_mean_score.to_excel(mi_save_path)

