import os
import sys

import fire
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, CodeLlamaTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter

import json
import re

code_type = 'decompile'
#model_type = 'CodeLlama-34b-Instruct'
model_type = 'Llama-2-13b-chat-hf'

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

try:
    if torch.backends.mps.is_available():
        device = "mps"
except:  # noqa: E722
    pass

def parse(result):
    res = {}
    lines = result.split(result)
    for line in lines:
        w = re.match('^[VAR|FUNC|TYPE]+[0-9]\:', line)
        if w != None:
            target = line.split(':')[0]
            answer = line.split(':')[1].split()
            res[target] = answer

    return res

def main(
    load_8bit: bool = False,
    base_model: str = "",
    lora_weights: str = "tloen/alpaca-lora-7b",
    #lora_weights: str = "",
    #prompt_template: str = "codellama",  # The prompt template to use, will default to alpaca.
    prompt_template: str = "deepseek",  # The prompt template to use, will default to alpaca.
    server_name: str = "0.0.0.0",  # Allows to listen on all interfaces by providing '0.
    share_gradio: bool = False,
    model_name: str = "",
    arch: str = "",
    opt: str = "",
    size: str = "",
    decomp: str = "",
):
    base_model = base_model or os.environ.get("BASE_MODEL", "")
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"

    prompter = Prompter(prompt_template)
    tokenizer = AutoTokenizer.from_pretrained(base_model)
    #tokenizer = LlamaTokenizer.from_pretrained(base_model)
    #tokenizer = CodeLlamaTokenizer.from_pretrained(base_model)

    if device == "cuda":
        model = AutoModelForCausalLM.from_pretrained(
            base_model,
            load_in_8bit=load_8bit,
            torch_dtype=torch.float16,
            device_map="auto",
            cache_dir='models/',
            trust_remote_code=True,
        )

        #model = PeftModel.from_pretrained(
        #    model,
        #    lora_weights,
        #    torch_dtype=torch.float16,
        #)

    # unwind broken decapoda-research config
    #model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
    #model.config.bos_token_id = 100000
    #model.config.eos_token_id = 100001
    # print(repr(tokenizer.pad_token)) ## ''
    # print(repr(tokenizer.bos_token)) ## ''
    # print(repr(tokenizer.eos_token)) ## ''

    if not load_8bit:
        model.half()  # seems to fix bugs for some users.

    model.eval()
    if torch.__version__ >= "2" and sys.platform != "win32":
        model = torch.compile(model)

    def evaluate(
        instructionList,
        inputList=None,
        #temperature=0.1,
        temperature=0,
        top_p=0.75,
        top_k=40,
        num_beams=1,
        max_new_tokens=256,
        #max_new_tokens=512,
        stream_output=False,
        **kwargs,
    ):
        # if inputList is None:
        #     prompt = [prompter.generate_prompt(instruction, None) for instruction in instructionList]
        # else:
        #     prompt = [prompter.generate_prompt(instruction, input) for instruction, input in zip(instructionList, inputList)]

        # print(prompt)
        # inputs = tokenizer.batch_encode_plus(prompt, return_tensors="pt", padding=True, truncation=True)

        prompt = prompter.generate_prompt(instructionList, inputList)
        #prompt = f"<s>[INST] <<SYS>>\\n{instructionList}\\n<</SYS>>\\n\\n{inputList}[/INST]"
        inputs = tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        generation_config = GenerationConfig(
            #temperature=temperature,
            #top_p=top_p,
            #top_k=top_k,
            num_beams=num_beams,
            **kwargs,
        )

        generate_params = {
            "input_ids": input_ids,
            "generation_config": generation_config,
            "return_dict_in_generate": True,
            "output_scores": True,
            "max_new_tokens": max_new_tokens,
        }

        # Without streaming
        with torch.no_grad():
            if len(input_ids[0]) < 1024:
                generation_output = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    pad_token_id=tokenizer.pad_token_id,
                    generation_config=generation_config,
                    repetition_penalty=1.1,
                    return_dict_in_generate=True,
                    output_scores=True,
                    max_new_tokens=max_new_tokens,
                )
                s = generation_output.sequences[0]
                output = tokenizer.decode(s)
                # s = generation_output.sequences
                # output = tokenizer.batch_decode(s, skip_special_tokens=True)
                # return output # codellama
                return prompter.get_response(output) # llama2
            else:
                return ""

    fileground = "inputs/" + arch + "_" + opt + "_input.json"
    fp = open(fileground, 'r')
    funcs = json.load(fp)

    fileinput = "finetune_output/" + model_name + "_" + arch + "_" + opt + ".json"
    fp = open(fileinput, 'r')
    answers = json.load(fp)

    output = {}
    for k, v in answers.items():
        print (k)
        instruction =  "The text is provided and we want to extract the value of target components from it. Please print only the below format without any explanation. e.g., FUNC1: prinft\nVAR1: sum"
        inputs = "Here is the text:\n" + ''.join(v)
        inputs = inputs + "\nWhat is the value of " + ', '.join(funcs[str(k)]["answer"].keys()) + "?\n"
        res = evaluate(instruction, inputs)
        output[k] = res

    fileoutput = "finetune_result/" + model_name + "_" +  arch + "_" + opt + ".json"
    with open(fileoutput, 'w') as fout:
        json.dump(output, fout, indent=2)

if __name__ == "__main__":
    fire.Fire(main)
