import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
root_path = "/root/workspace/self-improvement"
sys.path.append(root_path)
import copy
import json
import torch
import logging

from vllm import LLM, SamplingParams

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoModelWithLMHead, GenerationConfig

from inference.chat_template import CHAT_TEMPLATE, USER_START_END
from inference.inference_args import parse_args
from utils.get_data import get_datasets, apply_chat_template
from utils.extract_QA import extract_QA
from pathlib import Path
from eval.scorer2 import score_mix2, compute_score

os.umask(0)

logger = logging.getLogger(__name__)
logging.basicConfig(level='INFO')


def save_data(data, tgt_dir, tgt_file_name):
    try:
        if tgt_file_name.endswith("json"):
            with open(os.path.join(tgt_dir, tgt_file_name), 'w', encoding='utf-8') as fp:
                json.dump(data, fp, ensure_ascii=False, indent=2)
        elif tgt_file_name.endswith("txt"):
            with open(os.path.join(tgt_dir, tgt_file_name), 'w', encoding='utf-8') as fp:
                fp.write(data)
        else:
            raise NotImplementedError
    except:
        torch.save(data, os.path.join(tgt_dir, tgt_file_name.split(".")[0] + ".pt"))


def llm_generate(llm, prompts, sampling_params):
    outputs = llm.generate(
        prompts,
        sampling_params=sampling_params,
    )

    generated_texts = []
    for output in outputs:
        generated_text = output.outputs[0].text
        generated_texts.append(generated_text)
    return generated_texts


def main(args):
    run_name = "-".join(args.model_path.split("/")[-2:])
    args.tgt_dir = os.path.join(args.tgt_dir, run_name)
    args.tgt_dir = os.path.join(root_path, args.tgt_dir)
    tgt_dir = Path(args.tgt_dir)
    tgt_dir.mkdir(parents=True, exist_ok=True)

    print("tgt_dir:", args.tgt_dir)

    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
    tokenizer.chat_template = CHAT_TEMPLATE[args.chat_template_name]

    args.prompt_templates = {}
    prompt_template_names = ['dqa', 'qa']
    for prompt_template_name in prompt_template_names:
        with open(os.path.join(args.prompt_template_dir, prompt_template_name + '.txt'), 'r') as f:
            args.prompt_templates[prompt_template_name] = f.read()
    if args.do_eval:
        inference_tasks = ["eval"]
    else:
        inference_tasks = args.dataset_type.split("_")

    generation_config = GenerationConfig.from_pretrained(args.model_path, trust_remote_code=True)
    sampling_kwargs = {
        "early_stopping": False,
        "top_p": generation_config.top_p,
        "top_k": -1 if generation_config.top_k == 0 else generation_config.top_k,
        "temperature": generation_config.temperature,
        "max_tokens": 1500,
    }
    if args.do_eval:
        sampling_kwargs['max_tokens'] = 512
    sampling_params = SamplingParams(**sampling_kwargs)
    llm = LLM(
        model=args.model_path,
        gpu_memory_utilization=0.9,
        tensor_parallel_size=8,
        dtype=torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32),
        seed=args.seed,
        max_model_len=2048,  # should be enough
        enforce_eager=False,
        trust_remote_code=True,
        max_num_seqs=128,
    )
    if "dqa" in inference_tasks:
        dataset = get_datasets(dataset_name=args.dataset, dataset_type="dqa")
        # dataset = dataset[:64]
        for data in dataset:
            _data = apply_chat_template(args, tokenizer, data, "dqa", task="generation")
            data.update(_data)

        prompts = []
        for data in dataset:
            prompts.append(data['text'])
        print("Example: \n", prompts[0])

        generated_texts = llm_generate(llm, prompts, sampling_params)
        print("len(datase):", len(dataset))
        new_dataset = []
        for data, prompt, generated_text in zip(dataset, prompts, generated_texts):
            try:
                data['prompt'] = prompt
                _dict = extract_QA(generated_text)
                data['question'] = _dict['question']
                data['answer'] = _dict['answer']
                new_dataset.append(data)
            except:
                continue
        dataset = new_dataset
        print("final len(datase):", len(dataset))
        print("Example response:\n")
        print("question:", dataset[0]['question'])
        print("answer:", dataset[0]['answer'])
        save_data(dataset, tgt_dir=args.tgt_dir, tgt_file_name="dqa.json")

    if "qa" in inference_tasks:
        if "dqa" not in inference_tasks:
            dataset = get_datasets(dataset_name=args.dataset, dataset_type="qa")
        for data in dataset:
            _data = apply_chat_template(args, tokenizer, data, "qa", task="generation")
            data.update(_data)

        prompts = []
        for data in dataset:
            prompts.append(data['text'])

        generated_texts = llm_generate(llm, prompts, sampling_params)
        for data, prompt, generated_text in zip(dataset, prompts, generated_texts):
            data['qa_prompt'] = prompt
            if "dqa" in inference_tasks:
                data['qa_answer'] = generated_text
            else:
                data['answer'] = generated_text

        save_data(dataset, tgt_dir=args.tgt_dir, tgt_file_name="qa.json")

        if "dqa" in inference_tasks:
            for data in dataset:
                data['dqa_answer'] = data["answer"]
                del data['answer']
            save_data(dataset, tgt_dir=args.tgt_dir, tgt_file_name="dqa_qa.json")

    if args.do_eval:
        dataset = get_datasets(dataset_name=args.dataset, dataset_type="eval")
        for data in dataset:
            _data = apply_chat_template(args, tokenizer, data, "eval", task="generation")
            data.update(_data)

        prompts = []
        for data in dataset:
            prompts.append(data['text'])
        print("Example: \n", prompts[0])

        generated_texts = llm_generate(llm, prompts, sampling_params)

        for data, prompt, generated_text in zip(dataset, prompts, generated_texts):
            data['prompt'] = prompt
            data['huatuo_answer_0'] = generated_text

        idx = 1
        while True:
            if os.path.exists(os.path.join(args.tgt_dir, f"output_eval_{len(dataset)}_{idx}.json")):
                idx += 1
            else:
                break
        outstr_path = os.path.join(args.tgt_dir, f"output_eval_{len(dataset)}_{idx}.json")
        val_res = score_mix2(outstr_path, dataset, True, 1)
        print(val_res)
        del val_res['InputOutputTable']
        outstr = json.dumps(val_res, ensure_ascii=False, indent=2)

        outstr += '\n' + f'model_name_or_path: {args.model_path}'
        outstr += '\n' + f'output: {outstr_path}'
        outstr += '\n' + f"acc_str: {val_res['acc_str']}"
        outstr += '\n\n' + f'data_example: {json.dumps(dataset[0], ensure_ascii=False, indent=2)}'
        print(outstr)
        save_data(dataset, tgt_dir=args.tgt_dir, tgt_file_name=f"eval_{len(dataset)}_{idx}.json")
        save_data(outstr, tgt_dir=args.tgt_dir, tgt_file_name=f"eval_{len(dataset)}_{idx}.txt")


if __name__ == '__main__':
    args = parse_args()
    main(args)

"""
# inference
python vllm_predict.py --model_path /root/workspace/hf_models/meta-llama/Llama-3.1-8B-Instruct

# evaluate
python inference/vllm_predict.py --model_path ckpts/Llama-3.2-3B-Instruct/MedQA_en_dqa_qa_lr_2e-05_init_train_True \
    --chat_template_name llama-3.2-chat \
    --dataset MedQA_en \
    --do_eval True
python inference/vllm_predict.py --model_path /root/workspace/hf_models/meta-llama/Llama-3.1-8B-Instruct \
    --chat_template_name llama-3.1-chat \
    --dataset MedQA_en \
    --do_eval True
python inference/vllm_predict.py --model_path /root/workspace/hf_models/meta-llama/Llama-3.2-3B-Instruct \
    --chat_template_name llama-3.2-chat \
    --dataset MedQA_en \
    --do_eval True
"""
