import os 
os.environ["TOKENIZERS_PARALLELISM"] = "1"

import argparse 
from tqdm.auto import tqdm

import torch 
from torch.utils.data import DataLoader 
from accelerate import Accelerator 
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    default_data_collator, 
    get_linear_schedule_with_warmup, 
    set_seed,
)
from utils.data_preprocess import get_dataset
from peft import LoraConfig, get_peft_model, TaskType


TARGET_MAPPING = {
    0 : "False", 
    1 : "True",
}


def parse_args() : 
    parser = argparse.ArgumentParser()

    # Dataset 
    # parser.add_argument("--subset_name", default="rte", help="Subset of SuperGlue you would like to finetune on.")
    # parser.add_argument("--template", default="MNLI crowdsource", help="Template Prompt you would like to apply.")
    parser.add_argument("--max_length", type=int, default=256, help="Maximum Length of Tokens")
    parser.add_argument("--batch_size", default=4, type=int, help="Batch Size of the Data loaders [train, val, test]")
    parser.add_argument("--num_workers", default=8, type=int, help="Number of workers for data loaders.")

    # Model 
    parser.add_argument("--model_name_or_path", default="mistralai/Mistral-7B-v0.1", help="Tokenizer and Model you would like to finetune.")
    parser.add_argument("--rank", type=int, default=8, help="Rank for the LoRA Finetuning.")
    parser.add_argument("--lora_alpha", type=int, default=16, help="Alpha for LoRA finetuning.")
    parser.add_argument("--lora_dropout", type=float, default=0.1, help="Dropout for LoRA")

    # Training Parameters
    parser.add_argument("--lr", default=1e-4, type=float, help="Learning Rate for the trainer.")
    parser.add_argument("--num_epochs", type=int, default=10, help="Number of epochs you would like to finetune this for.")
    parser.add_argument("--num_warmup_steps", type=int, default = 0, help="Warmup steps for training.")
    # parser.add_argument("--wandb", default="true", action="store_true", help="Whether to launch a wandb project for it.")
    # parser.add_argument("--name", type=str, default=None, help="Name of the Project for Wandb")
    parser.add_argument("--seed", type=int, default=4232001, help="Seed for Reproducibility")
    parser.add_argument("--output_dir", type=str)

    return parser.parse_args()





def main(args) :
    
    # 3 min
    set_seed(args.seed)

    accelerator = Accelerator(log_with="wandb")

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path, 
        use_fast=True, 
        trust_remote_code=True,
    )

    name = f"mistral7b-lora-r{args.rank}-wnli-epochs{args.num_epochs}-adapter"

    accelerator.init_trackers(
        project_name=name,
        config={
                "model_name" : args.model_name_or_path,
                "seed" : args.seed,
                "num_epochs" : args.num_epochs,
                "batch_size" : args.batch_size,
                "lora_rank" : args.rank,
                "lora_alpha" : args.lora_alpha,
                "lora_dropout": args.lora_dropout, 
                "learning_rate": args.lr, 
            }
    )



    ds, _ = load_dataset("glue", "wnli")

    ds_train = ds['train']
    ds_test = ds['validation']

    if tokenizer.pad_token_id is None : 
        tokenizer.pad_token_id = tokenizer.eos_token_id

    def preprocess_func_wnli_train(samples) : 
        batch_size = len(samples['sentence1'])
        inputs = [f"{x}\nQuestion: {y} True or False?\nAnswer:" for x,y in zip(samples['sentence1'], samples['sentence2'])]
        targets = [TARGET_MAPPING[x] for x in samples['label']]

        model_inputs = tokenizer(inputs)
        labels = tokenizer(targets, add_special_tokens=False)

        for i in range(batch_size) : 
            sample_input_ids = model_inputs["input_ids"][i]
            label_input_ids = labels["input_ids"][i] + [tokenizer.eos_token_id]
            model_inputs["input_ids"][i] = sample_input_ids + label_input_ids
            labels["input_ids"][i] = [-100] * len(sample_input_ids) + label_input_ids
            model_inputs["attention_mask"][i] = [1] * len(model_inputs["input_ids"][i])

        for i in range(batch_size) : 
            sample_input_ids = model_inputs["input_ids"][i]
            label_input_ids = labels["input_ids"][i]
            model_inputs["input_ids"][i] = [tokenizer.pad_token_id] * (args.max_length - len(sample_input_ids)) + sample_input_ids
            model_inputs["attention_mask"][i] = [0] * (args.max_length - len(sample_input_ids)) + model_inputs["attention_mask"][i]
            labels["input_ids"][i] = [-100] * (args.max_length - len(sample_input_ids)) + label_input_ids

            model_inputs["input_ids"][i] = torch.tensor(model_inputs["input_ids"][i][:args.max_length])
            model_inputs["attention_mask"][i] = torch.tensor(model_inputs["attention_mask"][i][:args.max_length])
            labels["input_ids"][i] = torch.tensor(labels["input_ids"][i][:args.max_length])
        
        model_inputs["labels"] = labels["input_ids"]

        return model_inputs


    def preprocess_func_wnli_test(samples) : 
        batch_size = len(samples['sentence1'])
        inputs = [f"{x}\nQuestion: {y} True or False?\nAnswer:" for x,y in zip(samples['sentence1'], samples['sentence2'])]
        model_inputs = tokenizer(inputs)

        for i in range(batch_size) : 
            sample_input_ids = model_inputs["input_ids"][i]
            model_inputs["input_ids"][i] = [tokenizer.pad_token_id] * (args.max_length - len(sample_input_ids)) + sample_input_ids
            model_inputs["attention_mask"][i] = [0] * (args.max_length - len(sample_input_ids)) + model_inputs["attention_mask"][i]
            model_input_ids["input_ids"][i] = torch.tensor(model_inputs["input_ids"][i][args.max_length])
            model_inputs_ids["attention_mask"][i] = torch.tensor(model_inputs["attention_mask"][i][:args.max_length])

        return model_inputs


    with accelerator.main_process_first() : 
        processed_train_ds = ds_train.map(
            preprocess_func_wnli_train, 
            batched=True, 
            num_proc=8, 
            remove_columns=ds_train.column_names, 
            load_from_cache_file=False, 
            desc="Running Tokenizer on the Train Dataset"
        )

    accelerator.wait_for_everyone()

    train_dataloader = DataLoader(
        processed_train_ds, 
        shuffle=True, 
        collate_fn=default_data_collator,
        batch_size=args.batch_size, 
        num_workers=args.num_workers, 
        pin_memory=True,
    )

    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM, 
        target_modules=["q_proj", "v_proj", "up_proj", "down_proj", "gate_proj", "k_proj", "o_proj"],
        inference_mode=False, 
        r=args.rank, 
        lora_alpha=args.lora_alpha, 
        lora_dropout=args.lora_dropout,
    )

    print(f"First Batch Of Training Set : {next(iter(train_dataloader))}")

    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer, 
        num_warmup_steps=args.num_warmup_steps, 
        num_training_steps=(len(train_dataloader) * args.num_epochs),
    )

    model, train_dataloader, optimizer, lr_scheduler = accelerator.prepare(model, train_dataloader, optimizer, lr_scheduler)
    

    for epoch in range(args.num_epochs) : 
        model.train()
        epoch_loss = 0
        for step, batch in enumerate(tqdm(train_dataloader)) : 
            outputs = model(**batch)
            loss = outputs.loss 
            epoch_loss += loss.detach().cpu().float()
            iter_loss = loss.detach().cpu().float()

            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            accelerator.log(
                {"iter_loss" : iter_loss}
            )
        accelerator.log(
            {
                "epoch_loss" : epoch_loss,
            }
        )
    

    accelerator.wait_for_everyone()

    unwrapped_model = accelerator.unwrap_model(model)
    if accelerator.is_main_process : 
        unwrapped_model.push_to_hub(
            f"{args.output_dir}", 
            use_auth_token=True, 
        )

    accelerator.end_training()


if __name__ == "__main__" : 
    args = parse_args()
    main(args)



    

