from .methods import average_entropy, max_entropy, average_prob, min_prob, unnormalized_prob
from ..generation_evaluation import SemSimCalculator
from .esi import ESI_ESTIMATOR, DEFAULT_ESI_CONFIG, load_paraphrase
from .semantic_entropy import semantic_entropy, DEFAULT_SEMANTIC_ENTROPY_CONFIG
from .self_consistency import self_consistency,  DEFAULT_SELF_CON_CONFIG,spectral_clustering_metrics, spectral_clustering_metrics_plus_doc
from .predictive_entropy import predictive_entropy
from .inside import DEFAULT_INSIDE_CONFIG, calculate_inside
from .sar import DEFAULT_SAR_CONFIG, sar
from .p_true import calculate_p_true
from .ice import calculate_ice, DEFAULT_ICE_CONFIG
from .mi import DEFAULT_MI_CONFIG, read_mu_1, read_mu_2, calculate_mi_score
from ..response_generator import LLM_RESULTS, construct_hash, StandardGenerator, reshape_sequences
import copy
import itertools
import torch
from loguru import logger
from ..utils import LLM, load_data, get_gpu_memory, PromptTemplate
from .semantic_entropy.semantic_entropy_utils import EntailmentDeberta, get_semantic_ids
from .utils import load_sampling_results
from tqdm import tqdm
from itertools import chain
import json
import gc
import numpy as np

import os

LOGPROB_BASED_METHODS = {
    "mean_logprob": average_prob,
    "logprob": unnormalized_prob
}

AVAILABLE_ESTIMATION_METHODS = ["logprob", "mean_logprob", "semantic_entropy", "ice", "self_consistency", "sar", "esi", "inside", "mi", "p_true"]




def run_estimation(outputs, estimation_methods=AVAILABLE_ESTIMATION_METHODS, is_logits=True, temperature=1.0, device_name=None,batch_size=20, esi_config=DEFAULT_ESI_CONFIG, se_config=DEFAULT_SEMANTIC_ENTROPY_CONFIG, self_con_config=DEFAULT_SELF_CON_CONFIG, sar_config=DEFAULT_SAR_CONFIG, inside_config=DEFAULT_INSIDE_CONFIG, ice_config=DEFAULT_ICE_CONFIG, mi_config=DEFAULT_MI_CONFIG,sampling_only=False):
    """
    input: 
    logits - List[List[List[float]]], batch_size x token_num x vocab_size: the logits or probs of each vocab. if is_logits is False, the normalized probability should be given.
    transition_logprobs - List[List[float]], batch_size x token_num: the log prob of the generated token.
    skip_first: int or bool: if True, will skip the score of the first token. if a int number n is given, the first n token will be skipped. for example, llama-7b-chat will always generate '_' for the first token.
    skip_last: int or bool: if True, will skip the score of the last token. if a int number n is given, the last n token will be skipped. for example, llama-7b-chat will always generate '</s>' for the last token.
    """
    LLM.ddp = outputs.raw_config["ddp"]
    logits = outputs.logits["scores"]
    indexes = outputs.logits["ids"]
    transition_logprobs = outputs.transition_scores
    tokenizer = LLM.initial_tokenizer(outputs["model_name"])
    vocab_size = tokenizer.vocab_size
    estimation_scores = dict()

    for method in estimation_methods:

        logger.info(f"start to calculate scores of method '{method}'")
        assert method in AVAILABLE_ESTIMATION_METHODS, f"given estiamtion method {method} is not supported"

        
        if method in LOGPROB_BASED_METHODS:
            assert transition_logprobs is not None, f"estimation method '{method}' requries transition log probability, but None is given"
            scores = []
            
            return_prob = False
            for tran_logprobs in transition_logprobs:
                scores.append(-1 * LOGPROB_BASED_METHODS[method](tran_logprobs, return_prob=return_prob).item())
            estimation_scores[method] = scores 
        elif method == "esi":
            config = copy.deepcopy(DEFAULT_ESI_CONFIG)
            config.update(esi_config)
            sample_num = config.pop("sample_num")
            cached_result_path = config.pop("cached_result_path")
            estimator = ESI_ESTIMATOR(outputs, sample_num=sample_num, cached_result_path=cached_result_path)

            if not sampling_only:
                scores = estimator.estimate( device_name=device_name, batch_size=batch_size, **config)
                estimation_scores.update(scores)
            else:
                estimator.estimate( device_name=device_name, batch_size=batch_size, sampling_only=sampling_only, **config)
                continue
        elif method == "inside":
            generation_kwargs = copy.deepcopy(outputs.config["generation_config"])
            generation_kwargs.update(inside_config["generation_config"])
            tokenization_kwargs = copy.deepcopy(outputs.config["tokenization_config"])
            prompts = outputs.prompts
            model_name = outputs.model_name

            inside_scores = calculate_inside(outputs.prompts, outputs.model_name, generation_kwargs, tokenization_kwargs, cached_path = inside_config["cached_path"], save_path = inside_config["save_path"])
            estimation_scores.update(inside_scores)

        elif method == "semantic_entropy":


            generation_kwargs = copy.deepcopy(outputs.config["generation_config"])
            generation_kwargs.update(se_config["generation_config"])
            sampling_outputs = None
            if se_config["cached_path"] is not None:
                sampling_outputs = load_sampling_results(se_config["cached_path"], generation_kwargs)
            else: 
                logger.info("No cached path given, start to sample from scratch")
            param_name = str(se_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(se_config["generation_config"]["temperature"])
            if se_config["save_path"] is not None:
                
                if os.path.isdir(se_config["save_path"]):
                    se_save_path = os.path.join(se_config["save_path"], "se_sampling_results" + "_" + param_name + ".json")
                else:
                    assert se_config["save_path"].endswith(".json"), "save file should be end with '.json'"
                    se_save_path = se_config["save_path"]
         
            elif se_config["cached_path"] is not None:
                se_cache_dir = os.path.dirname(se_config["cached_path"])
                se_save_path = os.path.join(se_cache_dir, "se_sampling_results" + "_" + param_name + ".json")
            else:
                se_save_path = None
            if se_save_path is not None:
                logger.info(f"save sampling results to {se_save_path}")
                
            if sampling_outputs is None:
            
                prompts = outputs.prompts
                tokenization_kwargs = copy.deepcopy(outputs.config["tokenization_config"])


                sampling_outputs = LLM_RESULTS.from_dict(LLM.lm_generate(outputs.model_name, prompts, generation_kwargs, tokenization_kwargs, device_name = device_name, verbose=True))
            
                LLM.release_all()
            
                sampling_outputs.queries = outputs.queries
                sampling_outputs.queries_for_similarity = outputs.queries_for_similarity
                sampling_outputs.ground_truth = outputs.ground_truth

                if se_save_path is not None:
                    sampling_outputs.save(os.path.abspath(se_save_path))
            else:
                sampling_outputs.queries_for_similarity = outputs.queries_for_similarity
            if sampling_only:
                continue

            if outputs.queries_for_similarity is not None:
                se_queries = outputs.queries_for_similarity
            else:
                se_queries = outputs.queries

            if sampling_outputs.semantic_cluster_ids is None:
                entailment_model = EntailmentDeberta()
                se_inputs = [[ f'{q} {r}' for r in rs] for q, rs in zip(se_queries, sampling_outputs.responses)]
                semantic_ids_list = get_semantic_ids(se_inputs, entailment_model, batch_size=se_config["sim_batch_size"])
                sampling_outputs.semantic_cluster_ids = semantic_ids_list
                if se_save_path is not None:
                    sampling_outputs.save(os.path.abspath(se_save_path))
                del entailment_model
            
            estimation_scores.update(semantic_entropy(sampling_outputs.transition_scores, sampling_outputs.semantic_cluster_ids))

            del sampling_outputs
            
            gc.collect()

        elif method == "self_consistency":
            scores = []

            generation_kwargs = copy.deepcopy(outputs.config["generation_config"])
            generation_kwargs.update(self_con_config["generation_config"])

            n = generation_kwargs["num_responses_per_prompt"]
            sampling_outputs = None
            if self_con_config["cached_path"] is not None:
                
                sampling_outputs = load_sampling_results(self_con_config["cached_path"], generation_kwargs)
            else: 
                logger.info("No cached path given, start to sample from scratch")

             

            # setting save path
            param_name = str(self_con_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(self_con_config["generation_config"]["temperature"])
            if self_con_config["save_path"] is not None: 
                if os.path.isdir(self_con_config["save_path"]):
                    self_con_save_path = os.path.join(self_con_config["save_path"], "sc_sampling_results" + "_" + param_name + ".json")
                else:
                    assert self_con_config["save_path"].endswith(".json"), "save file should be end with '.json'"
                    self_con_save_path = self_con_config["save_path"] 

        
            elif self_con_config["cached_path"] is not None:
                sc_cache_dir = os.path.dirname(self_con_config["cached_path"])
                self_con_save_path = os.path.join(sc_cache_dir, "sc_sampling_results" + "_" + param_name + ".json")
            
            else:
                self_con_save_path = None
            
            logger.info(f"results will be saved to {self_con_save_path}")
            
            if sampling_outputs is None:

                prompts = outputs.prompts
                tokenization_kwargs = copy.deepcopy(outputs.config["tokenization_config"])

                sampling_outputs = LLM_RESULTS.from_dict(LLM.lm_generate(outputs.model_name, prompts, generation_kwargs, tokenization_kwargs, device_name = device_name, verbose=True))
                LLM.release_all()

                sampling_outputs.queries = outputs.queries
                sampling_outputs.queries_for_similarity = outputs.queries_for_similarity
                sampling_outputs.ground_truth = outputs.ground_truth
                if self_con_save_path is not None:
                    sampling_outputs.save(os.path.abspath(self_con_save_path))
            else:
                sampling_outputs.queries_for_similarity = outputs.queries_for_similarity
            
            if sampling_only:
                continue

                
            
            if sampling_outputs.sim_matrix is None or (self_con_config["model"] not in sampling_outputs.sim_matrix):
                
                sampling_outputs.evaluate_response_similarity(model=self_con_config["model"], device_name=device_name, batch_size=self_con_config["sim_batch_size"])
                
                if self_con_save_path is not None:
                    sampling_outputs.save(os.path.abspath(self_con_save_path))
            
            if outputs.queries_for_similarity is not None:
                sc_queries = outputs.queries_for_similarity
            else:
                sc_queries = outputs.queries
            
            key = construct_hash(outputs.responses + sampling_outputs.responses + sc_queries)

            if (sampling_outputs.sim_to_original is None) or (self_con_config["model"] not in sampling_outputs.sim_to_original) or (key not in sampling_outputs.sim_to_original[self_con_config["model"]]):
                
                
                sampling_outputs.evaluate_similarity_to_original_answers(outputs.responses, model=self_con_config["model"], device_name=device_name, batch_size=self_con_config["sim_batch_size"], queries=sc_queries)
                
                if self_con_save_path is not None:
                    sampling_outputs.save(os.path.abspath(self_con_save_path))
                
            

            # scores = self_consistency(outputs.responses,sampling_outputs.responses, self_con_config["model"], queries=outputs.queries,device_name=device_name )
            sim_scores = torch.tensor(sampling_outputs.sim_to_original[self_con_config["model"]][key])

            identifier = (sim_scores > 0.5).int()


            estimation_scores[method] = (1 - identifier.sum(dim=-1)/n).tolist()

            spectral_scores = spectral_clustering_metrics(sampling_outputs.sim_matrix[self_con_config["model"]])
            estimation_scores.update(spectral_scores)

            del sampling_outputs
            gc.collect()

        elif method == "sar":
            scores = []

            generation_kwargs = copy.deepcopy(outputs.config["generation_config"])
            generation_kwargs.update(sar_config["generation_config"])

            n = generation_kwargs["num_responses_per_prompt"]
            sampling_outputs = None
            to_cache_result = False
            if sar_config["cached_path"] is not None:
                
                sampling_outputs = load_sampling_results(sar_config["cached_path"], generation_kwargs)
            else: 
                logger.info("No cached path given, start to sample from scratch")

            # setting save path
            param_name = str(sar_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(sar_config["generation_config"]["temperature"])
            if sar_config["save_path"] is not None:
                    
                if os.path.isdir(sar_config["save_path"]):
                    sar_save_path = os.path.join(sar_config["save_path"], "sar_sampling_results" + "_" + param_name + ".json")
                else:
                    assert sar_config["save_path"].endswith(".json"), "save file should be end with '.json'"
                    sar_save_path = sar_config["save_path"] 
                
            elif sar_config["cached_path"] is not None:
                sar_cache_dir = os.path.dirname(sar_config["cached_path"])
                sar_save_path = os.path.join(sar_cache_dir, "sar_sampling_results" + "_" + param_name + ".json")
            else:
                sar_save_path = None

            if sar_save_path is not None:
                logger.info(f"sampling results will be saved to {sar_save_path}")
            
            if sampling_outputs is None:
                prompts = outputs.prompts
                tokenization_kwargs = copy.deepcopy(outputs.config["tokenization_config"])

                sampling_outputs = LLM_RESULTS.from_dict(LLM.lm_generate(outputs.model_name, prompts, generation_kwargs, tokenization_kwargs, device_name = device_name, verbose=True))
                LLM.release_all()

                sampling_outputs.queries = outputs.queries
                sampling_outputs.queries_for_similarity = outputs.queries_for_similarity
                sampling_outputs.ground_truth = outputs.ground_truth

                if sar_save_path is not None:
                    sampling_outputs.save(os.path.abspath(sar_save_path))
            else:
                sampling_outputs.queries_for_similarity = outputs.queries_for_similarity
            
            if sampling_only:
                continue
            
            if (sampling_outputs.token_importance) is None or (sar_config["token_importance_model"] not in sampling_outputs.token_importance):
                
                
                sampling_outputs.evaluate_token_importance(model=sar_config["token_importance_model"],device_name=device_name, batch_size=sar_config["sim_batch_size"])
                if sar_save_path is not None:
                    sampling_outputs.save(os.path.abspath(sar_save_path))
                
            
            
            if sampling_outputs.sim_matrix is None or (sar_config["sentence_similarity_model"] not in sampling_outputs.sim_matrix):
                
                sampling_outputs.evaluate_response_similarity(model=sar_config["sentence_similarity_model"], device_name=device_name, batch_size=sar_config["sim_batch_size"])
                if sar_save_path is not None:
                    sampling_outputs.save(os.path.abspath(sar_save_path))
                # print(f"after computing sar responses similarity: Current GPU Memory Usage: {float(get_gpu_memory())/1024:.2f} GB")
                
            
            token_importance = sampling_outputs.token_importance[sar_config["token_importance_model"]]
            sim_matrix = sampling_outputs.sim_matrix[sar_config["sentence_similarity_model"]]
            

            estimation_scores.update(sar(sampling_outputs.transition_scores, token_importance, sim_matrix))

            logprobs_list = [[np.mean(ts) for ts in tss] for tss in sampling_outputs.transition_scores]
            predictive_entropy_score = {"predictive_entropy": [predictive_entropy(logprobs).item() for logprobs in logprobs_list]}
            estimation_scores.update(predictive_entropy_score)

            del sampling_outputs
            gc.collect()
        
        elif method == "ice":
            generation_kwargs = copy.deepcopy(outputs.config["generation_config"])
            generation_kwargs.update(ice_config["generation_config"])
    
            sampling_outputs = None
            if ice_config["cached_path"] is not None:
                sampling_outputs = load_sampling_results(ice_config["cached_path"], generation_kwargs)
            else: 
                logger.info("No cached path given, start to sample from scratch")

            param_name = str(ice_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(ice_config["generation_config"]["temperature"])
            if ice_config["save_path"] is not None:
                
                if os.path.isdir(ice_config["save_path"]):
                    ice_save_path = os.path.join(ice_config["save_path"], "ice_sampling_results" + "_" + param_name + ".json")
                else:
                    assert ice_config["save_path"].endswith(".json"), "save file should be end with '.json'"
                    ice_save_path = ice_config["save_path"] 
        
            elif ice_config["cached_path"] is not None:
                
                ice_cache_dir = os.path.dirname(ice_config["cached_path"])
                ice_save_path = os.path.join(ice_cache_dir, "ice_sampling_results" + "_" + param_name + ".json")
            else:
                ice_save_path = None
            if ice_save_path is not None:
                logger.info(f"save sampling results to {ice_save_path}")
                
            if sampling_outputs is None:
                
                ice_queries = outputs.queries
                paraphrased_queries_mapping = load_paraphrase(ice_config["paraphrase_path"], ice_config["paraphrase_num"])

                paraphrased_queries = list(chain(*[paraphrased_queries_mapping[q.strip()] for q in ice_queries]))

                template_config = getattr(outputs, "raw_config", None)
                if template_config is None:
                    template_config = {
                        "verbose": False,
                        "system_id": None,
                        "template_id":2,
                        "generate_kwargs": dict()
                    }

                template_config["model_name"] = outputs["model_name"]
                if template_config["system_id"] == 0:
                    template_config["system_id"] = None
                prompt_template = StandardGenerator(template_config).prompt_template
                tokenization_kwargs = copy.deepcopy(outputs.config["tokenization_config"])
                prompts = [prompt_template.build_prompt({"query":q}) for q in paraphrased_queries]


                sampling_outputs = LLM_RESULTS.from_dict(LLM.lm_generate(outputs.model_name, prompts, generation_kwargs, tokenization_kwargs, device_name = device_name, verbose=True))
            
                LLM.release_all()

                if ice_save_path is not None:
                    sampling_outputs.save(os.path.abspath(ice_save_path))
            
               

            if sampling_only:
                continue

            if outputs.queries_for_similarity is not None:
                ice_queries = outputs.queries_for_similarity
            else:
                ice_queries = outputs.queries

            
            

            ice_responses = reshape_sequences(sampling_outputs["responses"], ice_config["paraphrase_num"])

            if sampling_outputs.semantic_cluster_ids is None:
                entailment_model = EntailmentDeberta()
                ice_inputs = [[ f'{q} {r}' for r in chain(*rs)] for q, rs in zip(ice_queries, ice_responses)]
                semantic_ids_list = get_semantic_ids(ice_inputs, entailment_model, batch_size=ice_config["sim_batch_size"])
                semantic_ids_list = [reshape_sequences(si, ice_config["generation_config"]["num_responses_per_prompt"] ) for si in semantic_ids_list]
                 
                sampling_outputs.semantic_cluster_ids = list(chain(*semantic_ids_list))
                if ice_save_path is not None:
                    sampling_outputs.save(os.path.abspath(ice_save_path))
                del entailment_model
            else:
                semantic_ids_list = reshape_sequences(sampling_outputs.semantic_cluster_ids, ice_config["paraphrase_num"])
            

            estimation_scores.update(calculate_ice(semantic_ids_list))
            

    

            del sampling_outputs
            
            gc.collect()
        elif method == "mi":


            generation_kwargs = copy.deepcopy(outputs.config["generation_config"])
            generation_kwargs.update(mi_config["generation_config"])
            sampling_outputs = None
            if mi_config["cached_path"] is not None:
                sampling_outputs = load_sampling_results(mi_config["cached_path"], generation_kwargs)
            else: 
                logger.info("No cached path given, start to sample from scratch")
            
            if mi_config["cached_mu2_path"] is not None:
                with open(mi_config["cached_mu2_path"], "r", encoding="utf-8") as f:
                    cached_mu2_result = json.load(f)
            else: 
                cached_mu2_result = None
                logger.info("No cached mu2 results given, start to read from scratch")
            param_name = str(mi_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(mi_config["generation_config"]["temperature"])
            if mi_config["save_path"] is not None:
                
                if os.path.isdir(mi_config["save_path"]):
                    mi_save_path = os.path.join(mi_config["save_path"], "mi_sampling_results" + "_" + param_name + ".json")
                else:
                    assert mi_config["save_path"].endswith(".json"), "save file should be end with '.json'"
                    mi_save_path = mi_config["save_path"] 
                logger.info(f"save sampling results to {mi_save_path}")
            elif  mi_config["cached_path"] is not None:
                mi_cache_dir = os.path.dirname(mi_config["cached_path"])
                mi_save_path = os.path.join(mi_cache_dir, "mi_sampling_results" + "_" + param_name + ".json")
                logger.info(f"save sampling results to {mi_save_path}")
                
            if sampling_outputs is None:
            
                prompts = outputs.prompts
                tokenization_kwargs = copy.deepcopy(outputs.config["tokenization_config"])


                sampling_outputs = LLM_RESULTS.from_dict(LLM.lm_generate(outputs.model_name, prompts, generation_kwargs, tokenization_kwargs, device_name = device_name, verbose=True))
            
                LLM.release_all()
            
                sampling_outputs.queries = outputs.queries
                sampling_outputs.queries_for_similarity = outputs.queries_for_similarity
                sampling_outputs.ground_truth = outputs.ground_truth

                
                sampling_outputs.save(os.path.abspath(mi_save_path))
            else:
                sampling_outputs.queries_for_similarity = outputs.queries_for_similarity
            

            raw_prob_list =[np.exp([np.sum(ts) for ts in ts_scores]) for ts_scores in sampling_outputs.transition_scores]
            mu_1_probs_dict_list = read_mu_1(sampling_outputs.responses, raw_prob_list)

            cached_mu2_result = read_mu_2(mu_1_probs_dict_list, sampling_outputs.queries, sampling_outputs.model_name, cached_mu_2_result=cached_mu2_result, device_name=device_name, batch_size=mi_config["generation_config"]["batch_size"]*5)

            if mi_config["save_mu2_path"] is not None:
                
                if os.path.isdir(mi_config["save_mu2_path"]):
                    mu2_save_path = os.path.join(mi_config["save_mu2_path"], "mi_mu2_results" + "_" + param_name + ".json")
                else:
                    assert mi_config["save_mu2_path"].endswith(".json"), "save file should be end with '.json'"
                    mu2_save_path = mi_config["save_mu2_path"] 
                logger.info(f"save condititonal prob mu2 results to {mu2_save_path}")
            elif  mi_config["cached_mu2_path"] is not None:
                mi_cache_dir = os.path.dirname(mi_config["cached_mu2_path"])
                mu2_save_path = os.path.join(mi_cache_dir, "mi_mu2_results" + "_" + param_name + ".json")
                logger.info(f"save condititonal prob mu2 results to {mu2_save_path}")
            
            with open(mu2_save_path, "w", encoding="utf-8") as f:
                json.dump(cached_mu2_result, f, indent=4)
            
            logger.info(f" condititonal prob mu2 results saved to {mu2_save_path}")

            if sampling_only:
                continue
            
            estimation_scores.update({"mi_score": calculate_mi_score(sampling_outputs.queries, mu_1_probs_dict_list, cached_mu2_result)})


            del sampling_outputs
            
            gc.collect()

        elif method == "p_true":
            tokenization_kwargs = copy.deepcopy(outputs.config["tokenization_config"])
            model_name = LLM.initial_lm(outputs.model_name, device_name = device_name, verbose=True, tokenizer_kwargs=tokenization_kwargs)
            model, tokenizer = LLM.loaded_llms[model_name]
            responses = outputs.responses
            queries = outputs.queries

            scores = calculate_p_true(queries, responses, model, tokenizer)
            estimation_scores[method] = scores 

            LLM.release_all()


    

    if not sampling_only:
        return estimation_scores   
    else:
        return None




    

