"""
Fine-tuning the library models for sequence to sequence.
.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
from transformers import Seq2SeqTrainingArguments
from utils import modify_model_after_init, save_training_config
import shutil
import glob
from data import AutoPostProcessor
from third_party.models import T5Config, T5ForConditionalGeneration
from dataclasses import dataclass, field
from training_args import AdapterTrainingArguments
from third_party.trainers import Seq2SeqTrainer
from data import TaskDataCollatorForSeq2Seq
from data import AutoTask
from utils import get_adapter_config
from transformers.trainer_utils import is_main_process, get_last_checkpoint
from transformers import (
    AutoTokenizer,
    HfArgumentParser,
    default_data_collator,
    set_seed,
    AdamW,
    get_linear_schedule_with_warmup
)
import transformers
from datasets import load_dataset, load_metric, concatenate_datasets
from typing import Optional, List
import subprocess
import sys
import functools
import logging
from pytz import common_timezones
import torch
import os
import pickle
from data.tasks import TASK_MAPPING
from datasets import Dataset
import argparse
import math

os.environ['MKL_THREADING_LAYER'] = 'GNU'
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'

logger = logging.getLogger(__name__)


def run_command(command):
    output = subprocess.getoutput(command)
    return output


TASK_TO_METRICS = {
    "mrpc": ["accuracy", "f1"],
    "mrpc_ppt": ["accuracy", "f1"],
    "cola": ["matthews_correlation"],
    "stsb": ["pearson", "spearmanr"],
    "sst2": ["accuracy"],
    "mnli": ["accuracy"],
    "mnli_ppt": ["accuracy"],
    "mnli_mismatched": ["accuracy"],
    "mnli_matched": ["accuracy"],
    "qnli": ["accuracy"],
    "rte": ["accuracy"],
    "wnli": ["accuracy"],
    "wnli_ppt": ["accuracy"],
    "qqp": ["accuracy", "f1"],
    "superglue-boolq": ["accuracy"],
    "superglue-boolq_ppt": ["accuracy"],
    "superglue-rte": ["accuracy"],
    "superglue-rte_ppt": ["accuracy"],
    "superglue-cb": ["f1_multiclass", "accuracy"],
    "superglue-cb_ppt": ["f1_multiclass", "accuracy"],
    "superglue-copa": ["accuracy"],
    "superglue-multirc": ["f1", "em"],
    "superglue-multirc_ppt": ["f1", "em"],
    "superglue-wic": ["accuracy"],
    "superglue-wic_ppt": ["accuracy"],
    "superglue-wsc.fixed": ["accuracy"],
    "superglue-wsc.fixed_ppt": ["accuracy"],
    "superglue-record": ["f1", "em"],
    "multi_nli": ["accuracy"],
    "squad": ["em", "f1"],
    "snli": ["accuracy"],
    "nq": ["em", "f1"],
    "hotpotqa": ["em", "f1"],
    "searchqa": ["em", "f1"],
    "newsqa": ["em", "f1"],
    "triviaqa": ["em", "f1"],
    "imdb": ["accuracy"],
    "winogrande": ["accuracy"],
    "scitail": ["accuracy"],
    "amazon_polarity": ["accuracy"],
    "yelp_polarity": ["accuracy"],
    "paws": ["accuracy"],
}

# run_seq2seq parameters.


def update_path_original(args):
    res = []
    for source in ["mnli", "sst2", "qnli", "qqp", "squad", "record"]:
        path = args.prompt_embedding_path_prefix + '/' + source + "_prompt.pt"
        res.append(path)
    return res

@dataclass
class TrainingArguments(Seq2SeqTrainingArguments):
    print_num_parameters: Optional[bool] = field(default=False, metadata={"help": "If set, print the parameters of "
                                                                                  "the model."})
    do_test: Optional[bool] = field(default=False, metadata={
        "help": "If set, evaluates the test performance."})
    do_eval_predict: Optional[bool] = field(default=False, metadata={
        "help": "If set, evaluates the test performance."})
    split_validation_test: Optional[bool] = field(default=False,
                                                  metadata={"help": "If set, for the datasets which do not"
                                                                    "have the test set, we use validation set as their"
                                                                    "test set and make a validation set from either"
                                                                    "splitting the validation set into half (for smaller"
                                                                    "than 10K samples datasets), or by using 1K examples"
                                                                    "from training set as validation set (for larger"
                                                                    " datasets)."})
    compute_time: Optional[bool] = field(
        default=False, metadata={"help": "If set measures the time."})
    compute_memory: Optional[bool] = field(
        default=False, metadata={"help": "if set, measures the memory"})
    prefix_length: Optional[int] = field(
        default=100, metadata={"help": "Defines the length for prefix tuning."})


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    model_name_or_path: str = field(
        metadata={
            "help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": "Where to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={
            "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={
            "help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
                    "with private models)."
        },
    )
    load_prefix_embeddings: bool = field(
        default=False,
        metadata={
            "help": "load prefix embeddings or not"
        },
    )
    save_prefix_only: bool = field(
        default=False,
        metadata={
            "help": "save prefix embeddings only"
        },
    )

    prompt_embedding_path: Optional[List[str]] = field(
        default=None,
        metadata={"help": "A list of the paths to prefix embeddings"}
    )

    target_prompt_embedding_path: Optional[str] = field(
        default=None,
        metadata={"help": "a path to the target prompt embedding"}
    )

    attn_prefix_tuning: bool = field(
        default=False,
        metadata={
            "help": "Set true if you try ATTEMPT."
        },
    )

    attn_method: Optional[str] = field(
        default="sub",
        metadata={
            "help": "Attention model for attn_prefix. We currently support the following methods: linear, sub (our main method), and constant (gives the constant and equal weights to all of the prompts.)"
        },
    )

    shared_attn: bool = field(
        default=False,
        metadata={
            "help": "shared attention"
        },
    )

    load_attention: bool = field(
        default=False,
        metadata={
            "help": "Set true if you want to load pre-trained attention weights"
        },
    )

    attn_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "path to attention weights (linear attentions). "
        },
    )

    attn_path_sub: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "list of the path to attention weights (sub attentions). [path_to_down_projection_weights, path_to_up_projection_weights]"
        },
    )

    ignore_target: bool = field(
        default=False,
        metadata={
            "help": "Whether to ignore the new target tokens. Mainly for ablation."
        },
    )

    fix_attention: bool = field(
        default=False,
        metadata={
            "help": "this will make the attention weights frozen during training. Mainly for ablation."
        },
    )

    temperature: float = field(
        default=2000,
        metadata={
            "help": "set the soft max temperature of ATTEMPT."
        },
    )

    attn_learning_rate: float = field(
        default=None,
        metadata={
            "help": "set the learning rate for the attention modules."
        },
    )

    load_layer_norm: bool = field(
        default=False,
        metadata={
            "help": "Set true if you want to load pre-trained layer-norm weight and biases."
        },
    )

    layer_norm_dir: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "Layer norm dir. [path_to_layer_norm_weight.pt, path_to_layer_norm_bias.pt]"
        },
    )

    prefix_num: Optional[int] = field(
        default=1, metadata={"help": "the number of prefix"})


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
    task_name: Optional[List[str]] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[List[str]] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    eval_dataset_name: Optional[List[str]] = field(
        default=None, metadata={"help": "The name of the evaluation dataset to use (via the datasets library)."}
    )
    eval_dataset_config_name: Optional[List[str]] = field(
        default=None,
        metadata={"help": "The configuration name of the evaluation dataset to use (via the datasets library)."}
    )
    test_dataset_name: Optional[List[str]] = field(
        default=None, metadata={"help": "The name of the test dataset to use (via the datasets library)."}
    )
    test_dataset_config_name: Optional[List[str]] = field(
        default=None, metadata={"help": "The configuration name of the test dataset to use (via the datasets library)."}
    )
    lang_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    max_source_length: Optional[int] = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
                    "than this will be truncated, sequences shorter will be padded."
        },
    )
    max_target_length: Optional[int] = field(
        default=128,
        metadata={
            "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
                    "than this will be truncated, sequences shorter will be padded."
        },
    )
    val_max_target_length: Optional[int] = field(
        default=None,
        metadata={
            "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
                    "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
                    "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
                    "during ``evaluate`` and ``predict``."
        },
    )
    test_max_target_length: Optional[int] = field(
        default=None,
        metadata={
            "help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
                    "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
                    "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
                    "during ``evaluate`` and ``predict``."
        },
    )
    pad_to_max_length: bool = field(
        default=False,
        metadata={
            "help": "Whether to pad all samples to model maximum sentence length. "
                    "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
                    "efficient on GPU but very bad for TPU."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
                    "value if set."
        },
    )
    max_val_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
                    "value if set."
        },
    )
    max_test_samples: Optional[int] = field(
        default=None,
        metadata={"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
                          "value if set."}
    )
    num_beams: Optional[int] = field(
        default=None, metadata={"help": "Number of beams to use for evaluation."})
    ignore_pad_token_for_loss: bool = field(
        default=True,
        metadata={
            "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
        },
    )
    task_adapters: Optional[List[str]] = field(
        default=None,
        metadata={"help": "Defines a dictionary from task adapters to the tasks."}
    )
    task_embeddings: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "Defines a dictionary from tasks to the tasks embeddings."}
    )
    data_seed: Optional[int] = field(
        default=42, metadata={"help": "seed used to shuffle the data."})

    train_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the training data."}
    )
    validation_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the validation data."}
    )
    test_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the test data."}
    )

    train_files: Optional[List[str]] = field(
        default=None, metadata={"help": "A list of csv or json files containing the training data."}
    )
    validation_files: Optional[List[str]] = field(
        default=None, metadata={"help": "A list of csv or json files containing the validation data."}
    )
    test_files: Optional[List[str]] = field(
        default=None, metadata={"help": "A list of csv or json files containing the test data."}
    )

    def __post_init__(self):
        if self.task_name is None:
            raise ValueError(
                "Need either a dataset name or a training/validation file.")
        if self.val_max_target_length is None:
            self.val_max_target_length = self.max_target_length
        if self.test_max_target_length is None:
            self.test_max_target_length = self.max_target_length


def main(seed_id, args):
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments,
                               AdapterTrainingArguments))
    if args.json_file:
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args, adapter_args = parser.parse_json_file(
            json_file=os.path.abspath(args.json_file))
    else:
        model_args, data_args, training_args, adapter_args = parser.parse_args_into_dataclasses()

    # freeze all the weights/ uniform
    if args.self_train_attn:
        model_args.attn_method = args.attn_method
        training_args.do_test = False
    else:
        training_args.do_test = True

    if args.save_dir:
        training_args.output_dir = args.save_dir

    if args.naked:
        model_args.prompt_embedding_path.append('./source_prompts/' + args.dataset_source + '_prompt.pt')

    if args.init_prompt:
        model_args.prompt_embedding_path = []
        model_args.load_prefix_embeddings = False

    if args.self_train or args.naked or args.self_train_attn:
        training_args.output_dir = args.save_dir
        data_args.test_dataset_name[0] = args.dataset_target
        data_args.eval_dataset_name[0] = args.dataset_target
        data_args.task_name[0] = args.dataset_target
        model_args.save_prefix_only = False

    if args.target_prompt_embedding_path:
        model_args.prompt_embedding_path = None
        model_args.target_prompt_embedding_path = args.target_prompt_embedding_path

    if args.load_attention:
        model_args.load_attention = True
    if args.attn_path_sub:
        model_args.attn_path_sub = [args.attn_path_sub+'/attn_W_down.pt', args.attn_path_sub+'/attn_W_up.pt']
    if args.layer_norm_dir:
        model_args.layer_norm_dir = args.layer_norm_dir
    if args.resume_from_checkpoint:
        training_args.resume_from_checkpoint = args.resume_from_checkpoint
    if args.ignore_target:
        model_args.ignore_target = True

    if args.zeroshot:
        training_args.do_train = False

    if args.model_name:
        model_args.model_name_or_path = args.model_name
        model_args.tokenizer_name = args.model_name

    if args.dataset_target:
        data_args.task_name = [args.dataset_target]
        data_args.test_dataset_name = [args.dataset_target]
        data_args.eval_dataset_name = [args.dataset_target]

    if args.update_path_original:
        model_args.prompt_embedding_path = update_path_original(args)

    if args.data_full:
        data_args.max_train_samples = None
        data_args.max_val_samples = None

    if args.fewshot_size:
        data_args.max_train_samples = args.fewshot_size
        data_args.max_val_samples = args.fewshot_size

    if args.deepspeed:
        training_args.deepspeed = args.deepspeed
        from transformers.integrations import DeepSpeedConfigHF

        # will be used later by the Trainer (leave self.deepspeed unmodified in case a user relies on it not to be modified)
        training_args.deepspeed_config_hf = DeepSpeedConfigHF(training_args)
        # training_args.deepspeed_config_hf = args.deepspeed
    if args.per_device_eval_batch_size:
        training_args.per_device_eval_batch_size = int(args.per_device_eval_batch_size)
    if args.per_device_train_batch_size:
        training_args.per_device_train_batch_size = int(args.per_device_train_batch_size)
    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        print("#### last_checkpoint ", last_checkpoint)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            '''
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
            '''
            pass
        elif last_checkpoint is not None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )
    if args.checkpoint_dir:
        last_checkpoint = get_last_checkpoint(args.checkpoint_dir)

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    logger.setLevel(logging.INFO if is_main_process(
        training_args.local_rank) else logging.WARN)

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed before initializing model.
    set_seed(training_args.seed + seed_id)

    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files in the summarization task, this script will use the first column for the full texts and the
    # second column for the summaries (unless you specify column names for this with the `text_column` and
    # `summary_column` arguments).
    # For translation, only JSON files are supported, with one field named "translation" containing two keys for the
    # source and target languages (unless you adapt what follows).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = T5Config.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    config.train_task_adapters = adapter_args.train_task_adapters
    config.prefix_tuning = adapter_args.prefix_tuning
    config.attn_prefix_tuning = model_args.attn_prefix_tuning
    config.attn_method = model_args.attn_method
    config.ignore_target = model_args.ignore_target
    config.shared_attn = model_args.shared_attn
    config.prefix_num = model_args.prefix_num
    config.num_target = len(data_args.task_name)
    config.temperature = model_args.temperature
    config.fix_attention = model_args.fix_attention
    adapter_config = get_adapter_config(
        adapter_args, data_args, training_args, config)
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    model = T5ForConditionalGeneration.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
        adapter_config=adapter_config
    )

    if model_args.load_prefix_embeddings is True:
        if model_args.prompt_embedding_path is None:
            for name, param in model.named_parameters():
                if "prefix_shared" in name or "prefix" in name:
                    shared_params = [param]
        else:
            shared_params = []
            for path in model_args.prompt_embedding_path:
                shared_param = torch.load(path)
                shared_params.append(shared_param)
            if model_args.target_prompt_embedding_path is not None:
                target_prompt_embedding = torch.load(
                    model_args.target_prompt_embedding_path)

        if model_args.attn_prefix_tuning is True:
            if training_args.do_train is True and model_args.shared_attn is False:
                # Load all of the source prompts
                # author - prefix
                model.store_prefix_weights(shared_params)
                # Initialize the target task prompt embedding using the first prompts
                model.update_prefix_weights_single(shared_params[0])
            elif training_args.do_train is True and model_args.shared_attn is True:
                # Load all of the source prompts
                model.store_prefix_weights(shared_params)
                # Initialize the target task prompt embeddings using the first prompts.
                if args.random_prefix:
                    model.update_prefix_weights([])
                else:
                    model.update_prefix_weights_multi(shared_params[0], num_target=config.num_target)
            else:
                # For inference
                # Load all of the source prompts
                model.store_prefix_weights(shared_params)
                # Load the trained target task prompt.
                model.update_prefix_weights_single(target_prompt_embedding)

        else:
            if model_args.target_prompt_embedding_path is None:
                if args.random_prefix:
                    model.update_prefix_weights([])
                else:
                    model.update_prefix_weights(shared_params)
            else:
                if model_args.target_prompt_embedding_path is not None:

                    target_prompt_embedding = torch.load(
                        model_args.target_prompt_embedding_path, map_location=training_args.device)
                if args.fewshot_source:
                    shared_params = []
                model.update_prefix_weights(
                    shared_params, target_prompt_embedding)
        # print(model)

    # Load linear attention
    if model_args.load_attention is True and model_args.attn_path is not None:
        model.update_attention_weights(torch.load(model_args.attn_path))

    # Load projection-based attentions
    if model_args.load_attention is True and model_args.attn_path_sub is not None:
        model.update_attention_weights_sub(model_args.attn_path_sub)

    # Load layer norm weights & biases
    if model_args.load_layer_norm is True and model_args.layer_norm_dir is not None:
        model.update_layer_norm_weights(model_args.layer_norm_dir)

    model.resize_token_embeddings(len(tokenizer))
    model = modify_model_after_init(
        model, training_args, adapter_args, adapter_config)

    data_args.data_seed += seed_id
    data_args.dataset_name = data_args.task_name
    data_args.eval_dataset_name = data_args.eval_dataset_name
    data_args.test_dataset_name = data_args.test_dataset_name
    data_args.dataset_config_name = data_args.dataset_config_name
    data_args.eval_dataset_config_name = data_args.eval_dataset_config_name
    data_args.test_dataset_config_name = data_args.test_dataset_config_name
    assert len(data_args.dataset_name) == len(data_args.dataset_config_name)
    if data_args.eval_dataset_name is not None:
        assert len(data_args.eval_dataset_name) == len(
            data_args.eval_dataset_config_name)
    if data_args.test_dataset_name is not None:
        assert len(data_args.test_dataset_name) == len(
            data_args.test_dataset_config_name)

    padding = "max_length" if data_args.pad_to_max_length else False

    def preprocess_function(examples, max_target_length, task_id=None):
        model_inputs = tokenizer(examples['source'], max_length=data_args.max_source_length,
                                 padding=padding, truncation=True)
        # Setup the tokenizer for targets
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                examples['target'], max_length=max_target_length, padding=padding, truncation=True)
        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
        # padding in the loss.
        if padding == "max_length" and data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]
        model_inputs["labels"] = labels["input_ids"]
        model_inputs["extra_fields"] = examples['extra_fields']
        if task_id is not None:
            model_inputs["task_ids"] = [
                task_id for _ in examples['extra_fields']]

        return model_inputs

    column_names = ['source', 'target', 'extra_fields']
    performance_metrics = {}
    if 1:  # training_args.do_train
        print('args.self_train_attn', args.self_train_attn)
        if not args.self_train_attn:  # few-shot samples
            if data_args.train_files is not None:
                train_datasets = [AutoTask.get(dataset_name,
                                               dataset_config_name,
                                               seed=data_args.data_seed).get(
                    split="train",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=data_args.max_train_samples, lang=data_args.lang_name, file_name=train_file)
                    for dataset_name, dataset_config_name, train_file
                    in zip(data_args.dataset_name, data_args.dataset_config_name, data_args.train_files)]
            else:
                train_datasets = [AutoTask.get(dataset_name,
                                               dataset_config_name,
                                               seed=data_args.data_seed).get(
                    split="train",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=data_args.max_train_samples, lang=data_args.lang_name, file_name=data_args.train_file)
                    for dataset_name, dataset_config_name
                    in zip(data_args.dataset_name, data_args.dataset_config_name)]
            max_target_lengths = [AutoTask.get(dataset_name, dataset_config_name).get_max_target_length(
                tokenizer=tokenizer, default_max_length=data_args.max_target_length, )
                for dataset_name, dataset_config_name in zip(data_args.dataset_name, data_args.dataset_config_name)]

            for i, train_dataset in enumerate(train_datasets):
                if model_args.shared_attn is True:
                    train_datasets[i] = train_datasets[i].map(
                        functools.partial(
                            preprocess_function, max_target_length=max_target_lengths[i], task_id=i),
                        batched=True,
                        num_proc=data_args.preprocessing_num_workers,
                        # if train_dataset != "superglue-record" else column_names+["answers"],
                        remove_columns=column_names,
                        load_from_cache_file=not data_args.overwrite_cache,
                    )
                else:
                    train_datasets[i] = train_datasets[i].map(
                        functools.partial(preprocess_function,
                                          max_target_length=max_target_lengths[i]),
                        batched=True,
                        num_proc=data_args.preprocessing_num_workers,
                        # if train_dataset != "superglue-record" else column_names+["answers"],
                        remove_columns=column_names,
                        load_from_cache_file=not data_args.overwrite_cache,
                    )
            train_dataset = concatenate_datasets(train_datasets)
        else:
            with open(args.self_train_source_dir+'/pseudo_label_dataset_' + str(seed_id) + '.pickle', 'rb') as handle:
                train_dataset = pickle.load(handle)
            print('train_datasets>>>>', train_dataset['labels'])
    if training_args.do_eval:
        if not args.self_train_attn:
            if data_args.validation_files is not None:
                eval_datasets = {eval_dataset: AutoTask.get(eval_dataset, eval_dataset_config,
                                                            seed=data_args.data_seed).get(
                    split="validation",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=data_args.max_val_samples, lang=data_args.lang_name, file_name=validation_file)
                    for eval_dataset, eval_dataset_config, validation_file in
                    zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name, data_args.validation_files)}
            else:
                eval_datasets = {eval_dataset: AutoTask.get(eval_dataset, eval_dataset_config,
                                                            seed=data_args.data_seed).get(
                    split="validation",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=data_args.max_val_samples, lang=data_args.lang_name, file_name=data_args.validation_file)
                    for eval_dataset, eval_dataset_config in
                    zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name)}
            max_target_lengths = [AutoTask.get(dataset_name, dataset_config_name).get_max_target_length(
                tokenizer=tokenizer, default_max_length=data_args.max_target_length)
                for dataset_name, dataset_config_name in
                zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name)]

            for k, name in enumerate(eval_datasets):
                if model_args.shared_attn is True:
                    eval_datasets[name] = eval_datasets[name].map(
                        functools.partial(
                            preprocess_function, max_target_length=max_target_lengths[k], task_id=k),
                        batched=True,
                        num_proc=data_args.preprocessing_num_workers,
                        # if name != "superglue-record" else column_names+["answers"],
                        remove_columns=column_names,
                        load_from_cache_file=not data_args.overwrite_cache,
                    )
                else:
                    eval_datasets[name] = eval_datasets[name].map(
                        functools.partial(preprocess_function,
                                          max_target_length=max_target_lengths[k]),
                        batched=True,
                        num_proc=data_args.preprocessing_num_workers,
                        # if name != "superglue-record" else column_names+["answers"],
                        remove_columns=column_names,
                        load_from_cache_file=not data_args.overwrite_cache,
                    )

            if args.self_train:
                if data_args.train_files is not None:
                    train_datasets_save = [AutoTask.get(dataset_name,
                                                   dataset_config_name,
                                                   seed=data_args.data_seed).get(
                        split="train+validation+test",
                        split_validation_test=training_args.split_validation_test,
                        add_prefix=False if adapter_args.train_task_adapters else True,
                        n_obs=None, lang=data_args.lang_name, file_name=train_file)
                        for dataset_name, dataset_config_name, train_file
                        in zip(data_args.dataset_name, data_args.dataset_config_name, data_args.train_files)]
                    test_datasets_save = [AutoTask.get(dataset_name,
                                                        dataset_config_name,
                                                        seed=data_args.data_seed).get(
                        split="test",
                        split_validation_test=training_args.split_validation_test,
                        add_prefix=False if adapter_args.train_task_adapters else True,
                        n_obs=None, lang=data_args.lang_name, file_name=train_file)
                        for dataset_name, dataset_config_name, train_file
                        in zip(data_args.dataset_name, data_args.dataset_config_name, data_args.train_files)]
                    train_datasets_predict = [AutoTask.get(dataset_name,
                                                       dataset_config_name,
                                                       seed=data_args.data_seed).get(
                        split="train",
                        split_validation_test=training_args.split_validation_test,
                        add_prefix=False if adapter_args.train_task_adapters else True,
                        n_obs=data_args.max_val_samples, lang=data_args.lang_name, file_name=train_file)
                        for dataset_name, dataset_config_name, train_file
                        in zip(data_args.dataset_name, data_args.dataset_config_name, data_args.train_files)]
                else:
                    train_datasets_save = [AutoTask.get(dataset_name,
                                                   dataset_config_name,
                                                   seed=data_args.data_seed).get(
                        split="train+validation+test",
                        split_validation_test=training_args.split_validation_test,
                        add_prefix=False if adapter_args.train_task_adapters else True,
                        n_obs=None, lang=data_args.lang_name, file_name=data_args.train_file)
                        for dataset_name, dataset_config_name
                        in zip(data_args.dataset_name, data_args.dataset_config_name)]
                    test_datasets_save = [AutoTask.get(dataset_name,
                                                        dataset_config_name,
                                                        seed=data_args.data_seed).get(
                        split="test",
                        split_validation_test=training_args.split_validation_test,
                        add_prefix=False if adapter_args.train_task_adapters else True,
                        n_obs=None, lang=data_args.lang_name, file_name=data_args.train_file)
                        for dataset_name, dataset_config_name
                        in zip(data_args.dataset_name, data_args.dataset_config_name)]
                    train_datasets_predict = [AutoTask.get(dataset_name,
                                                       dataset_config_name,
                                                       seed=data_args.data_seed).get(
                        split="train",
                        split_validation_test=training_args.split_validation_test,
                        add_prefix=False if adapter_args.train_task_adapters else True,
                        n_obs=data_args.max_val_samples, lang=data_args.lang_name, file_name=data_args.train_file)
                        for dataset_name, dataset_config_name
                        in zip(data_args.dataset_name, data_args.dataset_config_name)]

                max_target_lengths = [AutoTask.get(dataset_name, dataset_config_name).get_max_target_length(
                    tokenizer=tokenizer, default_max_length=data_args.max_target_length, )
                    for dataset_name, dataset_config_name in zip(data_args.dataset_name, data_args.dataset_config_name)]

                for i, train_dataset_save in enumerate(train_datasets_save):
                    if model_args.shared_attn is True:
                        train_datasets_save[i] = train_datasets_save[i].map(
                            functools.partial(
                                preprocess_function, max_target_length=max_target_lengths[i], task_id=i),
                            batched=True,
                            num_proc=data_args.preprocessing_num_workers,
                            # if train_dataset != "superglue-record" else column_names+["answers"],
                            remove_columns=column_names,
                            load_from_cache_file=not data_args.overwrite_cache,
                        )
                        test_datasets_save[i] = test_datasets_save[i].map(
                            functools.partial(
                                preprocess_function, max_target_length=max_target_lengths[i], task_id=i),
                            batched=True,
                            num_proc=data_args.preprocessing_num_workers,
                            # if train_dataset != "superglue-record" else column_names+["answers"],
                            remove_columns=column_names,
                            load_from_cache_file=not data_args.overwrite_cache,
                        )
                        train_datasets_predict[i] = train_datasets_predict[i].map(
                            functools.partial(
                                preprocess_function, max_target_length=max_target_lengths[i], task_id=i),
                            batched=True,
                            num_proc=data_args.preprocessing_num_workers,
                            # if train_dataset != "superglue-record" else column_names+["answers"],
                            remove_columns=column_names,
                            load_from_cache_file=not data_args.overwrite_cache,
                        )
                    else:
                        train_datasets_save[i] = train_datasets_save[i].map(
                            functools.partial(preprocess_function,
                                              max_target_length=max_target_lengths[i]),
                            batched=True,
                            num_proc=data_args.preprocessing_num_workers,
                            # if train_dataset != "superglue-record" else column_names+["answers"],
                            remove_columns=column_names,
                            load_from_cache_file=not data_args.overwrite_cache,
                        )
                        test_datasets_save[i] = test_datasets_save[i].map(
                            functools.partial(preprocess_function,
                                              max_target_length=max_target_lengths[i]),
                            batched=True,
                            num_proc=data_args.preprocessing_num_workers,
                            # if train_dataset != "superglue-record" else column_names+["answers"],
                            remove_columns=column_names,
                            load_from_cache_file=not data_args.overwrite_cache,
                        )
                        train_datasets_predict[i] = train_datasets_predict[i].map(
                            functools.partial(preprocess_function,
                                              max_target_length=max_target_lengths[i]),
                            batched=True,
                            num_proc=data_args.preprocessing_num_workers,
                            # if train_dataset != "superglue-record" else column_names+["answers"],
                            remove_columns=column_names,
                            load_from_cache_file=not data_args.overwrite_cache,
                        )

                train_dataset_save = concatenate_datasets(train_datasets_save)
                test_dataset_save = concatenate_datasets(test_datasets_save)
                train_datasets_predict = concatenate_datasets(train_datasets_predict)

                # for prediction
                if data_args.validation_files is not None:
                    eval_datasets_self_train = {eval_dataset: AutoTask.get(eval_dataset, eval_dataset_config,
                                                                seed=data_args.data_seed).get(
                        split="train+validation+test",
                        split_validation_test=training_args.split_validation_test,
                        add_prefix=False if adapter_args.train_task_adapters else True,
                        n_obs=None, lang=data_args.lang_name, file_name=validation_file)
                        for eval_dataset, eval_dataset_config, validation_file in
                        zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name, data_args.validation_files)}

                    test_datasets_self_train = {eval_dataset: AutoTask.get(eval_dataset, eval_dataset_config,
                                                                           seed=data_args.data_seed).get(
                        split="test",
                        split_validation_test=training_args.split_validation_test,
                        add_prefix=False if adapter_args.train_task_adapters else True,
                        n_obs=None, lang=data_args.lang_name, file_name=validation_file)
                        for eval_dataset, eval_dataset_config, validation_file in
                        zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name,
                            data_args.validation_files)}
                    train_datasets_self_train = {eval_dataset: AutoTask.get(eval_dataset, eval_dataset_config,
                                                                           seed=data_args.data_seed).get(
                        split="train",
                        split_validation_test=training_args.split_validation_test,
                        add_prefix=False if adapter_args.train_task_adapters else True,
                        n_obs=data_args.max_val_samples, lang=data_args.lang_name, file_name=validation_file)
                        for eval_dataset, eval_dataset_config, validation_file in
                        zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name,
                            data_args.validation_files)}
                else:
                    eval_datasets_self_train = {eval_dataset: AutoTask.get(eval_dataset, eval_dataset_config,
                                                                seed=data_args.data_seed).get(
                        split="train+validation+test",
                        split_validation_test=training_args.split_validation_test,
                        add_prefix=False if adapter_args.train_task_adapters else True,
                        n_obs=None, lang=data_args.lang_name, file_name=data_args.validation_file)
                        for eval_dataset, eval_dataset_config in
                        zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name)}

                    test_datasets_self_train = {eval_dataset: AutoTask.get(eval_dataset, eval_dataset_config,
                                                                           seed=data_args.data_seed).get(
                        split="test",
                        split_validation_test=training_args.split_validation_test,
                        add_prefix=False if adapter_args.train_task_adapters else True,
                        n_obs=None, lang=data_args.lang_name, file_name=data_args.validation_file)
                        for eval_dataset, eval_dataset_config in
                        zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name)}

                    train_datasets_self_train = {eval_dataset: AutoTask.get(eval_dataset, eval_dataset_config,
                                                                           seed=data_args.data_seed).get(
                        split="train",
                        split_validation_test=training_args.split_validation_test,
                        add_prefix=False if adapter_args.train_task_adapters else True,
                        n_obs=data_args.max_val_samples, lang=data_args.lang_name, file_name=data_args.validation_file)
                        for eval_dataset, eval_dataset_config in
                        zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name)}

                max_target_lengths_self_train  = [AutoTask.get(dataset_name, dataset_config_name).get_max_target_length(
                    tokenizer=tokenizer, default_max_length=data_args.max_target_length)
                    for dataset_name, dataset_config_name in
                    zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name)]

                for k, name in enumerate(eval_datasets_self_train):
                    if model_args.shared_attn is True:
                        eval_datasets_self_train[name] = eval_datasets_self_train[name].map(
                            functools.partial(
                                preprocess_function, max_target_length=max_target_lengths_self_train[k], task_id=k),
                            batched=True,
                            num_proc=data_args.preprocessing_num_workers,
                            # if name != "superglue-record" else column_names+["answers"],
                            remove_columns=column_names,
                            load_from_cache_file=not data_args.overwrite_cache,
                        )
                        test_datasets_self_train[name] = test_datasets_self_train[name].map(
                            functools.partial(
                                preprocess_function, max_target_length=max_target_lengths_self_train[k], task_id=k),
                            batched=True,
                            num_proc=data_args.preprocessing_num_workers,
                            # if name != "superglue-record" else column_names+["answers"],
                            remove_columns=column_names,
                            load_from_cache_file=not data_args.overwrite_cache,
                        )
                        train_datasets_self_train[name] = train_datasets_self_train[name].map(
                            functools.partial(
                                preprocess_function, max_target_length=max_target_lengths_self_train[k], task_id=k),
                            batched=True,
                            num_proc=data_args.preprocessing_num_workers,
                            # if name != "superglue-record" else column_names+["answers"],
                            remove_columns=column_names,
                            load_from_cache_file=not data_args.overwrite_cache,
                        )
                    else:
                        eval_datasets_self_train[name] = eval_datasets_self_train[name].map(
                            functools.partial(preprocess_function,
                                              max_target_length=max_target_lengths_self_train[k]),
                            batched=True,
                            num_proc=data_args.preprocessing_num_workers,
                            # if name != "superglue-record" else column_names+["answers"],
                            remove_columns=column_names,
                            load_from_cache_file=not data_args.overwrite_cache,
                        )
                        test_datasets_self_train[name] = test_datasets_self_train[name].map(
                            functools.partial(preprocess_function,
                                              max_target_length=max_target_lengths_self_train[k]),
                            batched=True,
                            num_proc=data_args.preprocessing_num_workers,
                            # if name != "superglue-record" else column_names+["answers"],
                            remove_columns=column_names,
                            load_from_cache_file=not data_args.overwrite_cache,
                        )
                        train_datasets_self_train[name] = train_datasets_self_train[name].map(
                            functools.partial(preprocess_function,
                                              max_target_length=max_target_lengths_self_train[k]),
                            batched=True,
                            num_proc=data_args.preprocessing_num_workers,
                            # if name != "superglue-record" else column_names+["answers"],
                            remove_columns=column_names,
                            load_from_cache_file=not data_args.overwrite_cache,
                        )

        else:  #self_train_attn
            # with open(args.self_train_source_dir + '/dataset_eval_' + str(seed_id) + '.pickle', 'rb') as handle:
            #     eval_datasets = pickle.load(handle)
            if data_args.validation_files is not None:
                eval_datasets = {eval_dataset: AutoTask.get(eval_dataset, eval_dataset_config,
                                                            seed=data_args.data_seed).get(
                    split="validation",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=data_args.max_val_samples, lang=data_args.lang_name, file_name=validation_file)
                    for eval_dataset, eval_dataset_config, validation_file in
                    zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name, data_args.validation_files)}
            else:
                eval_datasets = {eval_dataset: AutoTask.get(eval_dataset, eval_dataset_config,
                                                            seed=data_args.data_seed).get(
                    split="validation",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=data_args.max_val_samples, lang=data_args.lang_name, file_name=data_args.validation_file)
                    for eval_dataset, eval_dataset_config in
                    zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name)}
            max_target_lengths = [AutoTask.get(dataset_name, dataset_config_name).get_max_target_length(
                tokenizer=tokenizer, default_max_length=data_args.max_target_length)
                for dataset_name, dataset_config_name in
                zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name)]

            for k, name in enumerate(eval_datasets):
                if model_args.shared_attn is True:
                    eval_datasets[name] = eval_datasets[name].map(
                        functools.partial(
                            preprocess_function, max_target_length=max_target_lengths[k], task_id=k),
                        batched=True,
                        num_proc=data_args.preprocessing_num_workers,
                        # if name != "superglue-record" else column_names+["answers"],
                        remove_columns=column_names,
                        load_from_cache_file=not data_args.overwrite_cache,
                    )
                else:
                    eval_datasets[name] = eval_datasets[name].map(
                        functools.partial(preprocess_function,
                                          max_target_length=max_target_lengths[k]),
                        batched=True,
                        num_proc=data_args.preprocessing_num_workers,
                        # if name != "superglue-record" else column_names+["answers"],
                        remove_columns=column_names,
                        load_from_cache_file=not data_args.overwrite_cache,
                    )

    if training_args.do_test:  # full eval is test dataset
        if data_args.test_files is not None:
            test_datasets = {test_dataset: AutoTask.get(test_dataset, test_dataset_config,
                                                        seed=data_args.data_seed).get(
                split="test",
                split_validation_test=False,
                add_prefix=False if adapter_args.train_task_adapters else True,
                n_obs=None, lang=data_args.lang_name, file_name=test_file)
                for test_dataset, test_dataset_config, test_file in
                zip(data_args.test_dataset_name, data_args.test_dataset_config_name, data_args.test_files)}
        else:
            test_datasets = {test_dataset: AutoTask.get(test_dataset, test_dataset_config,
                                                        seed=data_args.data_seed).get(
                split="test",
                split_validation_test=False,
                add_prefix=False if adapter_args.train_task_adapters else True,
                n_obs=None, lang=data_args.lang_name, file_name=data_args.test_file)
                for test_dataset, test_dataset_config in
                zip(data_args.test_dataset_name, data_args.test_dataset_config_name)}

        max_target_lengths = [AutoTask.get(dataset_name, dataset_config_name).get_max_target_length(
            tokenizer=tokenizer, default_max_length=data_args.max_target_length)
            for dataset_name, dataset_config_name in
            zip(data_args.test_dataset_name, data_args.test_dataset_config_name)]
        for k, name in enumerate(test_datasets):
            if model_args.shared_attn is True:
                test_datasets[name] = test_datasets[name].map(
                    functools.partial(
                        preprocess_function, max_target_length=max_target_lengths[k], task_id=k),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )
            else:
                test_datasets[name] = test_datasets[name].map(
                    functools.partial(preprocess_function,
                                      max_target_length=max_target_lengths[k]),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )

        # add real test
        if data_args.test_files is not None:
            test_datasets_real = {test_dataset: AutoTask.get(test_dataset, test_dataset_config,
                                                        seed=data_args.data_seed).get(
                split="real_test",
                split_validation_test=False,
                add_prefix=False if adapter_args.train_task_adapters else True,
                n_obs=None, lang=data_args.lang_name, file_name=test_file)
                for test_dataset, test_dataset_config, test_file in
                zip(data_args.test_dataset_name, data_args.test_dataset_config_name, data_args.test_files)}
        else:
            test_datasets_real = {test_dataset: AutoTask.get(test_dataset, test_dataset_config,
                                                        seed=data_args.data_seed).get(
                split="real_test",
                split_validation_test=False,
                add_prefix=False if adapter_args.train_task_adapters else True,
                n_obs=None, lang=data_args.lang_name, file_name=data_args.test_file)
                for test_dataset, test_dataset_config in
                zip(data_args.test_dataset_name, data_args.test_dataset_config_name)}

        max_target_lengths = [AutoTask.get(dataset_name, dataset_config_name).get_max_target_length(
            tokenizer=tokenizer, default_max_length=data_args.max_target_length)
            for dataset_name, dataset_config_name in
            zip(data_args.test_dataset_name, data_args.test_dataset_config_name)]
        for k, name in enumerate(test_datasets):
            if model_args.shared_attn is True:
                test_datasets_real[name] = test_datasets_real[name].map(
                    functools.partial(
                        preprocess_function, max_target_length=max_target_lengths[k], task_id=k),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )
            else:
                test_datasets_real[name] = test_datasets_real[name].map(
                    functools.partial(preprocess_function,
                                      max_target_length=max_target_lengths[k]),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )

    # Data collator
    label_pad_token_id = - \
        100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    if data_args.pad_to_max_length:
        data_collator = default_data_collator
    else:
        data_collator = TaskDataCollatorForSeq2Seq(
            tokenizer,
            label_pad_token_id=label_pad_token_id,
            pad_to_multiple_of=8 if training_args.fp16 else None,
        )

    # Metric, we assume we have only one training task.
    print(data_args.dataset_name, data_args.dataset_config_name)
    eval_metrics = [AutoTask.get(dataset_name, dataset_config_name).metric
                    for dataset_name, dataset_config_name in
                    zip(data_args.dataset_name, data_args.dataset_config_name)][0]

    # Extracts the extra information needed to evaluate on each dataset.
    # These information are only used in the compute_metrics.
    # We will assume that the test/eval dataloader does not change the order of
    # the data.
    data_info = {"eval": eval_datasets[data_args.eval_dataset_name[0]]['extra_fields'],
                 "test": test_datasets[data_args.test_dataset_name[0]][
                     'extra_fields'] if training_args.do_test else None,
                 "train": train_dataset['extra_fields'] if training_args.do_train else None}

    def compute_metrics(eval_preds):

        preds, labels, data_info = eval_preds
        post_processor = AutoPostProcessor.get(data_args.dataset_name[0], tokenizer,
                                               data_args.ignore_pad_token_for_loss)
        decoded_preds, decoded_labels = post_processor.process(
            preds, labels, data_info)
        result = {}
        for metric in eval_metrics:
            result.update(metric(decoded_preds, decoded_labels))
        return result


    if model_args.attn_learning_rate is not None:
        # Initialize a customized optimizer to set a different learning rate for the attention module.
        all_parameters = set(model.parameters())
        attn_params = []
        for name, param in model.named_parameters():
            if name == "encoder.attn_W_up" or name == "encoder.attn_W_down" or name == "encoder.layer_norm":
                attn_params += list(param)
        attn_params = set(attn_params)
        non_attn_params = all_parameters - attn_params
        non_attn_params = list(non_attn_params)
        attn_params = list(attn_params)

        optim = AdamW([
            {'params': non_attn_params},
            {'params': attn_params, 'lr': model_args.attn_learning_rate},
        ], lr=training_args.learning_rate, )
        scheduler = get_linear_schedule_with_warmup(
            optim, num_warmup_steps=training_args.warmup_steps, num_training_steps=len(
                train_dataset) * training_args.num_train_epochs // (
                                                                                               training_args.gradient_accumulation_steps * training_args.per_device_train_batch_size)
        )

        # Initialize our Trainer
        trainer = Seq2SeqTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset if training_args.do_train else None,
            eval_dataset=list(eval_datasets.values())[
                0] if training_args.do_eval else None,
            data_info=data_info,
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics if training_args.predict_with_generate else None,
            evaluation_metrics=TASK_TO_METRICS[data_args.dataset_name[0]],
            shared=model_args.shared_attn,
            optimizers=(optim, scheduler),
        )

    else:
        trainer = Seq2SeqTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset if training_args.do_train else None,
            eval_dataset=list(eval_datasets.values())[
                0] if training_args.do_eval else None,
            data_info=data_info,
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics if training_args.predict_with_generate else None,
            evaluation_metrics=TASK_TO_METRICS[data_args.dataset_name[0]],
            shared=model_args.shared_attn,
        )

    # Saves training config.
    if trainer.is_world_process_zero():
        os.makedirs(training_args.output_dir, exist_ok=True)
        save_training_config(args.json_file, training_args.output_dir)

    # Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint

        if training_args.compute_time:
            torch.cuda.synchronize()  # wait for move to complete
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()

        print('training_args->', training_args)
        train_result = trainer.train()
        # train_result = trainer.train(resume_from_checkpoint=checkpoint)

        if training_args.compute_time:
            end.record()
            torch.cuda.synchronize()  # wait for all_reduce to complete
            total_time = start.elapsed_time(end) / (1000 * 60)
            performance_metrics.update({"total_time in minutes ": total_time})

        model_args.save_prefix_only = True
        if model_args.save_prefix_only:
            for name, param in trainer.model.named_parameters():
                if args.self_train_attn:
                    if "encoder.attn_W.weight" == name:
                        attn_weights_params = param
                        torch.save(attn_weights_params, os.path.join(
                            training_args.output_dir, "attn_W_st_"+str(seed_id)+".pt"))
                if model_args.attn_prefix_tuning is False and ("prefix_shared" in name or "prefix" in name):
                    shared_params = param
                    print('shared_params', shared_params)
                    torch.save(shared_params, os.path.join(
                        training_args.output_dir, "prefix_embeddings_"+str(seed_id)+".pt"))
                elif model_args.attn_prefix_tuning is True and name == "prefix_shared":
                    shared_params = param
                    if model_args.shared_attn is True:
                        for i in range(config.num_target):
                            print('shared_params', shared_params[i])
                            torch.save(shared_params[i], os.path.join(
                                training_args.output_dir, "prefix_embeddings_"+str(seed_id)+"{}.pt".format(i)))
                    else:
                        print('shared_params', shared_params)
                        torch.save(shared_params, os.path.join(
                            training_args.output_dir, "prefix_embeddings"+str(seed_id)+".pt"))
                if model_args.attn_prefix_tuning is True and "encoder.attn_Wa.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        training_args.output_dir, "attn_Wa_weights.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.attn_W_down.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        training_args.output_dir, "attn_W_down.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.attn_W_up.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        training_args.output_dir, "attn_W_up.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.layer_norm.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        training_args.output_dir, "layer_norm_weight.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.layer_norm.bias" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        training_args.output_dir, "layer_norm_bias.pt"))
        # else:
        # if not os.path.exists(args.save_dir+'/seed_' + str(seed_id)):
        #     os.makedirs(args.save_dir+'/seed_' + str(seed_id))
        # trainer.save_model(output_dir=args.save_dir+'/seed_' + str(seed_id))  # Saves the tokenizer too for easy upload

        train_metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(
                train_dataset)
        )
        train_metrics["train_samples"] = min(
            max_train_samples, len(train_dataset))
        trainer.log_metrics("train", train_metrics)
        trainer.save_metrics("train", train_metrics)

        if not model_args.save_prefix_only:
            if not os.path.exists(args.save_dir+'/seed_' + str(seed_id)):
                os.makedirs(args.save_dir+'/seed_' + str(seed_id))
            trainer.save_state()

    if torch.cuda.is_available() and training_args.compute_memory:
        peak_memory = (torch.cuda.max_memory_allocated() / 1024 ** 2) / 1000
        print(
            "Memory utilization",
            peak_memory,
            "GB"
        )
        performance_metrics.update({"peak_memory": peak_memory})
    if training_args.compute_memory or training_args.compute_time:
        trainer.save_metrics("performance", performance_metrics)

    # Evaluation
    if model_args.shared_attn is True and model_args.ignore_target is False:
        learned_embeddings = trainer.model.encoder.prefix_emb.clone().detach()
    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        if model_args.shared_attn is True:
            for idx, (task, eval_dataset) in enumerate(eval_datasets.items()):
                metrics = trainer.evaluate(eval_dataset=eval_dataset,
                                           max_length=data_args.val_max_target_length, num_beams=data_args.num_beams,
                                           )
                trainer.log_metrics("eval", metrics)
                trainer.save_metrics("eval", metrics)

        else:
            for task, eval_dataset in eval_datasets.items():
                metrics = trainer.evaluate(eval_dataset=eval_dataset,
                                           max_length=data_args.val_max_target_length, num_beams=data_args.num_beams,
                                           )
                trainer.log_metrics("eval", metrics)
                trainer.save_metrics("eval", metrics)
        results['eval'] = metrics

    # prediction!!!
    if args.self_train:
        logger.info("*** Predict ***")
        if model_args.shared_attn is True:
            for idx, (task, eval_dataset) in enumerate(eval_datasets_self_train.items()):
                metrics = trainer.evaluate(eval_dataset=eval_dataset,
                                           max_length=data_args.val_max_target_length, num_beams=data_args.num_beams,
                                           )
                trainer.log_metrics("eval", metrics)
                trainer.save_metrics("eval", metrics)
        else:
            '''
            predict the pseudo labels for the entire dataset (validation + train)
            '''
            # for task, eval_dataset in eval_datasets.items():
            prediction = list(list(trainer.predict(list(eval_datasets_self_train.values())[0]))[0])
            prediction_list = []
            # print(prediction)
            for j in range(len(prediction)): # row
                prediction_list.append([])
                for k in range(len(prediction[j])): # column
                    if prediction[j][k] > 0:
                        prediction_list[-1].append(prediction[j][k])

            # save to pickle files
            # for key in eval_datasets_self_train:
            #     dataset_save = {}
            #     for k in eval_datasets_self_train[key][0]:
            #         dataset_save[k] = eval_datasets_self_train[key][k]
            #     dataset_save["labels"] = prediction_list
            #     eval_datasets_self_train[key] = Dataset.from_dict(dataset_save)

            # get perplexity of each sample
            def score(tensor_input):
                device = torch.device(str(training_args.device))
                tensor_input = torch.tensor([tensor_input[:tensor_input.index(0)]]).to(device) if 0 in tensor_input else torch.tensor([tensor_input]).to(device)
                loss = model(tensor_input, labels=tensor_input).loss.cpu()
                return math.exp(loss)

            if args.perplexity_calc:
                perplexity_scores = [score(ids) for ids in list(eval_datasets_self_train.values())[0]['input_ids']]
                with open(training_args.output_dir + '/perplexity_' + str(seed_id) + '.pickle',
                          'wb') as handle:
                    pickle.dump(perplexity_scores, handle, protocol=pickle.HIGHEST_PROTOCOL)

            # save datasets
            with open(training_args.output_dir+'/pseudo_label_dataset_' + str(seed_id) + '.pickle', 'wb') as handle:
                pickle.dump(eval_datasets_self_train, handle, protocol=pickle.HIGHEST_PROTOCOL)

            with open(training_args.output_dir+'/train_datasets_self_train_' + str(seed_id) + '.pickle', 'wb') as handle:
                pickle.dump(train_datasets_self_train, handle, protocol=pickle.HIGHEST_PROTOCOL)

            with open(training_args.output_dir+'/train_dataset_fs_' + str(seed_id) + '.pickle', 'wb') as handle:
                pickle.dump(train_dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)

            with open(training_args.output_dir + '/eval_dataset_fs_' + str(seed_id) + '.pickle', 'wb') as handle:
                pickle.dump(eval_datasets, handle, protocol=pickle.HIGHEST_PROTOCOL)

            with open(training_args.output_dir + '/labels_pred_' + str(seed_id) + '.pickle', 'wb') as handle:
                pickle.dump(prediction_list, handle, protocol=pickle.HIGHEST_PROTOCOL)

            # train dataset -- whole
            with open(training_args.output_dir + '/train_dataset_save_' + str(seed_id) + '.pickle', 'wb') as handle:
                pickle.dump(train_dataset_save, handle, protocol=pickle.HIGHEST_PROTOCOL)  # for pseudo label

            # test dataset -- whole
            with open(training_args.output_dir + '/test_dataset_save_' + str(seed_id) + '.pickle', 'wb') as handle:
                pickle.dump(test_dataset_save, handle, protocol=pickle.HIGHEST_PROTOCOL)

            # eval train dataset -- whole
            with open(training_args.output_dir + '/train_datasets_predict_' + str(seed_id) + '.pickle', 'wb') as handle:
                pickle.dump(train_datasets_predict, handle, protocol=pickle.HIGHEST_PROTOCOL)

            # save eval preds
            prediction_eval = list(list(trainer.predict(list(eval_datasets.values())[0]))[0])
            prediction_eval_list = []
            # print(prediction)
            for j in range(len(prediction_eval)):  # row
                prediction_eval_list.append([])
                for k in range(len(prediction_eval[j])):  # column
                    if prediction_eval[j][k] > 0:
                        prediction_eval_list[-1].append(prediction_eval[j][k])
            with open(training_args.output_dir + '/labels_pred_eval_' + str(seed_id) + '.pickle', 'wb') as handle:
                pickle.dump(prediction_eval_list, handle, protocol=pickle.HIGHEST_PROTOCOL)

            # predict for test set
            prediction = list(list(trainer.predict(list(test_datasets_self_train.values())[0]))[0])
            prediction_list = []
            # print(prediction)
            for j in range(len(prediction)):  # row
                prediction_list.append([])
                for k in range(len(prediction[j])):  # column
                    if prediction[j][k] > 0:
                        prediction_list[-1].append(prediction[j][k])
            with open(training_args.output_dir + '/labels_pred_test_' + str(seed_id) + '.pickle', 'wb') as handle:
                pickle.dump(prediction_list, handle, protocol=pickle.HIGHEST_PROTOCOL)

            # save train preds
            prediction_train = list(list(trainer.predict(list(train_datasets_self_train.values())[0]))[0])
            prediction_train_list = []
            print(prediction_train)
            for j in range(len(prediction_train)):  # row
                prediction_train_list.append([])
                for k in range(len(prediction_train[j])):  # column
                    print(prediction_train_list[-1], prediction_train[j][k])
                    if prediction_train[j][k] > 0:
                        prediction_train_list[-1].append(prediction_train[j][k])
            with open(training_args.output_dir + '/labels_pred_train_' + str(seed_id) + '.pickle', 'wb') as handle:
                pickle.dump(prediction_train_list, handle, protocol=pickle.HIGHEST_PROTOCOL)

    # remove checkpoint for the save_prefix_only setting to avoid overly saving models.
    model_args.save_prefix_only = True
    if model_args.save_prefix_only:
        checkpoints = glob.glob(os.path.join(
            training_args.output_dir, "checkpoint-" + str(seed_id)))
        for checkpoint_dir in checkpoints:
            # save models
            if not os.path.exists(os.path.join(checkpoint_dir, "pytorch_model.bin")):
                continue
            checkpoint_model = torch.load(os.path.join(
                os.path.join(checkpoint_dir, "pytorch_model.bin")))
            new_dir = "{}_prompt_only".format(checkpoint_dir)
            os.mkdir(new_dir)
            for name, param in checkpoint_model.items():
                if model_args.attn_prefix_tuning is False and ("prefix_shared" in name or "prefix" in name):
                    shared_params = param
                    torch.save(shared_params, os.path.join(
                        training_args.output_dir, "prefix_embeddings.pt"))
                elif model_args.attn_prefix_tuning is True and name == "prefix_shared":
                    shared_params = param
                    if model_args.shared_attn is True:
                        for i in range(config.num_target):
                            torch.save(shared_params[i], os.path.join(
                                new_dir, "prefix_embeddings_{}"+str(seed_id)+".pt".format(i)))
                    else:
                        torch.save(shared_params, os.path.join(
                            new_dir, "prefix_embeddings"+str(seed_id)+".pt"))
                if model_args.attn_prefix_tuning is True and "encoder.attn_Wa.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        new_dir, "attn_Wa_weights.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.attn_W_down.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        new_dir, "attn_W_down.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.attn_W_up.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        new_dir, "attn_W_up.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.layer_norm.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        new_dir, "layer_norm_weight.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.layer_norm.bias" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        new_dir, "layer_norm_bias.pt"))
            # after saving prompts, we will remove unnecessary checkpoint dir.
            try:
                shutil.rmtree(checkpoint_dir)
            except OSError as e:
                print("Error: %s : %s" % (checkpoint_dir, e.strerror))

    # Test
    if training_args.do_test:
        logger.info("*** Test ***")
        if model_args.shared_attn is True:
            for idx, (task, test_dataset) in enumerate(test_datasets.items()):
                trainer.model.encoder.prefix_emb[0].data = learned_embeddings[idx]
                metrics = trainer.evaluate(eval_dataset=test_dataset,
                                           max_length=data_args.test_max_target_length, num_beams=data_args.num_beams,
                                           metric_key_prefix="test"
                                           )
                trainer.log_metrics("test", metrics)
                trainer.save_metrics("test", metrics)

        else:
            for task, test_dataset in test_datasets.items():
                metrics = trainer.evaluate(eval_dataset=test_dataset,
                                           max_length=data_args.test_max_target_length, num_beams=data_args.num_beams,
                                           metric_key_prefix="test"
                                           )
                trainer.log_metrics("test", metrics)
                trainer.save_metrics("test", metrics)
        results['test'] = metrics

    return results


def _mp_fn(index):
    main()


if __name__ == "__main__":
    parser = argparse.ArgumentParser("")
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--model_name", type=str, default=None)
    parser.add_argument("--dataset_target", type=str)
    parser.add_argument("--save_dir", type=str)
    parser.add_argument("--json_file", type=str)
    parser.add_argument("--fewshot", action='store_true')
    parser.add_argument("--self_train", action='store_true', help='whether for self-train, b/c you need to do prediction and save pickles')
    parser.add_argument("--naked", action='store_true')
    parser.add_argument("--random_prefix", action='store_true')
    parser.add_argument("--fewshot_num", type=int, default=20)
    parser.add_argument("--fewshot_num_begin", type=int, default=0)
    parser.add_argument("--attn_method", type=str)
    parser.add_argument("--dataset_source", type=str, default=None)
    parser.add_argument("--target_prompt_embedding_path", type=str, default=None)
    parser.add_argument("--attn_path_sub", type=str, default=None)
    parser.add_argument("--layer_norm_dir", type=str, default=None)
    parser.add_argument("--checkpoint_dir", type=str, default=None)
    parser.add_argument("--resume_from_checkpoint", type=str, default=None)
    parser.add_argument("--load_layer_norm", action='store_true')
    parser.add_argument("--load_attention", action='store_true')
    parser.add_argument("--ignore_target", action='store_true')
    parser.add_argument("--self_train_attn", action='store_true')
    parser.add_argument("--fewshot_source", action='store_true')
    parser.add_argument("--zeroshot", action='store_true')
    parser.add_argument("--self_train_source_dir", type=str)
    parser.add_argument("--eval_dataset_type", type=str, default='few-shot')
    parser.add_argument("--perplexity_calc", action='store_true', help='whether to calc perplexity when prediction')
    parser.add_argument("--update_path_original", action="store_true", help='True when you want to update it for original one')
    parser.add_argument("--fs_sd_update_prefix", action="store_true", help='True when you want to update it for original one')

    parser.add_argument("--prompt_embedding_path_prefix", type=str, default=None)
    parser.add_argument("--data_full", action="store_true",
                        help='use all the data not fewshot')
    parser.add_argument("--init_prompt", action="store_true",
                        help='init target')
    parser.add_argument("--deepspeed", default=None,
                        help='init target')
    parser.add_argument("--per_device_train_batch_size", default=None,
                        help='init target')
    parser.add_argument("--per_device_eval_batch_size", default=None,
                        help='init target')
    parser.add_argument("--fewshot_size", type=int, default=None,
                        help='init target')


    args = parser.parse_args()
    print('>>>')
    if args.fewshot:
        with open(args.save_dir+'/results.txt', 'w') as f:
            for i in range(args.fewshot_num_begin, args.fewshot_num):
                results = main(i, args)
                f.write(str(results) + '\n')
                f.flush()
                print('results --->', results)
    else:
        with open(args.save_dir + '/results_st_' + str(args.fewshot_num_begin) + '.txt', 'w') as f:
            for i in range(args.fewshot_num_begin, args.fewshot_num):
                results = main(i, args)
                f.write(str(results) + '\n')
                f.flush()

