# Copyright 2021 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import itertools
import os
import re
import time

from absl import logging
import mesh_tensorflow.transformer.dataset as transformer_dataset
import data
from evaluation.eval_utils import run_eval
import models.utils
from models.t5_model import T5Model
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
import torch
import torch.utils.tensorboard
import models.utils as model_utils

import globals

CHECKPOINT_FILE_FORMAT = "model-{}.checkpoint"
Head_FILE_FORMAT = "enc_cls_head.checkpoint{}"
#Head_FILE_FORMAT = "_enc_cls_head.checkpoint14000_"


def tokens_to_batches(dataset,
                      sequence_length,
                      batch_size,
                      output_features,
                      mixture_or_task=None):
    """Convert a dataset of token sequences to batches of padded/masked examples.

    Args:
      dataset: tf.data.Dataset containing examples with token sequences.
      sequence_length: dict of int, a dict mapping feature name to length.
      batch_size: int, the number of padded sequences in each batch.
      output_features: list of str, features to include in the dataset.
      mixture_or_task: a Task or Mixture object, used to correctly specify eos if
        provided. If none, eos is always added at the end of the sequence.

    Returns:
      A generator that produces batches of numpy examples.
    """

    if mixture_or_task:
        eos_keys = set(
            k for k, f in mixture_or_task.output_features.items() if f.add_eos)
    else:
        eos_keys = True

    dataset = transformer_dataset.pack_or_pad(
        dataset,
        sequence_length,
        pack=False,
        feature_keys=output_features,
        ensure_eos=eos_keys,
    )

    def _map_fn(ex):
        for key in output_features:
            tensor = ex[key]
            mask = tf.cast(tf.greater(tensor, 0), tensor.dtype)
            ex[key + "_mask"] = mask
        return ex

    dataset = dataset.map(
        _map_fn,
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
    )

    dataset = dataset.batch(batch_size, drop_remainder=False)
    return tfds.as_numpy(dataset)


def _get_dataset(mixture_or_task_or_name,
                 sequence_length,
                 split,
                 shuffle=True):
    """Get a tf.data.Dataset for a given Task or Mixture.

    Args:
      mixture_or_task_or_name: Task or Mixture or str, the name of the Mixture or
        Task to train on or the Tasks or Mixture object itself.
        Must be pre-registered in the global `t5.data.TaskRegistry` or
        `t5.data.MixtureRegistry.`
      sequence_length: dict of int, a dict mapping feature name to length.
      split: str or `tensorflow_datasets.Split`, the data split to load.
      shuffle: boolean, whether to shuffle the dataset.

    Returns:
      A generator that produces batches of numpy examples.
    """
    if isinstance(mixture_or_task_or_name, str):
        task = data.get_mixture_or_task(mixture_or_task_or_name)
    else:
        task = mixture_or_task_or_name

    return task.get_dataset(sequence_length, split, shuffle=shuffle)#, seed=100)#, use_cached=True)


import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoConfig


class HfPyTorchModel(T5Model):
    """Wrapper class for Hugging Face Transformers PyTorch T5 model."""

    def __init__(self, model_spec, model_dir, device, teacher_model_spec=None, teacher_model_dir=None,
                 thresholds_list=[0.0]):
        """Constructor for HfModel class.

        Args:
          model_spec: A str to pass into the `pretrained_model_name_or_path`
            argument of `transformers.T5ForConditionalGeneration.from_pretrained`
            (e.g. `"t5-base"` or a path to a previously trained model) or an
            instance of the `transformers.configuration_t5.T5Config` class to use
            to directly construct the `transformers.T5ForConditionalGeneration`
            object.
          model_dir: str, directory to save and load model checkpoints.
          device: `torch.device` on which the model should be run.
        """
        # We have to import transformers here because it has a side effect of
        # creating a TensorFlow graph, which prevents eager execution from being
        # enabled in files that import hf_model.py

        import transformers  # pylint: disable=import-outside-toplevel,g-import-not-at-top

        if isinstance(model_spec, str):
            self._model = transformers.T5ForConditionalGeneration.from_pretrained(model_spec)
        elif isinstance(model_spec, transformers.T5Config):
            self._model = transformers.T5ForConditionalGeneration(model_spec)
        else:
            raise ValueError("model_spec should be a string or T5Config.")

        tf.io.gfile.makedirs(model_dir)
        self._writer = torch.utils.tensorboard.writer.SummaryWriter(model_dir)
        self._model_dir = model_dir
        self._device = device
        self._step = 0
        self.to_tensor = functools.partial(torch.as_tensor, device=self._device)

        self.teacher_model_dir = teacher_model_dir
        if (teacher_model_spec is not None and teacher_model_dir is not None):
            if isinstance(teacher_model_spec, str):
                self.teacher_model = transformers.T5ForConditionalGeneration.from_pretrained(teacher_model_spec)
            elif isinstance(model_spec, transformers.T5Config):
                self.teacher_model = transformers.T5ForConditionalGeneration(teacher_model_spec)
            else:
                raise ValueError("teacher_model_spec should be a string or T5Config.")
            self.teacher_step = 0
        else:
            self.teacher_model = None

        self.thresholds_list = thresholds_list
        self.selector_threshold = None
        self.enc_cls_head = None

    @property
    def model(self):
        return self._model

    @property
    def step(self):
        return self._step

    def save_checkpoint(self, step):
        """Save the current model parameters to the `model_dir`.

        Args:
          step: int, the current training step.
        """
        path = os.path.join(self._model_dir, CHECKPOINT_FILE_FORMAT.format(step))
        # torch.save(self._model.state_dict(), path)

        if isinstance(self._model, DistributedDataParallel) or isinstance(self._model, nn.DataParallel):
            torch.save(self._model.module.state_dict(), path)
        else:
            torch.save(self._model.state_dict(), path)

    def load_checkpoint(self, step, model_dir=None, local_rank=None, teacher_step=None, teacher_model_dir=None):
        """Load the model parameters from a checkpoint at a given step.

        Args:
          step: int, load the checkpoint from this training step.
          model_dir: str, the directory of the checkpoint to load or None to use
            this model's directory.
        """

        if step != 0:
            model_dir = model_dir or self._model_dir

            path = os.path.join(model_dir, CHECKPOINT_FILE_FORMAT.format(step))
            logging.info("Loading from %s", path)

            # if local_rank != None:
            #    map_location = {"cuda:0": "cuda:0"}
            state_dict = torch.load(path, map_location='cpu')
            # else:
            #    state_dict = torch.load(path)
            self._model.load_state_dict(state_dict)
            self._step = step

        ### load teacher
        if (self.teacher_model_dir is not None or teacher_model_dir is not None) and teacher_step is not None:
            if teacher_step != 0:
                teacher_model_dir = teacher_model_dir or self.teacher_model_dir
                teacher_path = os.path.join(teacher_model_dir, CHECKPOINT_FILE_FORMAT.format(teacher_step))
                logging.info("Loading from %s", teacher_path)
                teacher_state_dict = torch.load(teacher_path, map_location='cpu')
                self.teacher_model.load_state_dict(teacher_state_dict)
                self.teacher_step = teacher_step


    def get_all_checkpoint_steps(self, model_dir=None, teacher_model_dir=None):
        """Retrieve the steps corresponding to all checkpoints in `model_dir`.

        Args:
          model_dir: str, the directory of the checkpoints or None to use this
            model's directory.

        Returns:
          A list of ints corresponding to all checkpoint steps, or None if there
            are no checkpoints in the model directory.
        """

        model_dir = model_dir or self._model_dir
        checkpoint_files = tf.io.gfile.glob(os.path.join(model_dir, CHECKPOINT_FILE_FORMAT.format("*")))

        ### teacher
        teacher_checkpoint_files = None
        if (teacher_model_dir is not None or self.teacher_model_dir is not None):
            teacher_model_dir = teacher_model_dir or self.teacher_model_dir
            teacher_checkpoint_files = tf.io.gfile.glob(
                os.path.join(teacher_model_dir, CHECKPOINT_FILE_FORMAT.format("*")))

        step_regex = re.compile(".*" + CHECKPOINT_FILE_FORMAT.format(r"(\d+)"))
        steps = [0]
        teacher_steps = [0]
        if checkpoint_files:
            steps = [int(step_regex.match(path).group(1)) for path in checkpoint_files]
        if teacher_checkpoint_files:
            teacher_steps = [int(step_regex.match(path).group(1)) for path in teacher_checkpoint_files]

        return sorted(steps), sorted(teacher_steps)

    def get_latest_checkpoint_step(self, model_dir=None, teacher_model_dir=None):
        """Retrieve the step corresponding to the most recent checkpoint.

        Args:
          model_dir: str, the directory of the checkpoints or None to use this
            model's directory.

        Returns:
          An integer corresponding to the most recent step, or None if there are no
          checkpoints in the model directory.
        """

        steps, teacher_steps = self.get_all_checkpoint_steps(model_dir, teacher_model_dir)
        return max(steps), max(teacher_steps)

    def load_latest_checkpoint(self, local_rank=None):
        """Load the most recent checkpoint and update the model's current step."""
        latest_step, teacher_latest_step = self.get_latest_checkpoint_step()

        self.load_checkpoint(step=latest_step, teacher_step=teacher_latest_step, local_rank=local_rank)

    def train_head(self, mixture_or_task_name, steps, save_steps, sequence_length, split, batch_size, optimizer,
                   learning_rate_scheduler=None):
        """Train the model on the given Mixture or Task.

        Args:
          mixture_or_task_name: str, the name of the Mixture or Task to train on.
            Must be pre-registered in the global `t5.data.TaskRegistry` or
            `t5.data.MixtureRegistry.`
          steps: int, the total number of steps to train for.
          save_steps: int, the number of steps between checkpoint saves.
          sequence_length: dict of int, a dict mapping feature name to length.
          split: str or `tensorflow_datasets.Split`, the data split to load.
          batch_size: int, the number of padded sequences in each batch.
          optimizer: function that takes the model parameters as its sole argument.
            For example, to use an AdamW optimizer with a learning rate of 1e-4,
            you could pass in `functools.partial(transformers.AdamW, lr=1e-4)`.
          learning_rate_scheduler: optional function that takes in an optimizer as
            its sole argument. For example, to use a schedule that warms up the
            optimizer's learning rate after 100 steps, you could pass in
            `functools.partial(transformers.get_constant_schedule_with_warmup,
           num_warmup_steps=100)`.
        """

        self.load_latest_checkpoint()

        #### model parallelism
        if torch.cuda.device_count() == 8:
            # for t5-11b
            device_map = {0: [0, 1, 2],
                          1: [3, 4, 5],
                          2: [6, 7, 8],
                          3: [9, 10, 11],
                          4: [12, 13, 14],
                          5: [15, 16, 17],
                          6: [18, 19, 20],
                          7: [21, 22, 23]}
            self._model.parallelize(device_map)  # Splits the model across several devices
        elif torch.cuda.device_count() == 6:
            # for t5-11b
            device_map = {0: [0, 1, 2, 3],
                          1: [4, 5, 6, 7],
                          2: [8, 9, 10, 11],
                          3: [12, 13, 14, 15],
                          4: [16, 17, 18, 19],
                          5: [20, 21, 22, 23]}
            self._model.parallelize(device_map)  # Splits the model across several devices\
        elif torch.cuda.device_count() == 4:
            # for t5-3b
            device_map = {0: [0, 1, 2],
                          1: [3, 4, 5, 6, 7, 8, 9],
                          2: [10, 11, 12, 13, 14, 15, 16],
                          3: [17, 18, 19, 20, 21, 22, 23]}
            self._model.parallelize(device_map)  # Splits the model across several devices
        elif torch.cuda.device_count() == 2:
            # for t5-large (2 gpus)
            device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                          1: [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]}
            self._model.parallelize(device_map)  # Splits the model across several devices
        else:
            self._model.cuda()
        #############################

        ds = _get_dataset(mixture_or_task_name, sequence_length, split)
        task = data.get_mixture_or_task(mixture_or_task_name)
        ds = tokens_to_batches(ds, sequence_length, batch_size,
                               tuple(task.output_features), task)
        # Repeat dataset forever
        ds = itertools.cycle(ds)

        vocabulary = model_utils.get_vocabulary(mixture_or_task_name)
        if task.name in ['wmt16_enro_v003','wmt15_enfr_v003','wmt14_ende_v003','super_glue_wsc_v102_simple_train']:
            # the head is very big for Translation tasks!
            enc_cls_head = torch.nn.Sequential(
                torch.nn.Linear(self._model.model_dim, vocabulary[1].vocab_size, bias=False),
            ).cuda()
        else:
            #### if finetuning enc_cls_head with nn.Linear
            original_labels = task.preprocessors[0].keywords['label_names']
            enc_cls_head = torch.nn.Linear(sequence_length["inputs"] * self._model.model_dim, len(original_labels), bias=False).cuda()
            # dec_cls_head = torch.nn.Linear(sequence_length["targets"] * self._model.model_dim, len(original_labels),bias=False).cuda()
            encoded_original_labels = []
            for original_label in original_labels:
                encoded_original_label = vocabulary[1].encode(original_label) + [1]
                encoded_original_label += [0] * (sequence_length["targets"] - len(encoded_original_label))
                encoded_original_labels.append(encoded_original_label)
            encoded_original_labels = self.to_tensor(encoded_original_labels).cuda()

        ### freeze the entire model excel the new head
        enc_cls_head.train()
        # dec_cls_head.train()
        self._model.eval()

        optimizer = optimizer(list(enc_cls_head.parameters()))  # +list(self._model.parameters()))
        if learning_rate_scheduler:
            learning_rate_scheduler = learning_rate_scheduler(optimizer)

        now = time.time()
        head_criterion = torch.nn.CrossEntropyLoss()
        #if task.name == 'wmt16_enro_v003':
        #    head_criterion = torch.nn.MSELoss()

        for train_step, batch in enumerate(itertools.islice(ds, steps)):
            if not train_step % save_steps:
                # TODO(craffel): Consider saving optimizer and scheduler state.
                logging.info("Saving head checkpoint for step %s", train_step)
                # self.save_checkpoint(self._step)
                torch.save(enc_cls_head.state_dict(), self._model_dir + '/enc_cls_head.checkpoint' + str(self._step))

            input_ids = self.to_tensor(batch["inputs"]).type(torch.long).to('cuda:0')
            attention_mask = self.to_tensor(batch["inputs_mask"]).to('cuda:0')
            decoder_attention_mask = self.to_tensor(batch["targets_mask"]).to('cuda:0')
            labels = self.to_tensor(batch["targets"]).type(torch.long).to('cuda:0')

            enc_cls_head.zero_grad()
            # self._model.zero_grad()

            with torch.no_grad():
                outputs = self._model(input_ids=input_ids, attention_mask=attention_mask,
                                      decoder_attention_mask=decoder_attention_mask, labels=labels,
                                      output_hidden_states=True)

            ### if finetuning enc_cls_head with nn.Linear
            # outputs.last_hidden_state # decoder
            # dec_logits = dec_cls_head(outputs.decoder_hidden_states[-1].view(outputs.decoder_hidden_states[-1].shape[0], -1))

            ### convert encoded_labels to numeric_labels
            if task.name in ['wmt16_enro_v003','wmt15_enfr_v003','wmt14_ende_v003','super_glue_wsc_v102_simple_train']:
                enc_logits = enc_cls_head(
                    outputs.encoder_last_hidden_state)
                enc_cls_head_labels = labels.view(-1)
                cls_loss = head_criterion(enc_logits.view(-1, enc_logits.size(-1)),
                                          enc_cls_head_labels)
            else:
                enc_logits = enc_cls_head(
                    outputs.encoder_last_hidden_state.view(outputs.encoder_last_hidden_state.shape[0], -1))
                enc_cls_head_labels = torch.zeros(labels.shape[0]).cuda()
                for j in range(len(encoded_original_labels)):
                    indices = torch.where((labels == encoded_original_labels[j]).all(dim=1))[0]
                    enc_cls_head_labels.index_fill_(0, indices, j)
                enc_cls_head_labels = enc_cls_head_labels.long()

                cls_loss = head_criterion(enc_logits, enc_cls_head_labels)  # - (head_criterion(enc_logits, (~enc_cls_head_labels.bool()).long()))
                # cls_loss = head_criterion(dec_logits, enc_cls_head_labels)

            ########### energy loss ###########
            '''
            mask_gt = torch.zeros(enc_logits.shape)
            for b in range(enc_logits.shape[0]):
              mask_gt[b,enc_cls_head_labels[b]] = 1
            enc_logits_gt = torch.masked_select(enc_logits, mask_gt.bool())
            enc_logits_others = torch.masked_select(enc_logits, ~mask_gt.bool())
            # m_in = -23  # based on the paper
            # m_out = -5  # based on the paper
            Ec_in = -torch.logsumexp(enc_logits_gt,dim=1)
            Ec_out = -torch.logsumexp(enc_logits_others,dim=1)
            # energy_loss += 0.001 * (torch.pow(F.relu(Ec_in - m_in), 2).mean() + torch.pow(F.relu(m_out - Ec_out), 2).mean())
            '''
            # energy_loss = (-torch.logsumexp(enc_logits, dim=1)).mean()
            # print(energy_loss)
            ###################################
            ###

            loss = cls_loss  # + outputs[0] #+ 0.01 * energy_loss #+ outputs[0]
            print(loss)
            # lm_logits = outputs[1] # outputs.logit

            loss.backward()

            optimizer.step()

            if learning_rate_scheduler:
                learning_rate_scheduler.step()

            # optimizer.zero_grad()

            now = time.time()
            self._step += 1

        # logging.info("Saving final checkpoint for step %s", self._step)
        logging.info("Saving final head checkpoint for step %s", str(self._step))

        # self.save_checkpoint(self._step)
        torch.save(enc_cls_head.state_dict(), self._model_dir + '/enc_cls_head.checkpoint' + str(self._step))

    def train(
            self,
            argv,
            mixture_or_task_name,
            steps,
            save_steps,
            sequence_length,
            split,
            batch_size,
            optimizer,
            learning_rate_scheduler=None,
            train_split_ratio=None,
    ):
        """Train the model on the given Mixture or Task.

        Args:
          mixture_or_task_name: str, the name of the Mixture or Task to train on.
            Must be pre-registered in the global `t5.data.TaskRegistry` or
            `t5.data.MixtureRegistry.`
          steps: int, the total number of steps to train for.
          save_steps: int, the number of steps between checkpoint saves.
          sequence_length: dict of int, a dict mapping feature name to length.
          split: str or `tensorflow_datasets.Split`, the data split to load.
          batch_size: int, the number of padded sequences in each batch.
          optimizer: function that takes the model parameters as its sole argument.
            For example, to use an AdamW optimizer with a learning rate of 1e-4,
            you could pass in `functools.partial(transformers.AdamW, lr=1e-4)`.
          learning_rate_scheduler: optional function that takes in an optimizer as
            its sole argument. For example, to use a schedule that warms up the
            optimizer's learning rate after 100 steps, you could pass in
            `functools.partial(transformers.get_constant_schedule_with_warmup,
           num_warmup_steps=100)`.
        """

        self.load_latest_checkpoint()

        #### model parallelism
        # world_size === 16?
        # rank === [0, 15]
        if torch.cuda.device_count() == 8:
            # for t5-11b
            device_map = {0: [0, 1, 2],
                          1: [3, 4, 5],
                          2: [6, 7, 8],
                          3: [9, 10, 11],
                          4: [12, 13, 14],
                          5: [15, 16, 17],
                          6: [18, 19, 20],
                          7: [21, 22, 23]}
            self._model.parallelize(device_map)  # Splits the model across several devices
        elif torch.cuda.device_count() == 6:
            # for t5-11b
            device_map = {0: [0, 1, 2, 3],
                          1: [4, 5, 6, 7],
                          2: [8, 9, 10, 11],
                          3: [12, 13, 14, 15],
                          4: [16, 17, 18, 19],
                          5: [20, 21, 22, 23]}
            self._model.parallelize(device_map)  # Splits the model across several devices
        elif torch.cuda.device_count() == 4:
            # for t5-3b
            device_map = {0: [0, 1, 2],
                          1: [3, 4, 5, 6, 7, 8, 9],
                          2: [10, 11, 12, 13, 14, 15, 16],
                          3: [17, 18, 19, 20, 21, 22, 23]}
            self._model.parallelize(device_map)  # Splits the model across several devices
        elif torch.cuda.device_count() == 2:
            # for t5-large (2 gpus)
            device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                          1: [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]}
            self._model.parallelize(device_map)  # Splits the model across several devices
        else:
            self._model.cuda()
        #############################

        # self._model = self._model.half()
        if False:  # torch.cuda.device_count() > 1:
            self._model = nn.DataParallel(self._model, device_ids=[0, 1, 2, 3])  # .cuda()

        #tf.set_random_seed(100)
        #ds = ds.shuffle(buffer_size=1000)#, seed=100)
        if train_split_ratio:
            # dont shuffle
            ds = _get_dataset(mixture_or_task_or_name=mixture_or_task_name, sequence_length=sequence_length,
                              split=split, shuffle=False)
            ds_size = tf.data.experimental.cardinality(ds).numpy()
            if train_split_ratio >= 0.5:
                ds = ds.take(ds_size * train_split_ratio)
            else:
                ds = ds.skip(ds_size * (1.0-train_split_ratio))
        else:
            ds = _get_dataset(mixture_or_task_or_name=mixture_or_task_name, sequence_length=sequence_length,
                              split=split, shuffle=True)

        task = data.get_mixture_or_task(mixture_or_task_name)
        ds = tokens_to_batches(ds, sequence_length, batch_size,
                               tuple(task.output_features), task)
        # Repeat dataset forever
        ds = itertools.cycle(ds)

        #### if finetuning enc_cls_head with nn.Linear
        self._model.train()

        optimizer = optimizer(self._model.parameters())
        if learning_rate_scheduler:
            learning_rate_scheduler = learning_rate_scheduler(optimizer)

        #### for fp16
        # scaler = torch.cuda.amp.GradScaler()

        now = time.time()

        for train_step, batch in enumerate(itertools.islice(ds, steps)):
            # input("test")
            if not train_step % save_steps:
                # TODO(craffel): Consider saving optimizer and scheduler state.
                logging.info("Saving checkpoint for step %s", self._step)
                self.save_checkpoint(self._step)
                ################ evaluate after each checkpoint saved ##############
                if argv.eval:
                    if argv.task == 'super_glue_wsc_v102_simple_train':
                        argv.task = 'super_glue_wsc_v102_simple_eval'
                    # only consider Teacher in evaluation for joint-inference (not distillation)
                    if argv.distill:
                        model.teacher_model_dir = None
                    # For translation, we need to pass the max_length, otherwise, it does not work
                    sequence_length = {"inputs": argv.sequence_length_inputs, "targets": argv.sequence_length_targets}
                    self.eval(
                        argv.task,
                        sequence_length=sequence_length,
                        # we cannot enable the following option when head used. Because head has been trained with a fixed length (e.g., 256). The exact same length should be passed in this case.
                        # compute_sequence_length=True,
                        batch_size=argv.eval_batch_size,
                        checkpoint_steps=None,
                        split=argv.eval_split,
                        head=argv.head,
                        head_prediction=argv.head_prediction,
                        model_type=argv.model_type,
                    )
                    if argv.task == 'super_glue_wsc_v102_simple_eval':
                        argv.task = 'super_glue_wsc_v102_simple_train'
                ####################################################################

            input_ids = self.to_tensor(batch["inputs"]).type(torch.long).to('cuda:0')
            attention_mask = self.to_tensor(batch["inputs_mask"]).to('cuda:0')
            decoder_attention_mask = self.to_tensor(batch["targets_mask"]).to('cuda:0')
            labels = self.to_tensor(batch["targets"]).type(torch.long).to('cuda:0')

            self._model.zero_grad()

            #### for fp16
            # with torch.cuda.amp.autocast():
            outputs = self._model(input_ids=input_ids, attention_mask=attention_mask,
                                  decoder_attention_mask=decoder_attention_mask, labels=labels)

            loss = outputs[0]
            # lm_logits = outputs[1] # outputs.logit

            if False:  # torch.cuda.device_count() > 1:
                loss.sum().backward()
            else:
                loss.backward()
                #### for fp16
                # scaler.scale(loss).backward()

            # loss.backward(torch.ones_like(loss))

            optimizer.step()
            #### for fp16
            # scaler.step(optimizer)
            # scaler.update()

            if learning_rate_scheduler:
                learning_rate_scheduler.step()

            # torch.cuda.empty_cache()

            # self._writer.add_scalar("loss", loss.detach().cpu().numpy(), self._step)
            # self._writer.add_scalar("step/s", 1 / (time.time() - now), self._step)
            now = time.time()
            self._step += 1

        logging.info("Saving final checkpoint for step %s", self._step)

        self.save_checkpoint(self._step)

    def eval(
            self,
            mixture_or_task_name,
            sequence_length,
            batch_size,
            checkpoint_steps=None,
            summary_dir=None,
            split="validation",
            compute_sequence_length=False,
            head_prediction=False,
            head=False,
            router='energy',
            model_type='-',
            **generate_kwargs,
    ):
        """Evaluate the model on the given Mixture or Task.

        *Note*: If a checkpoint step is provided (i.e. `checkpoint_steps is not
        None`), the model's state will be replaced by the state in those
        checkpoints. If you have not saved your model before calling `eval`, you
        should call `save_checkpoint` before `eval` to avoid losing its parameter
        values and state.

        Args:
          mixture_or_task_name: str, the name of the Mixture or Task to evaluate
            on.  Must be pre-registered in the global `t5.data.TaskRegistry` or
            `t5.data.MixtureRegistry.`
          sequence_length: dict of int, a dict mapping feature name to length.
          batch_size: int, the number of padded sequences in each batch.
          checkpoint_steps: int, list of ints, "all", or None. If None, eval in the
            model in its current state without loading any checkpoints. If an int
            or list of ints, evaluation will be run on the checkpoint files in
            `model_dir` whose global steps are those provided. If -1, eval on the
            latest checkpoint from the model directory. If "all", evaluate all
            checkpoints in the model directory.
          summary_dir: str, path to write TensorBoard events file summaries for
            eval. If None, use model_dir/{split}_eval.
          split: str, the mixture/task split to evaluate on.
          compute_sequence_length: bool, automatically compute sequence length
            during eval mode.
          **generate_kwargs: Additional keyword arguments to pass to
            `transformers.PretrainedModel.generate()`, for example to change the
            decoding strategy. See the documentation for
            `transformers.PretrainedModel.generate()` for options.
        """

        self.load_latest_checkpoint()

        #### model parallelism
        if torch.cuda.device_count() == 8:
            # for t5-11b
            device_map = {0: [0, 1, 2],
                          1: [3, 4, 5],
                          2: [6, 7, 8],
                          3: [9, 10, 11],
                          4: [12, 13, 14],
                          5: [15, 16, 17],
                          6: [18, 19, 20],
                          7: [21, 22, 23]}
            # don't parallelize the student in the joint mode
            if self._model.config.num_layers == 24 and not self.teacher_model:
                self._model.parallelize(device_map)  # Splits the model across several devices
            else:
                self._model.cuda()
            if self.teacher_model:
                self.teacher_model.parallelize(device_map)  # Splits the model across several devices
        elif torch.cuda.device_count() == 6:
            # for t5-11b
            device_map = {0: [0, 1, 2, 3],
                          1: [4, 5, 6, 7],
                          2: [8, 9, 10, 11],
                          3: [12, 13, 14, 15],
                          4: [16, 17, 18, 19],
                          5: [20, 21, 22, 23]}
            # don't parallelize the student in the joint mode
            if self._model.config.num_layers == 24 and not self.teacher_model:
                self._model.parallelize(device_map)  # Splits the model across several devices
            else:
                self._model.cuda()
            if self.teacher_model:
                self.teacher_model.parallelize(device_map)  # Splits the model across several devices
        elif torch.cuda.device_count() == 4:
            # for t5-3b
            device_map = {0: [0, 1, 2],
                          1: [3, 4, 5, 6, 7, 8, 9],
                          2: [10, 11, 12, 13, 14, 15, 16],
                          3: [17, 18, 19, 20, 21, 22, 23]}
            # don't parallelize the student in the joint mode
            if self._model.config.num_layers == 24 and not self.teacher_model:
                self._model.parallelize(device_map)  # Splits the model across several devices
            else:
                self._model.cuda()
            if self.teacher_model:
                self.teacher_model.parallelize(device_map)  # Splits the model across several devices
        elif torch.cuda.device_count() == 2:
            # for t5-large (2 gpus)
            device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                          1: [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]}
            if self._model.config.num_layers == 24 and not self.teacher_model:
                self._model.parallelize(device_map)  # Splits the model across several devices
            else:
                self._model.cuda()
            if self.teacher_model:
                self.teacher_model.parallelize(device_map)  # Splits the model across several devices
        else:
            self._model.cuda()
            if self.teacher_model:
                self.teacher_model.cuda()
        #############################

        ###### for now - just to allow model parallelism for 11B
        # if torch.cuda.device_count() >= 1:
        #  self._model.cuda()
        #  if self.teacher_model:
        #      self.teacher_model.cuda()

        # self._model = self._model.half()

        def _predict_from_tasks(tasks, vocabulary, checkpoint_step, sequence_length,
                                datasets, **unused_kwargs):

            if isinstance(vocabulary, tuple):
                vocab = vocabulary[1]

            if checkpoint_step != self._step:
                # self.load_checkpoint(checkpoint_step)
                self.load_checkpoint(step=checkpoint_step, teacher_step=None)

            self._model.eval()
            if self.teacher_model:
                self.teacher_model.eval()

            outputs = []
            for task in tasks:
                if compute_sequence_length:
                    # it seems the tokenization happens in the following function:
                    ds = _get_dataset(task.name, sequence_length, split, shuffle=False)
                else:
                    ds = datasets[task.name]

                ds_size = tf.data.experimental.cardinality(ds).numpy()
                print('**** number of samples: ' + str(ds_size))

                ds = list(tokens_to_batches(ds, sequence_length, batch_size, tuple(task.output_features), task))

                import torch.nn.functional as F
                energy_score_file = router + "-t5-" + mixture_or_task_name + "-"+("enc" if head else "dec")+"-logits-" + model_type + ".txt"
                f = open(energy_score_file, "w")

                teacher_count = 0
                student_correct = 0
                score = 0
                t1 = 0
                import random
                for batch in ds:
                    if router == 'random':
                        rnd_value = random.uniform(0, 1)

                    input_ids = self.to_tensor(batch["inputs"]).type(torch.long).to('cuda:0')
                    labels = self.to_tensor(batch["targets"]).type(torch.long).to('cuda:0')
                    if (router != 'random') or (router == 'random' and rnd_value <= self.selector_threshold):
                        t0 = time.time()
                        generated_outputs = self._model.generate(input_ids=input_ids, return_dict_in_generate=True,
                                                                 output_attentions=True, output_hidden_states=True,
                                                                 output_scores=True, max_length=sequence_length["targets"], **generate_kwargs)
                        t1 = t1 + (time.time() - t0)
                        predicted_tokens = generated_outputs.sequences

                    if head:
                        if task.name in ['wmt16_enro_v003','wmt15_enfr_v003','wmt14_ende_v003','super_glue_wsc_v102_simple_eval']:
                            enc_logits = self.enc_cls_head(generated_outputs.encoder_hidden_states[-1].to('cuda:'+str(torch.cuda.device_count()-1)))
                            if router == 'energy':
                                score = torch.logsumexp(enc_logits, dim=-1).mean()
                            enc_softmax = F.softmax(enc_logits, dim=-1)
                            if router == 'entropy':
                                score = -torch.sum(torch.multiply(enc_softmax, torch.log(enc_softmax)), axis=1)
                            if router == 'softmax':
                                score = enc_softmax.max()
                            #softmax_indices = enc_softmax.argmax(dim=-1)
                            #energy_score = torch.logsumexp(enc_logits[:,softmax_indices.nonzero()[:,1],:], dim=-1).mean()
                            #print(energy_score)
                            if head_prediction:
                                #predictions = []
                                predicted_tokens_head = enc_softmax.argmax(dim=-1).cpu().numpy().tolist()
                                predictions = [vocab.decode(p) for p in predicted_tokens_head]
                        else:
                            enc_logits = self.enc_cls_head(generated_outputs.encoder_hidden_states[-1].view(
                                generated_outputs.encoder_hidden_states[-1].shape[0], -1).to('cuda:'+str(torch.cuda.device_count()-1)))
                            ### following does not work. the shape is tuple(6)*tuple(13)*tensor(64*768), but it should be tuple(13)*tensor(64*6*768)
                            # dec_logits = dec_cls_head(generated_outputs.decoder_hidden_states[-1][-1].view(generated_outputs.decoder_hidden_states[-1][-1].shape[0],-1))
                            if router=='energy':
                                score = torch.logsumexp(enc_logits, dim=-1)[0]
                            enc_softmax = F.softmax(enc_logits, dim=-1)
                            if router == 'entropy':
                                score = -torch.sum(torch.multiply(enc_softmax, torch.log(enc_softmax)), axis=1)
                            if router == 'softmax':
                                score = enc_softmax.max()
                            if head_prediction:
                                predictions = []
                                for prediction in enc_softmax.argmax(dim=-1):
                                    predictions.append(task.preprocessors[0].keywords['label_names'][prediction])

                    if self.teacher_model:
                        if (not head) and router != 'random':
                            logit_scores = generated_outputs.scores
                            ##### energy calculation
                            # NOTE: len(logit_scores) == len(sequence_length_targets)
                            score = 0.0
                            for logits in logit_scores:#[0:1]:
                                # input(logits)
                                ############ consider only the logits related to the used words
                                ### mrpc: 'not_equivalent' with [59, 834, 15, 1169, 15592, 1] AND 'equivalent' with [7072, 1]
                                # indices = torch.tensor([7072, 59,   834,    15,  1169, 15592, 1]).cuda()
                                # indices = torch.tensor([7072, 59, 834]).cuda()
                                # logits = torch.index_select(input=logits, dim=1, index=indices).cuda()

                                ### max over logits
                                # energy_score = logits.max()

                                ### energy over logits
                                score = score + torch.logsumexp(logits, dim=1)

                                ### softmax
                                #softmax_score = F.softmax(logits, dim=1)
                                #energy_score = softmax_score.max()

                                ### log_softmax (based on the generate function)
                                # log_softmax_score = F.log_softmax(logits, dim=1)
                                # energy_score=log_softmax_score.max()

                            # print(energy_score)
                            score = score / len(logit_scores)
                            # if predictions==targets:
                            #    res = 'True'
                            # else:
                            #    res = 'False'

                        #if (predictions==['entailment'] and energy_score < -4.915) or (predictions==['not_entailment'] and energy_score < 3.214):
                        #if router=='soft' and score < self.selector_threshold:
                        if ((router=='energy' or router=='softmax') and score < self.selector_threshold) or (router=='entropy' and score > self.selector_threshold) or (router=='random' and rnd_value > self.selector_threshold):
                            generated_outputs = self.teacher_model.generate(input_ids=input_ids,
                                                                            return_dict_in_generate=True,
                                                                            output_attentions=False,
                                                                            output_hidden_states=False,
                                                                            output_scores=True, max_length=sequence_length["targets"], **generate_kwargs)
                            predicted_tokens = generated_outputs.sequences
                            teacher_count = teacher_count + 1
                        else: # this block is just for calculating student's accuracy over its own samples
                            predictions = [vocab.decode(p) for p in predicted_tokens.cpu().numpy().tolist()]
                            targets = [vocab.decode(p) for p in labels.cpu().numpy().tolist()]
                            if (targets == predictions):
                                student_correct += 1
                            #    f.write('1' + '\t')
                            #else:
                            #    f.write('0' + '\t')

                        #f.write(str(score.cpu().detach().numpy()) + '\n')
                        if router != 'random':
                            f.write(str(score.cpu().detach().numpy()) + '\n')

                    if not head or not head_prediction:
                        predicted_tokens = predicted_tokens.cpu().numpy().tolist()
                        predictions = [vocab.decode(p) for p in predicted_tokens]

                    outputs.extend(predictions)
            f.close()
            if head and head_prediction:
                print('********** Predictions by Head ***********')
            exit_rate = (len(ds) - teacher_count) / len(ds)
            avg_time = t1 / len(ds)
            print('********** selector_threshold: ' + str(self.selector_threshold) + ' ***********')
            print('********** exit_rate: ' + str(exit_rate) + ' ***********')
            print('********** avg. inference time (sec): ' + str(avg_time) + ' ***********')

            print('**** student accuracy ***')
            print(student_correct / (len(ds) - teacher_count))

            return outputs

        if checkpoint_steps is None:
            checkpoint_steps = [self._step]
        elif isinstance(checkpoint_steps, int):
            checkpoint_steps = [checkpoint_steps]
        elif checkpoint_steps == "all":
            checkpoint_steps = self.get_all_checkpoint_steps()[0]
        elif not isinstance(checkpoint_steps, (list, tuple)):
            raise ValueError(
                f"checkpoint_steps must be None, int or list; got {checkpoint_steps}"
            )

        summary_dir = summary_dir or os.path.join(self._model_dir, f"{split}_eval")
        tf.io.gfile.makedirs(summary_dir)

        head_files = tf.io.gfile.glob(os.path.join(self._model_dir, Head_FILE_FORMAT.format("*")))
        if not head or len(head_files) == 0:
            head_files = [None]
        for head_path in head_files:
            if head_path:
                vocabulary = model_utils.get_vocabulary(mixture_or_task_name)
                task = data.get_mixture_or_task(mixture_or_task_name)
                if task.name in ['wmt16_enro_v003','wmt15_enfr_v003','wmt14_ende_v003','super_glue_wsc_v102_simple_eval']:
                    # the head is very big for Translation tasks!
                    self.enc_cls_head = torch.nn.Sequential(
                        torch.nn.Linear(self._model.model_dim, vocabulary[1].vocab_size, bias=False),
                    ).to('cuda:'+str(torch.cuda.device_count()-1))
                else:
                    original_labels = task.preprocessors[0].keywords['label_names']
                    self.enc_cls_head = torch.nn.Linear(sequence_length["inputs"] * self._model.model_dim,
                                                        len(original_labels),
                                                        bias=False).to('cuda:'+str(torch.cuda.device_count()-1))
                self.enc_cls_head.load_state_dict(torch.load(head_path))
                self.enc_cls_head.eval()
                print('********** Using Head: ' + str(head_path) + ' ***********')
            for threshold in self.thresholds_list.split(','):
                self.selector_threshold = float(threshold)
                run_eval(
                    mixture_or_task_name=mixture_or_task_name,
                    predict_or_score_fn=_predict_from_tasks,
                    checkpoint_steps=checkpoint_steps,
                    dataset_fn=functools.partial(_get_dataset, shuffle=False),
                    summary_dir=summary_dir,
                    split=split,
                    sequence_length=None if compute_sequence_length else sequence_length,
                    batch_size=batch_size)

        print('*****************************************************************')
        print('Best checkpoint: ' + str(globals.best_checkpoint))
        print('Best result: ' + str(globals.best_result))
        print('*****************************************************************')

    def predict(
            self,
            inputs,
            sequence_length,
            batch_size,
            output_file=None,
            vocabulary=None,
            **generate_kwargs,
    ):
        """Evaluate the model on the given Mixture or Task.

        *Note*: If a checkpoint step is provided (i.e. `checkpoint_steps is not
        None`), the model's state will be replaced by the state in those
        checkpoints. If you have not saved your model before calling `eval`, you
        should call `save_checkpoint` before `eval` to avoid losing its parameter
        values and state.

        Args:
          inputs: list of str or str, either a list of inputs to feed into the
            model or the path to a text file that contains a single input on each
            line.
          sequence_length: dict of int, a dict mapping feature name to length.
          batch_size: int, the number of padded sequences in each batch.
          output_file: str or None, path to write out predictions or None to skip
            writing.
          vocabulary: t5.data.vocabularies.Vocabulary or dict or None. Either the
            Vocabulary to use for processing inputs and targets, a dict mapping
            "inputs" to a Vocabulary for encoding the inputs and "targets" for
            decoding the predictions, or None (default) to use a
            t5.data.SentencePieceVocabulary with the provided
            sentencepiece_model_path (as was used in all pre-trained T5 models).
          **generate_kwargs: Additional keyword arguments to pass to
            `transformers.PretrainedModel.generate()`, for example to change the
            decoding strategy. See the documentation for
            `transformers.PretrainedModel.generate()` for options.
        """

        if isinstance(inputs, str):
            if not tf.io.gfile.exists(inputs):
                raise ValueError(
                    f"A str was provided for `inputs`, but the path {inputs} does not "
                    "exist. If you want the model's output for {inputs}, you should "
                    "feed in inputs=['{inputs}']"
                )
            with tf.io.gfile.GFile(inputs) as f:
                inputs = [l.strip() for l in f]

        if vocabulary is None:
            import t5
            vocab = t5.data.SentencePieceVocabulary('./spiece.model') #data.get_default_vocabulary()
            #vocab = data.get_default_vocabulary()
            vocabs = {"inputs": vocab, "targets": vocab}
        elif isinstance(vocabulary, data.vocabularies.Vocabulary):
            vocabs = {"inputs": vocabulary, "targets": vocabulary}
        elif isinstance(vocabulary, dict):
            vocabs = vocabulary
        else:
            raise ValueError("vocabulary must be a dict, a Vocabulary, or None")

        dataset = tf.data.Dataset.from_tensor_slices(inputs)
        dataset = dataset.map(
            lambda x: {"inputs": tf.cast(vocabs["inputs"].encode_tf(x), tf.int64)},
            num_parallel_calls=tf.data.experimental.AUTOTUNE,
        )
        dataset = tokens_to_batches(
            dataset, sequence_length, batch_size, ["inputs"]
        )

        predictions = []
        for batch in dataset:
            predicted_tokens = self._model.generate(
                input_ids=self.to_tensor(batch["inputs"]), **generate_kwargs
            )
            predicted_tokens = predicted_tokens.cpu().numpy().tolist()
            predictions.extend(
                [vocabs["targets"].decode(p) for p in predicted_tokens]
            )

        for inp, pred in zip(inputs, predictions):
            logging.info("%s\n  -> %s", inp, pred)

        if output_file is not None:
            utils.write_lines_to_file(predictions, output_file)

    ############################################################## distillation

    def calc_ce_loss(self, mask, s_logits, t_logits):
        """Copy pasted from distillbert (transformers/examples/distillation/)"""
        # mask has False at padding_idx
        sel_mask = mask[:, :, None].expand_as(s_logits).bool()
        vocab_size = s_logits.size(-1)
        s_logits_slct = torch.masked_select(s_logits, sel_mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(t_logits, sel_mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, vocab_size)  # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, vocab_size)  # (bs * seq_length, voc_size) modulo the 1s in mask
        assert t_logits_slct.size() == s_logits_slct.size()
        ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
        temperature = 2.0
        loss_ce = (
                ce_loss_fct(
                    nn.functional.log_softmax(s_logits_slct / temperature, dim=-1), nn.functional.softmax(t_logits_slct / temperature, dim=-1),
                ) * (temperature) ** 2
        )
        return loss_ce

    def distill(
            self,
            argv,
            mixture_or_task_name,
            steps,
            save_steps,
            sequence_length,
            split,
            batch_size,
            optimizer,
            learning_rate_scheduler=None,
            train_split_ratio=None,
    ):
        # load both Teacher and Student
        self.load_latest_checkpoint()

        #### model parallelism
        if torch.cuda.device_count() == 8:
            # for t5-11b
            device_map = {0: [0, 1, 2],
                          1: [3, 4, 5],
                          2: [6, 7, 8],
                          3: [9, 10, 11],
                          4: [12, 13, 14],
                          5: [15, 16, 17],
                          6: [18, 19, 20],
                          7: [21, 22, 23]}
            # don't parallelize the student in the joint mode
            if self._model.config.num_layers == 24 and not self.teacher_model:
                self._model.parallelize(device_map)  # Splits the model across several devices
            else:
                self._model.cuda()
            if self.teacher_model:
                self.teacher_model.parallelize(device_map)  # Splits the model across several devices
        elif torch.cuda.device_count() == 6:
            # for t5-11b
            device_map = {0: [0, 1, 2, 3],
                          1: [4, 5, 6, 7],
                          2: [8, 9, 10, 11],
                          3: [12, 13, 14, 15],
                          4: [16, 17, 18, 19],
                          5: [20, 21, 22, 23]}
            # don't parallelize the student in the joint mode
            if self._model.config.num_layers == 24 and not self.teacher_model:
                self._model.parallelize(device_map)  # Splits the model across several devices
            else:
                self._model.cuda()
            if self.teacher_model:
                self.teacher_model.parallelize(device_map)  # Splits the model across several devices
        elif torch.cuda.device_count() == 4:
            # for t5-3b
            device_map = {0: [0, 1, 2],
                          1: [3, 4, 5, 6, 7, 8, 9],
                          2: [10, 11, 12, 13, 14, 15, 16],
                          3: [17, 18, 19, 20, 21, 22, 23]}
            # don't parallelize the student in the joint mode
            if self._model.config.num_layers == 24 and not self.teacher_model:
                self._model.parallelize(device_map)  # Splits the model across several devices
            else:
                self._model.cuda()
            if self.teacher_model:
                self.teacher_model.parallelize(device_map)  # Splits the model across several devices
        elif torch.cuda.device_count() == 2:
            # for t5-large (2 gpus)
            device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                          1: [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]}
            if self._model.config.num_layers == 24 and not self.teacher_model:
                self._model.parallelize(device_map)  # Splits the model across several devices
            else:
                self._model.cuda()
            if self.teacher_model:
                self.teacher_model.parallelize(device_map)  # Splits the model across several devices
        else:
            self._model.cuda()
            if self.teacher_model:
                self.teacher_model.cuda()
        #############################

        if train_split_ratio:
            # dont shuffle
            ds = _get_dataset(mixture_or_task_or_name=mixture_or_task_name, sequence_length=sequence_length,
                              split=split, shuffle=False)
            ds_size = tf.data.experimental.cardinality(ds).numpy()
            if train_split_ratio >= 0.5:
                ds = ds.take(ds_size * train_split_ratio)
            else:
                ds = ds.skip(ds_size * (1.0-train_split_ratio))
        else:
            ds = _get_dataset(mixture_or_task_or_name=mixture_or_task_name, sequence_length=sequence_length,
                              split=split, shuffle=True)


        task = data.get_mixture_or_task(mixture_or_task_name)
        ds = tokens_to_batches(ds, sequence_length, batch_size,
                               tuple(task.output_features), task)
        # Repeat dataset forever
        ds = itertools.cycle(ds)

        self._model.train()
        self.teacher_model.eval()

        optimizer = optimizer(self._model.parameters())
        if learning_rate_scheduler:
            learning_rate_scheduler = learning_rate_scheduler(optimizer)

        for train_step, batch in enumerate(itertools.islice(ds, steps)):
            if not train_step % save_steps:
                logging.info("Saving checkpoint for step %s", self._step)
                self.save_checkpoint(self._step)
                ################ evaluate after each checkpoint saved ##############
                if argv.eval:
                    # only consider Teacher in evaluation for joint-inference (not distillation)
                    teacher_model_dir_temp = self.teacher_model_dir
                    teacher_model_temp = self.teacher_model
                    self.teacher_model_dir = None
                    self.teacher_model = None
                    sequence_length = {"inputs": argv.sequence_length_inputs, "targets": argv.sequence_length_targets}
                    self.eval(
                        argv.task,
                        sequence_length=sequence_length,
                        batch_size=argv.eval_batch_size,
                        checkpoint_steps=None,
                        split=argv.eval_split,
                        head=argv.head,
                        head_prediction=argv.head_prediction,
                        model_type=argv.model_type,
                    )
                    self.teacher_model_dir = teacher_model_dir_temp
                    self.teacher_model = teacher_model_temp
                ####################################################################

            input_ids = self.to_tensor(batch["inputs"]).type(torch.long).to('cuda:0')
            attention_mask = self.to_tensor(batch["inputs_mask"]).to('cuda:0')
            decoder_attention_mask = self.to_tensor(batch["targets_mask"]).to('cuda:0')
            labels = self.to_tensor(batch["targets"]).type(torch.long).to('cuda:0')

            self._model.zero_grad()

            student_outputs = self._model(input_ids=input_ids, attention_mask=attention_mask,
                                          decoder_attention_mask=decoder_attention_mask, labels=labels)
            lm_logits = student_outputs["logits"]
            # lm_logits = student_outputs[1] # outputs.logit

            student_lm_loss = student_outputs[0]

            '''
            ### the following would be useful, if wanna include enc-logits in distillation as well, but we cannot because the shapes of Student and Teacher encoders are differet
            with torch.no_grad():
              all_teacher_encoder_outputs = self.teacher_model.get_encoder()(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=False)
            teacher_enc_outputs = all_teacher_encoder_outputs["last_hidden_state"]
            decoder_input_ids = self.teacher_model._shift_right(labels)
            with torch.no_grad():
                teacher_outputs = self.teacher_model(input_ids=input_ids, attention_mask=attention_mask, encoder_outputs=(teacher_enc_outputs,),
                    decoder_input_ids=decoder_input_ids,
                    output_hidden_states=False,
                    use_cache=False,  # since we are not passing labels, never let this default to True
                )
            '''

            with torch.no_grad():
                teacher_outputs = self.teacher_model(input_ids=input_ids, attention_mask=attention_mask,
                                                     decoder_attention_mask=decoder_attention_mask, labels=labels)

            loss_ce = self.calc_ce_loss(decoder_attention_mask, lm_logits, teacher_outputs["logits"])
            alpha_ce = 0.8
            alpha_mlm = 0.2
            loss = (alpha_ce * loss_ce + alpha_mlm * student_lm_loss)

            loss.backward()
            optimizer.step()

            if learning_rate_scheduler:
                learning_rate_scheduler.step()

            # self._writer.add_scalar("loss", loss.detach().cpu().numpy(), self._step)
            # self._writer.add_scalar("step/s", 1 / (time.time() - now), self._step)
            self._step += 1

        logging.info("Saving final checkpoint for step %s", self._step)

        self.save_checkpoint(self._step)
