import os
import wandb
import torch
import sys
import yaml

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from argparse import ArgumentParser
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
)

from utils.set_random_seed import set_random_seed
from utils.wandb_tracker import setup_wandb
from peft import LoraConfig
from datasets import load_dataset
from trl import SFTTrainer
from reader.llama_prompt_generator import prompt_generator, finetune_prompt_generator
from trl import DataCollatorForCompletionOnlyLM
from torch.utils.data import DataLoader
from datasets import Dataset

def finetune(path_to_yml:str):
    
    # Load the configuration file
    with open(path_to_yml, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    
    # setup the environment
    set_random_seed(42)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(config["visible_devices"])
    setup_wandb(project_name=config["wandb"]["project_name"], run_name=config["wandb"]["run_name"])

    # load the base model  
    model_name = config["base_model"]
    new_model = config["new_model"]

    if config['use_deepspeed']:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            use_cache=False,
            torch_dtype=torch.bfloat16,
        )

    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            use_cache=False,
            device_map="auto",
            torch_dtype=torch.bfloat16,
        )

    model.config.pretraining_tp = 1
    

    # load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config["base_model"], trust_remote_code=True, max_length=config["params"]["max_length"])
    tokenizer.add_special_tokens({"pad_token": '[PAD]'})
    tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
    model.resize_token_embeddings(len(tokenizer))

    # load dataset
    train_dataset = load_dataset("json", data_files= config["dataset"]["path_to_dataset"], split="train")
    print(train_dataset[0])

    def tokenize(example):
        return tokenizer(example["text"], return_length=True,max_length = 4096, truncation=True)

    tokenized_dataset = train_dataset.map(tokenize, remove_columns=train_dataset.column_names)

    # define collator
    collator = DataCollatorForCompletionOnlyLM(
        instruction_template='[INST]',
        response_template='[/INST]',
        tokenizer=tokenizer
    )

    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=8,
        shuffle=True,
        collate_fn=collator
    )

    print(f"Using a sample size of {len(train_dataset)} for fine-tuning.")
    print(train_dataset)

    # set the peft config
    peft_config = LoraConfig(
        lora_alpha = config["peft"]["lora_alpha"],
        lora_dropout = config["peft"]["lora_dropout"],
        r= config["peft"]["lora_r"],
        bias="none",
        task_type="CAUSAL_LM",
        target_modules = config["peft"]["target_modules"],
        modules_to_save = config["peft"]["modules_to_save"],
    )

    # set the model to training mode
    training_arguments = TrainingArguments(
        output_dir = config["output_dir"],
        num_train_epochs = config["params"]["num_train_epochs"],
        per_device_train_batch_size = config["params"]["per_device_train_batch_size"],
        gradient_accumulation_steps= config["params"]["gradient_accumulation_steps"],
        optim = config["params"]["optim"],
        save_steps = config["params"]["save_steps"],
        logging_steps = config["params"]["logging_steps"],
        learning_rate = config["params"]["learning_rate"],
        weight_decay = config["params"]["weight_decay"],
        bf16 = config["params"]["bf16"],
        max_grad_norm = config["params"]["max_grad_norm"],
        max_steps = config["params"]["max_steps"],
        warmup_ratio = config["params"]["warmup_ratio"],
        warmup_steps = config["params"]["warmup_steps"],
        group_by_length = config["params"]["group_by_length"], 
        lr_scheduler_type = config["params"]["lr_scheduler_type"],
        deepspeed = config["path_to_ds_config"],
        report_to = "wandb" 
    )

    # Use the SFTTrainer to train the model
    trainer = SFTTrainer(
        model=model,
        train_dataset = tokenized_dataset, 
        peft_config=peft_config,
        args=training_arguments,
        data_collator=collator,
        formatting_func=None,
        processing_class=tokenizer,
    )

    # start training
    trainer.train()

    # save the model
    trainer.model.save_pretrained(new_model)

if __name__ == "__main__":
    parser= ArgumentParser()
    parser.add_argument("--config", type=str, default="configs/finetune/llama.yml", help="Please input the name of configuratiosn file for fine-tuning")
    args = parser.parse_args()
    finetune(args.config)