import os
import sys

import fire
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, CodeLlamaTokenizer, RobertaTokenizer, T5ForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, Gemma3ForCausalLM

from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter

import json

import torch._dynamo
torch._dynamo.config.suppress_errors = True
torch._dynamo.disable()
torch._dynamo.config.disable = True

code_type = 'decompile'
#model_type = 'CodeLlama-34b-Instruct'
model_type = 'Llama-2-13b-chat-hf'
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

try:
    if torch.backends.mps.is_available():
        device = "mps"
except:  # noqa: E722
    pass


def main(
    load_8bit: bool = False,
    base_model: str = "",
    lora_weights: str = "",
    prompt_template: str = "gemma",  # The prompt template to use, will default to alpaca.
    server_name: str = "0.0.0.0",  # Allows to listen on all interfaces by providing '0.
    share_gradio: bool = False,
    arch: str = "",
    opt: str = "",
):
    base_model = base_model or os.environ.get("BASE_MODEL", "")
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"

    prompter = Prompter(prompt_template)
    tokenizer = AutoTokenizer.from_pretrained(base_model)

    if device == "cuda":
        model = Gemma3ForCausalLM.from_pretrained(
                base_model,
                load_in_8bit=load_8bit,
                torch_dtype=torch.float16,
                device_map="auto",
                cache_dir='models/',
                trust_remote_code=True,
                attn_implementation="sdpa",
            )


        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            torch_dtype=torch.float16,
        )

    if not load_8bit:
        model.bfloat16()

    model.eval()
    if torch.__version__ >= "2" and sys.platform != "win32":
        model = torch.compile(model)

    def evaluate(
        instructionList,
        inputList=None,
        temperature=0.1,
        top_p=0.75,
        top_k=40,
        num_beams=1,
        max_new_tokens=256,
        stream_output=False,
        **kwargs,
    ):
        prompt = prompter.generate_prompt(instructionList, inputList)
        inputs = tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        generation_config = GenerationConfig(
            num_beams=num_beams,
            **kwargs,
        )

        generate_params = {
            "input_ids": input_ids,
            "generation_config": generation_config,
            "return_dict_in_generate": True,
            "output_scores": True,
            "max_new_tokens": max_new_tokens,
        }

        # Without streaming
        with torch.no_grad():
            if len(input_ids[0]) < 1024:
                generation_output = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    pad_token_id=tokenizer.pad_token_id,
                    generation_config=generation_config,
                    repetition_penalty=1.1,
                    return_dict_in_generate=True,
                    output_scores=True,
                    max_new_tokens=max_new_tokens,
                )
                s = generation_output.sequences[0]
                output = tokenizer.decode(s)
                return prompter.get_response(output) # llama2
            else:
                return ""


    fileinput = "inputs/" + arch + "_" + opt + "_input.json"
    fp = open(fileinput, 'r')
    funcs = json.load(fp)

    output = {}
    #fout = open(fileoutput, 'w')
    for k, v in funcs.items():
        instruction =  "Let's assume you are a programmer. An assembly code is given, and the name of the function is unknown. Could you infer the name of the function? Please give the function name as follows: e.g., \"FUNC: printf\"."
        inputs = "Now here is an assembly code: \n" + v["assembly"]
        print (k, len(inputs.split()))
        res = evaluate(instruction, inputs)
        output[k] = res

    with open("asm_finetune_output/gemma_"  + arch + "_" + opt + "_asm.json", 'w') as fout:
        json.dump(output, fout, indent=2)

if __name__ == "__main__":
    fire.Fire(main)
