# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
from contextlib import nullcontext
from pathlib import Path

import evaluate
import torch
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup

from accelerate import Accelerator, DistributedType
from accelerate.utils import SAFE_WEIGHTS_NAME, TorchTensorParallelPlugin, set_seed
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler


MAX_GPU_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 32


def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = "bert-base-cased"):
    """
    Creates a set of `DataLoader`s for the `glue` dataset.

    Args:
        accelerator (`Accelerator`):
            An `Accelerator` object
        batch_size (`int`, *optional*):
            The batch size for the train and validation DataLoaders.
        model_name (`str`, *optional*):
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    datasets = load_dataset("glue", "mrpc")

    def tokenize_function(examples):
        # max_length=None => use the model max length (it's actually the default)
        outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
        return outputs

    # Apply the method we just defined to all the examples in all the splits of the dataset
    tokenized_datasets = datasets.map(
        tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False
    )

    # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
    # transformers library
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    def collate_fn(examples):
        # On TPU it's best to pad everything to the same length or training will be very slow.
        if accelerator.distributed_type == DistributedType.XLA:
            return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
        return tokenizer.pad(examples, padding="longest", return_tensors="pt")

    # Instantiate dataloaders.
    train_dataloader = DataLoader(
        tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
    )
    eval_dataloader = DataLoader(
        tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
    )

    return train_dataloader, eval_dataloader


def training_function(config, args):
    accelerator_kwargs = {}
    # need this for DeepSpeed tests as `args.tp_size` would be None and `torch.distributed.init_device_mesh` would fail
    if args.tp_size is not None:
        accelerator_kwargs["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=args.tp_size)

    # Initialize accelerator
    accelerator = Accelerator(**accelerator_kwargs)

    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
    lr = config["lr"]
    num_epochs = int(config["num_epochs"])
    seed = int(config["seed"])
    batch_size = int(config["batch_size"])
    model_name = args.model_name_or_path

    set_seed(seed)
    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name)

    # Add TP related kwargs if provided
    model_kwargs = {}
    if args.tp_plan is not None:
        model_kwargs["tp_plan"] = args.tp_plan
    if args.tp_size is not None:
        model_kwargs["tp_size"] = args.tp_size

    # Instantiate the model (we build the model here so that the seed also control new weights initialization)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True, **model_kwargs)

    if args.add_pad_token:
        if model.config.pad_token_id is None:
            model.config.pad_token_id = 0

    # Instantiate optimizer
    optimizer_cls = (
        AdamW
        if accelerator.state.deepspeed_plugin is None
        or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
        else DummyOptim
    )
    optimizer = optimizer_cls(params=model.parameters(), lr=lr)

    max_training_steps = len(train_dataloader) * num_epochs

    # Instantiate scheduler
    linear_decay_scheduler = False
    if (
        accelerator.state.deepspeed_plugin is None
        or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
    ):
        lr_scheduler = get_linear_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=0,
            num_training_steps=max_training_steps,
        )
        linear_decay_scheduler = True
    else:
        lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0)

    # Prepare everything
    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
    # prepare method.
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )

    # We also need to keep track of the stating epoch so files are named properly
    starting_epoch = 0

    # Now we train the model
    metric = evaluate.load("glue", "mrpc")
    best_performance = 0
    performance_metric = {}
    expected_lr_after_first_optim_step = lr * (
        1 - 1 / (max_training_steps / accelerator.num_processes / accelerator.gradient_accumulation_steps)
    )
    lr_scheduler_check_completed = False
    for epoch in range(starting_epoch, num_epochs):
        model.train()
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss
                accelerator.backward(loss)
                context = nullcontext
                if args.tp_plan is not None:
                    from torch.distributed._tensor.experimental import implicit_replication

                    context = implicit_replication
                with context():
                    optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

                # assert the learning rate after first optimizer step
                if (
                    accelerator.sync_gradients
                    and not lr_scheduler_check_completed
                    and linear_decay_scheduler
                    and accelerator.state.mixed_precision == "no"
                ):
                    assert lr_scheduler.get_last_lr()[0] == expected_lr_after_first_optim_step, (
                        f"Wrong lr found at second step, expected {expected_lr_after_first_optim_step}, got {lr_scheduler.get_last_lr()[0]}"
                    )
                    lr_scheduler_check_completed = True

        model.eval()
        samples_seen = 0
        for step, batch in enumerate(eval_dataloader):
            # We could avoid this line since we set the accelerator with `device_placement=True`.
            batch.to(accelerator.device)
            with torch.no_grad():
                outputs = model(**batch)
            predictions = outputs.logits.argmax(dim=-1)
            # It is slightly faster to call this once, than multiple times
            predictions, references = accelerator.gather(
                (predictions, batch["labels"])
            )  # If we are in a multiprocess environment, the last batch has duplicates
            if accelerator.use_distributed:
                if step == len(eval_dataloader) - 1:
                    predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
                    references = references[: len(eval_dataloader.dataset) - samples_seen]
                else:
                    samples_seen += references.shape[0]
            metric.add_batch(
                predictions=predictions,
                references=references,
            )

        eval_metric = metric.compute()
        # Use accelerator.print to print only on the main process.
        accelerator.print(f"epoch {epoch}:", eval_metric)
        performance_metric[f"epoch-{epoch}"] = eval_metric["accuracy"]

        if best_performance < eval_metric["accuracy"]:
            best_performance = eval_metric["accuracy"]

    # check that the LR is 0
    if linear_decay_scheduler and accelerator.state.mixed_precision == "no":
        assert lr_scheduler.get_last_lr()[0] == 0, (
            f"Wrong lr found at last step, expected 0, got {lr_scheduler.get_last_lr()[0]}"
        )

    if args.performance_lower_bound is not None:
        assert args.performance_lower_bound <= best_performance, (
            f"Best performance metric {best_performance} is lower than the lower bound {args.performance_lower_bound}"
        )

    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
            json.dump(performance_metric, f)

    # TODO: skip saving of the model test for TP until the feature lands
    if args.tp_plan is None:
        # Finally try saving the model
        accelerator.save_model(model, args.output_dir)
    accelerator.wait_for_everyone()
    if args.tp_plan is None:
        assert Path(args.output_dir, SAFE_WEIGHTS_NAME).exists(), (
            "Model was not saved when calling `Accelerator.save_model`"
        )
    accelerator.end_training()


def main():
    parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.")
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default="bert-base-cased",
        help="Path to pretrained model or model identifier from huggingface.co/models.",
        required=False,
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=".",
        help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.",
    )
    parser.add_argument(
        "--performance_lower_bound",
        type=float,
        default=None,
        help="Optional lower bound for the performance metric. If set, the training will throw error when the performance metric drops below this value.",
    )
    parser.add_argument(
        "--num_epochs",
        type=int,
        default=3,
        help="Number of train epochs.",
    )
    parser.add_argument(
        "--add_pad_token",
        type=bool,
        default=False,
        help="To add pad token if not exists.",
    )
    parser.add_argument(
        "--tp_plan",
        type=str,
        default=None,
        help="pass 'auto' to use TP",
    )
    parser.add_argument(
        "--tp_size",
        type=int,
        default=None,
        help="TP size to be used to shard the model",
    )
    args = parser.parse_args()
    config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16}
    training_function(config, args)


if __name__ == "__main__":
    main()
