import torch
import json
import os
import argparse
import random

import time
import shortuuid
from tqdm import tqdm

from os.path import exists, join, isdir
from packaging import version
from peft.tuners.lora import LoraLayer


from eval_mt_bench.common import load_questions, temperature_config
from eval_mt_bench.conversation import get_conv_template

import importlib

import transformers
from typing import Optional, Dict, Sequence

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    LlamaTokenizer,
    GenerationConfig
)

from peft import (
    PeftModel
)

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "<pad>"

# def smart_tokenizer_and_embedding_resize(
#         args,
#         special_tokens_dict: Dict,
#         tokenizer: transformers.PreTrainedTokenizer,
#         model: transformers.PreTrainedModel,
# ):
#     """Resize tokenizer and embedding.
#
#     Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
#     """
#     num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
#     model.config.pad_token_id = tokenizer.pad_token_id
#     model.resize_token_embeddings(len(tokenizer))
#
#     if num_new_tokens > 0:
#         input_embeddings_data = model.get_input_embeddings().weight.data
#         output_embeddings_data = model.get_output_embeddings().weight.data
#
#         input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
#         output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
#
#         input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
#         output_embeddings_data[-num_new_tokens:] = output_embeddings_avg

def sort_dict_by_value(d, largest_first=True):
    return {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=largest_first)}

def divide_dict(d, total_steps, ratio=0.5):
    value = total_steps * ratio
    return {k: v for k, v in d.items() if v > value}

def is_ipex_available():
    def get_major_and_minor_from_version(full_version):
        return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)

    _torch_version = importlib.metadata.version("torch")
    if importlib.util.find_spec("intel_extension_for_pytorch") is None:
        return False
    _ipex_version = "N/A"
    try:
        _ipex_version = importlib.metadata.version("intel_extension_for_pytorch")
    except importlib.metadata.PackageNotFoundError:
        return False
    torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
    ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
    if torch_major_and_minor != ipex_major_and_minor:
        warnings.warn(
            f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
            f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
        )
        return False
    return True

def get_last_checkpoint(checkpoint_dir):
    if isdir(checkpoint_dir):
        is_completed = exists(join(checkpoint_dir, 'completed'))
        # if is_completed: return None, True # already finished
        max_step = 0
        for filename in os.listdir(checkpoint_dir):
            if isdir(join(checkpoint_dir, filename)) and filename.startswith('checkpoint'):
                max_step = max(max_step, int(filename.replace('checkpoint-', '')))
        if max_step == 0: return None, is_completed # training started, but no checkpoint
        checkpoint_dir = join(checkpoint_dir, f'checkpoint-{max_step}')
        print(f"Found a previous checkpoint at: {checkpoint_dir}")
        return checkpoint_dir, is_completed # checkpoint found!
    return None, False # first training

def get_accelerate_model(args):
    if torch.cuda.is_available():
        n_gpus = torch.cuda.device_count()
    if is_ipex_available() and torch.xpu.is_available():
        n_gpus = torch.xpu.device_count()

    max_memory = f'{args.max_memory_MB}MB'
    max_memory = {i: max_memory for i in range(n_gpus)}
    # device_map = "auto"
    #
    # # if we are in a distributed setting, we need to set the device map and max memory per device
    # if os.environ.get('LOCAL_RANK') is not None:
    #     local_rank = int(os.environ.get('LOCAL_RANK', '0'))
    #     device_map = {'': local_rank}
    #     max_memory = {'': max_memory[local_rank]}

    print(f'loading base model {args.model_name_or_path}...')
    compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))

    #
    # if args.model_path is not None:
    #     model = AutoModelForCausalLM.from_pretrained(
    #         args.model_path,
    #         cache_dir=args.cache_dir,
    #         device_map={'': int(os.environ.get("LOCAL_RANK") or 0)},
    #         # max_memory=max_memory,
    #         torch_dtype=(torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)),
    #     )
    # else:

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        cache_dir=args.cache_dir,
        device_map={'': int(os.environ.get("LOCAL_RANK") or 0)},
        # max_memory=max_memory,
        torch_dtype=(torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)),
    )

    if compute_dtype == torch.float16 and args.bits == 4:
        if torch.cuda.is_bf16_supported():
            print('=' * 80)
            print('Your GPU supports bfloat16, you can accelerate training with the argument --bf16')
            print('=' * 80)

    if compute_dtype == torch.float16 and (is_ipex_available() and torch.xpu.is_available()):
        compute_dtype = torch.bfloat16
        print('Intel XPU does not support float16 yet, so switching to bfloat16')

    # setattr(model, 'model_parallel', True)
    # setattr(model, 'is_parallelizable', True)

    model.config.torch_dtype = (torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))


    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        cache_dir=args.cache_dir,
        padding_side="right",
        use_fast=False,
        add_eos_token=True,
        add_bos_token=True,
        add_prefix_space=True,
        # Needed for HF name change
    )
    # if tokenizer._pad_token is None:
    #     smart_tokenizer_and_embedding_resize(
    #         special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
    #         tokenizer=tokenizer,
    #         model=model,
    #     )

    if tokenizer._pad_token is None and tokenizer._unk_token is not None:
        # tokenizer.add_special_tokens(
        #     {"pad_token": tokenizer.unk_token}
        # )
        tokenizer.pad_token_id = tokenizer.unk_token_id
        model.config.pad_token_id = tokenizer.unk_token_id
        model.generation_config.pad_token_ids = tokenizer.pad_token_id

    print("Loading adapters from checkpoint.")
    if len(args.peft_path) > 0:
        model = PeftModel.from_pretrained(model, args.peft_path, is_trainable=False)
    # model = PeftModel.from_pretrained(model, join(args.peft_path), is_trainable=False)

    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            # module = module.to(torch.bfloat16)
            if args.bf16:
                module = module.to(torch.bfloat16)

    return model, tokenizer


@torch.inference_mode()
def get_model_answers(
    args,
    model_id,
    questions,
    answer_file,
    max_new_token,
    num_choices
):
    model, tokenizer = get_accelerate_model(args)
    model.eval()
    #
    # user_question_list = ["How can I improve my time management skills?",
    #                       "What are the main differences between Python and JavaScript programming languages?",
    #                       "Can you explain the basics of quantum computing?",
    #                       "What are the differences between plant-based and animal-based protein sources?",
    #                       "What are the most effective strategies for conflict resolution in the workplace?",
    #                       "How can governments utilize fiscal and monetary policies to combat economic recessions?",
    #                       "How do social media platforms influence the way people consume and share news, and what are the potential implications for the spread of misinformation?",
    #                       "Explain the process of natural selection and how it contributes to the evolution and adaptation of species.",
    #                       "As a pirate captain, what would you say to your crew to motivate them to search for hidden treasure?",
    #                       "As a space colonist on Mars, describe your daily life and the challenges you face living on another planet.",
    #                       "How can you determine if a restaurant is popular among locals or mainly attracts tourists, and why might this information be useful?",
    #                       "How can you determine if a person is genuinely interested in a conversation or simply being polite?",
    #                       "How can observing the behavior of other people in a social situation provide clues about cultural norms and expectations?",
    #                       "How many times does the average human blink in a lifetime? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.",
    #                       "What if the Black Death had not occurred in the 14th century?",
    #                       "Given that f(x) = 5x^3 - 2x + 3, find the value of f(2).",
    #                       "Write a script for a YouTube video exploring the history and cultural significance of jazz."]

    # for user_question in user_question_list:
    #     prompt = (
    #         "<s>Below is an instruction that describes a task. "
    #         "Write a response that appropriately completes the request.\n\n"
    #         "### Instruction:\n{user_question}\n\n### Response: "
    #     )
    #     prompt = prompt.format(user_question=user_question)
    #     inputs = tokenizer(user_question, return_tensors="pt").to('cuda')
    #     gen_config = GenerationConfig.from_pretrained('meta-llama/Llama-2-7b-hf')
    #     gen_config.max_new_tokens = max_new_token
    #     gen_config.do_sample = False
    #     with torch.inference_mode():
    #         outputs = model.generate(
    #             **inputs,
    #             # max_new_tokens=self.max_new_tokens,
    #             generation_config=gen_config
    #         )
    #     inputs = inputs['input_ids']
    #     text = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
    #     print(f'{prompt + text}')

    # conv = get_conv_template(model_id)
    # for j in range(3):
    #     qs = user_question_list[j]
    #     conv.append_message(conv.roles[0], qs)
    #     conv.append_message(conv.roles[1], None)
    #
    #     output = 'This answer is not available.'
    #     conv.update_last_message(output)
    #
    # prompt = conv.get_prompt()
    # print(prompt)

    for question in tqdm(questions):
        if question["category"] in temperature_config:
            temperature = temperature_config[question["category"]]
        else:
            temperature = 0.7

        choices = []
        for i in range(num_choices):
            torch.manual_seed(i)
            conv = get_conv_template(model_id)
            turns = []
            for j in range(len(question["turns"])):
                qs = question["turns"][j]
                conv.append_message(conv.roles[0], qs)
                conv.append_message(conv.roles[1], None)
                prompt = conv.get_prompt()
                input = tokenizer([prompt], return_tensors="pt", add_special_tokens=False).to('cuda')
                input_ids = input.input_ids

                if temperature < 1e-4:
                    do_sample = False
                else:
                    do_sample = True

                # gen_config = GenerationConfig(
                #     bos_token_id = 1,
                #     do_sample = True,
                #     eos_token_id = 2,
                #     max_length = 4096,
                #     pad_token_id = 0,
                #     temperature = 0.6,
                #     top_p = 0.9,
                #     max_new_tokens = max_new_token,
                # )

                # some models may error out when generating long outputs
                try:
                    with torch.inference_mode():
                        # output_ids = model.generate(
                        #     **input,
                        #     generation_config=gen_config
                        # )
                        # def get_tokens_as_list(word_list):
                        #     "Converts a sequence of words into a list of tokens"
                        #     tokens_list = []
                        #     for word in word_list:
                        #         tokenized_word = \
                        #         tokenizer([word], add_special_tokens=False).input_ids[0]
                        #         tokens_list.append(tokenized_word)
                        #     return tokens_list

                        # bad_words_ids = get_tokens_as_list(["<s>", "User", "Assistant"])

                        output_ids = model.generate(
                            **input,
                            generation_config=GenerationConfig(
                                do_sample=do_sample,
                                max_new_tokens=max_new_token,
                                # bad_words_ids=bad_words_ids,
                                temperature=temperature,
                                pad_token_id=tokenizer.pad_token_id,
                                )
                        )
                    # generation_config = GenerationConfig(
                    #     do_sample=True,
                    #     max_new_tokens=max_new_token,
                    #     top_p=0.9,
                    #     temperature=0.7,
                    # )
                    # output_ids = model.generate(
                    #     torch.as_tensor(input_ids).cuda(),
                    #     do_sample=do_sample,
                    #     temperature=temperature,
                    #     max_new_tokens=max_new_token,
                    # )
                    if model.config.is_encoder_decoder:
                        output_ids = output_ids[0]
                    else:
                        output_ids = output_ids[0][len(input_ids[0]) :]

                    if conv.stop_token_ids:
                        stop_token_ids_index = [
                            i
                            for i, id in enumerate(output_ids)
                            if id in conv.stop_token_ids
                        ]
                        if len(stop_token_ids_index) > 0:
                            output_ids = output_ids[: stop_token_ids_index[0]]

                    # inputs = inputs['input_ids']
                    # output = tokenizer.decode(output_ids, skip_special_tokens=True)
                    # print(prompt)
                    # print(output)
                    # be consistent with the template's stop_token_ids

                    output = tokenizer.decode(
                        output_ids,
                        spaces_between_special_tokens=False,
                    )

                    if conv.stop_str and isinstance(conv.stop_str, list):
                        stop_str_indices = sorted(
                            [
                                output.find(stop_str)
                                for stop_str in conv.stop_str
                                if output.find(stop_str) > 0
                            ]
                        )
                        if len(stop_str_indices) > 0:
                            output = output[: stop_str_indices[0]]
                    elif conv.stop_str and output.find(conv.stop_str) > 0:
                        output = output[: output.find(conv.stop_str)]

                    for special_token in tokenizer.special_tokens_map.values():
                        if isinstance(special_token, list):
                            for special_tok in special_token:
                                output = output.replace(special_tok, "")
                        else:
                            output = output.replace(special_token, "")
                    output = output.strip()

                    if conv.name == "xgen" and output.startswith("Assistant:"):
                        output = output.replace("Assistant:", "", 1).strip()
                except RuntimeError as e:
                    print("ERROR question ID: ", question["question_id"])
                    output = "ERROR"

                conv.update_last_message(output)
                turns.append(output)

            choices.append({"index": i, "turns": turns})

        # Dump answers
        os.makedirs(os.path.dirname(answer_file), exist_ok=True)
        with open(os.path.expanduser(answer_file), "a") as fout:
            ans_json = {
                "question_id": question["question_id"],
                "answer_id": shortuuid.uuid(),
                "model_id": model_id,
                "choices": choices,
                "tstamp": time.time(),
            }
            fout.write(json.dumps(ans_json) + "\n")


def reorg_answer_file(answer_file):
    """Sort by question id and de-duplication"""
    answers = {}
    with open(answer_file, "r") as fin:
        for l in fin:
            qid = json.loads(l)["question_id"]
            answers[qid] = l

    qids = sorted(list(answers.keys()))

    with open(answer_file, "w") as fout:
        for qid in qids:
            fout.write(answers[qid])


def run_eval(
    args,
    model_id,
    question_file,
    question_begin,
    question_end,
    answer_file,
    max_new_token,
    num_choices,
    num_gpus_per_model,
    num_gpus_total,
):
    questions = load_questions(question_file, question_begin, question_end)
    # random shuffle the questions to balance the loading
    random.shuffle(questions)

    # Split the question file into `num_gpus` files
    assert num_gpus_total % num_gpus_per_model == 0
    use_ray = num_gpus_total // num_gpus_per_model > 1

    if use_ray:
        get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
            get_model_answers
        ).remote
    else:
        get_answers_func = get_model_answers

    chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model)
    ans_handles = []
    for i in range(0, len(questions), chunk_size):
        ans_handles.append(
            get_answers_func(
                args,
                model_id,
                questions[i : i + chunk_size],
                answer_file,
                max_new_token,
                num_choices
            )
        )

    if use_ray:
        ray.get(ans_handles)


# def generate_model_answer(args, model_id, question_file, answer_file, max_new_token, num_choices):
#     if args.num_gpus_total // args.num_gpus_per_model > 1:
#         import ray
#
#         ray.init()



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name_or_path', type=str, default='huggyllama/llama-7b')
    parser.add_argument('--model_path', type=str, default=None)
    parser.add_argument('--model_name', type=str, default='')
    parser.add_argument('--bits', type=int, default=16)
    parser.add_argument('--fp16', type=bool, default=False)
    parser.add_argument('--bf16', type=bool, default=True)
    parser.add_argument('--cache_dir', type=str, default=None)
    parser.add_argument('--trust_remote_code', type=bool, default=False)
    parser.add_argument('--use_auth_token', type=bool, default=False)
    parser.add_argument('--max_memory_MB', type=int, default=80000)

    parser.add_argument(
        "--bench-name",
        type=str,
        default="mt_bench",
        help="The name of the benchmark question set.",
    )
    parser.add_argument(
        "--question-begin",
        type=int,
        help="A debug option. The begin index of questions.",
    )
    parser.add_argument(
        "--question-end", type=int, help="A debug option. The end index of questions."
    )
    parser.add_argument("--answer-file", type=str, help="The output answer file.")
    parser.add_argument(
        "--num-choices",
        type=int,
        default=1,
        help="How many completion choices to generate.",
    )
    parser.add_argument(
        "--num-gpus-per-model",
        type=int,
        default=1,
        help="The number of GPUs per model.",
    )
    parser.add_argument(
        "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
    )
    parser.add_argument(
        "--max-gpu-memory",
        type=str,
        help="Maxmum GPU memory used for model weights per GPU.",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        choices=["float32", "float16", "bfloat16"],
        help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
        default=None,
    )
    parser.add_argument(
        "--revision",
        type=str,
        default="main",
        help="The model revision to load.",
    )
    parser.add_argument(
        "--model_id",
        type=str,
        default="llama-7b",
        help="The model id.",
    )

    parser.add_argument(
        "--peft_path",
        type=str,
        default="",
        help="The path to the peft model.",
    )

    parser.add_argument(
        "--max_new_token",
        type=int,
        default=256,
        help="The maximum number of new generated tokens.",
    )

    args = parser.parse_args()

    if args.num_gpus_total // args.num_gpus_per_model > 1:
        import ray

        ray.init()

    question_file = f"./data/{args.bench_name}/question.jsonl"
    if args.answer_file:
        answer_file = args.answer_file
    else:
        answer_file = f"./data/{args.bench_name}/model_answer/{args.model_name}_{args.model_id}_{args.max_new_token}.jsonl"

    print(f"Output to {answer_file}")

    run_eval(
        args=args,
        model_id=args.model_id,
        question_file=question_file,
        question_begin=args.question_begin,
        question_end=args.question_end,
        answer_file=answer_file,
        max_new_token=args.max_new_token,
        num_choices=args.num_choices,
        num_gpus_per_model=args.num_gpus_per_model,
        num_gpus_total=args.num_gpus_total,
    )

    reorg_answer_file(answer_file)