import argparse
import os
import math
import json
from datetime import datetime
from pathlib import Path
from random import randint
from typing import Any, Dict, List, Union
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    default_data_collator,
    set_seed,
    SchedulerType,
    CONFIG_MAPPING,
    MODEL_MAPPING,
    AutoConfig,
    AutoModelForQuestionAnswering,
    EvalPrediction,
    get_scheduler
)

import collections
from typing import Optional, Tuple
import logging
from accelerate import Accelerator, dispatch_model
from accelerate.logging import get_logger
from datasets import load_dataset, load_metric
import evaluate
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import numpy as np
import transformers
import datasets 
import wandb
import math 

from peft import get_peft_model, TaskType, LoraConfig, RoLoraConfig, AdaLoraConfig, AdaLoraModel, PeftModel, RoLoraModel, prepare_model_for_kbit_training
from peft.utils import _freeze_adapter, get_peft_model_state_dict

from transformers import get_scheduler
from peft.utils import _get_submodules
logger = get_logger(__name__, log_level="INFO")

def postprocess_qa_predictions(
    examples,
    features,
    predictions: Tuple[np.ndarray, np.ndarray],
    version_2_with_negative: bool = False,
    n_best_size: int = 20,
    max_answer_length: int = 30,
    null_score_diff_threshold: float = 0.0,
    output_dir: Optional[str] = None,
    prefix: Optional[str] = None,
    log_level: Optional[int] = logging.WARNING,
):
    """
    Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
    original contexts. This is the base postprocessing functions for models that only return start and end logits.
    """
    if len(predictions) != 2:
        raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).")
    all_start_logits, all_end_logits = predictions

    if len(predictions[0]) != len(features):
        raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")

    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # The dictionaries we have to fill.
    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()
    if version_2_with_negative:
        scores_diff_json = collections.OrderedDict()

    # Logging.
    logger.setLevel(log_level)
    logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")

    # Let's loop over all the examples!
    for example_index, example in enumerate(tqdm(examples)):
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]

        min_null_prediction = None
        prelim_predictions = []

        # Looping through all the features associated to the current example.
        for feature_index in feature_indices:
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions in our logits to span of texts in the original
            # context.
            offset_mapping = features[feature_index]["offset_mapping"]
            # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
            # available in the current feature.
            token_is_max_context = features[feature_index].get("token_is_max_context", None)

            # Update minimum null prediction.
            feature_null_score = start_logits[0] + end_logits[0]
            if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
                min_null_prediction = {
                    "offsets": (0, 0),
                    "score": feature_null_score,
                    "start_logit": start_logits[0],
                    "end_logit": end_logits[0],
                }

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or len(offset_mapping[start_index]) < 2
                        or offset_mapping[end_index] is None
                        or len(offset_mapping[end_index]) < 2
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue
                    # Don't consider answer that don't have the maximum context available (if such information is
                    # provided).
                    if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
                        continue

                    prelim_predictions.append(
                        {
                            "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
                            "score": start_logits[start_index] + end_logits[end_index],
                            "start_logit": start_logits[start_index],
                            "end_logit": end_logits[end_index],
                        }
                    )
        if version_2_with_negative and min_null_prediction is not None:
            # Add the minimum null prediction
            prelim_predictions.append(min_null_prediction)
            null_score = min_null_prediction["score"]

        # Only keep the best `n_best_size` predictions.
        predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]

        # Add back the minimum null prediction if it was removed because of its low score.
        if (
            version_2_with_negative
            and min_null_prediction is not None
            and not any(p["offsets"] == (0, 0) for p in predictions)
        ):
            predictions.append(min_null_prediction)

        # Use the offsets to gather the answer text in the original context.
        context = example["context"]
        for pred in predictions:
            offsets = pred.pop("offsets")
            pred["text"] = context[offsets[0] : offsets[1]]

        # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
        # failure.
        if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""):
            predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0})

        # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
        # the LogSumExp trick).
        scores = np.array([pred.pop("score") for pred in predictions])
        exp_scores = np.exp(scores - np.max(scores))
        probs = exp_scores / exp_scores.sum()

        # Include the probabilities in our predictions.
        for prob, pred in zip(probs, predictions):
            pred["probability"] = prob

        # Pick the best prediction. If the null answer is not possible, this is easy.
        if not version_2_with_negative:
            all_predictions[example["id"]] = predictions[0]["text"]
        else:
            # Otherwise we first need to find the best non-empty prediction.
            i = 0
            while predictions[i]["text"] == "":
                i += 1
            best_non_null_pred = predictions[i]

            # Then we compare to the null prediction using the threshold.
            score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
            scores_diff_json[example["id"]] = float(score_diff)  # To be JSON-serializable.
            if score_diff > null_score_diff_threshold:
                all_predictions[example["id"]] = ""
            else:
                all_predictions[example["id"]] = best_non_null_pred["text"]

        # Make `predictions` JSON-serializable by casting np.float back to float.
        all_nbest_json[example["id"]] = [
            {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
            for pred in predictions
        ]

    # If we have an output_dir, let's save all those dicts.
    if output_dir is not None:
        if not os.path.isdir(output_dir):
            raise EnvironmentError(f"{output_dir} is not a directory.")

        prediction_file = os.path.join(
            output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
        )
        nbest_file = os.path.join(
            output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
        )
        if version_2_with_negative:
            null_odds_file = os.path.join(
                output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
            )

        logger.info(f"Saving predictions to {prediction_file}.")
        with open(prediction_file, "w") as writer:
            writer.write(json.dumps(all_predictions, indent=4) + "\n")
        logger.info(f"Saving nbest_preds to {nbest_file}.")
        with open(nbest_file, "w") as writer:
            writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
        if version_2_with_negative:
            logger.info(f"Saving null_odds to {null_odds_file}.")
            with open(null_odds_file, "w") as writer:
                writer.write(json.dumps(scores_diff_json, indent=4) + "\n")

    return all_predictions



def parse_args():
    parser = argparse.ArgumentParser(description="Sequence classification task")
    parser.add_argument(
        "--dataset_name",
        type=str,
        default=None,
        help="The name of the dataset to use (via the datasets library).",
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
        required=True,
    )
    parser.add_argument(
        "--per_device_train_batch_size",
        type=int,
        default=32,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--per_device_eval_batch_size",
        type=int,
        default=32,
        help="Batch size (per device) for the evaluation dataloader.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-3,
        help="Initial learning rate (after the potential warmup period) to use.",
    )

    parser.add_argument(
        "--max_seq_length",
        type=int,
        default=384,
        help="The maximum total input sequence length after tokenization.",
    )

    parser.add_argument(
        "--preprocessing_num_workers", 
        type=int, 
        default=4, 
        help="A csv or a json file containing the training data."
    )

    parser.add_argument(
        "--pad_to_max_length",
        type=bool,
        default=True,
        help="Whether to pad all samples to `max_seq_length`.",
    )

    parser.add_argument(
        "--doc_stride",
        type=int,
        default=128,
        help="When splitting up a long document into chunks how much stride to take between chunks.",
    )

    parser.add_argument(
        "--n_best_size",
        type=int,
        default=20,
        help="The total number of n-best predictions to generate when looking for an answer.",
    )
    parser.add_argument(
        "--null_score_diff_threshold",
        type=float,
        default=0.0,
        help=(
            "The threshold used to select the null answer: if the best answer has a score that is less than "
            "the score of the null answer minus this threshold, the null answer is selected for this example. "
            "Only useful when `version_2_with_negative=True`."
        ),
    )
    parser.add_argument(
        "--version_2_with_negative",
        action="store_true",
        help="If true, some of the examples do not have an answer.",
    )

    parser.add_argument(
        "--max_answer_length",
        type=int,
        default=30,
        help=(
            "The maximum length of an answer that can be generated. This is needed because the start "
            "and end predictions are not conditioned on one another."
        ),
    )
    parser.add_argument(
        "--max_eval_samples",
        type=int,
        default=None,
        help=(
            "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
            "value if set."
        ),
    )
    parser.add_argument(
        "--overwrite_cache", 
        type=bool, 
        default=False, 
        help="Overwrite the cached training and evaluation sets"
    )

    parser.add_argument(
        "--weight_decay",
        type=float,
        default=1e-2,
        help="Weight decay",
    )
    parser.add_argument(
        "--num_train_epochs", 
        type=int, 
        default=3, 
        help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--num_warmup_steps", 
        type=int, 
        default=0, 
        help="Number of steps for the warmup in the lr scheduler."
    )
    parser.add_argument(
        "--output_dir", 
        type=str, 
        default=None, 
        help="Where to store the final model."
    )
    parser.add_argument(
        "--seed", 
        type=int, 
        default=None, 
        help="A seed for reproducible training."
    )

    parser.add_argument(
        "--peft_type",
        type=str,
        default="lora",
        help="type of adapter: Lora, Rolora, or AdaLoRA.",
    )
    parser.add_argument(
        "--lora_alpha",
        type=int,
        default=32,
        help="LoRA alpha value.",
    )
    parser.add_argument(
        "--r",
        type=int,
        default=8,
        help="LoRA rank.",
    )

    parser.add_argument(        
        "--lora_dropout",
        type=float,
        default=0.1,
        help="LoRA dropout value.",
    )

    parser.add_argument(
        "--logging_steps",
        type=int,
        default=100,
        help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
    )
    
    
    parser.add_argument(
        "--with_tracking",
        action="store_true",
        help="Whether to enable experiment trackers for logging.",
    )

    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )

    parser.add_argument(
        "--report_to",
        type=str,
        default="all",
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
            ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
            "Only applicable when `--with_tracking` is passed."
        ),
    )

    # parser.add_argument(
    #     "--resume_from_checkpoint",
    #     type=str,
    #     default=None,
    #     help="If the training should continue from a checkpoint folder.",
    # )

    parser.add_argument(
        "--load_best_model",
        action="store_true",
        help="Whether to load the best model at the end of training",
    )

    parser.add_argument(
        "--lr_scheduler_type",
        type=SchedulerType,
        default="linear",
        help="The scheduler type to use.",
        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    )

    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=500,
        help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
    )

    parser.add_argument(
        "--evaluation_steps",
        type=int,
        default=100,
        help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
    )

    # ROLORA
    parser.add_argument(
        "--target_r",
        type=int,
        default=4,
        help="Target rank in the first round of ROLORA",
    )

    parser.add_argument(
        "--orth_reg_weight",
        type=float,
        default=0.1,
        help="regularization of ROLORA",
    )

    parser.add_argument(
        "--init_r",
        type=int,
        default=8,
        help="Initial rank in the first round of ROLORA",
    )

    parser.add_argument(
        "--tinit",
        type=int,
        default=10,
        help="Number of warm up steps in ROLORA (via AdaLoRA)",
    )

    parser.add_argument(
        "--tfinal",
        type=int,
        default=10,
        help="Number of final steps in ROLORA (via AdaLoRA)",
    )

    parser.add_argument(
        "--deltaT",
        type=int,
        default=10,
        help="delta T in ROLORA  (via AdaLoRA)",
    )

    parser.add_argument(
        "--num_retrain",
        type=int,
        default=1,
        help="total number of retrains in ROLORA",
    )

    # parser.add_argument(
    #     "--only_train_new",
    #     type=int,
    #     default=0,
    #     help="total number of retrains in ROLORA",
    # )

    parser.add_argument(
        '--target_modules', 
        nargs='+',
        default=[],
        help='target modules for PEFT'
    )



    args = parser.parse_args()

    assert args.output_dir is not None, "Need an `output_dir` to store the finetune model and verify."

    return args

def save_model_hook(models, weights, output_dir):
    for model in models:
        model.save_pretrained(output_dir)
        # make sure to pop weight so that corresponding model is not saved again
        weights.pop()
def load_model_hook(models, input_dir):
    while len(models) > 0:
        model = models.pop()
        # pop models so that they are not loaded again
        adapter_names = list(model.peft_config.keys())
        for key in adapter_names:
            PeftModel.from_pretrained(model.base_model.model, os.path.join(input_dir, key), adapter_name=key)

def eval_loop(model, 
              accelerator, 
              eval_dataloader, eval_dataset, eval_examples,
              post_processing_function,
              metric,
              create_and_fill_np_array, 
              pad_to_max_length):
    all_start_logits = []
    all_end_logits = []

    model.eval()

    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)
            start_logits = outputs.start_logits
            end_logits = outputs.end_logits

            if not pad_to_max_length:  # necessary to pad predictions and labels for being gathered
                start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
                end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)

            all_start_logits.append(accelerator.gather(start_logits).cpu().numpy())
            all_end_logits.append(accelerator.gather(end_logits).cpu().numpy())

    max_len = max([x.shape[1] for x in all_start_logits])  # Get the max_length of the tensor

    # concatenate the numpy array
    start_logits_concat = create_and_fill_np_array(all_start_logits, eval_dataset, max_len)
    end_logits_concat = create_and_fill_np_array(all_end_logits, eval_dataset, max_len)

    # delete the list of numpy arrays
    del all_start_logits
    del all_end_logits

    outputs_numpy = (start_logits_concat, end_logits_concat)
    prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
    eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)

    return eval_metric


def main():
    args = parse_args()

    args.output_dir = os.path.join(args.output_dir, f"{args.peft_type}-{args.dataset_name}-{args.r}-{args.seed}")

    accelerator_kwargs = {"gradient_accumulation_steps": args.gradient_accumulation_steps}
    if args.with_tracking:
        accelerator_kwargs["log_with"] = args.report_to
        accelerator_kwargs["project_dir"] = args.output_dir
    accelerator = Accelerator(**accelerator_kwargs)


    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
    
    
    if args.target_modules == []:
        if "roberta" in args.model_name_or_path:
            target_modules = ["query", "value"]
        elif "deberta" in args.model_name_or_path:
            target_modules = ["query_proj", "key_proj", "value_proj"]
    else:
        target_modules = args.target_modules

    if args.seed is not None:
        set_seed(args.seed)

    # Handle the repository creation
    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)
    accelerator.wait_for_everyone()

    

    dataset_name = args.dataset_name
    batch_size = args.per_device_train_batch_size
    num_epochs = args.num_train_epochs
    lora_alpha = args.lora_alpha
    lora_dropout = args.lora_dropout
    rank = args.r
    learning_rate = args.learning_rate
    weight_decay = args.weight_decay
    metric_name = "exact_match" if args.dataset_name == "squad" else "exact"

    # Prepare the dataset

    config = AutoConfig.from_pretrained(args.model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    raw_datasets = load_dataset(dataset_name, None)
    # Preprocessing the datasets.
    # Preprocessing is slighlty different for training and evaluation.

    column_names = raw_datasets["train"].column_names

    question_column_name = "question" if "question" in column_names else column_names[0]
    context_column_name = "context" if "context" in column_names else column_names[1]
    answer_column_name = "answers" if "answers" in column_names else column_names[2]

    # Padding side determines if we do (question|context) or (context|question).
    pad_on_right = tokenizer.padding_side == "right"

    max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)

    # Training preprocessing
    def prepare_train_features(examples):
        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
        # left whitespace
        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]

        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = tokenizer(
            examples[question_column_name if pad_on_right else context_column_name],
            examples[context_column_name if pad_on_right else question_column_name],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_length,
            stride=args.doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length" if args.pad_to_max_length else False,
        )

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
        # The offset mappings will give us a map from token to character position in the original context. This will
        # help us compute the start_positions and end_positions.
        offset_mapping = tokenized_examples.pop("offset_mapping")

        # Let's label those examples!
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []

        for i, offsets in enumerate(offset_mapping):
            # We will label impossible answers with the index of the CLS token.
            input_ids = tokenized_examples["input_ids"][i]
            cls_index = input_ids.index(tokenizer.cls_token_id)

            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            answers = examples[answer_column_name][sample_index]
            # If no answers are given, set the cls_index as answer.
            if len(answers["answer_start"]) == 0:
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Start/end character index of the answer in the text.
                start_char = answers["answer_start"][0]
                end_char = start_char + len(answers["text"][0])

                # Start token index of the current span in the text.
                token_start_index = 0
                while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                    token_start_index += 1

                # End token index of the current span in the text.
                token_end_index = len(input_ids) - 1
                while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                    token_end_index -= 1

                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
                if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                    tokenized_examples["start_positions"].append(cls_index)
                    tokenized_examples["end_positions"].append(cls_index)
                else:
                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                    # Note: we could go after the last offset if the answer is the last word (edge case).
                    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                        token_start_index += 1
                    tokenized_examples["start_positions"].append(token_start_index - 1)
                    while offsets[token_end_index][1] >= end_char:
                        token_end_index -= 1
                    tokenized_examples["end_positions"].append(token_end_index + 1)

        return tokenized_examples

    train_dataset = raw_datasets["train"]
    with accelerator.main_process_first():
        train_dataset = train_dataset.map(
                prepare_train_features,
                batched=True,
                num_proc=args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=args.overwrite_cache,
                desc="Running tokenizer on train dataset",
            )
    # Validation preprocessing
    def prepare_validation_features(examples):
        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
        # left whitespace
        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]

        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = tokenizer(
            examples[question_column_name if pad_on_right else context_column_name],
            examples[context_column_name if pad_on_right else question_column_name],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_length,
            stride=args.doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length" if args.pad_to_max_length else False,
        )

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

        # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
        # corresponding example_id and we will store the offset mappings.
        tokenized_examples["example_id"] = []

        for i in range(len(tokenized_examples["input_ids"])):
            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)
            context_index = 1 if pad_on_right else 0

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            tokenized_examples["example_id"].append(examples["id"][sample_index])

            # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
            # position is part of the context or not.
            tokenized_examples["offset_mapping"][i] = [
                (o if sequence_ids[k] == context_index else None)
                for k, o in enumerate(tokenized_examples["offset_mapping"][i])
            ]

        return tokenized_examples

    eval_examples = raw_datasets["validation"]
    with accelerator.main_process_first():
        eval_dataset = eval_examples.map(
                prepare_validation_features,
                batched=True,
                num_proc=args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not args.overwrite_cache,
                desc="Running tokenizer on validation dataset",
            )
    data_collator = default_data_collator
    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
    )

    eval_dataset_for_model = eval_dataset.remove_columns(["example_id", "offset_mapping"])
    eval_dataloader = DataLoader(
        eval_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
    )
    
    # Post-processing:
    def post_processing_function(examples, features, predictions, stage="eval"):
        # Post-processing: we match the start logits and end logits to answers in the original context.
        predictions = postprocess_qa_predictions(
            examples=examples,
            features=features,
            predictions=predictions,
            version_2_with_negative=args.version_2_with_negative,
            n_best_size=args.n_best_size,
            max_answer_length=args.max_answer_length,
            null_score_diff_threshold=args.null_score_diff_threshold,
            output_dir=args.output_dir,
            prefix=stage,
        )
        # Format the result to the format the metric expects.
        if args.version_2_with_negative:
            formatted_predictions = [
                {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
            ]
        else:
            formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]

        references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
        return EvalPrediction(predictions=formatted_predictions, label_ids=references)
    metric = load_metric("squad_v2" if args.version_2_with_negative else "squad", trust_remote_code=True)

    # Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor
    def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
        """
        Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor

        Args:
            start_or_end_logits(:obj:`tensor`):
                This is the output predictions of the model. We can only enter either start or end logits.
            eval_dataset: Evaluation dataset
            max_len(:obj:`int`):
                The maximum length of the output tensor. ( See the model.eval() part for more details )
        """

        step = 0
        # create a numpy array and fill it with -100.
        logits_concat = np.full((len(dataset), max_len), -100, dtype=np.float64)
        # Now since we have create an array now we will populate it with the outputs gathered using accelerator.gather
        for i, output_logit in enumerate(start_or_end_logits):  # populate columns
            # We have to fill it such that we have to take the whole tensor and replace it on the newly created array
            # And after every iteration we have to change the step

            batch_size = output_logit.shape[0]
            cols = output_logit.shape[1]

            if step + batch_size < len(dataset):
                logits_concat[step : step + batch_size, :cols] = output_logit
            else:
                logits_concat[step:, :cols] = output_logit[: len(dataset) - step]

            step += batch_size

        return logits_concat

    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    num_training_steps = num_epochs * num_update_steps_per_epoch
    num_warmup_steps = args.num_warmup_steps * num_update_steps_per_epoch



    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    
    # model
    if args.peft_type == "lora":
        peft_config = LoraConfig(
            task_type=TaskType.QUESTION_ANS, inference_mode=False, lora_alpha=lora_alpha, lora_dropout=lora_dropout, 
            r=rank,
            target_modules=target_modules,
        )
    elif args.peft_type == "adalora":
        peft_config = AdaLoraConfig(
            task_type=TaskType.QUESTION_ANS, inference_mode=False, lora_alpha=lora_alpha, lora_dropout=lora_dropout, 
            target_r=args.target_r,
            tinit=args.tinit, tfinal=args.tfinal, 
            init_r=args.init_r, total_step=num_training_steps,
            orth_reg_weight=args.orth_reg_weight,
            deltaT=args.deltaT,
            target_modules=target_modules
        )
    elif args.peft_type == "rolora":
        peft_config = RoLoraConfig(
            task_type=TaskType.QUESTION_ANS, inference_mode=False, lora_alpha=lora_alpha, lora_dropout=lora_dropout, 
            target_r=args.target_r,
            tinit=0, tfinal=args.tfinal, # To avoid a bug, set the first tinit to 0
            init_r=args.init_r, total_step=num_training_steps,
            orth_reg_weight=args.orth_reg_weight,
            deltaT=args.deltaT,
            target_modules=target_modules
        )
    repeat = args.num_retrain


    model = AutoModelForQuestionAnswering.from_pretrained(
            args.model_name_or_path,
            config=config,
        )
    model = get_peft_model(model, peft_config, adapter_name="default_0")


    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)   
    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate)

    # Scheduler and math around the number of training steps.

    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )

    

    # Prepare everything with our `accelerator`.
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if args.with_tracking:
        run_name = f"run-\
            {args.peft_type}-{args.dataset_name}-{args.r}-\
                {args.lora_alpha}-{args.seed}-{args.learning_rate}-\
                    {args.weight_decay}-{args.max_seq_length}-\
                        {datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
        experiment_config = vars(args)
        # TensorBoard cannot log Enums, need the raw value
        # experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
        accelerator.init_trackers(
            "PEFT Fine-Tuning", config=experiment_config, init_kwargs={"wandb": {"name": run_name}}
        )

    # saving and loading checkpoints for resuming training
    accelerator.register_save_state_pre_hook(save_model_hook)
    accelerator.register_load_state_pre_hook(load_model_hook)

    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
    logger.info("***** Running training *****")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {num_training_steps}")
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(num_training_steps*repeat), disable=not accelerator.is_local_main_process)
    starting_epoch = 0
    best_metric = None
    resume_step = 0
    iteration_step = 0
    global_step = 0

    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        accelerator.load_state(args.resume_from_checkpoint)
        starting_epoch = resume_step // len(train_dataloader)
        resume_step -= starting_epoch * len(train_dataloader)

    for r in range(repeat):
        progress_bar.update(resume_step)

        
        for epoch in range(starting_epoch, num_epochs):

            
            if args.with_tracking:
                total_loss = 0
                running_loss = 0

            for step, batch in enumerate(accelerator.skip_first_batches(train_dataloader, num_batches=resume_step)):
                model.train()
                with accelerator.accumulate(model):
                    outputs = model(**batch)
                    loss = outputs.loss
                    loss = loss / args.gradient_accumulation_steps
                    accelerator.backward(loss)
                    if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                        optimizer.step()
                        lr_scheduler.step()

                        if (args.peft_type == 'rolora' and r < repeat-1) or (args.peft_type == 'adalora'):
                            model.update_and_allocate(iteration_step)

                        optimizer.zero_grad()
                        progress_bar.update(1)
                        iteration_step += 1
                        global_step += 1

                if args.with_tracking:
                    step_loss = accelerator.reduce(loss.detach().clone()).item()
                    total_loss += step_loss
                    running_loss += step_loss

                if global_step % args.logging_steps == 0:
                    if args.with_tracking:
                        accelerator.log({"train/running_loss": running_loss / args.logging_steps}, step=global_step)
                        running_loss = 0

                if global_step % args.evaluation_steps == 0:
                    eval_metrics = eval_loop(model, 
                        accelerator, 
                        eval_dataloader, eval_dataset, eval_examples,
                        post_processing_function,
                        metric,
                        create_and_fill_np_array, 
                        args.pad_to_max_length)

                    if args.with_tracking:
                        logger.info(f"Step {iteration_step} eval metrics: {eval_metrics}")
                        accelerator.log(eval_metrics, step=global_step)
                    if best_metric is None or eval_metrics[metric_name] > best_metric:
                        best_metric = eval_metrics[metric_name]
                        accelerator.save_state(os.path.join(args.output_dir, f"best_checkpoint_{r}"))
                        with open(os.path.join(args.output_dir, f"all_results_{r}.json"), "w") as f:
                            json.dump(eval_metrics, f)
                    

            if args.with_tracking:
                train_epoch_loss = total_loss / (step + 1)
                logger.info(f"Epoch {epoch} train loss: {train_epoch_loss}")
                accelerator.log({"epoch/train_loss": train_epoch_loss}, step=epoch)


            print("==============END OF EPOCH================")
            eval_metrics = eval_loop(model, 
                accelerator, 
                eval_dataloader, eval_dataset, eval_examples,
                post_processing_function,
                metric,
                create_and_fill_np_array, 
                args.pad_to_max_length)
            print(best_metric, eval_metrics[metric_name])

            if best_metric is None or eval_metrics[metric_name] > best_metric:
                best_metric = eval_metrics[metric_name]
                accelerator.save_state(os.path.join(args.output_dir, f"best_checkpoint_{r}"))
                with open(os.path.join(args.output_dir, f"all_results_{r}.json"), "w") as f:
                    json.dump(eval_metrics, f)
            
            with open(os.path.join(args.output_dir, "metric"), "a+") as f:
                f.write(str(r) + " " + str(epoch) + "\n")
                f.write(str(eval_metrics[metric_name]) + "\n")
    
        if args.peft_type == 'rolora' and r < repeat-1:

            if args.load_best_model:
                # load the best model
                accelerator.load_state(os.path.join(args.output_dir, f"best_checkpoint_{r}"))
            if r == repeat-2:
                orth_reg_weight = 0
            else:  
                orth_reg_weight = args.orth_reg_weight

            adapter_name = "default_0"
            RoLoraModel.extend_modules(model, adapter_name, r, repeat)
            model.peft_config[adapter_name].orth_reg_weight = orth_reg_weight
            
            starting_epoch = 0
            resume_step = 0
            iteration_step = 0

            # Optimizer
            # Split weights in two groups, one with weight decay and the other not.
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": weight_decay,
                },
                {
                    "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]

            optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

            lr_scheduler = get_scheduler(
                "linear",
                optimizer=optimizer,
                num_warmup_steps=num_warmup_steps,
                num_training_steps=num_training_steps,
            )
            
            model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
                model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
            )

            accelerator.save_state(os.path.join(args.output_dir, f"best_checkpoint_{r+1}"))
        # after each iteration, reinitialize  
        elif r < repeat-1:
            if args.load_best_model:
                # load the best model
                accelerator.load_state(os.path.join(args.output_dir, f"best_checkpoint_{r}"))
                
            model.peft_config["default_0"].inference_mode = False 
            for n, p in model.named_parameters():
                if "default_0" in n and "lora" in n:
                    p.requires_grad = True
            starting_epoch = 0
            resume_step = 0
            iteration_step = 0

            # Optimizer
            # Split weights in two groups, one with weight decay and the other not.
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": weight_decay,
                },
                {
                    "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
            optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate)

            # Scheduler and math around the number of training steps.
            num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

            lr_scheduler = get_scheduler(
                name=args.lr_scheduler_type,
                optimizer=optimizer,
                num_warmup_steps=num_warmup_steps,
                num_training_steps=num_training_steps,
            )

            model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
                model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
            )
            accelerator.save_state(os.path.join(args.output_dir, f"best_checkpoint_{r+1}"))

    
    if args.load_best_model:
        # load the best model
        accelerator.load_state(os.path.join(args.output_dir, f"best_checkpoint_{repeat-1}"))

        eval_metrics = eval_loop(model, 
                accelerator, 
                eval_dataloader, eval_dataset, eval_examples,
                post_processing_function,
                metric,
                create_and_fill_np_array, 
                args.pad_to_max_length)
        
        if args.with_tracking:
            best_metrics = {"best_" + k: v for k, v in eval_metrics.items()}
            accelerator.log(best_metrics, step=global_step)

    # accelerator.wait_for_everyone()
    # unwrapped_model = accelerator.unwrap_model(model)
    # unwrapped_model.save_pretrained(args.output_dir, is_main_process=accelerator.is_main_process)

    with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
        json.dump(eval_metrics, f)
        print(eval_metrics)


if __name__ == "__main__":
    main()
