import argparse
import os
import math
import json
from datetime import datetime
from pathlib import Path
from random import randint
from typing import Any, Dict, List, Union
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    default_data_collator,
    set_seed,
    SchedulerType,
)
import logging
from accelerate import Accelerator, dispatch_model
from accelerate.logging import get_logger
from datasets import load_dataset
import evaluate
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import numpy as np
import transformers
import datasets 
import wandb
from peft.optimizers.sparse_optimizer_multiply_lr import SparseAdamW

from peft import get_peft_model, TaskType, LoraConfig, SoraConfig, PeftModel, SoraModel, prepare_model_for_kbit_training
from peft.utils import _freeze_adapter, get_peft_model_state_dict

from transformers import get_scheduler
from peft.utils import _get_submodules
logger = get_logger(__name__, log_level="INFO")


def parse_args():
    parser = argparse.ArgumentParser(description="Sequence classification task")
    parser.add_argument(
        "--task",
        type=str,
        help="which dataset to perform.",
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
        required=True,
    )
    parser.add_argument(
        "--per_device_train_batch_size",
        type=int,
        default=32,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--per_device_eval_batch_size",
        type=int,
        default=32,
        help="Batch size (per device) for the evaluation dataloader.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-3,
        help="Initial learning rate (after the potential warmup period) to use.",
    )

    parser.add_argument(
        "--max_seq_length",
        type=int,
        default=128,
        help="The maximum total input sequence length after tokenization.",
    )


    parser.add_argument(
        "--pad_to_max_length",
        type=bool,
        default=True,
        help="Whether to pad all samples to `max_seq_length`.",
    )

    parser.add_argument(
        "--weight_decay",
        type=float,
        default=1e-2,
        help="Weight decay",
    )
    parser.add_argument(
        "--num_train_epochs", 
        type=int, 
        default=3, 
        help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--num_warmup_steps", 
        type=int, 
        default=0, 
        help="Number of steps for the warmup in the lr scheduler."
    )
    parser.add_argument(
        "--output_dir", 
        type=str, 
        default=None, 
        help="Where to store the final model."
    )
    parser.add_argument(
        "--seed", 
        type=int, 
        default=None, 
        help="A seed for reproducible training."
    )

    parser.add_argument(
        "--peft_type",
        type=str,
        default="sora",
        help="LoRA via Sora",
    )
    parser.add_argument(
        "--lora_alpha",
        type=int,
        default=32,
        help="LoRA alpha value.",
    )
    parser.add_argument(
        "--r",
        type=int,
        default=8,
        help="LoRA rank.",
    )

    parser.add_argument(        
        "--lora_dropout",
        type=float,
        default=0.1,
        help="LoRA dropout value.",
    )

    parser.add_argument(
        "--logging_steps",
        type=int,
        default=100,
        help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
    )
    
    
    parser.add_argument(
        "--with_tracking",
        action="store_true",
        help="Whether to enable experiment trackers for logging.",
    )

    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )

    parser.add_argument(
        "--report_to",
        type=str,
        default="all",
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
            ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
            "Only applicable when `--with_tracking` is passed."
        ),
    )

    # parser.add_argument(
    #     "--resume_from_checkpoint",
    #     type=str,
    #     default=None,
    #     help="If the training should continue from a checkpoint folder.",
    # )

    parser.add_argument(
        "--load_best_model",
        action="store_true",
        help="Whether to load the best model at the end of training",
    )

    parser.add_argument(
        "--lr_scheduler_type",
        type=SchedulerType,
        default="linear",
        help="The scheduler type to use.",
        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    )

    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=500,
        help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
    )

    parser.add_argument(
        "--evaluation_steps",
        type=int,
        default=100,
        help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
    )

    parser.add_argument(        
        "--threshold",
        type=float,
        default=1e-4,
        help="Sparsity threshold for SORA",
    )
    

    parser.add_argument(
        "--num_retrain",
        type=int,
        default=1,
        help="total number of retrains in SORA",
    )

    # parser.add_argument(
    #     "--only_train_new",
    #     type=int,
    #     default=1,
    #     help="total number of retrains in SORA",
    # )

    parser.add_argument(
        '--target_modules', 
        nargs='+',
        default=[],
        help='target modules for PEFT'
    )



    args = parser.parse_args()

    assert args.output_dir is not None, "Need an `output_dir` to store the finetune model and verify."

    return args

def save_model_hook(models, weights, output_dir):
    for model in models:
        model.save_pretrained(output_dir)
        # make sure to pop weight so that corresponding model is not saved again
        weights.pop()
def load_model_hook(models, input_dir):
    while len(models) > 0:
        model = models.pop()
        # pop models so that they are not loaded again
        adapter_names = list(model.peft_config.keys())
        for key in adapter_names:
            PeftModel.from_pretrained(model.base_model.model, os.path.join(input_dir, key), adapter_name=key)

def eval_loop(model, task, device, eval_dataloader, accelerator):
    model.eval()

    metrics = evaluate.load("glue", task)
    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        metrics.add_batch(predictions=predictions, references=batch["labels"])
    eval_metrics = metrics.compute()
    return eval_metrics


def main():
    args = parse_args()

    args.output_dir = os.path.join(args.output_dir, f"{args.peft_type}-{args.task}-{args.r}-{args.seed}")

    accelerator_kwargs = {"gradient_accumulation_steps": args.gradient_accumulation_steps}
    if args.with_tracking:
        accelerator_kwargs["log_with"] = args.report_to
        accelerator_kwargs["project_dir"] = args.output_dir
    accelerator = Accelerator(**accelerator_kwargs)


    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    if args.seed is not None:
        set_seed(args.seed)

    # Handle the repository creation
    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)
    accelerator.wait_for_everyone()

    task_to_keys = {
        "cola": ("sentence", None),
        "mnli": ("premise", "hypothesis"),
        "mnli-mm": ("premise", "hypothesis"),
        "mrpc": ("sentence1", "sentence2"),
        "qnli": ("question", "sentence"),
        "qqp": ("question1", "question2"),
        "rte": ("sentence1", "sentence2"),
        "sst2": ("sentence", None),
        "stsb": ("sentence1", "sentence2"),
        "wnli": ("sentence1", "sentence2"),
    }

    task = args.task
    batch_size = args.per_device_train_batch_size
    num_epochs = args.num_train_epochs
    lora_alpha = args.lora_alpha
    lora_dropout = args.lora_dropout
    rank = args.r
    learning_rate = args.learning_rate
    weight_decay = args.weight_decay

    sentence1_key, sentence2_key = task_to_keys[task]
    dataset = load_dataset("glue", task)

    
    if any(k in args.model_name_or_path for k in ("gpt", "opt", "bloom")):
        padding_side = "left"
    else:
        padding_side = "right"

    # Padding strategy
    if args.pad_to_max_length:
        padding = "max_length"
    else:
        # We will pad later, dynamically at batch creation, to the max sequence length in each batch
        padding = False

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, padding_side=padding_side)

    max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)

    if getattr(tokenizer, "pad_token_id") is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id


    def tokenize_function(examples):
        if sentence2_key is None:
            return tokenizer(examples[sentence1_key], padding=padding, max_length=max_seq_length, truncation=True)
        return tokenizer(examples[sentence1_key], examples[sentence2_key], padding=padding, max_length=max_seq_length, truncation=True)
    
    tokenized_datasets = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=['idx', sentence1_key, sentence2_key] if sentence2_key is not None else ['idx', sentence1_key]
    )

    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    tokenized_datasets.set_format("torch")
    tokenized_datasets["train"].column_names

    data_collator = default_data_collator #DataCollatorWithPadding(tokenizer=tokenizer)
    train_dataloader = DataLoader(
        tokenized_datasets["train"], shuffle=True, batch_size=batch_size, collate_fn=data_collator
    )

    validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matched" if task == "mnli" else "validation"
    eval_dataloader = DataLoader(
        tokenized_datasets[validation_key], batch_size=batch_size, collate_fn=data_collator
    )



    metric_name = "pearson" if task == "stsb" else "matthews_correlation" if task == "cola" else "accuracy"
    num_labels = 3 if task.startswith("mnli") else 1 if task=="stsb" else 2

    if args.target_modules == []:
        if "roberta" in args.model_name_or_path:
            target_modules = ["query", "value"]
        elif "deberta" in args.model_name_or_path:
            target_modules = ["query_proj", "key_proj", "value_proj"]
    else:
        target_modules = args.target_modules

    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    num_training_steps = num_epochs * num_update_steps_per_epoch
    num_warmup_steps = args.num_warmup_steps * num_update_steps_per_epoch

    peft_config = SoraConfig(
            task_type=TaskType.SEQ_CLS, inference_mode=False, lora_alpha=lora_alpha, lora_dropout=lora_dropout, 
            r=rank,
            target_modules=target_modules,
        )
    
    repeat = args.num_retrain


    model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path,  num_labels=num_labels, return_dict=True)
    model = get_peft_model(model, peft_config, adapter_name="default_0")

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)   

    adamW_group = [p for n, p in model.named_parameters() if "lora_E" not in n and p.requires_grad]
    sparse_adamW_group = [p for n, p in model.named_parameters() if "lora_E" in n and p.requires_grad]

    sparse_optimizer = SparseAdamW(params=sparse_adamW_group, sparse_lambda=args.threshold, lr=learning_rate, weight_decay=weight_decay)
    optimizer = torch.optim.AdamW(adamW_group, lr=learning_rate, weight_decay=weight_decay)

    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )

    sparse_lr_scheduler = get_scheduler(
        "linear",
        optimizer=sparse_optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )
    

    # Prepare everything with our `accelerator`.
    model, sparse_optimizer, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
        model, sparse_optimizer,  optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if args.with_tracking:
        run_name = f"run-{args.peft_type}-SOPT-{args.task}-\
            {args.r}-{args.threshold}-{args.lora_alpha}-\
                {args.seed}-{args.learning_rate}-{args.weight_decay}-\
                    {args.max_seq_length}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
        experiment_config = vars(args)
        # TensorBoard cannot log Enums, need the raw value
        # experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
        accelerator.init_trackers(
            "PEFT Fine-Tuning", config=experiment_config, init_kwargs={"wandb": {"name": run_name}}
        )

    # saving and loading checkpoints for resuming training
    accelerator.register_save_state_pre_hook(save_model_hook)
    accelerator.register_load_state_pre_hook(load_model_hook)

    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
    logger.info("***** Running training *****")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {num_training_steps}")
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(num_training_steps*repeat), disable=not accelerator.is_local_main_process)
    starting_epoch = 0
    best_metric = None
    resume_step = 0
    iteration_step = 0
    global_step = 0

    # Potentially load in the weights and states from a previous save
    # if args.resume_from_checkpoint:
    #     accelerator.load_state(args.resume_from_checkpoint)
    #     starting_epoch = resume_step // len(train_dataloader)
    #     resume_step -= starting_epoch * len(train_dataloader)

    for r in range(repeat):
        progress_bar.update(resume_step)
        for epoch in range(starting_epoch, num_epochs):
            if args.with_tracking:
                total_loss = 0
                running_loss = 0

            for step, batch in enumerate(accelerator.skip_first_batches(train_dataloader, num_batches=resume_step)):
                model.train()
                with accelerator.accumulate(model):
                    outputs = model(**batch)
                    loss = outputs.loss
                    loss = loss / args.gradient_accumulation_steps
                    accelerator.backward(loss)
                    if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                        optimizer.step()
                        lr_scheduler.step()
                        optimizer.zero_grad()
                        if r < repeat - 1:
                            sparse_optimizer.step()
                            sparse_lr_scheduler.step()
                            sparse_optimizer.zero_grad()
                        progress_bar.update(1)
                    
                    
                        iteration_step += 1
                        global_step += 1

                if args.with_tracking:
                    step_loss = accelerator.reduce(loss.detach().clone()).item()
                    total_loss += step_loss
                    running_loss += step_loss

                if global_step % args.logging_steps == 0:
                    if args.with_tracking:
                        accelerator.log({"train/running_loss": running_loss / args.logging_steps}, step=global_step)
                        running_loss = 0

                if global_step % args.evaluation_steps == 0:
                    eval_metrics = eval_loop(model, task, device, eval_dataloader, accelerator)
                    if args.with_tracking:
                        logger.info(f"Step {iteration_step} eval metrics: {eval_metrics}")
                        accelerator.log(eval_metrics, step=global_step)
                    if best_metric is None or eval_metrics[metric_name] > best_metric:
                        best_metric = eval_metrics[metric_name]
                        accelerator.save_state(os.path.join(args.output_dir, f"best_checkpoint_{r}"))
                        with open(os.path.join(args.output_dir, f"all_results_{r}.json"), "w") as f:
                            json.dump(eval_metrics, f)
                    

            if args.with_tracking:
                train_epoch_loss = total_loss / (step + 1)
                logger.info(f"Epoch {epoch} train loss: {train_epoch_loss}")
                accelerator.log({"epoch/train_loss": train_epoch_loss}, step=epoch)


            print("==============END OF EPOCH================")
            eval_metrics = eval_loop(model, task, device, eval_dataloader, accelerator)
            print(best_metric, eval_metrics[metric_name])

            if best_metric is None or eval_metrics[metric_name] > best_metric:
                best_metric = eval_metrics[metric_name]
                accelerator.save_state(os.path.join(args.output_dir, f"best_checkpoint_{r}"))
                with open(os.path.join(args.output_dir, f"all_results_{r}.json"), "w") as f:
                    json.dump(eval_metrics, f)
            
            with open(os.path.join(args.output_dir, "metric"), "a+") as f:
                f.write(str(r) + " " + str(epoch) + "\n")
                f.write(str(eval_metrics[metric_name]) + "\n")
        if r < repeat-1:


            if args.load_best_model:
                # load the best model
                accelerator.load_state(os.path.join(args.output_dir, f"best_checkpoint_{r}"))
            
            model.peft_config["default_0"].inference_mode = False 
            for n, p in model.named_parameters():
                if "default_0" in n and "lora" in n:
                    p.requires_grad = True
            
            adapter_name = "default_0"
            SoraModel.extend_modules(model, adapter_name, r)
            starting_epoch = 0
            resume_step = 0
            iteration_step = 0


            if r < repeat - 2:

                adamW_group = [p for n, p in model.named_parameters() if "lora_E" not in n and p.requires_grad]
                sparse_adamW_group = [p for n, p in model.named_parameters() if "lora_E" in n and p.requires_grad]

                sparse_optimizer = SparseAdamW(params=sparse_adamW_group, sparse_lambda=args.threshold, lr=learning_rate, weight_decay=weight_decay)
                optimizer = torch.optim.AdamW(adamW_group, lr=learning_rate, weight_decay=weight_decay)

                lr_scheduler = get_scheduler(
                    "linear",
                    optimizer=optimizer,
                    num_warmup_steps=num_warmup_steps,
                    num_training_steps=num_training_steps,
                )

                sparse_lr_scheduler = get_scheduler(
                    "linear",
                    optimizer=sparse_optimizer,
                    num_warmup_steps=num_warmup_steps,
                    num_training_steps=num_training_steps,
                )
                model, sparse_optimizer, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
                    model, sparse_optimizer, optimizer, train_dataloader, eval_dataloader, lr_scheduler
                )
            else:
                optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

                lr_scheduler = get_scheduler(
                    "linear",
                    optimizer=optimizer,
                    num_warmup_steps=num_warmup_steps,
                    num_training_steps=num_training_steps,
                )
                model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
                    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
                )

            accelerator.save_state(os.path.join(args.output_dir, f"best_checkpoint_{r+1}"))

        
    
    if args.load_best_model:
        # load the best model
        accelerator.load_state(os.path.join(args.output_dir, f"best_checkpoint_{repeat-1}"))

        eval_metrics = eval_loop(model, task, device, eval_dataloader, accelerator)
        
        if args.with_tracking:
            best_metrics = {"best_" + k: v for k, v in eval_metrics.items()}
            accelerator.log(best_metrics, step=global_step)



    with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
        json.dump(eval_metrics, f)
        print(eval_metrics)


if __name__ == "__main__":
    main()
