from uncertainty.uncertainty_estimation import run_estimation
from uncertainty.uncertainty_evaluation import Uncertainty_Evaluator    
import argparse
import torch
import json
import os
import numpy as np
import random
import debugpy
import copy
from uncertainty.uncertainty_estimation import DEFAULT_SEMANTIC_ENTROPY_CONFIG, DEFAULT_SAR_CONFIG, DEFAULT_SELF_CON_CONFIG, DEFAULT_INSIDE_CONFIG, DEFAULT_ICE_CONFIG, DEFAULT_MI_CONFIG
from uncertainty import LLM_RESULTS
from loguru import logger
import time
from tqdm import tqdm
if __name__ == '__main__':
    # parse args
    parser = argparse.ArgumentParser(description='a project on uncertainty estimation')
    parser.add_argument('-d', '--debug', action='store_true',help='use valid dataset to debug your system')
    parser.add_argument('--seed', type=int, default=42, help='seed')
    parser.add_argument('-c', '--cached_result_path', type=str, help='path cached the outputs of generation')
    parser.add_argument('-o', '--output_dir', type=str, help='the directory to save the result')
    

    parser.add_argument("-m", "--correctness_metric", type=str, default="bem", help='metrics to evaluate correctness')
    parser.add_argument("-t", "--correctness_threshold", type=float, default=0.7, help='threshold to evaluate correctness')
    parser.add_argument("-b", "--batch_size", type=int, default=20, help='batch size for reading logits')
    parser.add_argument("-n", "--sample_num", type=int, default=10, help='num of samples per prompt for esi, i.e. the number of intervened variants')
    parser.add_argument("-f", "--aug_func", type=str, default="skip", help='function to intervene the prompts')
    parser.add_argument("-p", "--percent", type=float, default=0.3, help='percentage of tokens to be intervened')
    parser.add_argument("-ct", "--cached_path_for_esi", type=str, help='path cached the intermediate results of esi')
    parser.add_argument("-pp", "--paraphrase_result_path", type=str, help='path cached the paraphased queries')
    parser.add_argument('-ot', '--dir_to_cache_for_esi', type=str, help='the directory to save the itermediate result of esi')
    parser.add_argument("--num_scores_returned", type=int, default=100, help='num of logits cached')
    parser.add_argument('--store_score', action='store_true', help='whether to store the uncertainty scores')
    parser.add_argument('--sampling_only', action='store_true', help='whether to sampling responses only and remain score calculation for further computation')
    parser.add_argument('--esi_only', action='store_true', help='whether to compute ESI score only')
    parser.add_argument('--evaluate_method', type=str, help='the single method name used to evaluate uncertainty score')
    parser.add_argument('--sim_batch_size', type=int, default=256, help='batch size for computing semantic similarity')
    
    #args for semantic entropy
    parser.add_argument("--se_cached_path", type=str, help='path cached the sampling results for Semantic entropy ')
    parser.add_argument("--se_save_path", type=str, help='path to save the sampling outputs of Semantic entropy')
    parser.add_argument("--se_temperature", type=float, help='temperature for semantic entropy sampling')
    parser.add_argument("--se_n", type=int, help='number of sampling sentence for computing semantic entropy')
    parser.add_argument("--se_batch_size", type=int, help='batch_size when sampling')

    #args for input clarification ensemble
    parser.add_argument("--ice_cached_path", type=str, help='path cached the sampling results for ICE')
    parser.add_argument("--ice_save_path", type=str, help='path to save the sampling outputs of ICE')
    parser.add_argument("--ice_batch_size", type=int, help='batch_size when sampling')


    #args for self-con
    parser.add_argument("--sc_cached_path", type=str, help='path cached the sampling results for self-consistency')
    parser.add_argument("--sc_save_path", type=str, help='path to save the sampling outputs of self-consistency')
    parser.add_argument("--sc_temperature", type=float, help='temperature for self-consistency sampling')
    parser.add_argument("--sc_n", type=int, help='number of sampling sentence for computing self-consistency')
    parser.add_argument("--sc_model", type=str, help='model used to compute semantic similarity for computing self-consistency')
    parser.add_argument("--sc_batch_size", type=int, help='batch size when sampling sentences for computing semantic entropy')
    #args for sar
    parser.add_argument("--sar_cached_path", type=str, help='path cached the sampling results for SAR')
    parser.add_argument("--sar_save_path", type=str, help='path to save the sampling outputs of SAR')
    parser.add_argument("--sar_temperature", type=float, help='temperature for SAR sampling')
    parser.add_argument("--sar_n", type=int, help='number of sampling sentence for computing SAR')
    parser.add_argument("--sar_token_model", type=str, help='model used to compute token_wise_importance for computing SAR')
    parser.add_argument("--sar_sentence_model", type=str, help='model used to compute sentence similarity for computing SAR')
    parser.add_argument("--sar_batch_size", type=int, help='batch size when sampling sentences for computing sar')

    #args for inside
    parser.add_argument("--inside_cached_path", type=str, help='path cached the sampling hidden states for INSIDE')
    parser.add_argument("--inside_save_path", type=str, help='path to save the sampling hidden states of INSIDE')
    parser.add_argument("--inside_n", type=int, help='number of sampling responses for computing INSIDE')
    parser.add_argument("--inside_batch_size", type=int, help='batch size when sampling sentences for computing inside')

    #args for mi
    parser.add_argument("--mi_cached_path", type=str, help='path cached the sampling results for MI ')
    parser.add_argument("--mi_cached_mu2_path", type=str, help='path cached the conditional prob mu2 for MI ')
    parser.add_argument("--mi_save_path", type=str, help='path to save the sampling outputs of MI')
    parser.add_argument("--mi_temperature", type=float, help='temperature for MI sampling')
    parser.add_argument("--mi_n", type=int, help='number of sampling sentence for computing MI')
    parser.add_argument("--mi_batch_size", type=int, help='batch_size when sampling')
    
    parser.add_argument("-l", "--log_name", type=str, help='name of the log file')

    args, _ = parser.parse_known_args()

    if args.debug:
        debugpy.listen(("0.0.0.0", 14328))
        print("listen ready")
        debugpy.wait_for_client()
    
    # set seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    
    # read generation results
    cached_results = LLM_RESULTS.load(args.cached_result_path)
    # cached_results = LLM_RESULTS.from_records(cached_results.to_records()[:5])
    model_name = cached_results.model_name

    
    if args.evaluate_method is None:
        if args.esi_only:
            estimation_methods = ['esi']
        else:
            estimation_methods = ["logprob", "mean_logprob", "esi", "sar","semantic_entropy", "inside", "mi" ]

            
    else:
        assert args.evaluate_method in ["logprob", "mean_logprob", "sar", "semantic_entropy", "self_consistency", "inside", "ice", "esi", "mi", "p_true"], f"the given method {args.evaluate_method} is not supported"
        estimation_methods = [args.evaluate_method]

    
    # set config for sar methods
    if "sar" in estimation_methods:
        sar_config = copy.deepcopy(DEFAULT_SAR_CONFIG)
        sar_config["sim_batch_size"] = args.sim_batch_size
        if args.sar_token_model is not None:
            sar_config["token_importance_model"] = args.sar_token_model
        if args.sar_sentence_model is not None:
            sar_config["sentence_similarity_model"] = args.sar_sentence_model
        if args.sar_n is not None:
            sar_config["generation_config"]["num_responses_per_prompt"] = args.sar_n
        if args.sar_temperature is not None:
            sar_config["generation_config"]["temperature"] = args.sar_temperature
        if args.sar_batch_size is not None:
            sar_config["generation_config"]["batch_size"] = args.sar_batch_size

        if args.sar_save_path is not None:
            if os.path.isdir(args.sar_save_path):
                file_name = "sar_sampling_results" + "_" + str(sar_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(sar_config["generation_config"]["temperature"]) + ".json"
                save_path_for_sar = os.path.join(args.sar_save_path, file_name)
            else:
                save_path_for_sar = args.sar_save_path   
        else:
            save_path_for_sar = None

        if args.sar_cached_path is not None:
            if os.path.isdir(args.sar_cached_path):
                file_name = "sar_sampling_results" + "_" + str(sar_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(sar_config["generation_config"]["temperature"]) + ".json"
                cached_path_for_sar = os.path.join(args.sar_cached_path, file_name)
            else:
                cached_path_for_sar = args.sar_cached_path  

        else:
            cached_path_for_sar = None

        sar_config["cached_path"] = cached_path_for_sar
        sar_config["save_path"] = save_path_for_sar
    else:
        sar_config = copy.deepcopy(DEFAULT_SAR_CONFIG)

    # set config for semantic entropy methods
    if "semantic_entropy" in estimation_methods:
        se_config = copy.deepcopy(DEFAULT_SEMANTIC_ENTROPY_CONFIG)
        se_config["sim_batch_size"] = args.sim_batch_size
        if args.se_n is not None:
            se_config["generation_config"]["num_responses_per_prompt"] = args.se_n
        if args.se_temperature is not None:
            se_config["generation_config"]["temperature"] = args.se_temperature
        if args.se_batch_size is not None:
            se_config["generation_config"]["batch_size"] = args.se_batch_size
        if args.se_save_path is not None:
            if os.path.isdir(args.se_save_path):
                file_name = "se_sampling_results" + "_" + str(se_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(se_config["generation_config"]["temperature"]) + ".json"
                save_path_for_se = os.path.join(args.se_save_path, file_name)
            else:
                save_path_for_se = args.se_save_path   
        else:
            save_path_for_se = None

        if args.se_cached_path is not None:
            if os.path.isdir(args.se_cached_path):
                file_name = "se_sampling_results" + "_" + str(se_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(se_config["generation_config"]["temperature"]) + ".json"
                cached_path_for_se = os.path.join(args.se_cached_path, file_name)
            else:
                cached_path_for_se = args.se_cached_path  

        else:
            cached_path_for_se = None

        se_config["cached_path"] = cached_path_for_se
        se_config["save_path"] = save_path_for_se
    else:
        se_config = copy.deepcopy(DEFAULT_SEMANTIC_ENTROPY_CONFIG)

    # set config for self-consistency methods
    if "self_consistency" in estimation_methods:
        sc_config = copy.deepcopy(DEFAULT_SELF_CON_CONFIG)
        sc_config["sim_batch_size"] = args.sim_batch_size
        if args.sc_model is not None:
            sc_config["model"] = args.sc_model

        if args.sc_n is not None:
            sc_config["generation_config"]["num_responses_per_prompt"] = args.sc_n
        if args.sc_temperature is not None:
            sc_config["generation_config"]["temperature"] = args.sc_temperature
        if args.sc_batch_size is not None:
            sc_config["generation_config"]["batch_size"] = args.sc_batch_size

        if args.sc_save_path is not None:
            if os.path.isdir(args.sc_save_path):
                file_name = "sc_sampling_results" + "_" + str(sc_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(sc_config["generation_config"]["temperature"]) + ".json"
                save_path_for_sc = os.path.join(args.sc_save_path, file_name)
            else:
                save_path_for_sc = args.sc_save_path   
        else:
            save_path_for_sc = None

        if args.sc_cached_path is not None:
            if os.path.isdir(args.sc_cached_path):
                file_name = "sc_sampling_results" + "_" + str(sc_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(sc_config["generation_config"]["temperature"]) + ".json"
                cached_path_for_sc = os.path.join(args.sc_cached_path, file_name)
            else:
                cached_path_for_sc = args.sc_cached_path  

        else:
            cached_path_for_sc = None

        sc_config["cached_path"] = cached_path_for_sc
        sc_config["save_path"] = save_path_for_sc
    else:
        sc_config = copy.deepcopy(DEFAULT_SELF_CON_CONFIG)

    # set config for inside methods
    if "inside" in estimation_methods:
        inside_config = copy.deepcopy(DEFAULT_INSIDE_CONFIG)

        if args.inside_batch_size is not None:
            inside_config["generation_config"]["batch_size"] = args.inside_batch_size
        if args.inside_n is not None:
            inside_config["generation_config"]["num_responses_per_prompt"] = args.inside_n

        if args.inside_save_path is not None:
            if os.path.isdir(args.inside_save_path):
                file_name = "inside_sampling_results" + "_" + str(inside_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(inside_config["generation_config"]["temperature"]) + ".json"
                save_path_for_inside = os.path.join(args.inside_save_path, file_name)
            else:
                save_path_for_inside = args.inside_save_path   
        else:
            save_path_for_inside = None

        if args.inside_cached_path is not None:
            if os.path.isdir(args.inside_cached_path):
                file_name = "inside_sampling_results" + "_" + str(inside_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(inside_config["generation_config"]["temperature"]) + ".json"
                cached_path_for_inside = os.path.join(args.inside_cached_path, file_name)
            else:
                cached_path_for_inside = args.inside_cached_path  

        else:
            cached_path_for_inside = None

        inside_config["cached_path"] = cached_path_for_inside
        inside_config["save_path"] = save_path_for_inside
    else:
        inside_config = copy.deepcopy(DEFAULT_INSIDE_CONFIG)

    # set config for input clarification ensemble methods
    if "ice" in estimation_methods:
        assert args.paraphrase_result_path is not None, "paraphrased queries should be given for method ice"
            
        ice_config = copy.deepcopy(DEFAULT_ICE_CONFIG)
        ice_config["sim_batch_size"] = args.sim_batch_size
        ice_config["paraphrase_path"] = args.paraphrase_result_path
        
        
        if args.ice_batch_size is not None:
            ice_config["generation_config"]["batch_size"] = args.ice_batch_size
        if args.ice_save_path is not None:
            if os.path.isdir(args.ice_save_path):
                file_name = "ice_sampling_results" + "_" + str(ice_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(ice_config["generation_config"]["temperature"]) + ".json"
                save_path_for_ice = os.path.join(args.ice_save_path, file_name)
            else:
                save_path_for_ice = args.ice_save_path   
        else:
            save_path_for_ice = None

        if args.ice_cached_path is not None:
            if os.path.isdir(args.ice_cached_path):
                file_name = "ice_sampling_results" + "_" + str(ice_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(ice_config["generation_config"]["temperature"]) + ".json"
                cached_path_for_ice = os.path.join(args.ice_cached_path, file_name)
            else:
                cached_path_for_ice = args.ice_cached_path  
            

        else:
            cached_path_for_ice = None

        ice_config["cached_path"] = cached_path_for_ice
        ice_config["save_path"] = save_path_for_ice
    else:
        ice_config = copy.deepcopy(DEFAULT_ICE_CONFIG)

    # set config for MI methods
    if "mi" in estimation_methods:
        mi_config = copy.deepcopy(DEFAULT_MI_CONFIG)
        if args.mi_n is not None:
            mi_config["generation_config"]["num_responses_per_prompt"] = args.mi_n
        if args.mi_temperature is not None:
            mi_config["generation_config"]["temperature"] = args.mi_temperature
        if args.mi_batch_size is not None:
            mi_config["generation_config"]["batch_size"] = args.mi_batch_size
        if args.mi_save_path is not None:
            if os.path.isdir(args.mi_save_path):
                file_name = "mi_sampling_results" + "_" + str(mi_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(mi_config["generation_config"]["temperature"]) + ".json"
                file_name_for_m2 = "mi_mu2_results" + "_" + str(mi_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(mi_config["generation_config"]["temperature"]) + ".json"
                save_path_for_mi = os.path.join(args.mi_save_path, file_name)
                mu2_save_path_for_mi = os.path.join(args.mi_save_path, file_name_for_m2)
            else:
                save_path_for_mi = args.mi_save_path 
                file_name_for_m2 = "mi_mu2_results" + "_" + str(mi_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(mi_config["generation_config"]["temperature"]) + ".json"
                mu2_save_path_for_mi = os.path.join(os.path.dirname(args.mi_save_path), file_name_for_m2)

        else:
            save_path_for_mi = None
            mu2_save_path_for_mi = None

        if args.mi_cached_path is not None:
            if os.path.isdir(args.mi_cached_path):
                file_name = "mi_sampling_results" + "_" + str(mi_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(mi_config["generation_config"]["temperature"]) + ".json"
                cached_path_for_mi = os.path.join(args.mi_cached_path, file_name)
            else:
                cached_path_for_mi = args.mi_cached_path  

        else:
            cached_path_for_mi = None
        
        if args.mi_cached_mu2_path is not None:
            if os.path.isdir(args.mi_cached_mu2_path):
                file_name = "mi_mu2_results" + "_" + str(mi_config["generation_config"]["num_responses_per_prompt"])+ "_temp_" + str(mi_config["generation_config"]["temperature"]) + ".json"
                mu2_cached_path_for_mi = os.path.join(args.mi_cached_mu2_path, file_name)
            else:
                mu2_cached_path_for_mi = args.mi_cached_mu2_path  

        else:
            mu2_cached_path_for_mi = None
        mi_config["cached_mu2_path"] = mu2_cached_path_for_mi
        mi_config["cached_path"] = cached_path_for_mi
        mi_config["save_path"] = save_path_for_mi
        mi_config["save_mu2_path"] = mu2_save_path_for_mi
    else:
        mi_config = copy.deepcopy(DEFAULT_MI_CONFIG)


    
    # set config for esi methods
    if args.aug_func == "paraphrase":
        assert args.paraphrase_result_path is not None, f"paraphrased queries should be given for paraphrase-based esi"
    if args.dir_to_cache_for_esi is not None:
        if os.path.isdir(args.dir_to_cache_for_esi):
            if not os.path.exists(args.dir_to_cache_for_esi):
                os.makedirs(args.dir_to_cache_for_esi)
            file_name = "esi" + "_" + args.aug_func + "_" + str(args.sample_num) + "_" + str(args.percent) + ".json"
            save_path = os.path.join(args.dir_to_cache_for_esi, file_name)
        else:
            save_path = args.dir_to_cache_for_esi  
    else:
        save_path = None
    
    if args.cached_path_for_esi is not None:
        if os.path.isdir(args.cached_path_for_esi):
            file_name = "esi" + "_" + args.aug_func + "_" + str(args.sample_num) + "_" + str(args.percent) + ".json"
            cached_path_for_esi = os.path.join(args.cached_path_for_esi, file_name)
        else:
            cached_path_for_esi = args.cached_path_for_esi  
    
    else:
        cached_path_for_esi = None

    esi_config = {
            "cached_result_path": cached_path_for_esi,
            "save_path": save_path,
            "augfunc": args.aug_func,
            "sample_num": args.sample_num,
            "percent": args.percent,
            "sim_batch_size": args.sim_batch_size,
            "num_scores_returned": args.num_scores_returned,
            "paraphrase_path": args.paraphrase_result_path
    }
    
    # set logger
    if args.debug:
        level = 'DEBUG'
    else:
        level = 'INFO'
    logger.remove()

    if getattr(args, "log_name"):
        if not os.path.exists("log"):
            os.makedirs("log")
        
        logger.add(os.path.join("log", args.log_name + ".log"), level=level)
    
    logger.add(lambda msg: tqdm.write(msg, end=''), colorize=True, level=level)
    
    

    # gpu setting
    if not torch.cuda.is_available():
        device_name = "cpu"
    else:
        device_name = "gpu0"
   

    
    
    logger.info(f"start to compute the uncertainty scores for outputs generated by model '{model_name}' on dataset from '{args.cached_result_path}'")
    estimated_scores = run_estimation(cached_results, estimation_methods=estimation_methods,  device_name=device_name, batch_size = args.batch_size, esi_config=esi_config, se_config=se_config, self_con_config=sc_config, sar_config=sar_config, inside_config=inside_config, ice_config=ice_config,sampling_only=args.sampling_only, mi_config=mi_config)

    if not args.sampling_only:
        if len(estimation_methods) == 1:
            name = estimation_methods[0] + "_" + args.aug_func + "_" + str(args.sample_num) + "_" + str(args.percent)
        else:

            name = "full_test" + "_" + args.aug_func + "_" + str(args.sample_num) + "_" + str(args.percent)
        uncertainty_score_save_path = os.path.join(os.path.dirname(args.output_dir), name + "_" + "uncertainty_scores.json")
        with open(uncertainty_score_save_path, 'w', encoding='utf-8') as f:
            json.dump(estimated_scores, f, indent=4)
    
        truth_label = (np.array(cached_results.scores[args.correctness_metric]) < args.correctness_threshold).astype(int)
    
        logger.info('start to evaluate error detection performace')
        evaluator = Uncertainty_Evaluator(metrics="auroc")
        evaluator.evaluate(estimated_scores, truth_label)
        if args.output_dir is not None:
    
            evaluator.to_excel(args.output_dir, name=name)

    