import os
import sys
import json
import fire
import gradio as gr
import wandb

from tqdm import tqdm

import torch

import transformers
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from peft import PeftModel

from utils.prompter import Prompter


if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

try:
    if torch.backends.mps.is_available():
        device = "mps"
except:  # noqa: E722
    pass


def main(
    load_8bit: bool = True,  # set load_8bit=True, if you want to generate data samples that are similar to train data, load_8bit can provide more sample diversity
    base_model: str = 'decapoda-research/llama-7b-hf',
    data_path: str = "../processed/mcwiki_toy.json",
    cache_path: str = "/data/home/",
    result_path: str = "../results/",
    result_file: str = "mcwiki_toy_result.json",
    lora_weights: str = "/data/home/lora-alpaca",
    prompt_template: str = "",  # The prompt template to use, will default to alpaca.
    num_beams: int = 1,
    max_new_tokens: int = 256,
    server_name: str = "0.0.0.0",  # Allows to listen on all interfaces by providing '0.
    share_gradio: bool = False,
):
    base_model = base_model or os.environ.get("BASE_MODEL", "")
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"

    # device
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    try:
        if torch.backends.mps.is_available():
            device = "mps"
    except:  # noqa: E722
        pass

    # dataset
    test_json = json.load(open(data_path))
    prompter = Prompter(prompt_template)

    # tokenizer
    tokenizer = LlamaTokenizer.from_pretrained(base_model)
    tokenizer.pad_token_id = 0  # unk. we want this to be different from the eos token)
    tokenizer.bos_token_id = 0
    tokenizer.eos_token_id = 0
    tokenizer.unk_token_id = 0
    tokenizer.padding_side = "left"  # Allow batched inference

    # model
    if device == "cuda":
        model = LlamaForCausalLM.from_pretrained(
            base_model,
            load_in_8bit=load_8bit,
            torch_dtype=torch.float16,
            device_map="auto",
            cache_dir=cache_path
        )
        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            torch_dtype=torch.float16,
        )
    elif device == "mps":
        model = LlamaForCausalLM.from_pretrained(
            base_model,
            device_map={"": device},
            torch_dtype=torch.float16,
        )
        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            device_map={"": device},
            torch_dtype=torch.float16,
        )
    else:
        model = LlamaForCausalLM.from_pretrained(
            base_model, device_map={"": device}, low_cpu_mem_usage=True
        )
        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            device_map={"": device},
        )

    # unwind broken decapoda-research config
    model.config.pad_token_id = 0 # unk
    model.config.unk_token_id = 0
    model.config.bos_token_id = 0
    model.config.eos_token_id = 0

    # if not load_8bit:
    #     model.half()  # seems to fix bugs for some users.

    model.eval()
    if torch.__version__ >= "2" and sys.platform != "win32":
        model = torch.compile(model)

    def evaluate(
        instruction,
        input=None,
        temperature=0.1,
        top_p=0.75,
        top_k=40,
        num_beams=1,
        max_new_tokens=128,
        stream_output=False,
        **kwargs,
    ):
        # if len(instruction) > 1024:
        #     instruction = instruction[:1024]

        prompt = prompter.generate_prompt(instruction, input)
        inputs = tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"].to(device)
        generation_config = GenerationConfig(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            num_beams=num_beams,
            **kwargs,
        )
        generation_config.pad_token_id = model.config.pad_token_id
        generation_config.unk_token_id = model.config.unk_token_id
        generation_config.bos_token_id = model.config.bos_token_id
        generation_config.eos_token_id = model.config.eos_token_id

        # Without streaming
        with torch.no_grad():
            generation_output = model.generate(
                input_ids=input_ids,
                generation_config=generation_config,
                return_dict_in_generate=True,
                output_scores=True,
                max_new_tokens=max_new_tokens,
            )
        s = generation_output.sequences[0]
        output = tokenizer.decode(s)
        return prompter.get_response(output)

    test_json = json.load(open(data_path))

    result_list = []
    dataset_name = data_path.split('/')[-1].split('_')[0]
    for q_idx in tqdm(range(len(test_json))):
        result = {}
        result["Instruction"] = test_json[q_idx]['instruction']
        result["Input"] = test_json[q_idx]['input']
        result["GroudTruth"] = test_json[q_idx]['output']
        result["Response"] = evaluate(test_json[q_idx]['instruction'], test_json[q_idx]['input'], num_beams=num_beams, max_new_tokens=max_new_tokens)
        result_list.append(result)

        if q_idx % 10 == 0:
            json.dump(result_list, open(os.path.join(result_path, result_file), 'w'))
    json.dump(result_list, open(os.path.join(result_path, result_file), 'w'))


if __name__ == "__main__":
    fire.Fire(main)