import os
import wandb
import torch
import sys
import yaml

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
)

from utils.set_random_seed import set_random_seed
from utils.json_reader import json_loader,jsonl_loader
from utils.wandb_tracker import setup_wandb
from peft import LoraConfig
from datasets import load_dataset
from trl import SFTTrainer
from reader.llama_prompt_generator import prompt_generator, finetune_prompt_generator

# remember to use "deepspeed --num_gpus finetune.py" when using deepspeed for parallel processing
def finetune(path_to_yml:str):
    
    # Load the configuration file
    with open(path_to_yml, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    
    # setup the environment
    set_random_seed(42)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(config["visible_devices"])
    setup_wandb(project_name=config["wandb"]["project_name"], run_name=config["wandb"]["run_name"])

    # load the base model  
    model_name = config["base_model"]
    new_model = config["new_model"]

    if config['use_deepspeed']:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            use_cache=False,
            torch_dtype=torch.bfloat16,
        )

    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            use_cache=False,
            device_map="auto",
            torch_dtype=torch.bfloat16,
        )

    model.config.pretraining_tp = 1

    # load dataset
    train_dataset = load_dataset("json", data_files= config["dataset"]["path_to_dataset"], split="train")

    if config["dataset"]["shuffle"]:
        train_dataset = train_dataset.shuffle(seed=42)

    print(f"Using a sample size of {len(train_dataset)} for fine-tuning.")
    print(train_dataset)

    # load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config["base_model"], trust_remote_code=True, max_length=config["params"]["max_length"])
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

    sample = train_dataset[4]["messages"]
    print(tokenizer.apply_chat_template(sample, tokenize=False, add_generation_prompt=False))

    # set the peft config
    peft_config = LoraConfig(
        lora_alpha = config["peft"]["lora_alpha"],
        lora_dropout = config["peft"]["lora_dropout"],
        r= config["peft"]["lora_r"],
        bias="none",
        task_type="CAUSAL_LM",
        target_modules = config["peft"]["target_modules"],
        modules_to_save = config["peft"]["modules_to_save"],
    )

    # set the model to training mode
    training_arguments = TrainingArguments(
        output_dir = config["output_dir"],
        num_train_epochs = config["params"]["num_train_epochs"],
        per_device_train_batch_size = config["params"]["per_device_train_batch_size"],
        gradient_accumulation_steps= config["params"]["gradient_accumulation_steps"],
        optim = config["params"]["optim"],
        save_steps = config["params"]["save_steps"],
        logging_steps = config["params"]["logging_steps"],
        learning_rate = config["params"]["learning_rate"],
        weight_decay = config["params"]["weight_decay"],
        bf16 = config["params"]["bf16"],
        max_grad_norm = config["params"]["max_grad_norm"],
        max_steps = config["params"]["max_steps"],
        warmup_ratio = config["params"]["warmup_ratio"],
        warmup_steps = config["params"]["warmup_steps"],
        group_by_length = config["params"]["group_by_length"], 
        lr_scheduler_type = config["params"]["lr_scheduler_type"],
        deepspeed = config["path_to_ds_config"],
        report_to = "wandb" 
    )

    # Use the SFTTrainer to train the model
    trainer = SFTTrainer(
        model=model,
        train_dataset = train_dataset, 
        peft_config=peft_config,
        args=training_arguments,
    )

    # start training
    trainer.train()

    # save the model
    trainer.model.save_pretrained(new_model)

if __name__ == "__main__":
    path_to_yml = "configs/finetune/v6.2_mistral_qa.yml"
    finetune(path_to_yml)