import argparse
import ipdb
from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig
from vllm import EngineArgs, LLMEngine, SamplingParams
import torch
try:
    from tigerscore import TIGERScorer
except:
    pass
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, AutoConfig, AutoModel
import json
import os
import sys
import warnings
from huggingface_hub import login
from vllm import LLM

from utils.data_processor import DataProcessor
from utils.utils import load_jsonl
import evaluate_configuration
from infer.infer import infer_hf, infer_openai, infer_hf_critic_model
from infer.result_processor import ResultProcessor


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--tasks", type=str, default="GQC", choices=["GQC", "QC", "G", "Q", "C"],
                        help="Specify the evaluation task. Options: G (generation), Q (critique), C (correction),"
                             "The correction task must be executed after the critique, "
                             "or a critique result file must be specified.")
    parser.add_argument("--hf_model", type=str, default=None,
                        help="Path or identifier for a HuggingFace model.")
    parser.add_argument("--hf_critic_model", type=str, default=None,
                        help="Identifier for a critique model from Hugging Face, "
                             "supporting GAIR/autoj-13b and openbmb/UltraCM-13b models.")
    parser.add_argument("--openai_model", type=str, default=None,
                        help="Model from openai")
    parser.add_argument("--enable_code_execution", action="store_true",
                        help="Enable code execution for generation and correction.")

    parser.add_argument("--available_gpus", type=str, default="0",
                        help="Specifies which GPUs to use, by ID, separated by commas (e.g., '0,1').")
    # parser.add_argument("--enable_vllm", action="store_true", default=True)
    parser.add_argument("--prompt_type", type=str, default="fs",
                        choices=["fs", "zs-crit-cot", "zs-crit-ao-1", "zs-crit-ao-2", "zs-crit-ao-3"])

    parser.add_argument("--existed_gen_file", type=str, default=None,
                        help="Path to an existing file with generation results to be used or evaluated.")
    parser.add_argument("--existed_crit_file", type=str, default=None,
                        help="Path to an existing file with critique results to be used or evaluated.")
    parser.add_argument("--existed_corr_file", type=str, default=None,
                        help="Path to an existing file with correction results to be used or evaluated.")

    parser.add_argument("--prompt_dir", type=str, default="./prompt")
    parser.add_argument("--data_cache_dir", type=str, default="./data")
    parser.add_argument("--output_dir", type=str, default="./output")
    args = parser.parse_args()
    return args

def set_environment(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.available_gpus
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    evaluate_configuration.openai_api_key = os.environ.get("OPENAI_API_KEY")
    evaluate_configuration.hf_token = os.environ.get("HUGGINGFACE_TOKEN")
    if args.enable_code_execution:
        os.environ["HF_ALLOW_CODE_EVAL"] = "1"


def set_existed_file(args):
    existed_file_by_task = {
        "G": args.existed_gen_file,
        "Q": args.existed_crit_file,
        "C": args.existed_corr_file,
    }
    has_existed_file = not all(value is None for value in existed_file_by_task.values())
    return existed_file_by_task, has_existed_file


def set_prompt_type_by_task(prompt_type):
    prompt_type_by_task = {t: "" for t in ["G", "Q", "C"]}
    if prompt_type == "fs":
        prompt_type_by_task = {t: "fs" for t in ["G", "Q", "C"]}
    elif prompt_type == "zs-crit-cot":
        prompt_type_by_task = {
            "G": "fs",
            "Q": prompt_type,
            "C": "zs-cot"
        }
    elif prompt_type in ["zs-crit-ao-1", "zs-crit-ao-2", "zs-crit-ao-3"]:
        prompt_type_by_task = {
            "G": "fs",
            "Q": prompt_type,
            "C": "zs-ao"
        }
    return prompt_type_by_task


def load_exist_result(existed_file):
    if os.path.exists(existed_file):
        print(f"evaluate result in {existed_file}")
        result_list = list(load_jsonl(existed_file))
        return result_list
    else:
        sys.exit(f"file:{existed_file} not existed")

def init_hf_model_with_vllm(model, available_gpus, llm):
    if 'autoj-13b' in model or 'UltraCM' in model or 'Llama-2' in model:
        available_gpus = available_gpus.split(',')
        llm = LLM(model=model, tensor_parallel_size=len(available_gpus), trust_remote_code=True)
    elif 'TIGERScore' in model:
        llm = TIGERScorer(model_name=model, use_vllm=True)
    elif 'Themis' in model:
        engine_args = EngineArgs(
            model='/cpfs02/llm/shared/public/lantian/hf_models/PKU-ONELab/Themis', 
            tensor_parallel_size=1,
            max_num_seqs=1024,
            max_num_batched_tokens=2048,
            gpu_memory_utilization=0.98,
            swap_space=16)
        llm = LLMEngine.from_engine_args(engine_args)
    else:
        model_path = model
        if model in [
            '/cpfs02/llm/shared/public/lantian/exp/s2_add_critictuning_v01rc1/s2_no_critic_344/aliyun_Ampere_7B_v1.1_enchance_FT_v1.0.0_s1_rc47_s2_no_critictuning_v01rc1/344_hf_ckpt',
            '/cpfs01/shared/public/public_hdd/llmeval/model_weights/hf_hub/models--internlm--internlm2-chat-20b/snapshots/3f710f76f56f8c40dc5dd800dbe66f3341cb2c87',
            '/cpfs02/llm/shared/public/lantian/exp/20240418_aliyun_Ampere_7B_v1_1_FT_v1_0_0_s1_rc48_1660_hf_ckpt',
            '/cpfs02/llm/shared/public/lantian/exp/s2_add_critictuning_v01rc1/s2_add_critic_530/aliyun_Ampere_7B_v1.1_enchance_FT_v1.0.0_s1_rc47_s2_add_critictuning_v01rc1/530_hf_ckpt',
            '/cpfs02/llm/shared/public/lantian/exp/20240727_d4_st_5_l1_resumm_debug/aliyun_Ampere_7B_v1.1_enchance_FT_v1.0.0_s1_rc47_FINAL_critic_v4_st_5_l1_resumm/0_hf_ckpt',
            '/cpfs01/shared/public/llmeval/model_weights/hf_hub/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/5206a32e0bd3067aef1ce90f5528ade7d866253f',
            '/cpfs01/shared/public/llmeval/model_weights/hf_hub/models--Qwen--Qwen2-7B-Instruct/snapshots/f2826a00ceef68f0f2b946d945ecc0477ce4450c',
            '/cpfs02/llm/shared/public/lantian/exp/internlm2_5/s1_d4_st_5_l1_resumm_20240906_internlm2_5_7b_s1/aliyun_Ampere_7B_v1.1_enchance_FT_v1.0.0_s1_rc47_s1_d4_st_5_l1_resumm_20240906_internlm2_5_7b_s1/385_hf_ckpt',
            '/cpfs02/llm/shared/public/lantian/hf_models/meta-llama-3-8B-Instruct',
            #'/cpfs01/shared/public/llmeval/model_weights/hf_hub/models--meta-llama--Meta-Llama-3-70B-Instruct/snapshots/7129260dd854a80eb10ace5f61c20324b472b31c'
        #] or 'transfer_from_tos' in model or 'd4_st_5_l1_resumm' in model or 'v4_st_5_l1_resumm' in model:
        ] or 'transfer_from_tos' in model or 'd4_st_5_l1_resumm' in model or 'v4_st_5_l1_resumm' in model or 'step_train_exp_20240925' in model:
            backend_config = PytorchEngineConfig(
                session_len=32768, 
                # llama3_1 comment this line
                model_name='internlm2', 
                #model_name='llama3',
                tp=1
            )
        elif 'promethues' in model or 'prometheus' in model:
            backend_config = PytorchEngineConfig(
                session_len=32768, 
                model_name='internlm2',
                tp=1
            )
        else:
            backend_config = PytorchEngineConfig(
                session_len=32768, 
                # only for llama-3-70b
                model_name='llama3',
                tp=1
            )
        gen_config = GenerationConfig(
            temperature=0.0,
            top_p=0.95,
            top_k=50,
            max_new_tokens=4096)
        llm = pipeline(model_path, backend_config=backend_config)
    return llm


def evaluate_model(args, existed_file_by_task, prompt_type_by_task, data_processor, result_processor):
    llm = None
    if args.hf_model is not None and evaluate_configuration.hf_token is not None:
        login(token=evaluate_configuration.hf_token)
    result_dict = {}
    for task in ["G", "Q", "C"]:
        if task in args.tasks:
            print(f"----------Start {evaluate_configuration.full_task_name_dict[task]}----------")
            dataset_with_prompt = data_processor.prepare_data_with_prompt_info(
                evaluate_configuration.full_task_name_dict[task])
            existed_file = existed_file_by_task[task]
            prompt_type = prompt_type_by_task[task]
            if existed_file is None:
                if args.hf_model is not None:
                    llm = init_hf_model_with_vllm(model=args.hf_model, available_gpus=args.available_gpus, llm=llm)
                    result_list = infer_hf(model=args.hf_model,
                                           llm=llm,
                                           dataset_with_prompt=dataset_with_prompt,
                                           out_dir=args.output_dir,
                                           task=evaluate_configuration.full_task_name_dict[task],
                                           prompt_type=prompt_type)
                elif args.openai_model is not None:
                    result_list = infer_openai(model=args.openai_model,
                                               api_key=evaluate_configuration.openai_api_key,
                                               dataset_with_prompt=dataset_with_prompt,
                                               out_dir=args.output_dir,
                                               task=evaluate_configuration.full_task_name_dict[task],
                                               prompt_type=prompt_type)
            else:
                if prompt_type not in existed_file:
                    warnings.warn("Mismatched prompt types may lead to issues with the extract answer function.",
                                  UserWarning)
                result_list = load_exist_result(existed_file=existed_file)
            result_by_id = result_processor.result_check(task=evaluate_configuration.full_task_name_dict[task],
                                                         dataset_with_prompt=dataset_with_prompt,
                                                         result_list=result_list,
                                                         enable_code_execution=args.enable_code_execution)
            total_score, score_by_type, score_by_dataset = (
                result_processor.analyse_result(task=evaluate_configuration.full_task_name_dict[task],
                                                dataset_with_prompt=dataset_with_prompt,
                                                result_by_id=result_by_id))
            result_dict[evaluate_configuration.full_task_name_dict[task]] = {
                f"{evaluate_configuration.full_task_name_dict[task]} score": total_score,
                f"{evaluate_configuration.full_task_name_dict[task]} score by type": score_by_type,
                f"{evaluate_configuration.full_task_name_dict[task]} score by dataset": score_by_dataset
            }
            if task == "Q":
                for sample in dataset_with_prompt:
                    id = sample["id"]
                    sample["critique"] = result_by_id[id]["critique_result"]
    print("----------Evaluation result----------")
    print(json.dumps(result_dict, indent=2))


def evaluate_critic_model(args, data_processor, result_processor):
    llm = None
    if "UltraCM-13b" or "autoj-13b" in args.hf_critic_model:
        if args.existed_crit_file is None:
            llm = init_hf_model_with_vllm(model=args.hf_critic_model, available_gpus=args.available_gpus, llm=llm)
            result_list = infer_hf_critic_model(model=args.hf_critic_model,
                                                llm=llm,
                                                dataset=data_processor.dataset,
                                                out_dir=args.output_dir,
                                                task="critique")
        else:
            result_list = load_exist_result(existed_file=args.existed_crit_file)
    result_by_id = result_processor.critic_model_result_check(model=args.hf_critic_model,
                                                              task="critique",
                                                              dataset=data_processor.dataset,
                                                              result_list=result_list)
    total_score, score_by_type, _ = (result_processor.analyse_result(task="critique",
                                        dataset_with_prompt=data_processor.dataset,
                                        result_by_id=result_by_id))
    print(f"critique score: {total_score}")
    print(f"critique score by type: {json.dumps(score_by_type, indent=2)}")


if __name__ == '__main__':
    args = parse_args()
    set_environment(args)
    existed_file_by_task, has_existed_file = set_existed_file(args)
    prompt_type_by_task = set_prompt_type_by_task(prompt_type=args.prompt_type)
    data_processor = DataProcessor(data_cache_dir=args.data_cache_dir,
                                   prompt_dir=args.prompt_dir,
                                   prompt_type_by_task=prompt_type_by_task)
    result_processor = ResultProcessor()
    if args.hf_critic_model is not None:
        evaluate_critic_model(args=args,
                              data_processor=data_processor,
                              result_processor=result_processor)
    elif args.hf_model is not None or args.openai_model is not None or has_existed_file:
        evaluate_model(args=args,
                       existed_file_by_task=existed_file_by_task,
                       prompt_type_by_task=prompt_type_by_task,
                       data_processor=data_processor,
                       result_processor=result_processor)
