import os
import sys
from typing import List

import fire
import torch
import transformers
from datasets import load_dataset
import json
import random
"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_kbit_training,
    set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer, Gemma3ForCausalLM

from utils.prompter import Prompter

import torch._dynamo
torch._dynamo.config.suppress_errors = True
torch._dynamo.disable()
torch._dynamo.config.disable = True


def train(
    # model/data params
    base_model: str = "google/gemma-3-12b-it",  # the only required argument
    data_path: str = "yahma/alpaca-cleaned",
    output_dir: str = "finetune/gemma",
    # training hyperparams
    batch_size: int = 128,
    micro_batch_size: int = 4,
    #batch_size: int = 512,
    #micro_batch_size: int = 8,
    num_epochs: int = 3,
    learning_rate: float = 3e-4,
    cutoff_len: int = 256,
    val_set_size: int = 2000,
    # lora hyperparams
    lora_r: int = 8,
    lora_alpha: int = 16,
    lora_dropout: float = 0.05,
    lora_target_modules: List[str] = [
        "v_proj",
        "q_proj",
    ],
    # llm hyperparams
    train_on_inputs: bool = True,  # if False, masks out inputs in loss
    add_eos_token: bool = False,
    group_by_length: bool = False,  # faster, but produces an odd training loss curve
    # wandb params
    wandb_project: str = "",
    wandb_run_name: str = "",
    wandb_watch: str = "",  # options: false | gradients | all
    wandb_log_model: str = "",  # options: false | true
    resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
    #resume_from_checkpoint: str = "asm_finetune/wizardcoder/checkpoint-1200",  # either training checkpoint or final adapter
    prompt_template_name: str = "gemma",  # The prompt template to use, will default to alpaca.
):
    if int(os.environ.get("LOCAL_RANK", 0)) == 0:
        print(
            f"Training Alpaca-LoRA model with params:\n"
            f"base_model: {base_model}\n"
            f"data_path: {data_path}\n"
            f"output_dir: {output_dir}\n"
            f"batch_size: {batch_size}\n"
            f"micro_batch_size: {micro_batch_size}\n"
            f"num_epochs: {num_epochs}\n"
            f"learning_rate: {learning_rate}\n"
            f"cutoff_len: {cutoff_len}\n"
            f"val_set_size: {val_set_size}\n"
            f"lora_r: {lora_r}\n"
            f"lora_alpha: {lora_alpha}\n"
            f"lora_dropout: {lora_dropout}\n"
            f"lora_target_modules: {lora_target_modules}\n"
            f"train_on_inputs: {train_on_inputs}\n"
            f"add_eos_token: {add_eos_token}\n"
            f"group_by_length: {group_by_length}\n"
            f"wandb_project: {wandb_project}\n"
            f"wandb_run_name: {wandb_run_name}\n"
            f"wandb_watch: {wandb_watch}\n"
            f"wandb_log_model: {wandb_log_model}\n"
            f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
            f"prompt template: {prompt_template_name}\n"
        )
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
    gradient_accumulation_steps = batch_size // micro_batch_size

    prompter = Prompter(prompt_template_name)

    device_map = "auto"
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if ddp:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
        gradient_accumulation_steps = gradient_accumulation_steps // world_size

    # Check if parameter passed or if set within environ
    use_wandb = len(wandb_project) > 0 or (
        "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
    )
    # Only overwrite environ if wandb param passed
    if len(wandb_project) > 0:
        os.environ["WANDB_PROJECT"] = wandb_project
    if len(wandb_watch) > 0:
        os.environ["WANDB_WATCH"] = wandb_watch
    if len(wandb_log_model) > 0:
        os.environ["WANDB_LOG_MODEL"] = wandb_log_model

    model = Gemma3ForCausalLM.from_pretrained(
            base_model,
            load_in_8bit=True,
            torch_dtype=torch.float16,
            device_map="auto",
            cache_dir='models/',
            trust_remote_code=True,
            attn_implementation="sdpa",
            )

    #model = AutoModelForCausalLM.from_pretrained(
    #    base_model,
    #    load_in_8bit=True,
    #    torch_dtype=torch.float16,
    #    device_map=device_map,
	#cache_dir='models/',
    #trust_remote_code=True,
    #)
    print (model)

    tokenizer = AutoTokenizer.from_pretrained(base_model)

    tokenizer.pad_token_id = (
        0  # unk. we want this to be different from the eos token
    )
    tokenizer.padding_side = "left"  # Allow batched inference

    def tokenize(prompt, add_eos_token=True):
        # there's probably a way to do this with the tokenizer settings
        # but again, gotta move fast
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=cutoff_len,
            padding=False,
            return_tensors=None,
        )
        if (
            result["input_ids"][-1] != tokenizer.eos_token_id
            and len(result["input_ids"]) < cutoff_len
            and add_eos_token
        ):
            result["input_ids"].append(tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        result["labels"] = result["input_ids"].copy()

        return result

    def generate_and_tokenize_prompt(data_point):
        full_prompt = prompter.generate_prompt(
            data_point["instruction"],
            data_point["input"],
            data_point["output"],
        )
        tokenized_full_prompt = tokenize(full_prompt)
        if not train_on_inputs:
            user_prompt = prompter.generate_prompt(
                data_point["instruction"], data_point["input"]
            )
            tokenized_user_prompt = tokenize(
                user_prompt, add_eos_token=add_eos_token
            )
            user_prompt_len = len(tokenized_user_prompt["input_ids"])

            if add_eos_token:
                user_prompt_len -= 1

            tokenized_full_prompt["labels"] = [
                -100
            ] * user_prompt_len + tokenized_full_prompt["labels"][
                user_prompt_len:
            ]  # could be sped up, probably
        return tokenized_full_prompt

    model = prepare_model_for_kbit_training(model)

    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=lora_target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, config)

    fileinput = "finetune_script/whole.json"
    fp = open(fileinput, 'r')
    funcs = json.load(fp)
    #if data_path.endswith(".json") or data_path.endswith(".jsonl"):
    #    data = load_dataset("json", data_files=data_path)
    #else:
    #    data = load_dataset(data_path)

    if resume_from_checkpoint:
        # Check the available weights and load them
        checkpoint_name = os.path.join(
            resume_from_checkpoint, "pytorch_model.bin"
        )  # Full checkpoint
        if not os.path.exists(checkpoint_name):
            checkpoint_name = os.path.join(
                resume_from_checkpoint, "adapter_model.bin"
            )  # only LoRA model - LoRA config above has to fit
            resume_from_checkpoint = (
                False  # So the trainer won't try loading its state
            )
        # The two files above have a different name depending on how they were saved, but are actually the same.
        if os.path.exists(checkpoint_name):
            print(f"Restarting from {checkpoint_name}")
            adapters_weights = torch.load(checkpoint_name)
            set_peft_model_state_dict(model, adapters_weights)
        else:
            print(f"Checkpoint {checkpoint_name} not found")

    model.print_trainable_parameters()  # Be more transparent about the % of trainable params.

    #if val_set_size > 0:
    #    train_val = data["train"].train_test_split(
    #        test_size=val_set_size, shuffle=True, seed=42
    #    )
    #    train_data = (
    #        train_val["train"].shuffle().map(generate_and_tokenize_prompt)
    #    )
    #    val_data = (
    #        train_val["test"].shuffle().map(generate_and_tokenize_prompt)
    #    )
    #else:
    #    train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
    #    val_data = None
    train_data = []
    val_data = []
    for k, v in funcs.items():
        temp = {}
        temp["instruction"] = "Let's assume you are a programmer. A decompiled C function is given, and the name of the function and the types and names of variables are changed to FUNC, VAR, and TYPE. Understand the function and infer original names of replacements without explanation. Output the result format as follows. e.g., \"FUNC1: printf\", \"VAR1: count\", \"VAR2: index\", \"TYPE1: int\", \"TYPE2: char\""
        temp["input"] = "Now here is a decompiled code of a function: \n" + v["funcbody"] + "\n. Please provide the original name of " + ", ".join(v["answer"].keys()) + "?"
        temp["output"] = ""
        for name, answer in v["answer"].items():
            temp["output"] += f"{name}: {answer}\n"

        #temp["instruction"] = "Let's assume you are a programmer. An assembly code is give, and the name of the function is unknown. Could you infer the name of the function? Please give the function name as    follows: e.g., \"FUNC: printf\"."
        #temp["input"] = "Now here is an assembly code: \n" + v["assembly"]
        #temp["output"] = v["answer"]["FUNC1"]

        #if int(k) < 350:

        #instruction = tokenizer((f"<｜begin▁of▁sentence｜>User: {temp['instruction']}\n{temp['input']}\nAssistant:").strip(), add_special_tokens=False)
        #response = tokenizer(f"{temp['output']}<｜end▁of▁sentence｜>", add_special_tokens=False)

        #input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
        #attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]
        #labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]

        #result = {}
        #result["input_ids"] = input_ids
        #result["attention_mask"] = attention_mask
        #result["labels"] = labels
        #train_data.append(result)

        train_data.append(generate_and_tokenize_prompt(temp))
        #else:
        #    val_data.append(generate_and_tokenize_prompt(temp))
    random.shuffle(train_data)
    idx = int(len(train_data)*0.80)
    tr_data = train_data[:idx]
    val_data = train_data[idx:]

    if not ddp and torch.cuda.device_count() > 1:
        # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
        model.is_parallelizable = True
        model.model_parallel = True

    trainer = transformers.Trainer(
        model=model,
        train_dataset=tr_data,
        eval_dataset=val_data,
        args=transformers.TrainingArguments(
            per_device_train_batch_size=micro_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_steps=100,
            num_train_epochs=num_epochs,
            learning_rate=learning_rate,
            fp16=True,
            logging_steps=10,
            optim="adamw_torch",
            eval_strategy="steps" if val_set_size > 0 else "no",
            save_strategy="steps",
            eval_steps=200 if val_set_size > 0 else None,
            save_steps=200,
            output_dir=output_dir,
            save_total_limit=3,
            load_best_model_at_end=True if val_set_size > 0 else False,
            ddp_find_unused_parameters=False if ddp else None,
            group_by_length=group_by_length,
            report_to="wandb" if use_wandb else None,
            run_name=wandb_run_name if use_wandb else None,
        ),
        data_collator=transformers.DataCollatorForSeq2Seq(
            tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
        ),
    )
    model.config.use_cache = False

    old_state_dict = model.state_dict
    #model.state_dict = (
    #    lambda self, *_, **__: get_peft_model_state_dict(
    #        self, old_state_dict()
    #    )
    #).__get__(model, type(model))

    if torch.__version__ >= "2" and sys.platform != "win32":
        model = torch.compile(model)

    trainer.train(resume_from_checkpoint=resume_from_checkpoint)

    model.save_pretrained(output_dir)

    print(
        "\n If there's a warning about missing keys above, please disregard :)"
    )


if __name__ == "__main__":
    fire.Fire(train)
