import torch
import json
import os
import argparse
import random

import time
import shortuuid
from tqdm import tqdm

from os.path import exists, join, isdir
from packaging import version
from peft.tuners.lora import LoraLayer


from eval_mt_bench.common import load_questions, temperature_config
from eval_mt_bench.conversation import get_conv_template

import importlib

import transformers
from typing import Optional, Dict, Sequence

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    LlamaTokenizer,
    GenerationConfig
)

from peft import (
    PeftModel
)

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"

def smart_tokenizer_and_embedding_resize(
        special_tokens_dict: Dict,
        tokenizer: transformers.PreTrainedTokenizer,
        model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings_data = model.get_input_embeddings().weight.data
        output_embeddings_data = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
        output_embeddings_data[-num_new_tokens:] = output_embeddings_avg

    model.config.pad_token_id = tokenizer.pad_token_id

def is_ipex_available():
    def get_major_and_minor_from_version(full_version):
        return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)

    _torch_version = importlib.metadata.version("torch")
    if importlib.util.find_spec("intel_extension_for_pytorch") is None:
        return False
    _ipex_version = "N/A"
    try:
        _ipex_version = importlib.metadata.version("intel_extension_for_pytorch")
    except importlib.metadata.PackageNotFoundError:
        return False
    torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
    ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
    if torch_major_and_minor != ipex_major_and_minor:
        warnings.warn(
            f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
            f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
        )
        return False
    return True



def get_accelerate_model(args):
    if torch.cuda.is_available():
        n_gpus = torch.cuda.device_count()
    if is_ipex_available() and torch.xpu.is_available():
        n_gpus = torch.xpu.device_count()

    max_memory = f'{args.max_memory_MB}MB'
    max_memory = {i: max_memory for i in range(n_gpus)}
    device_map = "auto"

    # if we are in a distributed setting, we need to set the device map and max memory per device
    if os.environ.get('LOCAL_RANK') is not None:
        local_rank = int(os.environ.get('LOCAL_RANK', '0'))
        device_map = {'': local_rank}
        max_memory = {'': max_memory[local_rank]}

    print(f'loading base model {args.model_name_or_path}...')
    compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        load_in_4bit=args.bits == 4,
        load_in_8bit=args.bits == 8,
        cache_dir=args.cache_dir,
        device_map=device_map,
        max_memory=max_memory,
        torch_dtype=(torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)),
    )

    if compute_dtype == torch.float16 and args.bits == 4:
        if torch.cuda.is_bf16_supported():
            print('=' * 80)
            print('Your GPU supports bfloat16, you can accelerate training with the argument --bf16')
            print('=' * 80)

    if compute_dtype == torch.float16 and (is_ipex_available() and torch.xpu.is_available()):
        compute_dtype = torch.bfloat16
        print('Intel XPU does not support float16 yet, so switching to bfloat16')

    setattr(model, 'model_parallel', True)
    setattr(model, 'is_parallelizable', True)

    model.config.torch_dtype = (torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))

    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path
    )

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        cache_dir=args.cache_dir,
        padding_side="right",
        use_fast=False,
        add_eos_token=True,
        add_bos_token=True,
        # Needed for HF name change
    )
    if tokenizer._pad_token is None:
        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
            tokenizer=tokenizer,
            model=model,
        )

    print("Loading adapters from checkpoint.")
    model = PeftModel.from_pretrained(model, join(args.peft_path), is_trainable=False)
    model.eval()

    return model, tokenizer


@torch.inference_mode()
def get_model_answers(
    args,
    model_id,
    questions,
    answer_file,
    max_new_token,
    num_choices
):
    model, tokenizer = get_accelerate_model(args)
    model.eval()

    for question in tqdm(questions):
        if question["category"] in temperature_config:
            temperature = temperature_config[question["category"]]
        else:
            temperature = 0.7

        choices = []
        for i in range(num_choices):
            torch.manual_seed(i)
            conv = get_conv_template(model_id)
            turns = []
            for j in range(len(question["turns"])):
                qs = question["turns"][j]
                conv.append_message(conv.roles[0], qs)
                conv.append_message(conv.roles[1], None)
                prompt = conv.get_prompt()
                input = tokenizer([prompt], return_tensors="pt", add_special_tokens=False).to('cuda')
                input_ids = input.input_ids

                if temperature < 1e-4:
                    do_sample = False
                else:
                    do_sample = True

                # some models may error out when generating long outputs
                try:
                    output_ids = model.generate(
                        **input,
                        generation_config=GenerationConfig(
                            do_sample=do_sample,
                            max_new_tokens=max_new_token,
                            temperature=temperature,
                        )
                    )
                    # generation_config = GenerationConfig(
                    #     do_sample=True,
                    #     max_new_tokens=max_new_token,
                    #     top_p=0.9,
                    #     temperature=0.7,
                    # )
                    # output_ids = model.generate(
                    #     torch.as_tensor(input_ids).cuda(),
                    #     do_sample=do_sample,
                    #     temperature=temperature,
                    #     max_new_tokens=max_new_token,
                    # )
                    if model.config.is_encoder_decoder:
                        output_ids = output_ids[0]
                    else:
                        output_ids = output_ids[0][len(input_ids[0]) :]

                    # be consistent with the template's stop_token_ids
                    if conv.stop_token_ids:
                        stop_token_ids_index = [
                            i
                            for i, id in enumerate(output_ids)
                            if id in conv.stop_token_ids
                        ]
                        if len(stop_token_ids_index) > 0:
                            output_ids = output_ids[: stop_token_ids_index[0]]

                    output = tokenizer.decode(
                        output_ids
                    )
                    # print(prompt)
                    # print(output)
                    # print("-------------------------------------------------------------------"
                    #       "")
                    # output = tokenizer.decode(
                    #     output_ids,
                    #     skip_special_tokens=True,
                    # )
                    if conv.stop_str and isinstance(conv.stop_str, list):
                        stop_str_indices = sorted(
                            [
                                output.find(stop_str)
                                for stop_str in conv.stop_str
                                if output.find(stop_str) > 0
                            ]
                        )
                        if len(stop_str_indices) > 0:
                            output = output[: stop_str_indices[0]]
                    elif conv.stop_str and output.find(conv.stop_str) > 0:
                        output = output[: output.find(conv.stop_str)]

                    for special_token in tokenizer.special_tokens_map.values():
                        if isinstance(special_token, list):
                            for special_tok in special_token:
                                output = output.replace(special_tok, "")
                        else:
                            output = output.replace(special_token, "")
                    output = output.strip()

                    if conv.name == "xgen" and output.startswith("Assistant:"):
                        output = output.replace("Assistant:", "", 1).strip()
                except RuntimeError as e:
                    print("ERROR question ID: ", question["question_id"])
                    output = "ERROR"

                conv.update_last_message(output)
                turns.append(output)

            choices.append({"index": i, "turns": turns})

        # Dump answers
        os.makedirs(os.path.dirname(answer_file), exist_ok=True)
        with open(os.path.expanduser(answer_file), "a") as fout:
            ans_json = {
                "question_id": question["question_id"],
                "answer_id": shortuuid.uuid(),
                "model_id": model_id,
                "choices": choices,
                "tstamp": time.time(),
            }
            fout.write(json.dumps(ans_json) + "\n")


def reorg_answer_file(answer_file):
    """Sort by question id and de-duplication"""
    answers = {}
    with open(answer_file, "r") as fin:
        for l in fin:
            qid = json.loads(l)["question_id"]
            answers[qid] = l

    qids = sorted(list(answers.keys()))
    with open(answer_file, "w") as fout:
        for qid in qids:
            fout.write(answers[qid])


def run_eval(
    args,
    model_id,
    question_file,
    question_begin,
    question_end,
    answer_file,
    max_new_token,
    num_choices,
    num_gpus_per_model,
    num_gpus_total,
):
    questions = load_questions(question_file, question_begin, question_end)
    # random shuffle the questions to balance the loading
    random.shuffle(questions)

    # Split the question file into `num_gpus` files
    assert num_gpus_total % num_gpus_per_model == 0
    use_ray = num_gpus_total // num_gpus_per_model > 1

    if use_ray:
        get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
            get_model_answers
        ).remote
    else:
        get_answers_func = get_model_answers

    chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model)
    ans_handles = []
    for i in range(0, len(questions), chunk_size):
        ans_handles.append(
            get_answers_func(
                args,
                model_id,
                questions[i : i + chunk_size],
                answer_file,
                max_new_token,
                num_choices
            )
        )

    if use_ray:
        ray.get(ans_handles)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name_or_path', type=str, default='huggyllama/llama-7b')
    parser.add_argument('--bits', type=int, default=32)
    parser.add_argument('--fp16', type=bool, default=False)
    parser.add_argument('--bf16', type=bool, default=False)
    parser.add_argument('--cache_dir', type=str, default=None)
    parser.add_argument('--trust_remote_code', type=bool, default=False)
    parser.add_argument('--use_auth_token', type=bool, default=False)
    parser.add_argument('--max_memory_MB', type=int, default=80000)

    parser.add_argument(
        "--bench-name",
        type=str,
        default="mt_bench",
        help="The name of the benchmark question set.",
    )
    parser.add_argument(
        "--question-begin",
        type=int,
        help="A debug option. The begin index of questions.",
    )
    parser.add_argument(
        "--question-end", type=int, help="A debug option. The end index of questions."
    )
    parser.add_argument("--answer-file", type=str, help="The output answer file.")
    parser.add_argument(
        "--max-new-token",
        type=int,
        default=1024,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--num-choices",
        type=int,
        default=1,
        help="How many completion choices to generate.",
    )
    parser.add_argument(
        "--num-gpus-per-model",
        type=int,
        default=1,
        help="The number of GPUs per model.",
    )
    parser.add_argument(
        "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
    )
    parser.add_argument(
        "--max-gpu-memory",
        type=str,
        help="Maxmum GPU memory used for model weights per GPU.",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        choices=["float32", "float16", "bfloat16"],
        help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
        default=None,
    )
    parser.add_argument(
        "--revision",
        type=str,
        default="main",
        help="The model revision to load.",
    )
    parser.add_argument(
        "--model_id",
        type=str,
        default="llama-7b",
        help="The model id.",
    )

    parser.add_argument(
        "--peft_path",
        type=str,
        default="peft/peft",
        help="The path to the peft model.",
    )

    parser.add_argument(
        "--max_new_token",
        type=int,
        default=64,
        help="The maximum number of new generated tokens.",
    )

    prompt = (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{user_question}\n\n ### Response: "
    )

    arg = parser.parse_args()
    model, tokenizer = get_accelerate_model(arg)

    user_question_list = ["What is Einstein's theory of relativity?",
                          "Explain the use of word embeddings in Natural Language Processing.",
                          "What does DNA stand for?",
                          "What do you think about ChatGPT?",
                          "How do I build a PC?",
                          "can you cook an egg only using durect sunlight in any place of our solar system?",
                          "If the endpoints of a line segment are (2, -2) and (10, 4), what is the length of the segment?",
                          "Implement a queue data structure using two stacks in Python.",
                          "How can you determine if a person is genuinely interested in a conversation or simply being polite?",
                          "As a space colonist on Mars, describe your daily life and the challenges you face living on another planet.",
                          "What if the Black Death had not occurred in the 14th century?",
                          "Implement a Python function to find the longest common subsequence of two input strings using dynamic programming.",
                          "Given that f(x) = 5x^3 - 2x + 3, find the value of f(2)."]

    def generate(model, user_question, max_new_tokens=arg.max_new_token, top_p=0.9, temperature=0.7):
        inputs = tokenizer(prompt.format(user_question=user_question), return_tensors="pt").to('cuda')

        outputs = model.generate(
            **inputs,
            generation_config=GenerationConfig(
                do_sample=True,
                max_new_tokens=max_new_tokens,
                top_p=top_p,
                temperature=temperature,
            )
        )

        text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(text)
        return text

    save_file = []
    save_path = join(arg.peft_path, f"output_{1000}.json")

    for user_question in user_question_list:
        text = generate(model, user_question)
        save_file.append(text)

    with open(save_path, "w") as fout:
        fout.write(json.dumps(save_file))
