import json
import logging
import os
import random
import re
import shutil
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, NamedTuple

import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
import torch.nn.functional as F
from tqdm.auto import tqdm, trange

from transformers.data.data_collator import DataCollator, DefaultDataCollator
from transformers.modeling_utils import PreTrainedModel

import transformers
from transformers.optimization import (
    AdamW,
    get_linear_schedule_with_warmup,
    get_constant_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)
from transformers.trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    EvalPrediction,
    PredictionOutput,
    TrainOutput,
)
from transformers.training_args import TrainingArguments, is_tpu_available

from transformers import Trainer

import dataclasses
from dataclasses import dataclass, field
from copy import deepcopy

from lang_exps.data.processors.data import output_modes, tasks_num_labels
from lang_exps.trainer.exemplars import ExemplarHandler, ReplayDataset
from lang_exps.trainer.sam import SAM

try:
    from apex import amp

    _has_apex = True
except ImportError:
    _has_apex = False


def is_apex_available():
    return _has_apex


try:
    from torch.utils.tensorboard import SummaryWriter

    _has_tensorboard = True
except ImportError:
    try:
        from tensorboardX import SummaryWriter

        _has_tensorboard = True
    except ImportError:
        _has_tensorboard = False


def is_tensorboard_available():
    return _has_tensorboard


try:
    import wandb

    _has_wandb = True
except ImportError:
    _has_wandb = False


def is_wandb_available():
    return _has_wandb


logger = logging.getLogger(__name__)


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class EvalPrediction(NamedTuple):
    """
    Evaluation output (always contains labels), to be used
    to compute metrics.
    """

    predictions: np.ndarray
    label_ids: np.ndarray
    guids: np.ndarray = None


@dataclass
class CLTrainingArguments(TrainingArguments):

    do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
    do_val: bool = field(
        default=False, metadata={"help": "Whether to run evaluation on the dev set."}
    )
    do_eval: bool = field(
        default=False, metadata={"help": "Whether to run evaluation on the test set."}
    )
    do_replay_train: bool = field(
        default=False, metadata={"help": "Whether to use replay to run few epochs."}
    )

    warmup_steps: int = field(
        default=0, metadata={"help": "Linear warmup over warmup_steps."}
    )
    warmup_ratio: float = field(
        default=0.1, metadata={"help": "Linear warmup over warmup_steps."}
    )
    logging_first_step: bool = field(
        default=False, metadata={"help": "Log and eval the first global_step"}
    )
    logging_steps: int = field(
        default=100, metadata={"help": "Log every X updates steps."}
    )
    num_loggings: int = field(
        default=0, metadata={"help": "Total amount of evaluations in training."}
    )
    save_steps: int = field(
        default=500, metadata={"help": "Save checkpoint every X updates steps."}
    )

    learning_rate: float = field(
        default=2e-5, metadata={"help": "The initial learning rate for Adam."}
    )
    weight_decay: float = field(
        default=0.0, metadata={"help": "Weight decay if we apply some."}
    )
    adam_epsilon: float = field(
        default=1e-8, metadata={"help": "Epsilon for Adam optimizer."}
    )
    max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
    max_steps: int = field(
        default=-1,
        metadata={
            "help": "If > 0: set total number of training steps to perform. Override num_train_epochs."
        },
    )
    num_train_epochs: float = field(
        default=3.0, metadata={"help": "Total number of training epochs to perform."}
    )

    write_rate: float = field(
        default=0.0, metadata={"help": "how frequently to write examples to memory."}
    )
    replay_rate: float = field(
        default=0.0, metadata={"help": "how frequently to replay examples from memory."}
    )
    sparse_replay: str = field(
        default="no", metadata={"help": "Whether to sparse replay examples."}
    )
    replay_only_fraction: str = field(
        default="no",
        metadata={"help": "Whether to replay only fraction of iterations!"},
    )
    log_active_gates: bool = field(
        default=False, metadata={"help": "Whether to log gate distribution."}
    )
    min_examples_per_class: int = field(
        default=1, metadata={"help": "Minimum examples per class."}
    )
    max_examples_per_class: int = field(
        default=-1, metadata={"help": "Maximum examples per class."}
    )
    patience: int = field(default=-1, metadata={"help": "Patience for early stopping."})
    write_strategy: str = field(
        default="random", metadata={"help": "Choices: random, mof, kmeans"}
    )

    reinit_head: bool = field(
        default=False,
        metadata={"help": "Whether to re-initialize the classifier head."},
    )
    num_train_head_epochs: float = field(
        default=1.0, metadata={"help": "Total number of training epochs to perform."}
    )
    finetuning_task: str = field(default=None, metadata={"help": "Fine-tuning task!"})
    replay_tiny: bool = field(
        default=False, metadata={"help": "Replay in Multi-task fashion."}
    )
    replay_tiny_type: str = field(
        default="unweighted", metadata={"help": "choices: ['unweighted', 'gradnorm']"}
    )
    enable_mtl_baseline: bool = field(
        default=False, metadata={"help": "Enable MTL equivalent baseline."}
    )
    debug: bool = field(default=False, metadata={"help": "Enable debug mode."})
    freeze_embeddings: bool = field(
        default=False,
        metadata={"help": "Freeze pre-trained/ random initialized embeddings."},
    )
    log_grad_norm: bool = field(
        default=False, metadata={"help": "Whether to log grad norm!"}
    )
    skip_replay: bool = field(
        default=False, metadata={"help": "Whether to skip replay!"}
    )

    ## EWC parameters
    ewc: bool = field(default=False, metadata={"help": "Whether to enable online EWC."})
    gamma: float = field(
        default=0.98,
        metadata={
            "help": "hyperparam (online EWC): decay-term for old tasks' contribution to quadratic term"
        },
    )
    lmbda: float = field(
        default=10000.0,
        metadata={
            "help": "hyperparam (online EWC): decay-term for old tasks' contribution to quadratic term"
        },
    )
    fisher_n: int = field(
        default=1001,
        metadata={
            "help": "max number of samples for computing fisher information. Default of -1 means use all examples."
        },
    )
    enable_l2: bool = field(
        default=False, metadata={"help": "Whether to enable l2 reg."}
    )
    l2_weight: float = field(default=0.01, metadata={"help": "hyperparam for l2 reg."})
    supconloss: bool = field(
        default=False,
        metadata={"help": "Whether to enable supervised contrastive loss."},
    )
    lmbda_sc: float = field(
        default=0.5, metadata={"help": "hyperparam for supervised contrastive."}
    )
    optimizer: str = field(
        default="adam", metadata={"help": "Choices: adam, sam, asam"},
    )
    rho: float = field(
        default=0.05, metadata={"help": "rho"}
    )
    disable_scheduler: bool = field(
        default=False, metadata={"help": "Whether to enable l2 reg."}
    )
    batch_mode: bool = field(
        default=False, metadata={"help": "Whether to enable batch mode."}
    )


class TrainerReplay:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
    optimized for Transformers.
    """

    model: PreTrainedModel
    args: CLTrainingArguments
    data_collator: DataCollator
    train_dataset: Optional[Dataset]
    val_dataset: Optional[Dataset]
    eval_dataset: Optional[Dataset]
    compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
    prediction_loss_only: bool
    tb_writer: Optional["SummaryWriter"] = None
    optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None

    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        task: Optional[str] = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        val_dataset: Optional[Dataset] = None,
        test_dataset: Optional[Dataset] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        prediction_loss_only=False,
        tb_writer: Optional["SummaryWriter"] = None,
        optimizers: Tuple[
            torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR
        ] = None,
        adapter_fn: bool = False,
        exemplarHandler: ExemplarHandler = None,
        task_metric: str = "acc",
        est_fisher_info=None,
        est_mean_prev_task=None,
        task_idx=0,
    ):
        """
        Trainer is a simple but feature-complete training and eval loop for PyTorch,
        optimized for Transformers.
        """

        self.model = model
        self.args = args
        if data_collator is not None:
            self.data_collator = data_collator
        else:
            self.data_collator = DefaultDataCollator()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.compute_metrics = compute_metrics
        self.prediction_loss_only = prediction_loss_only
        self.optimizers = optimizers
        self.adapter_fn = adapter_fn
        self.task = task
        self.exemplarHandler = exemplarHandler
        self.task_metric = task_metric
        self.task_idx = task_idx

        self.best_model = None
        self.best_epoch_optimizer = None
        self.best_epoch_scheduler = None
        self.best_score = 0.0
        self.best_epoch = 0
        self.best_global_step = 0
        self.all_logs = {}

        ## EWC
        self.est_fisher_info = est_fisher_info
        self.est_mean_prev_task = est_mean_prev_task

        if tb_writer is not None:
            self.tb_writer = tb_writer
        elif is_tensorboard_available() and self.args.local_rank in [-1, 0]:
            logger.info(f"Setting tensorboard logging dir to {self.args.logging_dir}")
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)

        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
        if not is_wandb_available():
            logger.info(
                "You are instantiating a Trainer but wandb is not installed. Install it to use Weights & Biases logging."
            )
        set_seed(self.args.seed)
        # Create output directory if needed
        if self.is_local_master():
            os.makedirs(self.args.output_dir, exist_ok=True)

    def get_train_dataloader(self, train_dataset=None, batch_size=None) -> DataLoader:

        if train_dataset is None and self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_data = train_dataset if train_dataset is not None else self.train_dataset

        train_sampler = (
            RandomSampler(train_data)
            if self.args.local_rank == -1
            else DistributedSampler(train_data)
        )

        if batch_size is None:
            batch_size = self.args.train_batch_size

        data_loader = DataLoader(
            train_data,
            batch_size=batch_size,
            sampler=train_sampler,
            collate_fn=self.data_collator.collate_batch,
        )

        return data_loader

    def get_val_dataloader(self, val_dataset: Optional[Dataset] = None) -> DataLoader:

        if val_dataset is None and self.val_dataset is None:
            raise ValueError("Trainer: evaluation requires an val_dataset.")

        sampler = None

        data_loader = DataLoader(
            val_dataset if val_dataset is not None else self.val_dataset,
            sampler=sampler,
            batch_size=self.args.eval_batch_size,
            shuffle=False,
            collate_fn=self.data_collator.collate_batch,
        )

        return data_loader

    def get_test_dataloader(self, test_dataset: Optional[Dataset] = None) -> DataLoader:
        # We use the same batch_size as for val.
        if test_dataset is None and self.test_dataset is None:
            raise ValueError("Trainer: evaluation requires an val_dataset.")

        sampler = None

        data_loader = DataLoader(
            test_dataset if test_dataset is not None else self.test_dataset,
            sampler=sampler,
            batch_size=self.args.eval_batch_size,
            shuffle=False,
            collate_fn=self.data_collator.collate_batch,
        )

        return data_loader

    def get_replay_dataloader(self, replay_dataset=None) -> DataLoader:

        if replay_dataset is None:
            raise ValueError("Trainer: training requires a replay_dataset.")

        replay_sampler = (
            RandomSampler(replay_dataset)
            if self.args.local_rank == -1
            else DistributedSampler(replay_dataset)
        )

        data_loader = DataLoader(
            replay_dataset,
            batch_size=self.args.train_batch_size,
            sampler=replay_sampler,
            collate_fn=self.data_collator.collate_batch,
        )

        return data_loader

    def get_replay_dataloader_v2(self, replay_dataset=None) -> DataLoader:

        if replay_dataset is None:
            raise ValueError("Trainer: training requires a replay_dataset.")

        replay_sampler = None

        data_loader = DataLoader(
            replay_dataset,
            batch_size=self.args.train_batch_size,
            sampler=replay_sampler,
            shuffle=False,
            collate_fn=self.data_collator.collate_batch,
        )

        return data_loader

    def do_replay(self, model, optimizer, update_replay=True):

        replay_dataset, replay_task = self.exemplarHandler.read(
            n_examples_replay=self.args.train_batch_size
        )

        if replay_dataset is None:  # Memory is empty so no-op
            return None, None, None

        replay_dataloader = self.get_replay_dataloader(replay_dataset=replay_dataset)

        epoch_iterator = tqdm(replay_dataloader, desc="Iteration", disable=True)

        tr_loss = []
        sc_loss = []
        total_norm = []
        for step, inputs in enumerate(epoch_iterator):

            if update_replay:
                if self.args.supconloss:
                    loss, loss_sc = self._training_step(model, inputs, do_backward=True)
                    tr_loss.append(loss)
                    sc_loss.append(loss_sc)
                else:
                    loss = self._training_step(model, inputs, do_backward=True)
                    tr_loss.append(loss)
            else:
                if self.args.supconloss:
                    loss, loss_sc = self._training_step(
                        model, inputs, do_backward=False
                    )
                    tr_loss.append(loss)
                    sc_loss.append(loss_sc)
                else:
                    loss = self._training_step(model, inputs, do_backward=False)
                    tr_loss.append(loss)

            if update_replay:
                total_norm.append(
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), self.args.max_grad_norm
                    ).item()
                )
                optimizer.step()
                model.zero_grad()

        if self.args.supconloss and (not update_replay):
            return tr_loss, sc_loss, replay_task
        elif not update_replay:
            return tr_loss, None, replay_task

        return np.mean(tr_loss), np.mean(total_norm), replay_task

    def do_replay_multiple_tasks(self, model, replay_datasets=None):

        if replay_datasets is None:
            replay_datasets = self.exemplarHandler.read_multiple_tasks(
                n_examples_replay=self.args.train_batch_size
            )

        if replay_datasets is None:  # Memory is empty so no-op
            return None, None, None, None

        task_loss = {}
        task_examples = {}
        n_examples = 0

        for replay_task in replay_datasets:
            task_examples[replay_task] = len(replay_datasets[replay_task])
            n_examples += task_examples[replay_task]

        for replay_task in replay_datasets:

            task_loss[replay_task] = None
            # task_examples[replay_task] = len(replay_datasets[replay_task])
            replay_dataloader = self.get_replay_dataloader(
                replay_dataset=replay_datasets[replay_task]
            )

            epoch_iterator = tqdm(replay_dataloader, desc="Iteration", disable=True)

            for step, inputs in enumerate(epoch_iterator):

                loss = self._training_step(
                    model, inputs, task=replay_task, replay=True, do_backward=False
                )
                if task_loss[replay_task] is None:
                    task_loss[replay_task] = loss.item()
                else:
                    task_loss[replay_task] += loss.item()
                loss = loss/n_examples
                loss.backward()

        # Compute avg loss
        tr_loss = None
        # n_examples = 0
        for replay_task in task_loss:
            # tr_loss += task_loss[replay_task].item()
            if tr_loss is None:
                tr_loss = task_loss[replay_task]
            else:
                tr_loss += task_loss[replay_task]
            # n_examples += task_examples[replay_task]
        tr_loss /= n_examples

        return tr_loss, n_examples, task_loss, task_examples

    def do_replay_multiple_tasks_batch_mode(self, model,
                                            replay_datasets=None,
                                            replay_tasks=None):

        if replay_datasets is None:
            replay_datasets, replay_tasks = self.exemplarHandler.read_multiple_tasks_batch_mode(
                n_examples_replay=self.args.train_batch_size
            )

        if replay_datasets is None:  # Memory is empty so no-op
            return None, None, None, None

        n_examples = len(replay_tasks)

        replay_dataloader = self.get_replay_dataloader_v2(replay_dataset=replay_datasets)

        epoch_iterator = tqdm(replay_dataloader, desc="Iteration", disable=True)

        start = 0
        end = 0
        tr_loss = 0.0

        for step, inputs in enumerate(epoch_iterator):

            model.train()
            for k, v in inputs.items():
                inputs[k] = v.to(self.args.device)

            n_examples_batch = len(inputs['input_ids'])
            end = start + n_examples_batch

            # task = None
            outputs = model(**inputs, task=replay_tasks[start: end], replay=False)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
            tr_loss += loss.item()
            loss = loss / n_examples_batch
            loss.backward()

        tr_loss /= n_examples

        return tr_loss, n_examples, None, None

    def get_grad_norm(self, loss, model, named_model_params, task=None):

        named_params = {n: p for n, p in model.named_parameters()}

        params = [named_params[name] for name in named_model_params]

        task_grad = torch.autograd.grad(
            loss, params, retain_graph=True, allow_unused=True
        )
        param_grad = {
            n: grad.detach().clone()
            for n, grad in zip(named_model_params, task_grad)
            if grad is not None
        }
        names = [
            "word_embeddings",
            "position_embeddings",
            "token_type_embeddings",
            "layer.0.",
            "layer.1.",
            "layer.2.",
            "layer.3.",
            "layer.4.",
            "layer.5.",
            "layer.6.",
            "layer.7.",
            "layer.8.",
            "layer.9.",
            "layer.10.",
            "layer.11.",
            "pooler",
        ]
        if task is not None:
            names.append("classifier_{}".format(task))
        else:
            names.append("classifier")

        grad_norms = {}
        unit_grad_norms = {}

        def get_name(param):

            for name in names:
                if name in param:
                    return name
            return None

        for param in param_grad:

            name = get_name(param=param)
            if name is None:
                continue

            if name not in grad_norms:
                grad_norms[name] = []
            if name not in unit_grad_norms:
                unit_grad_norms[name] = {}

            grad_norms[name].append(torch.norm(param_grad[param], 2))
            unit_grad_norms[name][param] = param_grad[param]

        for name in grad_norms:
            if len(grad_norms[name]) > 0:
                grad_norms[name] = torch.norm(torch.stack(grad_norms[name]), 2)
                for _param in unit_grad_norms[name]:
                    unit_grad_norms[name][_param] = (
                        unit_grad_norms[name][_param] / grad_norms[name]
                    )
            else:
                grad_norms[name] = 0.0

        grad_norms["whole"] = torch.norm(
            torch.stack(
                [
                    torch.norm(param_grad[p], 2)
                    for p in param_grad
                    if "classifier" not in p
                ]
            ),
            2,
        )
        grad_norms["whole_w_classifier"] = torch.norm(
            torch.stack([torch.norm(param_grad[p], 2) for p in param_grad]), 2
        )

        unit_norm = {
            param: param_grad[param] / grad_norms["whole"]
            for param in named_model_params
            if "classifier" not in param
        }

        return grad_norms, unit_norm, unit_grad_norms

    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.
        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (
                self.args.max_steps
                // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                + 1
            )
        else:
            t_total = int(
                len(train_dataloader)
                // self.args.gradient_accumulation_steps
                * self.args.num_train_epochs
            )
            num_train_epochs = self.args.num_train_epochs


        if self.args.disable_scheduler:
            optimizer = self.get_optimizers(num_training_steps=t_total)
        else:
            optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # If we specify the num_logging...update the task-specific logging based upon total updates for the required task
        if self.args.num_loggings > 0:
            logging_steps = t_total // self.args.num_loggings
        else:
            logging_steps = self.args.logging_steps

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"))
            )
            if not self.args.disable_scheduler:
                scheduler.load_state_dict(
                    torch.load(os.path.join(model_path, "scheduler.pt"))
                )

        model = self.model
        model.to(self.args.device)

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
        if is_wandb_available():
            self._setup_wandb()

        # Train!
        num_examples = len(train_dataloader.dataset)
        total_train_batch_size = (
            self.args.train_batch_size
            * self.args.gradient_accumulation_steps
            * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
        )

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", num_examples)
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info(
            "  Instantaneous batch size per device = %d",
            self.args.per_gpu_train_batch_size,
        )
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            total_train_batch_size,
        )
        logger.info(
            "  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps
        )
        logger.info("  Total optimization steps = %d", t_total)

        if self.args.batch_mode:
            logger.info("Replay in batch mode!")

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = global_step // (
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )
                steps_trained_in_current_epoch = global_step % (
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch,
                )
            except ValueError:
                global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        tr_ewc_loss = 0.0
        prev_ewc_loss = 0.0
        tr_replay_loss = 0.0
        tr_max_loss = 0.0
        tr_max_replay_loss = 0.0
        prev_tr_replay_loss = 0.0
        prev_tr_max_replay_loss = 0.0
        logging_loss = 0.0
        logging_replay_loss = 0.0
        patience = 0
        ran_out_of_patience = False
        n_replay_steps = 0
        prev_tr_loss = 0.0
        prev_tr_max_loss = 0.0

        tr_loss_sc = 0.0
        prev_tr_loss_sc = 0.0
        tr_replay_loss_sc = 0.0
        prev_tr_replay_loss_sc = 0.0
        tr_grad_norm = 0.0
        logging_grad_norm = 0.0

        model.zero_grad()
        train_iterator = trange(
            epochs_trained,
            int(num_train_epochs),
            desc="Epoch",
            disable=not self.is_local_master(),
        )

        # update_replay=True

        named_model_params = [
            n for n, p in self.model.named_parameters() if p.requires_grad
        ]

        for epoch in train_iterator:
            epoch_iterator = tqdm(
                train_dataloader, desc="Iteration", disable=not self.is_local_master()
            )
            self.all_logs[epoch] = {}

            for step, inputs in enumerate(epoch_iterator):

                r = np.random.uniform(0, 1)

                if self.args.optimizer == "sam" or self.args.optimizer == "asam":

                    if self.args.replay_tiny and self.args.sparse_replay == "yes" and self.args.replay_rate > 0.0 and r <= self.args.replay_rate:
                        curr_task_loss, replay_loss, max_current_task_loss, max_replay_loss, total_grad_norm = self._er_sam_training_step(model=model,
                                                                                                                   optimizer=optimizer,
                                                                                                                   inputs=inputs)
                        tr_loss += curr_task_loss.item()
                        tr_max_loss += max_current_task_loss.item()

                        if replay_loss is not None:
                            # tr_replay_loss += replay_loss.item()
                            # tr_max_replay_loss += max_replay_loss.item()
                            tr_replay_loss += replay_loss
                            tr_max_replay_loss += max_replay_loss
                            n_replay_steps += 1

                    else:
                        curr_task_loss, max_current_task_loss, total_grad_norm = self._sam_training_step(model=model,
                                                                                                         optimizer=optimizer,
                                                                                                         inputs=inputs)
                        tr_loss += curr_task_loss.item()
                        tr_max_loss += max_current_task_loss.item()

                elif self.args.replay_tiny:

                    curr_task_loss = self._training_step(
                        model, inputs, replay=True, do_backward=False
                    )
                    n_curr_task_examples = inputs["input_ids"].size()[0]

                    if (
                        self.args.sparse_replay == "yes"
                        and self.args.replay_rate > 0.0
                        and r <= self.args.replay_rate
                    ):
                        curr_task_loss = curr_task_loss / n_curr_task_examples
                        curr_task_loss.backward()

                        if self.args.batch_mode:
                            (
                                replay_loss,
                                n_replay_examples,
                                _,
                                _,
                            ) = self.do_replay_multiple_tasks_batch_mode(model=model)
                        else:
                            (
                                replay_loss,
                                n_replay_examples,
                                task_loss,
                                task_examples,
                            ) = self.do_replay_multiple_tasks(model=model)

                        if replay_loss is not None:

                            tr_replay_loss += replay_loss
                            n_replay_steps += 1
                            # tr_loss += curr_task_loss.item() / n_curr_task_examples
                            tr_loss += curr_task_loss.item()

                            # for replay_task in task_loss:
                            #     curr_task_loss += task_loss[replay_task]
                            # curr_task_loss = curr_task_loss / (
                            #     n_curr_task_examples + n_replay_examples
                            # )
                            # if self.args.ewc and self.task_idx > 0:
                            #     ewc_loss = self.get_ewc_loss(model=model)
                            #     curr_task_loss += self.args.lmbda * ewc_loss
                            #     tr_ewc_loss += ewc_loss.item()
                            #
                            # curr_task_loss.backward()

                            # curr_task_loss = curr_task_loss / n_curr_task_examples
                            if self.args.ewc and self.task_idx > 0:
                                ewc_loss = self.get_ewc_loss(model=model)
                                tr_ewc_loss += ewc_loss.item()
                                # curr_task_loss += self.args.lmbda * ewc_loss
                                ewc_loss = self.args.lmbda * ewc_loss
                                ewc_loss.backward()


                            # curr_task_loss.backward()

                            # replay_loss.backward()

                            # for replay_task in task_loss:
                            #     curr_task_loss += task_loss[replay_task]
                            # curr_task_loss = curr_task_loss / (
                            #     n_curr_task_examples + n_replay_examples
                            # )

                        else:
                            # If we are with first task then we don't have replay loss
                            # curr_task_loss = curr_task_loss / n_curr_task_examples
                            tr_loss += curr_task_loss.item()
                            if self.args.ewc and self.task_idx > 0:
                                ewc_loss = self.get_ewc_loss(model=model)
                                tr_ewc_loss += ewc_loss.item()
                                # curr_task_loss += self.args.lmbda * ewc_loss
                                ewc_loss = self.args.lmbda * ewc_loss
                                ewc_loss.backward()

                            # curr_task_loss.backward()

                    else:
                        # If sparse replay and we don't decide to do replay then proceed as normal
                        curr_task_loss = curr_task_loss / n_curr_task_examples
                        tr_loss += curr_task_loss.item()

                        if self.args.ewc and self.task_idx > 0:
                            ewc_loss = self.get_ewc_loss(model=model)
                            curr_task_loss += self.args.lmbda * ewc_loss
                            tr_ewc_loss += ewc_loss.item()

                        curr_task_loss.backward()

                elif (
                    self.args.sparse_replay == "yes"
                    and self.args.replay_rate > 0.0
                    and r <= self.args.replay_rate
                ):
                    # If not tiny replay and sparse replay enabled then proceed as normal replay
                    if self.args.supconloss:
                        loss, loss_sc = self._training_step(
                            model, inputs, do_backward=False
                        )
                        tr_loss_sc += loss_sc.item()

                        total_loss = 0.0
                        total_loss += loss
                        if loss_sc.item() != 0.0:
                            total_loss += self.args.lmbda_sc * loss_sc

                        replay_loss, replay_loss_sc, replay_task = self.do_replay(
                            model, optimizer, update_replay=False
                        )

                        if replay_loss is None:
                            total_loss.backward()
                        else:
                            replay_loss = (
                                torch.mean(torch.cat(replay_loss))
                                if len(replay_loss) > 1
                                else replay_loss[0]
                            )
                            replay_loss_sc = (
                                torch.mean(torch.cat(replay_loss_sc))
                                if len(replay_loss_sc) > 1
                                else replay_loss_sc[0]
                            )
                            n_replay_steps += 1

                            total_loss = 0.5 * total_loss + 0.5 * (
                                replay_loss + self.args.lmbda_sc * replay_loss_sc
                            )
                            total_loss.backward()
                            tr_replay_loss += replay_loss.item()
                            tr_replay_loss_sc += replay_loss_sc.item()

                    else:
                        loss = self._training_step(model, inputs, do_backward=False)
                        replay_loss, _, replay_task = self.do_replay(
                            model, optimizer, update_replay=False
                        )
                        if replay_loss is None:
                            loss.backward()
                        else:
                            replay_loss = (
                                torch.mean(torch.cat(replay_loss))
                                if len(replay_loss) > 1
                                else replay_loss[0]
                            )
                            n_replay_steps += 1
                            total_loss = 0.0
                            total_loss += 0.5 * loss
                            total_loss += 0.5 * replay_loss
                            total_loss.backward()
                            tr_replay_loss += replay_loss.item()

                    tr_loss += loss.item()

                else:
                    # If not tiny replay OR normal replay then proceed as normal
                    if self.args.supconloss:
                        loss, loss_sc = self._training_step(
                            model, inputs, do_backward=False
                        )
                        tr_loss_sc += loss_sc.item()
                        total_loss = 0.0
                        total_loss += loss + self.args.lmbda_sc * loss_sc
                        total_loss.backward()
                    else:
                        loss = self._training_step(model, inputs, do_backward=False)
                        loss.backward()

                    if self.args.ewc and self.task_idx > 0:
                        ewc_loss = self.get_ewc_loss(model=model)
                        loss += self.args.lmbda * ewc_loss
                        tr_ewc_loss += ewc_loss.item()

                    tr_loss += loss.item()

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
                    if self.args.optimizer != "sam" and self.args.optimizer != "asam":
                        total_grad_norm = torch.nn.utils.clip_grad_norm_(
                            model.parameters(), self.args.max_grad_norm
                        ).item()

                        tr_grad_norm += total_grad_norm

                        optimizer.step()
                    else:
                        tr_grad_norm += total_grad_norm

                    if not self.args.disable_scheduler:
                        scheduler.step()
                    model.zero_grad()
                    global_step += 1

                    if self.tb_writer and self.task is not None:
                        key = "{}/train_grad_norm".format(self.task)
                        self.tb_writer.add_scalar(key, total_grad_norm, global_step)
                        key1 = "{}/avg_train_loss".format(self.task)
                        self.tb_writer.add_scalar(
                            key1, tr_loss / global_step, global_step
                        )
                        key2 = "{}/inst_train_loss".format(self.task)
                        self.tb_writer.add_scalar(
                            key2, (tr_loss - prev_tr_loss), global_step
                        )
                        prev_tr_loss = tr_loss

                        if self.args.optimizer == "sam" or self.args.optimizer == "asam":
                            key1 = "{}/avg_max_train_loss".format(self.task)
                            self.tb_writer.add_scalar(
                                key1, tr_max_loss / global_step, global_step
                            )
                            key2 = "{}/inst_max_train_loss".format(self.task)
                            self.tb_writer.add_scalar(
                                key2, (tr_max_loss - prev_tr_max_loss), global_step
                            )
                            prev_tr_max_loss = tr_max_loss


                        if self.args.supconloss:
                            key1 = "{}/avg_supcon_loss".format(self.task)
                            self.tb_writer.add_scalar(
                                key1, tr_loss_sc / global_step, global_step
                            )
                            key2 = "{}/inst_supcon_loss".format(self.task)
                            self.tb_writer.add_scalar(
                                key2, (tr_loss_sc - prev_tr_loss_sc), global_step
                            )
                            prev_tr_loss_sc = tr_loss_sc

                        if self.args.ewc and self.task_idx > 0:
                            key1 = "{}/avg_ewc_loss".format(self.task)
                            self.tb_writer.add_scalar(
                                key1, tr_ewc_loss / global_step, global_step
                            )
                            key2 = "{}/inst_ewc_loss".format(self.task)
                            self.tb_writer.add_scalar(
                                key2, (tr_ewc_loss - prev_ewc_loss), global_step
                            )
                            prev_ewc_loss = tr_ewc_loss

                        if (
                            self.args.replay_tiny
                            and n_replay_steps > 0
                            and self.args.replay_rate > 0.0
                            and r <= self.args.replay_rate
                        ):
                            key1 = "{}/avg_replay_loss".format(self.task)
                            self.tb_writer.add_scalar(
                                key1, tr_replay_loss / n_replay_steps, global_step
                            )
                            key2 = "{}/inst_replay_loss".format(self.task)
                            self.tb_writer.add_scalar(
                                key2,
                                (tr_replay_loss - prev_tr_replay_loss),
                                global_step,
                            )
                            prev_tr_replay_loss = tr_replay_loss

                            if self.args.optimizer == "sam" or self.args.optimizer == "asam":
                                key1 = "{}/avg_max_replay_loss".format(self.task)
                                self.tb_writer.add_scalar(
                                    key1, tr_max_replay_loss / n_replay_steps, global_step
                                )
                                key2 = "{}/inst_max_replay_loss".format(self.task)
                                self.tb_writer.add_scalar(
                                    key2,
                                    (tr_max_replay_loss - prev_tr_max_replay_loss),
                                    global_step,
                                )
                                prev_tr_max_replay_loss = tr_max_replay_loss

                        elif (
                            self.args.sparse_replay == "yes"
                            and self.args.replay_rate > 0.0
                            and r <= self.args.replay_rate
                            and n_replay_steps > 0
                        ):
                            key1 = "{}/{}/avg_replay_loss".format(
                                self.task, replay_task
                            )
                            self.tb_writer.add_scalar(
                                key1, tr_replay_loss / n_replay_steps, global_step
                            )
                            key2 = "{}/{}/inst_replay_loss".format(
                                self.task, replay_task
                            )
                            self.tb_writer.add_scalar(
                                key2,
                                (tr_replay_loss - prev_tr_replay_loss),
                                global_step,
                            )
                            prev_tr_replay_loss = tr_replay_loss

                            if self.args.supconloss:
                                key1 = "{}/{}/avg_replay_supcon_loss".format(
                                    self.task, replay_task
                                )
                                self.tb_writer.add_scalar(
                                    key1,
                                    tr_replay_loss_sc / n_replay_steps,
                                    global_step,
                                )
                                key2 = "{}/{}/inst_replay_supcon_loss".format(
                                    self.task, replay_task
                                )
                                self.tb_writer.add_scalar(
                                    key2,
                                    (tr_replay_loss_sc - prev_tr_replay_loss_sc),
                                    global_step,
                                )
                                prev_tr_replay_loss_sc = tr_replay_loss_sc

                    if self.is_local_master():
                        if (logging_steps > 0 and global_step % logging_steps == 0) or (
                            global_step == 1 and self.args.logging_first_step
                        ):
                            logs = {}
                            if self.args.evaluate_during_training:
                                results = self.evaluate()
                                for key, value in results.items():
                                    if self.task is not None:
                                        eval_key = "{}/eval_{}".format(self.task, key)
                                    else:
                                        eval_key = "eval_{}".format(key)
                                    logs[eval_key] = value

                            loss_scalar = (tr_loss - logging_loss) / logging_steps
                            if self.args.replay_tiny and n_replay_steps > 0:
                                replay_loss_scalar = (
                                    tr_replay_loss - logging_replay_loss
                                ) / logging_steps

                            if not self.args.disable_scheduler:
                                learning_rate_scalar = scheduler.get_last_lr()[0]
                            else:
                                learning_rate_scalar = optimizer.param_groups[0]['lr']

                            if self.task is not None:
                                logs[
                                    "{}/learning_rate".format(self.task)
                                ] = learning_rate_scalar
                                logs["{}/train_loss".format(self.task)] = loss_scalar
                                logs["{}/grad_norm".format(self.task)] = (
                                    tr_grad_norm - logging_grad_norm
                                ) / logging_steps
                                if self.args.replay_tiny and n_replay_steps > 0:
                                    logs[
                                        "{}/replay_loss".format(self.task)
                                    ] = replay_loss_scalar
                            else:
                                logs["learning_rate"] = learning_rate_scalar
                                logs["train_loss"] = loss_scalar
                                logs["grad_norm".format(self.task)] = (
                                    tr_grad_norm / logging_steps
                                )
                                if self.args.replay_tiny and n_replay_steps > 0:
                                    logs["replay_loss"] = replay_loss_scalar
                            logging_loss = tr_loss
                            tr_grad_norm = 0.0
                            if self.args.replay_tiny and n_replay_steps > 0:
                                logging_replay_loss = tr_replay_loss

                            if self.tb_writer:
                                for k, v in logs.items():
                                    self.tb_writer.add_scalar(k, v, global_step)
                            if is_wandb_available():
                                wandb.log(logs, step=global_step)

                            epoch_iterator.write(
                                json.dumps({**logs, **{"step": global_step}})
                            )
                            self.all_logs[epoch][global_step] = logs

                            if self.args.evaluate_during_training:
                                key = (
                                    "{}/eval_{}".format(self.task, self.task_metric)
                                    if self.task is not None
                                    else "eval_{}".format(self.task_metric)
                                )
                                if logs[key] > self.best_score:
                                    self.best_score = logs[key]
                                    patience = 0
                                    self.best_epoch = epoch + 1
                                    self.best_global_step = global_step
                                else:
                                    patience += 1

                                logger.info(
                                    "Best score: {}, epoch: {}, global step: {}".format(
                                        self.best_score,
                                        self.best_epoch,
                                        self.best_global_step,
                                    )
                                )

                                if (
                                    self.args.patience != -1
                                    and patience >= self.args.patience
                                ):
                                    ran_out_of_patience = True
                                    logger.info("Ran out of the patience!")
                                    epoch_iterator.close()
                                    break

                        if (
                            self.args.save_steps > 0
                            and global_step % self.args.save_steps == 0
                        ):
                            # In all cases (even distributed/parallel), self.model is always a reference
                            # to the model we want to save.
                            if hasattr(model, "module"):
                                assert model.module is self.model
                            else:
                                assert model is self.model
                            # Save model checkpoint
                            output_dir = os.path.join(
                                self.args.output_dir,
                                f"{PREFIX_CHECKPOINT_DIR}-{global_step}",
                            )

                            self.save_model(output_dir)
                            # self._rotate_checkpoints()
                            torch.save(
                                optimizer.state_dict(),
                                os.path.join(output_dir, "optimizer.pt"),
                            )
                            torch.save(
                                scheduler.state_dict(),
                                os.path.join(output_dir, "scheduler.pt"),
                            )
                            logger.info(
                                "Saving optimizer and scheduler states to %s",
                                output_dir,
                            )

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break

            logs = {}
            results = self.evaluate()
            for key, value in results.items():
                if self.task is not None:
                    eval_key = "epoch/{}/eval_{}".format(self.task, key)
                else:
                    eval_key = "epoch/eval_{}".format(key)
                logs[eval_key] = value

            if self.tb_writer:
                for k, v in logs.items():
                    self.tb_writer.add_scalar(k, v, (epoch + 1))

            train_iterator.write(json.dumps({**logs, **{"epoch": epoch + 1}}))

            if self.args.replay_only_fraction == "yes":
                update_replay = False
                logger.info(
                    "  Replay only fraction enabled! Therefore disabled replay loss after epoch = {}".format(
                        epoch
                    )
                )

            if ran_out_of_patience:
                logger.info("Ran out of the patience!")
                train_iterator.close()
                break

            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        if self.task is not None:
            logger.info(
                "  Best model for task {} found in epoch = {}".format(
                    self.task, self.best_epoch
                )
            )
        else:
            logger.info("  Best model found in epoch = {}".format(self.best_epoch))

        if self.args.write_rate > 0.0:
            if self.args.write_strategy == "random":
                self.exemplarHandler.write(
                    dataset=train_dataloader.dataset,
                    write_rate=self.args.write_rate,
                    task=self.task,
                    min_examples_per_class=self.args.min_examples_per_class,
                    max_examples_per_class=self.args.max_examples_per_class,
                )
            elif self.args.write_strategy == "mof":
                dataloader = self.get_val_dataloader(
                    val_dataset=self.train_dataset
                )  # Note that we just want dataloader here!
                self.exemplarHandler.write_mof(
                    training_args=self.args,
                    model=self.best_model,
                    dataloader=dataloader,
                    write_rate=self.args.write_rate,
                    task=self.task,
                    min_examples_per_class=self.args.min_examples_per_class,
                    max_examples_per_class=self.args.max_examples_per_class,
                )
            elif self.args.write_strategy == "kmeans":
                dataloader = self.get_val_dataloader(
                    val_dataset=self.train_dataset
                )  # Note that we just want dataloader here!
                self.exemplarHandler.write_kmeans(
                    training_args=self.args,
                    model=self.best_model,
                    dataloader=dataloader,
                    write_rate=self.args.write_rate,
                    task=self.task,
                    min_examples_per_class=self.args.min_examples_per_class,
                    max_examples_per_class=self.args.max_examples_per_class,
                )

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(global_step, tr_loss / global_step)

    def update_to_best_model(self):

        if self.best_model is not None:
            logger.info(
                "Updating to the best model after training on task {}".format(self.task)
            )
            self.model = deepcopy(self.best_model)
            self.best_model = None
        else:
            logger.info(
                "Best model after training on task {} is not available. So returning current model".format(
                    self.task
                )
            )

        return self.model

    def _sam_training_step(self,
                           model: nn.Module,
                           optimizer: torch.optim.Optimizer,
                           inputs: Dict[str, torch.Tensor],
                           task: str = None,
                           replay: bool = False):
        model.train()
        model.zero_grad()
        for k, v in inputs.items():
            inputs[k] = v.to(self.args.device)

        if task is None:
            task = self.task

        # first forward-backward step
        # model.apply(disable_running_stats)
        # enable_running_stats(model)  # <- this is the important line
        outputs = model(**inputs, task=task, replay=replay)
        loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
        loss.mean().backward()
        optimizer.first_step(zero_grad=True)

        # second forward-backward step
        # model.apply(enable_running_stats)
        # disable_running_stats(model)  # <- this is the important line
        outputs = model(**inputs, task=task, replay=replay)
        max_loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
        max_loss.mean().backward()
        total_grad_norm = optimizer.second_step(zero_grad=True, max_grad_norm=self.args.max_grad_norm)

        return loss, max_loss, total_grad_norm

    def _er_sam_training_step(self,
                              model: nn.Module,
                              optimizer: torch.optim.Optimizer,
                              inputs: Dict[str, torch.Tensor],
                              task: str = None,
                              replay: bool = False):
        model.train()
        model.zero_grad()
        for k, v in inputs.items():
            inputs[k] = v.to(self.args.device)

        if task is None:
            task = self.task

        if self.args.batch_mode:
            replay_datasets, replay_tasks = self.exemplarHandler.read_multiple_tasks_batch_mode(
                n_examples_replay=self.args.train_batch_size
            )
        else:
            replay_datasets = self.exemplarHandler.read_multiple_tasks(
                n_examples_replay=self.args.train_batch_size
            )

        if replay_datasets is None:  # Memory is empty so no-op

            # first forward-backward step
            # model.apply(disable_running_stats)
            # enable_running_stats(model)  # <- this is the important line
            outputs = model(**inputs, task=task, replay=replay)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
            loss.mean().backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            # model.apply(enable_running_stats)
            # disable_running_stats(model)  # <- this is the important line
            outputs = model(**inputs, task=task, replay=replay)
            max_loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
            max_loss.mean().backward()
            total_grad_norm = optimizer.second_step(zero_grad=True, max_grad_norm=self.args.max_grad_norm)

            replay_loss = None
            max_replay_loss = None

        else:

            # first forward-backward step
            # model.apply(disable_running_stats)
            # enable_running_stats(model)  # <- this is the important line
            outputs = model(**inputs, task=task, replay=replay)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
            loss.mean().backward()

            # Evaluate replay loss and take backward!
            if self.args.batch_mode:
                (
                    replay_loss,
                    n_replay_examples,
                    task_loss,
                    task_examples,
                ) = self.do_replay_multiple_tasks_batch_mode(model=model,
                                                             replay_datasets=replay_datasets,
                                                             replay_tasks=replay_tasks)
                # replay_loss.backward()
            else:
                (
                    replay_loss,
                    n_replay_examples,
                    task_loss,
                    task_examples,
                ) = self.do_replay_multiple_tasks(model=model, replay_datasets=replay_datasets)
                # replay_loss.backward()

            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            # model.apply(enable_running_stats)
            # disable_running_stats(model)  # <- this is the important line
            outputs = model(**inputs, task=task, replay=replay)
            max_loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
            max_loss.mean().backward()

            # Evaluate replay loss and take backward!
            if self.args.batch_mode:
                (
                    max_replay_loss,
                    n_replay_examples,
                    task_loss,
                    task_examples,
                ) = self.do_replay_multiple_tasks_batch_mode(model=model,
                                                             replay_datasets=replay_datasets,
                                                             replay_tasks=replay_tasks)
                # max_replay_loss.backward()
            else:
                (
                    max_replay_loss,
                    n_replay_examples,
                    task_loss,
                    task_examples,
                ) = self.do_replay_multiple_tasks(model=model, replay_datasets=replay_datasets)
                # max_replay_loss.backward()

            total_grad_norm = optimizer.second_step(zero_grad=True, max_grad_norm=self.args.max_grad_norm)

        return loss, replay_loss, max_loss, max_replay_loss, total_grad_norm

    def _training_step(
        self,
        model: nn.Module,
        inputs: Dict[str, torch.Tensor],
        task: str = None,
        replay: bool = False,
        do_backward: bool = True,
    ):
        model.train()
        for k, v in inputs.items():
            inputs[k] = v.to(self.args.device)

        if task is None:
            task = self.task
        # task = None
        outputs = model(**inputs, task=task, replay=replay)
        loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if replay:
            return loss
        loss_sc = torch.tensor(0)
        if do_backward:
            if self.args.supconloss:
                total_loss = 0.0
                total_loss += loss
                if loss_sc.item() != 0.0:
                    total_loss += self.args.lmbda_sc * loss_sc
                total_loss.backward()
                return loss, loss_sc
            else:
                loss.backward()

        if self.args.supconloss:
            return loss, loss_sc

        return loss

    def estimate_fisher(self, model):
        # Prepare <dict> to store estimated Fisher Information matrix

        old_est_fisher_info = self.est_fisher_info

        est_mean_prev_task = {}
        est_fisher_info = {}
        for n, p in model.named_parameters():
            if p.requires_grad:
                n = n.replace(".", "__")
                est_mean_prev_task[n] = p.detach().clone()
                est_fisher_info[n] = p.detach().clone().zero_()

        # Set model to evaluation mode
        model.to(self.args.device)
        model.eval()

        # Create data-loader to give batches of size 1
        train_dataset = self.train_dataset
        dataloader = self.get_train_dataloader(
            train_dataset=train_dataset, batch_size=1
        )

        n_samples = len(dataloader)

        epoch_iterator = tqdm(dataloader, desc="EWC-Iteration")
        # Estimate the FI-matrix for [self.fisher_n] batches of size 1
        for index, inputs in enumerate(epoch_iterator):
            # break from for-loop if max number of samples has been reached
            if self.args.fisher_n is not None:
                if index >= self.args.fisher_n:
                    break
            # run forward pass of model
            model.zero_grad()
            for k, v in inputs.items():
                inputs[k] = v.to(self.args.device)
            task = self.task
            outputs = model(**inputs, task=task, replay=False)
            loss = outputs[0]

            loss.backward()

            # Square gradients and keep running sum
            for n, p in model.named_parameters():
                if p.requires_grad:
                    n = n.replace(".", "__")
                    if p.grad is not None:
                        est_fisher_info[n] += p.grad.detach().clone() ** 2

        # Normalize by sample size used for estimation
        est_fisher_info = {n: p / index for n, p in est_fisher_info.items()}

        # args.gamma = 1.  # -> hyperparam (online EWC): decay-term for old tasks' contribution to quadratic term
        if old_est_fisher_info is not None:
            est_fisher_info = {
                n: p + self.args.gamma * old_est_fisher_info[n]
                for n, p in est_fisher_info.items()
            }

        return est_mean_prev_task, est_fisher_info

    def get_ewc_loss(self, model):
        """Calculate EWC-loss."""
        if self.task_idx > 0:
            losses = []
            for n, p in model.named_parameters():
                if p.requires_grad:
                    # Retrieve stored mode (MAP estimate) and precision (Fisher Information matrix)
                    n = n.replace(".", "__")
                    mean = self.est_mean_prev_task[n]
                    fisher = self.est_fisher_info[n]

                    # If "online EWC", apply decay-term to the running sum of the Fisher Information matrices
                    if self.args.enable_l2:
                        fisher = self.args.l2_weight
                    else:
                        fisher = self.args.gamma * fisher
                    # Calculate EWC-loss
                    losses.append((fisher * (p - mean) ** 2).sum())
            return (1.0 / 2) * sum(losses)
        else:
            # EWC-loss is 0 if there are no stored mode and precision yet
            return torch.tensor(0.0).to(self.args.device)

    def evaluate(
        self,
        val_dataset: Optional[Dataset] = None,
        prediction_loss_only: Optional[bool] = None,
    ) -> Dict[str, float]:
        """
        Run evaluation and return metrics.
        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent.
        Args:
            val_dataset: (Optional) Pass a dataset if you wish to override
            the one on the instance.
        Returns:
            A dict containing:
                - the eval loss
                - the potential metrics computed from the predictions
        """
        val_dataloader = self.get_val_dataloader(val_dataset)

        output = self._prediction_loop(
            val_dataloader,
            description="Validation",
            prediction_loss_only=prediction_loss_only,
        )

        return output.metrics

    def predict(self, test_dataset: Optional[Dataset] = None) -> PredictionOutput:
        """
        Run prediction and return predictions and potential metrics.
        Depending on the dataset and your use case, your test dataset may contain labels.
        In that case, this method will also return metrics, like in evaluate().
        """
        test_dataloader = self.get_test_dataloader(test_dataset)
        return self._prediction_loop(test_dataloader, description="Test")

    def _prediction_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
    ) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
        Works both with or without labels.
        """

        prediction_loss_only = (
            prediction_loss_only
            if prediction_loss_only is not None
            else self.prediction_loss_only
        )

        # multi-gpu eval
        if self.args.n_gpu > 1 and not isinstance(self.model, torch.nn.DataParallel):
            model = torch.nn.DataParallel(self.model)
        else:
            model = self.model
        model.to(self.args.device)

        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", len(dataloader.dataset))
        logger.info("  Batch size = %d", dataloader.batch_size)
        eval_losses: List[float] = []
        preds: np.ndarray = None
        label_ids: np.ndarray = None
        ex_ids = None
        active_gates_to_log1: List[np.ndarray] = []
        active_gates_to_log2: List[np.ndarray] = []
        model.eval()

        for inputs in tqdm(dataloader, desc=description):
            has_labels = any(
                inputs.get(k) is not None for k in ["labels", "masked_lm_labels"]
            )

            for k, v in inputs.items():
                inputs[k] = v.to(self.args.device)

            with torch.no_grad():
                outputs = model(**inputs, task=self.task)

                if has_labels:
                    step_eval_loss, logits = outputs[:2]
                    eval_losses += [step_eval_loss.mean().item()]
                    outputs = outputs[2:]
                else:
                    logits = outputs[0]
                    outputs = outputs[1:]

                if model.config.output_hidden_states:
                    outputs = outputs[1:]

                if model.config.output_attentions:
                    outputs = outputs[1:]

                if self.args.log_active_gates:

                    gates_to_log1 = outputs[0]
                    gates_to_log2 = outputs[1]

                    if len(active_gates_to_log1) == 0:
                        for i in range(len(gates_to_log1)):
                            active_gates_to_log1.append(gates_to_log1[i])
                            active_gates_to_log2.append(gates_to_log2[i])
                    else:
                        for i in range(len(gates_to_log1)):
                            active_gates_to_log1[i] = np.append(
                                active_gates_to_log1[i], gates_to_log1[i], axis=0
                            )
                            active_gates_to_log2[i] = np.append(
                                active_gates_to_log2[i], gates_to_log2[i], axis=0
                            )

            if not prediction_loss_only:
                if preds is None:
                    preds = logits.detach().cpu().numpy()
                else:
                    preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                if inputs.get("labels") is not None:
                    if label_ids is None:
                        label_ids = inputs["labels"].detach().cpu().numpy()
                    else:
                        label_ids = np.append(
                            label_ids, inputs["labels"].detach().cpu().numpy(), axis=0
                        )
                if inputs.get("guid") is not None:
                    if ex_ids is None:
                        ex_ids = inputs["guid"].detach().cpu().numpy()
                    else:
                        ex_ids = np.append(
                            ex_ids, inputs["guid"].detach().cpu().numpy(), axis=0
                        )

        if (
            self.compute_metrics is not None
            and preds is not None
            and label_ids is not None
        ):
            metrics = self.compute_metrics(
                EvalPrediction(predictions=preds, label_ids=label_ids, guids=ex_ids)
            )
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["loss"] = np.mean(eval_losses)

        if self.args.log_active_gates:
            metrics["active_gates1"] = np.array(
                [
                    np.mean(active_gates_to_log1[i], axis=0)
                    for i in range(len(active_gates_to_log1))
                ]
            ).tolist()
            metrics["active_gates2"] = np.array(
                [
                    np.mean(active_gates_to_log2[i], axis=0)
                    for i in range(len(active_gates_to_log2))
                ]
            ).tolist()

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)

    def get_optimizers(
        self,
        num_training_steps: int,
        only_update_head: bool = False,
    ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
        """
        Setup the optimizer and the learning rate scheduler.
        We provide a reasonable default that works well.
        If you want to use something else, you can pass a tuple in the Trainer's init,
        or override this method in a subclass.
        """
        if self.optimizers is not None:
            return self.optimizers

        # Prepare optimizer and schedule (linear warmup and decay)
        if only_update_head:
            if self.task is not None:
                only_update = ["classifier_{}".format(self.task)]
            else:
                only_update = ["classifier"]

            optimizer_grouped_parameters = [
                {
                    "params": [
                        p
                        for n, p in self.model.named_parameters()
                        if any(nd in n for nd in only_update) and p.requires_grad
                    ],
                    "weight_decay": self.args.weight_decay,
                },
            ]
            total_params = np.sum(
                [np.prod(p.size()) for p in optimizer_grouped_parameters[0]["params"]]
            )

        else:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [
                        p
                        for n, p in self.model.named_parameters()
                        if not any(nd in n for nd in no_decay) and p.requires_grad
                    ],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [
                        p
                        for n, p in self.model.named_parameters()
                        if any(nd in n for nd in no_decay) and p.requires_grad
                    ],
                    "weight_decay": 0.0,
                },
            ]
            total_params = np.sum(
                [np.prod(p.size()) for p in optimizer_grouped_parameters[0]["params"]]
            )
            total_params += np.sum(
                [np.prod(p.size()) for p in optimizer_grouped_parameters[1]["params"]]
            )

        logger.info("Total no. of tunable params: %s", str(total_params))

        if self.args.warmup_steps == 0 and self.args.warmup_ratio != 0:
            warmup_steps = int(self.args.warmup_ratio * num_training_steps)
        else:
            warmup_steps = self.args.warmup_steps

        if self.args.optimizer == "sam":
            base_optimizer = transformers.optimization.AdamW  # define an optimizer for the "sharpness-aware" update
            optimizer = SAM(optimizer_grouped_parameters,
                            base_optimizer,
                            rho=self.args.rho,
                            lr=self.args.learning_rate,
                            eps=self.args.adam_epsilon,)
            logger.info("Initialized with SAM optimizer.")
        elif self.args.optimizer == "asam":
            base_optimizer = transformers.optimization.AdamW  # define an optimizer for the "sharpness-aware" update
            optimizer = SAM(optimizer_grouped_parameters,
                            base_optimizer,
                            rho=self.args.rho,
                            adaptive=True,
                            lr=self.args.learning_rate,
                            eps=self.args.adam_epsilon,)
            logger.info("Initialized with Adaptive SAM optimizer.")
        else:
            optimizer = AdamW(
                optimizer_grouped_parameters,
                lr=self.args.learning_rate,
                eps=self.args.adam_epsilon,
            )

        if self.args.disable_scheduler:
            logger.info("Disabling LR Scheduler!")
            return optimizer

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=num_training_steps,
        )
        return optimizer, scheduler

    def save_model(
        self, output_dir: Optional[str] = None, save_bertmodel: bool = False
    ):
        """
        Saving best-practices: if you use default names for the model,
        you can reload it using from_pretrained().
        Will only save from the master process.
        """
        if self.is_world_master():
            self._save(output_dir=output_dir, save_bertmodel=save_bertmodel)

    def _save(self, output_dir: Optional[str] = None, save_bertmodel: bool = False):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")

        if self.best_model is not None:
            if save_bertmodel:
                self.best_model.bert.save_pretrained(output_dir)
            else:
                self.best_model.save_pretrained(output_dir)

        else:
            if save_bertmodel:
                self.model.bert.save_pretrained(output_dir)
            else:
                self.model.save_pretrained(output_dir)

        logger.info("Saving optimizer and scheduler states to %s", output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

    def _setup_wandb(self):
        """
        Setup the optional Weights & Biases (`wandb`) integration.
        One can override this method to customize the setup if needed.
        """
        wandb.init(name=self.args.logging_dir, config=vars(self.args))
        # keep track of model topology and gradients
        wandb.watch(self.model)

    def is_local_master(self) -> bool:

        return self.args.local_rank in [-1, 0]

    def is_world_master(self) -> bool:
        """
        This will be True only in one process, even in distributed mode,
        even when training on multiple machines.
        """
        return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
