import os
import sys

import fire
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, CodeLlamaTokenizer, LogitsProcessorList

from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter

import json


code_type = 'decompile'
model_type = 'CodeLlama-34b-Instruct'

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

try:
    if torch.backends.mps.is_available():
        device = "mps"
except:  # noqa: E722
    pass

def modify(orig):
    new = orig.split('\n')
    idx = 0
    while True:
        if 'MASK' not in new[idx]:
            idx += 1
        else:
            break

        if idx == len(new):
            break

    if idx < len(new):
        output = '\n'.join(new[idx:])
    else:
        output = orig
    return output


def main(
    load_8bit: bool = False,
    base_model: str = "codellama/CodeLlama-34b-Instruct-hf",
    lora_weights: str = "finetune/",
    prompt_template: str = "codellama",  # The prompt template to use, will default to alpaca.
    server_name: str = "0.0.0.0",  # Allows to listen on all interfaces by providing '0.
    share_gradio: bool = False,
    arch: str = "",
    opt: str= "",
):
    lora_weights = f"{lora_weights}/{arch}/{opt}"
    base_model = base_model or os.environ.get("BASE_MODEL", "")
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"

    prompter = Prompter(prompt_template)
    tokenizer = CodeLlamaTokenizer.from_pretrained(base_model, add_prefix_space=True)

    if device == "cuda":
        model = LlamaForCausalLM.from_pretrained(
            base_model,
            load_in_8bit=load_8bit,
            torch_dtype=torch.float16,
            device_map="auto",
            cache_dir='models/',
        )

        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            torch_dtype=torch.float16,
        )


    # unwind broken decapoda-research config
    model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
    model.config.bos_token_id = 1
    model.config.eos_token_id = 2

    if not load_8bit:
        model.half()  # seems to fix bugs for some users.

    model.eval()
    if torch.__version__ >= "2" and sys.platform != "win32":
        model = torch.compile(model)

    def get_tokens_as_list(word_list):
        "Converts a sequence of words into a list of tokens"
        return tokenizer(word_list, add_special_tokens=False).input_ids


    def evaluate(
        instructionList,
        inputList=None,
        temperature=0.2,
        top_p=0.75,
        top_k=40,
        num_beams=1,
        max_new_tokens=64,
        stream_output=False,
        **kwargs,
    ):

        prompt = prompter.generate_prompt(instructionList, inputList)
        inputs = tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        generation_config = GenerationConfig(
            num_beams=num_beams,
            **kwargs,
        )

        generate_params = {
            "input_ids": input_ids,
            "generation_config": generation_config,
            "return_dict_in_generate": True,
            "output_scores": True,
            "max_new_tokens": max_new_tokens,
        }

        with torch.no_grad():
            #if len(input_ids[0]) < 1024:
            generation_output = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                pad_token_id=tokenizer.pad_token_id,
                generation_config=generation_config,
                return_dict_in_generate=True,
                output_scores=True,
                max_new_tokens=max_new_tokens,
            )
            s = generation_output.sequences[0]
            output = tokenizer.decode(s)
            return prompter.get_response(output) # llama2



    fileinput = f"var/{arch}_{opt}_var_replaced.jsonl"
    fp = open(fileinput, 'r')

    fout =  open(f"output/{arch}_{opt}_result.jsonl", 'w')
    for i, line in enumerate(fp.readlines()):
        output = {}
        output = json.loads(line)
        print (i)
        instruction = output['instruction']
        inputs = modify(output["replaced"])

        res = evaluate(instruction, inputs)

        final = {}
        final["final"] = res
        final["output"] = output['output']
        json.dump(final, fout)

        fout.write('\n')
        fout.flush()


if __name__ == "__main__":
    fire.Fire(main)
