# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
from transformers import Seq2SeqTrainingArguments
from utils import modify_model_after_init, save_training_config
import shutil
import glob
from data import AutoPostProcessor
from third_party.models import T5Config, T5ForConditionalGeneration_vote
from dataclasses import dataclass, field
from training_args import AdapterTrainingArguments
from third_party.trainers import Seq2SeqTrainer_vote, MyModel
from data import TaskDataCollatorForSeq2Seq
from data import AutoTask
from utils import get_adapter_config
from transformers.trainer_utils import is_main_process, get_last_checkpoint
from transformers import (
    AutoTokenizer,
    HfArgumentParser,
    default_data_collator,
    set_seed,
    AdamW,
    get_linear_schedule_with_warmup,
)
import transformers
from datasets import load_dataset, load_metric, concatenate_datasets
from typing import Optional, List
import subprocess
import sys
import functools
import logging
from pytz import common_timezones
import torch
import os
import pickle
from data.tasks import TASK_MAPPING
from datasets import Dataset
import argparse


os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "1"

logger = logging.getLogger(__name__)


def run_command(command):
    output = subprocess.getoutput(command)
    return output


TASK_TO_METRICS = {
    "mrpc": ["accuracy", "f1"],
    "mrpc_ppt": ["accuracy", "f1"],
    "cola": ["matthews_correlation"],
    "stsb": ["pearson", "spearmanr"],
    "sst2": ["accuracy"],
    "mnli": ["accuracy"],
    "mnli_ppt": ["accuracy"],
    "mnli_mismatched": ["accuracy"],
    "mnli_matched": ["accuracy"],
    "qnli": ["accuracy"],
    "rte": ["accuracy"],
    "wnli": ["accuracy"],
    "wnli_ppt": ["accuracy"],
    "qqp": ["accuracy", "f1"],
    "superglue-boolq": ["accuracy"],
    "superglue-boolq_ppt": ["accuracy"],
    "superglue-rte": ["accuracy"],
    "superglue-rte_ppt": ["accuracy"],
    "superglue-cb": ["f1_multiclass", "accuracy"],
    "superglue-cb_ppt": ["f1_multiclass", "accuracy"],
    "superglue-copa": ["accuracy"],
    "superglue-multirc": ["f1", "em"],
    "superglue-multirc_ppt": ["f1", "em"],
    "superglue-wic": ["accuracy"],
    "superglue-wic_ppt": ["accuracy"],
    "superglue-wsc.fixed": ["accuracy"],
    "superglue-wsc.fixed_ppt": ["accuracy"],
    "superglue-record": ["f1", "em"],
    "multi_nli": ["accuracy"],
    "squad": ["em", "f1"],
    "snli": ["accuracy"],
    "nq": ["em", "f1"],
    "hotpotqa": ["em", "f1"],
    "searchqa": ["em", "f1"],
    "newsqa": ["em", "f1"],
    "triviaqa": ["em", "f1"],
    "imdb": ["accuracy"],
    "winogrande": ["accuracy"],
    "scitail": ["accuracy"],
    "amazon_polarity": ["accuracy"],
    "yelp_polarity": ["accuracy"],
    "paws": ["accuracy"],
}


# run_seq2seq parameters.


@dataclass
class TrainingArguments(Seq2SeqTrainingArguments):
    # author: add ur parser here~
    print_num_parameters: Optional[bool] = field(
        default=False,
        metadata={"help": "If set, print the parameters of " "the model."},
    )
    do_test: Optional[bool] = field(
        default=False, metadata={"help": "If set, evaluates the test performance."}
    )
    do_eval_predict: Optional[bool] = field(
        default=False, metadata={"help": "If set, evaluates the test performance."}
    )
    split_validation_test: Optional[bool] = field(
        default=False,
        metadata={
            "help": "If set, for the datasets which do not"
            "have the test set, we use validation set as their"
            "test set and make a validation set from either"
            "splitting the validation set into half (for smaller"
            "than 10K samples datasets), or by using 1K examples"
            "from training set as validation set (for larger"
            " datasets)."
        },
    )
    compute_time: Optional[bool] = field(
        default=False, metadata={"help": "If set measures the time."}
    )
    compute_memory: Optional[bool] = field(
        default=False, metadata={"help": "if set, measures the memory"}
    )
    prefix_length: Optional[int] = field(
        default=100, metadata={"help": "Defines the length for prefix tuning."}
    )


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={
            "help": "Path to pretrained model or model identifier from huggingface.co/models"
        }
    )
    config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained config name or path if not the same as model_name"
        },
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained tokenizer name or path if not the same as model_name"
        },
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": "Where to store the pretrained models downloaded from huggingface.co"
        },
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={
            "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
        },
    )
    model_revision: str = field(
        default="main",
        metadata={
            "help": "The specific model version to use (can be a branch name, tag name or commit id)."
        },
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )
    load_prefix_embeddings: bool = field(
        default=False, metadata={"help": "load prefix embeddings or not"},
    )
    save_prefix_only: bool = field(
        default=False, metadata={"help": "save prefix embeddings only"},
    )

    prompt_embedding_path: Optional[List[str]] = field(
        default=None, metadata={"help": "A list of the paths to prefix embeddings"}
    )

    target_prompt_embedding_path: Optional[str] = field(
        default=None, metadata={"help": "a path to the target prompt embedding"}
    )

    attn_prefix_tuning: bool = field(
        default=False, metadata={"help": "Set true if you try ATTEMPT."},
    )

    attn_method: Optional[str] = field(
        default="sub",
        metadata={
            "help": "Attention model for attn_prefix. We currently support the following methods: linear, sub (our main method), and constant (gives the constant and equal weights to all of the prompts.)"
        },
    )

    shared_attn: bool = field(
        default=False, metadata={"help": "shared attention"},
    )

    load_attention: bool = field(
        default=False,
        metadata={"help": "Set true if you want to load pre-trained attention weights"},
    )

    attn_path: Optional[str] = field(
        default=None,
        metadata={"help": "path to attention weights (linear attentions). "},
    )

    attn_path_sub: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "list of the path to attention weights (sub attentions). [path_to_down_projection_weights, path_to_up_projection_weights]"
        },
    )

    ignore_target: bool = field(
        default=False,
        metadata={
            "help": "Whether to ignore the new target tokens. Mainly for ablation."
        },
    )

    fix_attention: bool = field(
        default=False,
        metadata={
            "help": "this will make the attention weights frozen during training. Mainly for ablation."
        },
    )

    temperature: float = field(
        default=2000, metadata={"help": "set the soft max temperature of ATTEMPT."},
    )

    attn_learning_rate: float = field(
        default=None,
        metadata={"help": "set the learning rate for the attention modules."},
    )

    load_layer_norm: bool = field(
        default=False,
        metadata={
            "help": "Set true if you want to load pre-trained layer-norm weight and biases."
        },
    )

    layer_norm_dir: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "Layer norm dir. [path_to_layer_norm_weight.pt, path_to_layer_norm_bias.pt]"
        },
    )

    prefix_num: Optional[int] = field(
        default=1, metadata={"help": "the number of prefix"}
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    task_name: Optional[List[str]] = field(
        default=None,
        metadata={"help": "The name of the dataset to use (via the datasets library)."},
    )
    dataset_config_name: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "The configuration name of the dataset to use (via the datasets library)."
        },
    )
    eval_dataset_name: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "The name of the evaluation dataset to use (via the datasets library)."
        },
    )
    eval_dataset_config_name: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "The configuration name of the evaluation dataset to use (via the datasets library)."
        },
    )
    test_dataset_name: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "The name of the test dataset to use (via the datasets library)."
        },
    )
    test_dataset_config_name: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "The configuration name of the test dataset to use (via the datasets library)."
        },
    )
    lang_name: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the dataset to use (via the datasets library)."},
    )
    overwrite_cache: bool = field(
        default=False,
        metadata={"help": "Overwrite the cached training and evaluation sets"},
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    max_source_length: Optional[int] = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    max_target_length: Optional[int] = field(
        default=128,
        metadata={
            "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    val_max_target_length: Optional[int] = field(
        default=None,
        metadata={
            "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
            "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
            "during ``evaluate`` and ``predict``."
        },
    )
    test_max_target_length: Optional[int] = field(
        default=None,
        metadata={
            "help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
            "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
            "during ``evaluate`` and ``predict``."
        },
    )
    pad_to_max_length: bool = field(
        default=False,
        metadata={
            "help": "Whether to pad all samples to model maximum sentence length. "
            "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
            "efficient on GPU but very bad for TPU."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
            "value if set."
        },
    )
    max_val_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
            "value if set."
        },
    )
    max_test_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
            "value if set."
        },
    )
    num_beams: Optional[int] = field(
        default=None, metadata={"help": "Number of beams to use for evaluation."}
    )
    ignore_pad_token_for_loss: bool = field(
        default=True,
        metadata={
            "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
        },
    )
    task_adapters: Optional[List[str]] = field(
        default=None,
        metadata={"help": "Defines a dictionary from task adapters to the tasks."},
    )
    task_embeddings: Optional[List[str]] = field(
        default=None,
        metadata={"help": "Defines a dictionary from tasks to the tasks embeddings."},
    )
    data_seed: Optional[int] = field(
        default=42, metadata={"help": "seed used to shuffle the data."}
    )

    train_file: Optional[str] = field(
        default=None,
        metadata={"help": "A csv or a json file containing the training data."},
    )
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "A csv or a json file containing the validation data."},
    )
    test_file: Optional[str] = field(
        default=None,
        metadata={"help": "A csv or a json file containing the test data."},
    )

    train_files: Optional[List[str]] = field(
        default=None,
        metadata={"help": "A list of csv or json files containing the training data."},
    )
    validation_files: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "A list of csv or json files containing the validation data."
        },
    )
    test_files: Optional[List[str]] = field(
        default=None,
        metadata={"help": "A list of csv or json files containing the test data."},
    )

    def __post_init__(self):
        if self.task_name is None:
            raise ValueError(
                "Need either a dataset name or a training/validation file."
            )
        if self.val_max_target_length is None:
            self.val_max_target_length = self.max_target_length
        if self.test_max_target_length is None:
            self.test_max_target_length = self.max_target_length


def save_to_dict(path_save, save_stuff):
    with open(path_save, "wb") as handle:
        pickle.dump(save_stuff, handle, protocol=pickle.HIGHEST_PROTOCOL)


def load_args(args, seed_id, json_file=None):
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
    parser = HfArgumentParser(
        (
            ModelArguments,
            DataTrainingArguments,
            TrainingArguments,
            AdapterTrainingArguments,
        )
    )
    if json_file:
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args, adapter_args = parser.parse_json_file(
            json_file=os.path.abspath(json_file)
        )
    elif args.json_file:
        model_args, data_args, training_args, adapter_args = parser.parse_json_file(
            json_file=os.path.abspath(args.json_file)
        )
    else:
        (
            model_args,
            data_args,
            training_args,
            adapter_args,
        ) = parser.parse_args_into_dataclasses()

    # freeze all the weights/ uniform
    if args.attn_method:
        model_args.attn_method = args.attn_method
    if args.save_dir:
        training_args.output_dir = args.save_dir
    if args.naked:
        model_args.prompt_embedding_path.append(
            "./source_prompts/" + args.dataset_source + "_prompt.pt"
        )
    # if args.self_train or args.naked:
    #     training_args.output_dir = args.save_dir
    #     data_args.test_dataset_name[0] = args.dataset_target
    #     data_args.eval_dataset_name[0] = args.dataset_target
    #     data_args.task_name[0] = args.dataset_target
    if args.dataset_target:
        data_args.task_name = [args.dataset_target]
        data_args.test_dataset_name = [args.dataset_target]
        data_args.eval_dataset_name = [args.dataset_target]
    if args.target_prompt_embedding_path:
        model_args.target_prompt_embedding_path = args.target_prompt_embedding_path
    if args.load_attention:
        model_args.load_attention = True
    if args.attn_path_sub:
        model_args.attn_path_sub = [
            args.attn_path_sub + "/attn_W_down.pt",
            args.attn_path_sub + "/attn_W_up.pt",
        ]
    if args.layer_norm_dir:
        model_args.layer_norm_dir = args.layer_norm_dir
    if args.resume_from_checkpoint:
        training_args.resume_from_checkpoint = args.resume_from_checkpoint
    if args.ignore_target:
        model_args.ignore_target = True
        training_args.do_test = False

    # uodate data
    data_args.data_seed += seed_id
    data_args.dataset_name = data_args.task_name
    data_args.eval_dataset_name = data_args.eval_dataset_name
    data_args.test_dataset_name = data_args.test_dataset_name
    data_args.dataset_config_name = data_args.dataset_config_name
    data_args.eval_dataset_config_name = data_args.eval_dataset_config_name
    data_args.test_dataset_config_name = data_args.test_dataset_config_name
    assert len(data_args.dataset_name) == len(data_args.dataset_config_name)
    if data_args.eval_dataset_name is not None:
        assert len(data_args.eval_dataset_name) == len(
            data_args.eval_dataset_config_name
        )
    if data_args.test_dataset_name is not None:
        assert len(data_args.test_dataset_name) == len(
            data_args.test_dataset_config_name
        )

    return model_args, data_args, training_args, adapter_args


def setup_logger(training_args):
    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    logger.setLevel(
        logging.INFO if is_main_process(training_args.local_rank) else logging.WARN
    )

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
    logger.info("Training/evaluation parameters %s", training_args)
    return logger


def detect_checkpoints(args, training_args):
    # Detecting last checkpoint.
    last_checkpoint = None
    if (
        os.path.isdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        print("#### last_checkpoint ", last_checkpoint)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            pass
        elif last_checkpoint is not None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )
    if args.checkpoint_dir:
        last_checkpoint = get_last_checkpoint(args.checkpoint_dir)
    return last_checkpoint


def setup_config_T5(data_args, model_args, training_args, adapter_args):
    config = T5Config.from_pretrained(
        model_args.config_name
        if model_args.config_name
        else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    config.train_task_adapters = adapter_args.train_task_adapters
    config.prefix_tuning = adapter_args.prefix_tuning
    config.attn_prefix_tuning = model_args.attn_prefix_tuning
    config.attn_method = model_args.attn_method
    config.ignore_target = model_args.ignore_target
    config.shared_attn = model_args.shared_attn
    config.prefix_num = model_args.prefix_num
    config.num_target = len(data_args.task_name)
    config.temperature = model_args.temperature
    config.fix_attention = model_args.fix_attention
    adapter_config = get_adapter_config(adapter_args, data_args, training_args, config)
    return config, adapter_config


def update_model_memory(
    model, model_args, training_args, tokenizer, adapter_args, adapter_config, config
):
    # print('>>>',model_args.prompt_embedding_path)
    if model_args.load_prefix_embeddings is True:
        if model_args.prompt_embedding_path is None:
            for name, param in model.named_parameters():
                if "prefix_shared" in name or "prefix" in name:
                    shared_params = [param]

        else:
            shared_params = []
            # print(model_args.prompt_embedding_path)
            for path in model_args.prompt_embedding_path:
                shared_param = torch.load(path, map_location=training_args.device)
                # print('shared_param', shared_param)
                shared_params.append(shared_param)
            if model_args.target_prompt_embedding_path is not None:
                target_prompt_embedding = torch.load(
                    model_args.target_prompt_embedding_path
                )
            # print('shared_params', len(shared_params), shared_params)
            model.store_prefix_weights(shared_params)

        if model_args.attn_prefix_tuning is True:
            if training_args.do_train is True and model_args.shared_attn is False:
                # Load all of the source prompts
                # author - prefix
                model.store_prefix_weights(shared_params)
                # Initialize the target task prompt embedding using the first prompts
                model.update_prefix_weights_single(shared_params[0])
            elif training_args.do_train is True and model_args.shared_attn is True:
                # Load all of the source prompts
                model.store_prefix_weights(shared_params)
                # Initialize the target task prompt embeddings using the first prompts.
                if args.random_prefix:
                    model.update_prefix_weights([])
                else:
                    model.update_prefix_weights_multi(
                        shared_params[0], num_target=config.num_target
                    )
            else:
                # For inference
                # Load all of the source prompts
                model.store_prefix_weights(shared_params)
                # Load the trained target task prompt.
                model.update_prefix_weights_single(target_prompt_embedding)

        else:
            if model_args.target_prompt_embedding_path is None:
                if args.random_prefix:
                    model.update_prefix_weights([])
                else:
                    model.update_prefix_weights(shared_params)
            else:
                model.update_prefix_weights(shared_params, target_prompt_embedding)

    # Load linear attention
    if model_args.load_attention is True and model_args.attn_path is not None:
        model.update_attention_weights(torch.load(model_args.attn_path))

    # Load projection-based attentions
    if model_args.load_attention is True and model_args.attn_path_sub is not None:
        model.update_attention_weights_sub(model_args.attn_path_sub)

    # Load layer norm weights & biases
    if model_args.load_layer_norm is True and model_args.layer_norm_dir is not None:
        model.update_layer_norm_weights(model_args.layer_norm_dir)

    model.resize_token_embeddings(len(tokenizer))
    model = modify_model_after_init(model, training_args, adapter_args, adapter_config)

    return model


def train_data_setup(
    data_args,
    tokenizer,
    model_args,
    preprocess_function,
    training_args,
    adapter_args,
    seed_id,
    args,
):
    column_names = ["source", "target", "extra_fields"]
    # full size for pseudo label training.
    if args.get_pseudo_labels:
        if data_args.train_files is not None:
            train_datasets_save = [
                AutoTask.get(
                    dataset_name, dataset_config_name, seed=data_args.data_seed
                ).get(
                    split="train+test+validation",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=None,
                    lang=data_args.lang_name,
                    file_name=train_file,
                )
                for dataset_name, dataset_config_name, train_file in zip(
                    data_args.dataset_name,
                    data_args.dataset_config_name,
                    data_args.train_files,
                )
            ]
        else:
            train_datasets_save = [
                AutoTask.get(
                    dataset_name, dataset_config_name, seed=data_args.data_seed
                ).get(
                    split="train+test+validation",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=None,
                    lang=data_args.lang_name,
                    file_name=data_args.train_file,
                )
                for dataset_name, dataset_config_name in zip(
                    data_args.dataset_name, data_args.dataset_config_name
                )
            ]

        max_target_lengths = [
            AutoTask.get(dataset_name, dataset_config_name).get_max_target_length(
                tokenizer=tokenizer, default_max_length=data_args.max_target_length,
            )
            for dataset_name, dataset_config_name in zip(
                data_args.dataset_name, data_args.dataset_config_name
            )
        ]

        for i, train_dataset_save in enumerate(train_datasets_save):
            if model_args.shared_attn is True:
                train_datasets_save[i] = train_datasets_save[i].map(
                    functools.partial(
                        preprocess_function,
                        max_target_length=max_target_lengths[i],
                        task_id=i,
                    ),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    # if train_dataset != "superglue-record" else column_names+["answers"],
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )
            else:
                train_datasets_save[i] = train_datasets_save[i].map(
                    functools.partial(
                        preprocess_function, max_target_length=max_target_lengths[i]
                    ),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    # if train_dataset != "superglue-record" else column_names+["answers"],
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )
        train_dataset_save = concatenate_datasets(train_datasets_save)
    else:
        train_dataset_save = None

    # few-shot
    if not args.self_train_attn:
        if data_args.train_files is not None:
            train_datasets = [
                AutoTask.get(
                    dataset_name, dataset_config_name, seed=data_args.data_seed
                ).get(
                    split="train",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=data_args.max_train_samples,
                    lang=data_args.lang_name,
                    file_name=train_file,
                )
                for dataset_name, dataset_config_name, train_file in zip(
                    data_args.dataset_name,
                    data_args.dataset_config_name,
                    data_args.train_files,
                )
            ]
        else:
            train_datasets = [
                AutoTask.get(
                    dataset_name, dataset_config_name, seed=data_args.data_seed
                ).get(
                    split="train",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=data_args.max_train_samples,
                    lang=data_args.lang_name,
                    file_name=data_args.train_file,
                )
                for dataset_name, dataset_config_name in zip(
                    data_args.dataset_name, data_args.dataset_config_name
                )
            ]
        max_target_lengths = [
            AutoTask.get(dataset_name, dataset_config_name).get_max_target_length(
                tokenizer=tokenizer, default_max_length=data_args.max_target_length,
            )
            for dataset_name, dataset_config_name in zip(
                data_args.dataset_name, data_args.dataset_config_name
            )
        ]

        for i, train_dataset in enumerate(train_datasets):
            if model_args.shared_attn is True:
                train_datasets[i] = train_datasets[i].map(
                    functools.partial(
                        preprocess_function,
                        max_target_length=max_target_lengths[i],
                        task_id=i,
                    ),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    # if train_dataset != "superglue-record" else column_names+["answers"],
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )
            else:
                train_datasets[i] = train_datasets[i].map(
                    functools.partial(
                        preprocess_function, max_target_length=max_target_lengths[i]
                    ),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    # if train_dataset != "superglue-record" else column_names+["answers"],
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )
        train_dataset = concatenate_datasets(train_datasets)
    else:
        with open(
            args.self_train_source_dir
            + "/train_dataset_fs_"
            + str(seed_id)
            + ".pickle",
            "rb",
        ) as handle:
            train_dataset = pickle.load(handle)
    print('train_dataset', train_dataset[0])
    return train_dataset, train_dataset_save


def eval_data_setup(
    data_args,
    tokenizer,
    model_args,
    preprocess_function,
    training_args,
    adapter_args,
    seed_id,
    args,
):
    column_names = ["source", "target", "extra_fields"]

    if args.get_pseudo_labels:
        if data_args.validation_files is not None:
            eval_datasets_self_train = {
                eval_dataset: AutoTask.get(
                    eval_dataset, eval_dataset_config, seed=data_args.data_seed
                ).get(
                    split="train+validation+test",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=None,
                    lang=data_args.lang_name,
                    file_name=validation_file,
                )
                for eval_dataset, eval_dataset_config, validation_file in zip(
                    data_args.eval_dataset_name,
                    data_args.eval_dataset_config_name,
                    data_args.validation_files,
                )
            }

        else:
            eval_datasets_self_train = {
                eval_dataset: AutoTask.get(
                    eval_dataset, eval_dataset_config, seed=data_args.data_seed
                ).get(
                    split="train+validation+test",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=None,
                    lang=data_args.lang_name,
                    file_name=data_args.validation_file,
                )
                for eval_dataset, eval_dataset_config in zip(
                    data_args.eval_dataset_name, data_args.eval_dataset_config_name
                )
            }

        max_target_lengths_self_train = [
            AutoTask.get(dataset_name, dataset_config_name).get_max_target_length(
                tokenizer=tokenizer, default_max_length=data_args.max_target_length
            )
            for dataset_name, dataset_config_name in zip(
                data_args.eval_dataset_name, data_args.eval_dataset_config_name
            )
        ]

        for k, name in enumerate(eval_datasets_self_train):
            if model_args.shared_attn is True:
                eval_datasets_self_train[name] = eval_datasets_self_train[name].map(
                    functools.partial(
                        preprocess_function,
                        max_target_length=max_target_lengths_self_train[k],
                        task_id=k,
                    ),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    # if name != "superglue-record" else column_names+["answers"],
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )
            else:
                eval_datasets_self_train[name] = eval_datasets_self_train[name].map(
                    functools.partial(
                        preprocess_function,
                        max_target_length=max_target_lengths_self_train[k],
                    ),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    # if name != "superglue-record" else column_names+["answers"],
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )
    else:
        eval_datasets_self_train = None
    # few shot
    if not args.self_train_attn:
        if data_args.validation_files is not None:
            eval_datasets = {
                eval_dataset: AutoTask.get(
                    eval_dataset, eval_dataset_config, seed=data_args.data_seed
                ).get(
                    split="validation",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=data_args.max_val_samples,
                    lang=data_args.lang_name,
                    file_name=validation_file,
                )
                for eval_dataset, eval_dataset_config, validation_file in zip(
                    data_args.eval_dataset_name,
                    data_args.eval_dataset_config_name,
                    data_args.validation_files,
                )
            }
        else:
            eval_datasets = {
                eval_dataset: AutoTask.get(
                    eval_dataset, eval_dataset_config, seed=data_args.data_seed
                ).get(
                    split="validation",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=data_args.max_val_samples,
                    lang=data_args.lang_name,
                    file_name=data_args.validation_file,
                )
                for eval_dataset, eval_dataset_config in zip(
                    data_args.eval_dataset_name, data_args.eval_dataset_config_name
                )
            }
        max_target_lengths = [
            AutoTask.get(dataset_name, dataset_config_name).get_max_target_length(
                tokenizer=tokenizer, default_max_length=data_args.max_target_length
            )
            for dataset_name, dataset_config_name in zip(
                data_args.eval_dataset_name, data_args.eval_dataset_config_name
            )
        ]

        for k, name in enumerate(eval_datasets):
            if model_args.shared_attn is True:
                eval_datasets[name] = eval_datasets[name].map(
                    functools.partial(
                        preprocess_function,
                        max_target_length=max_target_lengths[k],
                        task_id=k,
                    ),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    # if name != "superglue-record" else column_names+["answers"],
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )
            else:
                eval_datasets[name] = eval_datasets[name].map(
                    functools.partial(
                        preprocess_function, max_target_length=max_target_lengths[k]
                    ),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    # if name != "superglue-record" else column_names+["answers"],
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )

    else:  # self_train_attn
        eval_datasets_self_train = None
        with open(
            args.self_train_source_dir + "/eval_dataset_fs_" + str(seed_id) + ".pickle",
            "rb",
        ) as handle:
            eval_datasets = pickle.load(handle)

    return eval_datasets, eval_datasets_self_train


def test_data_setup(
    data_args,
    tokenizer,
    model_args,
    preprocess_function,
    training_args,
    adapter_args,
    seed_id,
    args,
):
    if args.test_with_val:
        split = 'validation'
        n_obs = data_args.max_val_samples
    else:
        split = 'test'
        n_obs = None
    column_names = ["source", "target", "extra_fields"]
    if data_args.test_files is not None:
        test_datasets = {
            test_dataset: AutoTask.get(
                test_dataset, test_dataset_config, seed=data_args.data_seed
            ).get(
                split=split,
                split_validation_test=training_args.split_validation_test,
                add_prefix=False if adapter_args.train_task_adapters else True,
                n_obs=n_obs,
                lang=data_args.lang_name,
                file_name=test_file,
            )
            for test_dataset, test_dataset_config, test_file in zip(
                data_args.test_dataset_name,
                data_args.test_dataset_config_name,
                data_args.test_files,
            )
        }
    else:
        test_datasets = {
            test_dataset: AutoTask.get(
                test_dataset, test_dataset_config, seed=data_args.data_seed
            ).get(
                split=split,
                split_validation_test=training_args.split_validation_test,
                add_prefix=False if adapter_args.train_task_adapters else True,
                n_obs=n_obs,
                lang=data_args.lang_name,
                file_name=data_args.test_file,
            )
            for test_dataset, test_dataset_config in zip(
                data_args.test_dataset_name, data_args.test_dataset_config_name
            )
        }
    max_target_lengths = [
        AutoTask.get(dataset_name, dataset_config_name).get_max_target_length(
            tokenizer=tokenizer, default_max_length=data_args.max_target_length
        )
        for dataset_name, dataset_config_name in zip(
            data_args.test_dataset_name, data_args.test_dataset_config_name
        )
    ]
    for k, name in enumerate(test_datasets):
        if model_args.shared_attn is True:
            test_datasets[name] = test_datasets[name].map(
                functools.partial(
                    preprocess_function,
                    max_target_length=max_target_lengths[k],
                    task_id=k,
                ),
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
            )
        else:
            test_datasets[name] = test_datasets[name].map(
                functools.partial(
                    preprocess_function, max_target_length=max_target_lengths[k]
                ),
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
            )
    return test_datasets


def def_trainer(
    t5_model,
    tokenizer,
    model_args,
    training_args,
    data_args,
    train_dataset,
    eval_datasets,
    test_datasets,
    compute_metrics,
    config,
    args,
    seed_id,
):
    label_pad_token_id = (
        -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    )
    if data_args.pad_to_max_length:
        data_collator = default_data_collator
    else:
        data_collator = TaskDataCollatorForSeq2Seq(
            tokenizer,
            label_pad_token_id=label_pad_token_id,
            pad_to_multiple_of=8 if training_args.fp16 else None,
        )

    data_info = {
        "eval": eval_datasets[data_args.eval_dataset_name[0]]["extra_fields"],
        "test": test_datasets[data_args.test_dataset_name[0]]["extra_fields"]
        if training_args.do_test
        else None,
        "train": train_dataset["extra_fields"] if training_args.do_train else None,
    }

    idx_lst = None
    if args.idx_lst:
        idx_lst = [int(i) for i in args.idx_lst.split(",")]

    file_path = None

    if args.test_only:
        type = "test"
        training_args.do_train = False
        training_args.do_eval = False
        train_dataset = None
        eval_datasets = None
    else:
        type = "uw-input"
        if args.attention_uw:
            type='uw-input-attention_at'

    if args.eval_acc:
        training_args.do_train = False
        training_args.do_eval = False
        train_dataset = None
        eval_datasets = None
        if args.load_idxs:
            type = "load_idxs"
            file_path = args.load_idxs
        else:
            type = "load"
            file_path = args.load_eval_vote_file
        #     file_path = (
        #     "/export/home/OpenPrompt/mixture_prompt/plots/train_loss_acc_teacher.pickle"
        # )

    if args.load_idxs_diff_num:
        file_path = args.load_idxs

    print('Trainer type ->>>>', type)
    my_model = MyModel(
        embed_size=t5_model.model_dim,
        type=type,
        vocab_size=config.vocab_size,
        target_task=data_args.eval_dataset_name[0],
        mapping=True,
        file_path=file_path,
        seed_id=seed_id,
        neural_deal=args.neural_deal,
        idx_lst=idx_lst,
        attention_size_input_key=args.attention_size_input_key,
        attention_size_input_query=args.attention_size_input_query,
        attention_size_output=args.attention_size_output,
        hloss=args.hloss,
        dropout=args.dropout,
        update_idx=args.load_idxs_diff_num,
    )

    # my_model = MyModel(vocab_size=t5_model.vocab_size)
    trainer = Seq2SeqTrainer_vote(
        model=my_model,
        t5_model=t5_model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=list(eval_datasets.values())[0] if training_args.do_eval else None,
        data_info=data_info,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics
        if training_args.predict_with_generate
        else None,
        evaluation_metrics=TASK_TO_METRICS[data_args.dataset_name[0]],
        shared=model_args.shared_attn,
    )  # ,vocab_size=t5_model.vocab_size)

    # for param in trainer.attention_net.parameters():
    #     param.requires_grad = True
    # trainer.attention_net.requires_grad = True
    # print('attention_net', trainer.attention_net[0].weight)
    return trainer


def train_once(trainer, training_args, last_checkpoint=None):
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint

    # train
    train_result = trainer.train(resume_from_checkpoint=checkpoint)

    train_metrics = train_result.metrics
    # max_train_samples = (
    #     data_args.max_train_samples if data_args.max_train_samples is not None else len(
    #         train_dataset)
    # )
    # train_metrics["train_samples"] = min(
    #     max_train_samples, len(train_dataset))
    trainer.log_metrics("train", train_metrics)
    trainer.save_metrics("train", train_metrics)
    return trainer


def update_path_sp(args, seed_id):
    res = []
    for source in ["mnli", "sst2", "qnli", "qqp", "squad", "record"]:
        path = (
            "/export/home/OpenPrompt/mixture_prompt/outputs/vote_attn_teachers/self_fs_s_"
            + source
            + "_t_"
            + args.dataset_target
            + "_st_attn/prefix_embeddings_"
            + str(seed_id)
            + ".pt"
        )
        res.append(path)
    return res

def update_path_fs(args, seed_id): # _2
    res = []
    for source in ["mnli", "sst2", "qnli", "qqp", "squad", "record"]:
        path = args.prompt_embedding_path_prefix + '/' + args.dataset_target + "_" + \
               source + '/prefix_embeddings_' + str(seed_id) +'.pt'
        res.append(path)
    return res

def update_path_single_seed(args, seed_id):
    res = []
    path = args.prompt_embedding_path_prefix + '/'+args.dataset_target.replace('_ppt', '')+'/prefix_embeddings_'+str(seed_id)+'.pt'
    res.append(path)
    return res

def update_path_source(args):
    res = []
    for source in ["mnli", "sst2", "qnli", "qqp", "squad", "record"]:
        path = args.prompt_embedding_path_prefix + '/' + args.prompt_embedding_path_model + '_' + source + "/prefix_embeddings.pt"
        res.append(path)
    return res

def update_path_original(args):
    res = []
    for source in ["mnli", "sst2", "qnli", "qqp", "squad", "record"]:
        path = args.prompt_embedding_path_prefix + '/' + source + "_prompt.pt"
        res.append(path)
    return res

def predict(trainer, eval_datasets_self_train, train_datasets_save):
    # for task, eval_dataset in eval_datasets.items():
    prediction = list(
        list(trainer.predict(list(eval_datasets_self_train.values())[0]))[0]
    )
    prediction_list = []
    # print(prediction)
    for j in range(len(prediction)):  # row
        prediction_list.append([])
        for k in range(len(prediction[j])):  # column
            if prediction[j][k] > 0:
                prediction_list[-1].append(prediction[j][k])

    # save to pickle files
    # for train_datasets in train_datasets_save:
    #     dataset_save = {}
    #     for k in eval_datasets_self_train[key][0]:
    #         dataset_save[k] = eval_datasets_self_train[key][k]
    #     dataset_save["labels"] = prediction_list
    #     eval_datasets_self_train[key] = Dataset.from_dict(dataset_save)
    dataset_res = []
    dataset = train_datasets_save
    label_max = max([len(i) for i in dataset["labels"]])
    for i, element in enumerate(dataset):

        # if vote and len(element['attention_mask']) <= 256 and len(element['input_ids']) <= 256:
        element["labels"] = prediction_list[i]
        # make the size the same for batch training.
        element["attention_mask"] += [0] * (256 - len(element["attention_mask"]))
        element["input_ids"] += [0] * (256 - len(element["input_ids"]))
        element["labels"] = tuple(
            list(element["labels"]) + [-100] * (label_max - len(element["labels"]))
        )
        dataset_res.append(element)

    dataset_dict = {}
    for k in dataset_res[0]:
        # print('k', k)
        dataset_dict[k] = [ele[k] for ele in dataset_res]

    dataset = Dataset.from_dict(dataset_dict)

    return dataset


def main(seed_id, args, json_file=None):
    print(args)
    # set up args
    model_args, data_args, training_args, adapter_args = load_args(
        args, seed_id, json_file
    )

    # update soft prompts to fine-tuned ones.
    if args.update_sp:
        model_args.prompt_embedding_path = update_path_sp(args, seed_id)
    if args.fewshot_update_prefix:
        model_args.prompt_embedding_path = update_path_fs(args, seed_id)
    if args.fs_update_one_prompt_seed:
        model_args.prompt_embedding_path = update_path_single_seed(args, seed_id)
    if args.source_update_prefix:
        model_args.prompt_embedding_path = update_path_source(args)
    if args.update_path_original:
        model_args.prompt_embedding_path = update_path_original(args)
    # if args.update_path_fs:
    #     model_args.prompt_embedding_path = update_path_fs(args)
    if args.ppt_update_prefix:
        model_args.prompt_embedding_path = [args.prompt_embedding_path_prefix]
    if args.fs_sd_update_prefix:
        model_args.prompt_embedding_path = [args.prompt_embedding_path_prefix + '/prefix_embeddings_'+str(seed_id)+'.pt']
    # update model
    if args.model_name:
        model_args.model_name_or_path = args.model_name
        model_args.tokenizer_name = args.model_name
    if args.per_device_train_batch_size:
        training_args.per_device_train_batch_size = args.per_device_train_batch_size
    if args.per_device_eval_batch_size:
        training_args.per_device_eval_batch_size = args.per_device_eval_batch_size



    # setup logger
    logger = setup_logger(training_args)

    # Set seed before initializing model.
    set_seed(training_args.seed + seed_id)
    """
    @author-3
    For six tachers we use six models. 
    """
    model_args.ignore_target = True
    model_args.attn_method = "vote"
    model_args.attn_prefix_tuning = False

    # setup T5 model config
    config, adapter_config = setup_config_T5(
        data_args, model_args, training_args, adapter_args
    )
    print('data_args', data_args)
    # setup tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name
        else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    print('model_args>>>', model_args)
    # setup T5 model
    model = T5ForConditionalGeneration_vote.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
        adapter_config=adapter_config,
    )
    # update T5 model memory.
    # print('model_args.load_prefix_embeddings', model_args.load_prefix_embeddings)

    model = update_model_memory(
        model,
        model_args,
        training_args,
        tokenizer,
        adapter_args,
        adapter_config,
        config,
    )

    # print('>>>', model.generate(tokenizer('I likte you')))
    padding = "max_length" if data_args.pad_to_max_length else False

    def preprocess_function(examples, max_target_length, task_id=None):
        model_inputs = tokenizer(
            examples["source"],
            max_length=data_args.max_source_length,
            padding=padding,
            truncation=True,
        )
        # Setup the tokenizer for targets
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                examples["target"],
                max_length=max_target_length,
                padding=padding,
                truncation=True,
            )
        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
        # padding in the loss.
        if padding == "max_length" and data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label]
                for label in labels["input_ids"]
            ]
        model_inputs["labels"] = labels["input_ids"]
        model_inputs["extra_fields"] = examples["extra_fields"]
        if task_id is not None:
            model_inputs["task_ids"] = [task_id for _ in examples["extra_fields"]]

        return model_inputs

    # datasets
    # debug
    if args.full_dataset:
        data_args.max_train_samples = None
        data_args.max_val_samples = None

    if args.fewshot_size:
        data_args.max_train_samples = args.fewshot_size
        data_args.max_val_samples = args.fewshot_size

    if training_args.do_train:
        train_dataset, train_datasets_save = train_data_setup(
            data_args,
            tokenizer,
            model_args,
            preprocess_function,
            training_args,
            adapter_args,
            seed_id,
            args,
        )

    if training_args.do_eval:
        eval_datasets, eval_datasets_self_train = eval_data_setup(
            data_args,
            tokenizer,
            model_args,
            preprocess_function,
            training_args,
            adapter_args,
            seed_id,
            args,
        )

    if training_args.do_test:
        test_datasets = test_data_setup(
            data_args,
            tokenizer,
            model_args,
            preprocess_function,
            training_args,
            adapter_args,
            seed_id,
            args,
        )

    # Metric, we assume we have only one training task.
    eval_metrics = [
        AutoTask.get(dataset_name, dataset_config_name).metric
        for dataset_name, dataset_config_name in zip(
            data_args.dataset_name, data_args.dataset_config_name
        )
    ][0]

    def compute_metrics(eval_preds):
        preds, labels, data_info = eval_preds
        post_processor = AutoPostProcessor.get(
            data_args.dataset_name[0], tokenizer, data_args.ignore_pad_token_for_loss
        )
        decoded_preds, decoded_labels = post_processor.process(preds, labels, data_info)
        result = {}
        for metric in eval_metrics:
            result.update(metric(decoded_preds, decoded_labels))
        return result

    # define trainer:
    trainer = def_trainer(
        model,
        tokenizer,
        model_args,
        training_args,
        data_args,
        train_dataset,
        eval_datasets,
        test_datasets,
        compute_metrics,
        config,
        args,
        seed_id,
    )

    # Saves training config.
    if trainer.is_world_process_zero():
        os.makedirs(training_args.output_dir, exist_ok=True)
        save_training_config(args.json_file, training_args.output_dir)

    # Training -- few shot UW baseline
    # debug
    if args.test_only:
        training_args.do_train = False
        training_args.do_eval = False

    if training_args.do_train:
        trainer = train_once(trainer, training_args)

    # Evaluation
    if model_args.shared_attn is True and model_args.ignore_target is False:
        learned_embeddings = trainer.t5_model.encoder.prefix_emb.clone().detach()

    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        if model_args.shared_attn is True:
            for idx, (task, eval_dataset) in enumerate(eval_datasets.items()):
                metrics = trainer.evaluate(
                    eval_dataset=eval_dataset,
                    max_length=data_args.val_max_target_length,
                    num_beams=data_args.num_beams,
                )
                trainer.log_metrics("eval", metrics)
                trainer.save_metrics("eval", metrics)

        else:
            for task, eval_dataset in eval_datasets.items():
                metrics = trainer.evaluate(
                    eval_dataset=eval_dataset,
                    max_length=data_args.val_max_target_length,
                    num_beams=data_args.num_beams,
                )
                trainer.log_metrics("eval", metrics)
                trainer.save_metrics("eval", metrics)
        results["eval"] = metrics

    # get pseudo lables
    if args.get_pseudo_labels:
        dataset_save = predict(trainer, eval_datasets_self_train, train_datasets_save)
        save_to_dict(
            path_save=args.save_dir
            + "/pseudo_label_dataset_"
            + str(seed_id)
            + ".pickle",
            save_stuff=dataset_save,
        )

    # Test
    if training_args.do_test:
        logger.info("*** Test ***")
        if model_args.shared_attn is True:
            for idx, (task, test_dataset) in enumerate(test_datasets.items()):
                # trainer.t5_model.encoder.prefix_emb[0].data = learned_embeddings[idx]
                metrics = trainer.evaluate(
                    eval_dataset=test_dataset,
                    max_length=data_args.test_max_target_length,
                    num_beams=data_args.num_beams,
                    metric_key_prefix="test",
                )
                trainer.log_metrics("test", metrics)
                trainer.save_metrics("test", metrics)

        else:
            for task, test_dataset in test_datasets.items():
                metrics = trainer.evaluate(
                    eval_dataset=test_dataset,
                    max_length=data_args.test_max_target_length,
                    num_beams=data_args.num_beams,
                    metric_key_prefix="test",
                )
                trainer.log_metrics("test", metrics)
                trainer.save_metrics("test", metrics)
        results["test"] = metrics

    return results


def _mp_fn(index):
    main()


if __name__ == "__main__":
    parser = argparse.ArgumentParser("")
    parser.add_argument("--dataset_target", type=str)
    parser.add_argument("--save_dir", type=str)
    parser.add_argument("--json_file", type=str)
    parser.add_argument("--fewshot", action="store_true")
    parser.add_argument("--self_train", action="store_true")
    parser.add_argument("--naked", action="store_true")
    parser.add_argument("--random_prefix", action="store_true")
    parser.add_argument("--fewshot_num", type=int, default=20)
    parser.add_argument("--seed_start", type=int, default=0)
    parser.add_argument("--attn_method", type=str)
    parser.add_argument(
        "--dataset_source", type=str, default=None, help="source datasets of teachers"
    )
    parser.add_argument("--target_prompt_embedding_path", type=str, default=None)
    parser.add_argument("--attn_path_sub", type=str, default=None)
    parser.add_argument("--layer_norm_dir", type=str, default=None)
    parser.add_argument("--checkpoint_dir", type=str, default=None)
    parser.add_argument("--resume_from_checkpoint", type=str, default=None)
    parser.add_argument("--load_layer_norm", action="store_true")
    parser.add_argument("--load_attention", action="store_true")
    parser.add_argument("--ignore_target", action="store_true")
    parser.add_argument(
        "--update_sp", action="store_true", help="whether update sp files"
    )
    parser.add_argument("--self_train_attn", action="store_true")
    parser.add_argument("--test_only", action="store_true")
    parser.add_argument("--eval_acc", action="store_true")
    parser.add_argument("--full_dataset", action="store_true")
    parser.add_argument("--get_pseudo_labels", action="store_true")
    parser.add_argument("--neural_deal", action="store_true")
    parser.add_argument("--self_teach_num", type=int)
    parser.add_argument("--idx_lst", type=str, default=None)
    parser.add_argument("--attention_size_input_key", type=int, default=100)
    parser.add_argument("--attention_size_input_query", type=int, default=100)
    parser.add_argument("--attention_size_output", type=int, default=100)
    parser.add_argument("--hloss", type=int, default=None)
    parser.add_argument("--dropout", type=int, default=0)

    # forfew shot new prefix embedding
    parser.add_argument("--load_idxs_diff_num", action="store_true",
                        help='num of teachers change?')
    parser.add_argument("--load_eval_vote_file", default=None, type=str,
                        help='file path of pickle of eval acc')
    parser.add_argument("--load_idxs", default=None,
                        help='path of file of idx')
    parser.add_argument("--update_idx", default=None,
                        help='use different idx for different seeds, only for testing')
    parser.add_argument("--fewshot_update_prefix", action="store_true", help='True when you want to update it for fewshot')
    parser.add_argument("--update_path_original", action="store_true", help='True when you want to update it for original one')
    parser.add_argument("--source_update_prefix", action="store_true", help='True when you want to update it for larger model')
    parser.add_argument("--fs_update_one_prompt_seed", action="store_true",
                        help='True when you want to update it for fewshot, only one, with seed_id')
    parser.add_argument("--fs_sd_update_prefix", action="store_true",
                        help='True when you want to update it for fewshot, only one, with seed_id')
    parser.add_argument("--ppt_update_prefix", action="store_true", help='True when you want to update it for larger model')
    parser.add_argument("--prompt_embedding_path_prefix", type=str, default=None)
    parser.add_argument("--prompt_embedding_path_model", type=str, default='t5_base')
    parser.add_argument("--per_device_train_batch_size", type=int, default=None)
    parser.add_argument("--per_device_eval_batch_size", type=int, default=None)
    # parser.add_argument("--prompt_embedding_path", type=str, default=None)
    parser.add_argument("--fewshot_size", type=int, default=None, help='# of fewshot size')


    parser.add_argument("--model_name", type=str, default=None)
    parser.add_argument("--test_with_val", action="store_true", help='True when you want to use eval to test')
    parser.add_argument("--attention_uw", action="store_true", help='True for ablation study of attention between prefix and input')

    args = parser.parse_args()
    if args.fewshot:
        with open(args.save_dir + "/results.txt", "w") as f:
            for i in range(args.seed_start, args.seed_start+args.fewshot_num):  # args.fewshot_num
                results = main(i, args)
                f.write(str(results) + "\n")
                f.flush()
                print("results --->", results)
    else:
        with open(args.save_dir + "/results_st.txt", "w") as f:
            for i in range(args.seed_start, args.seed_start+args.fewshot_num):
                results, _ = main(i, args)
                f.write(str(results) + "\n")
                f.flush()
