# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
from transformers import Seq2SeqTrainingArguments
from utils import modify_model_after_init, save_training_config
import shutil
import glob
from data import AutoPostProcessor
from third_party.models import T5Config, T5ForConditionalGeneration
from dataclasses import dataclass, field
from training_args import AdapterTrainingArguments
from third_party.trainers import Seq2SeqTrainer
from data import TaskDataCollatorForSeq2Seq
from data import AutoTask
from utils import get_adapter_config
from transformers.trainer_utils import is_main_process, get_last_checkpoint
from transformers import (
    AutoTokenizer,
    HfArgumentParser,
    default_data_collator,
    set_seed,
    AdamW,
    get_linear_schedule_with_warmup
)
import transformers
from datasets import load_dataset, load_metric, concatenate_datasets
from typing import Optional, List
import subprocess
import sys
import functools
import logging
from pytz import common_timezones
import torch
import os
import pickle
import argparse

from data.tasks import TASK_MAPPING
os.environ['MKL_THREADING_LAYER'] = 'GNU'
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'


logger = logging.getLogger(__name__)



def run_command(command):
    output = subprocess.getoutput(command)
    return output


TASK_TO_METRICS = {"mrpc": ["accuracy", "f1"],
                   "cola": ['matthews_correlation'],
                   "stsb": ['pearson', 'spearmanr'],
                   'sst2': ['accuracy'],
                   'sst2_ppt': ['accuracy'],
                   "mnli": ["accuracy"],
                   "mnli_ppt": ["accuracy"],
                   "mnli_mismatched": ["accuracy"],
                   "mnli_matched": ["accuracy"],
                   "qnli": ["accuracy"],
                   "qnli_ppt": ["accuracy"],
                   "rte": ["accuracy"],
                   "wnli": ["accuracy"],
                   "qqp": ["accuracy", "f1"],
                   "qqp_ppt": ["accuracy", "f1"],
                   "superglue-boolq": ["accuracy"],
                   "superglue-rte": ["accuracy"],
                   "superglue-cb": ["f1_multiclass", "accuracy"],
                   "superglue-copa": ["accuracy"],
                   "superglue-multirc": ["f1", "em"],
                   "superglue-wic": ["accuracy"],
                   "superglue-wsc.fixed": ["accuracy"],
                   "superglue-record": ["em", "f1"],
                   "record_ppt": ["f1", "em"],
                   "multi_nli": ["accuracy"],
                   "squad": ["em", "f1"],
                   "squad_ppt": ["em", "f1"],
                   "snli": ["accuracy"],
                   "nq": ["em", "f1"],
                   "hotpotqa": ["em", "f1"],
                   "searchqa": ["em", "f1"],
                   "newsqa": ["em", "f1"],
                   "triviaqa": ["em", "f1"],
                   "imdb": ["accuracy"],
                   "winogrande": ["accuracy"],
                   "scitail": ["accuracy"],
                   "amazon_polarity": ["accuracy"],
                   "yelp_polarity": ["accuracy"],
                   "paws": ["accuracy"],}

# run_seq2seq parameters.


@dataclass
class TrainingArguments(Seq2SeqTrainingArguments):
    # author: add ur parser here~
    print_num_parameters: Optional[bool] = field(default=False, metadata={"help": "If set, print the parameters of "
                                                                          "the model."})
    do_test: Optional[bool] = field(default=False, metadata={
                                    "help": "If set, evaluates the test performance."})
    do_eval_predict: Optional[bool] = field(default=False, metadata={
        "help": "If set, evaluates the test performance."})
    split_validation_test: Optional[bool] = field(default=False,
                                                  metadata={"help": "If set, for the datasets which do not"
                                                                    "have the test set, we use validation set as their"
                                                                    "test set and make a validation set from either"
                                                                    "splitting the validation set into half (for smaller"
                                                                    "than 10K samples datasets), or by using 1K examples"
                                                                    "from training set as validation set (for larger"
                                                                    " datasets)."})
    compute_time: Optional[bool] = field(
        default=False, metadata={"help": "If set measures the time."})
    compute_memory: Optional[bool] = field(
        default=False, metadata={"help": "if set, measures the memory"})
    prefix_length: Optional[int] = field(
        default=100, metadata={"help": "Defines the length for prefix tuning."})


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    model_name_or_path: str = field(
        metadata={
            "help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": "Where to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={
            "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={
            "help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )
    load_prefix_embeddings: bool = field(
        default=False,
        metadata={
            "help": "load prefix embeddings or not"
        },
    )
    save_prefix_only: bool = field(
        default=False,
        metadata={
            "help": "save prefix embeddings only"
        },
    )

    prompt_embedding_path: Optional[List[str]] = field(
        default=None,
        metadata={"help": "A list of the paths to prefix embeddings"}
    )

    target_prompt_embedding_path: Optional[str] = field(
        default=None,
        metadata={"help": "a path to the target prompt embedding"}
    )
    
    attn_prefix_tuning: bool = field(
        default=False,
        metadata={
            "help": "Set true if you try ATTEMPT."
        },
    )

    attn_method: Optional[str] = field(
        default="sub",
        metadata={
            "help": "Attention model for attn_prefix. We currently support the following methods: linear, sub (our main method), and constant (gives the constant and equal weights to all of the prompts.)"
        },
    )

    shared_attn: bool = field(
        default=False,
        metadata={
            "help": "shared attention"
        },
    )

    load_attention: bool = field(
        default=False,
        metadata={
            "help": "Set true if you want to load pre-trained attention weights"
        },
    )

    attn_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "path to attention weights (linear attentions). "
        },
    )

    attn_path_sub: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "list of the path to attention weights (sub attentions). [path_to_down_projection_weights, path_to_up_projection_weights]"
        },
    )

    ignore_target: bool = field(
        default=False,
        metadata={
            "help": "Whether to ignore the new target tokens. Mainly for ablation."
        },
    )

    fix_attention: bool = field(
        default=False,
        metadata={
            "help": "this will make the attention weights frozen during training. Mainly for ablation."
        },
    )

    temperature: float = field(
        default=2000,
        metadata={
            "help": "set the soft max temperature of ATTEMPT."
        },
    )

    attn_learning_rate: float = field(
        default=None,
        metadata={
            "help": "set the learning rate for the attention modules."
        },
    )

    load_layer_norm: bool = field(
        default=False,
        metadata={
            "help": "Set true if you want to load pre-trained layer-norm weight and biases."
        },
    )

    layer_norm_dir: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "Layer norm dir. [path_to_layer_norm_weight.pt, path_to_layer_norm_bias.pt]"
        },
    )

    prefix_num: Optional[int] = field(
        default=1, metadata={"help": "the number of prefix"})


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
    task_name: Optional[List[str]] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[List[str]] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    eval_dataset_name: Optional[List[str]] = field(
        default=None, metadata={"help": "The name of the evaluation dataset to use (via the datasets library)."}
    )
    eval_dataset_config_name: Optional[List[str]] = field(
        default=None, metadata={"help": "The configuration name of the evaluation dataset to use (via the datasets library)."}
    )
    test_dataset_name: Optional[List[str]] = field(
        default=None, metadata={"help": "The name of the test dataset to use (via the datasets library)."}
    )
    test_dataset_config_name: Optional[List[str]] = field(
        default=None, metadata={"help": "The configuration name of the test dataset to use (via the datasets library)."}
    )
    lang_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    max_source_length: Optional[int] = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    max_target_length: Optional[int] = field(
        default=128,
        metadata={
            "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    val_max_target_length: Optional[int] = field(
        default=None,
        metadata={
            "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
            "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
            "during ``evaluate`` and ``predict``."
        },
    )
    test_max_target_length: Optional[int] = field(
        default=None,
        metadata={
            "help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
                    "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
                    "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
                    "during ``evaluate`` and ``predict``."
        },
    )
    pad_to_max_length: bool = field(
        default=False,
        metadata={
            "help": "Whether to pad all samples to model maximum sentence length. "
            "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
            "efficient on GPU but very bad for TPU."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
            "value if set."
        },
    )
    max_val_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
            "value if set."
        },
    )
    max_test_samples: Optional[int] = field(
        default=None,
        metadata={"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
                  "value if set."}
    )
    num_beams: Optional[int] = field(
        default=None, metadata={"help": "Number of beams to use for evaluation."})
    ignore_pad_token_for_loss: bool = field(
        default=True,
        metadata={
            "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
        },
    )
    task_adapters: Optional[List[str]] = field(
        default=None,
        metadata={"help": "Defines a dictionary from task adapters to the tasks."}
    )
    task_embeddings: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "Defines a dictionary from tasks to the tasks embeddings."}
    )
    data_seed: Optional[int] = field(
        default=42, metadata={"help": "seed used to shuffle the data."})

    train_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the training data."}
    )
    validation_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the validation data."}
    )
    test_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the test data."}
    )

    train_files: Optional[List[str]] = field(
        default=None, metadata={"help": "A list of csv or json files containing the training data."}
    )
    validation_files: Optional[List[str]] = field(
        default=None, metadata={"help": "A list of csv or json files containing the validation data."}
    )
    test_files: Optional[List[str]] = field(
        default=None, metadata={"help": "A list of csv or json files containing the test data."}
    )

    def __post_init__(self):
        if self.task_name is None:
            raise ValueError(
                "Need either a dataset name or a training/validation file.")
        if self.val_max_target_length is None:
            self.val_max_target_length = self.max_target_length
        if self.test_max_target_length is None:
            self.test_max_target_length = self.max_target_length


def main(seed_id, output_dir=None, fewshot=None, args=None):
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments,
                               AdapterTrainingArguments))
    if args.json_file:
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args, adapter_args = parser.parse_json_file(
            json_file=os.path.abspath(args.json_file))
    else:
        model_args, data_args, training_args, adapter_args = parser.parse_args_into_dataclasses()
    # Detecting last checkpoint.

    # multiple gpu
    if args.multiGpu:
        torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1)

    if args.selfTrain:  # self_train
        training_args.do_eval = True
        training_args.do_test = True
    if output_dir:
        training_args.output_dir = output_dir
    if args.save_dir:
        training_args.output_dir = args.save_dir
    if fewshot:
        data_args.max_val_samples = fewshot
        data_args.max_val_samples = fewshot
        data_args.max_train_samples = fewshot
    if args.dataset_source:
        data_args.task_name = [args.dataset_source]
        data_args.test_dataset_name = [args.dataset_source]
        data_args.eval_dataset_name = [args.dataset_source]
    if args.dataset_target:
        data_args.task_name = [args.dataset_target]
        data_args.test_dataset_name = [args.dataset_target]
        data_args.eval_dataset_name = [args.dataset_target]
    if args.model_name:
        model_args.model_name_or_path = args.model_name
        model_args.tokenizer_name = args.model_name

    if args.test_only:
        model_args.prompt_embedding_path = [args.prompt_embedding_path_prefix]

    if args.dataset_target:
        data_args.task_name = [args.dataset_target]
        data_args.test_dataset_name = [args.dataset_target]
        data_args.eval_dataset_name = [args.dataset_target]

    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            '''
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
            '''
            pass
        elif last_checkpoint is not None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    logger.setLevel(logging.INFO if is_main_process(
        training_args.local_rank) else logging.WARN)

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed before initializing model.
    set_seed(training_args.seed+seed_id)

    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files in the summarization task, this script will use the first column for the full texts and the
    # second column for the summaries (unless you specify column names for this with the `text_column` and
    # `summary_column` arguments).
    # For translation, only JSON files are supported, with one field named "translation" containing two keys for the
    # source and target languages (unless you adapt what follows).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = T5Config.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    config.train_task_adapters = adapter_args.train_task_adapters
    config.prefix_tuning = adapter_args.prefix_tuning
    config.attn_prefix_tuning = model_args.attn_prefix_tuning
    config.attn_method = model_args.attn_method
    config.ignore_target = model_args.ignore_target
    config.shared_attn = model_args.shared_attn
    config.prefix_num = model_args.prefix_num
    config.num_target = len(data_args.task_name)
    config.temperature = model_args.temperature
    config.fix_attention = model_args.fix_attention
    adapter_config = get_adapter_config(
        adapter_args, data_args, training_args, config)
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )


    model = T5ForConditionalGeneration.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
        adapter_config=adapter_config
    )

    if model_args.load_prefix_embeddings is True:
        if model_args.prompt_embedding_path is None:
            for name, param in model.named_parameters():
                if "prefix_shared" in name or "prefix" in name:
                    shared_params = [param]
        else:
            shared_params = []
            for path in model_args.prompt_embedding_path:
                shared_param = torch.load(path)
                shared_params.append(shared_param)
            if model_args.target_prompt_embedding_path is not None:
                target_prompt_embedding = torch.load(
                    model_args.target_prompt_embedding_path)

        if model_args.attn_prefix_tuning is True:
            if training_args.do_train is True and model_args.shared_attn is False:
                # Load all of the source prompts
                #author - prefix
                model.store_prefix_weights(shared_params)
                # Initialize the target task prompt embedding using the first prompts
                model.update_prefix_weights_single(shared_params[0])
            elif training_args.do_train is True and model_args.shared_attn is True:
                # Load all of the source prompts
                model.store_prefix_weights(shared_params)
                # Initialize the target task prompt embeddings using the first prompts.
                model.update_prefix_weights_multi(
                    shared_params[0], num_target=config.num_target)
            else:
                # For inference
                # Load all of the source prompts
                model.store_prefix_weights(shared_params)
                # Load the trained target task prompt.
                model.update_prefix_weights_single(target_prompt_embedding)

        else:
            if model_args.target_prompt_embedding_path is None:
                model.update_prefix_weights(shared_params)
            else:
                model.update_prefix_weights(
                    shared_params, target_prompt_embedding)

    # Load linear attention
    if model_args.load_attention is True and model_args.attn_path is not None:
        model.update_attention_weights(torch.load(model_args.attn_path))

    # Load projection-based attentions 
    if model_args.load_attention is True and model_args.attn_path_sub is not None:
        model.update_attention_weights_sub(model_args.attn_path_sub)

    # Load layer norm weights & biases 
    if model_args.load_layer_norm is True and model_args.layer_norm_dir is not None:
        model.update_layer_norm_weights(model_args.layer_norm_dir)

    model.resize_token_embeddings(len(tokenizer))
    model = modify_model_after_init(
        model, training_args, adapter_args, adapter_config)

    # model = torch.nn.DataParallel(model)

    data_args.data_seed += seed_id
    data_args.dataset_name = data_args.task_name
    data_args.eval_dataset_name = data_args.eval_dataset_name
    data_args.test_dataset_name = data_args.test_dataset_name
    data_args.dataset_config_name = data_args.dataset_config_name
    data_args.eval_dataset_config_name = data_args.eval_dataset_config_name
    data_args.test_dataset_config_name = data_args.test_dataset_config_name
    assert len(data_args.dataset_name) == len(data_args.dataset_config_name)
    if data_args.eval_dataset_name is not None:
        assert len(data_args.eval_dataset_name) == len(
            data_args.eval_dataset_config_name)
    if data_args.test_dataset_name is not None:
        assert len(data_args.test_dataset_name) == len(
            data_args.test_dataset_config_name)

    padding = "max_length" if data_args.pad_to_max_length else False

    def preprocess_function(examples, max_target_length, task_id=None):
        model_inputs = tokenizer(examples['source'], max_length=data_args.max_source_length,
                                 padding=padding, truncation=True)
        # Setup the tokenizer for targets
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                examples['target'], max_length=max_target_length, padding=padding, truncation=True)
        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
        # padding in the loss.
        if padding == "max_length" and data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]
        model_inputs["labels"] = labels["input_ids"]
        model_inputs["extra_fields"] = examples['extra_fields']
        if task_id is not None:
            model_inputs["task_ids"] = [
                task_id for _ in examples['extra_fields']]

        return model_inputs

    column_names = ['source', 'target', 'extra_fields']
    performance_metrics = {}
    if training_args.do_train:
        if args.selfTrain or args.multiGpu:
            print('data_args.dataset_name>>>>', data_args.dataset_name)
            if data_args.train_files is not None:
                train_datasets = [AutoTask.get(dataset_name,
                                               dataset_config_name,
                                               seed=data_args.data_seed).get(
                    split="train_source",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=data_args.max_train_samples, lang=data_args.lang_name, file_name=train_file)
                    for dataset_name, dataset_config_name, train_file
                    in zip(data_args.dataset_name, data_args.dataset_config_name, data_args.train_files)]

            else:
                for dataset_name, dataset_config_name in zip(data_args.dataset_name, data_args.dataset_config_name):
                    print('dataset_name, dataset_config_name', dataset_name, dataset_config_name)
                train_datasets = [AutoTask.get(dataset_name,
                                               dataset_config_name,
                                               seed=data_args.data_seed).get(
                    split="train_source",
                    split_validation_test=training_args.split_validation_test,
                    add_prefix=False if adapter_args.train_task_adapters else True,
                    n_obs=data_args.max_train_samples, lang=data_args.lang_name, file_name=data_args.train_file)
                    for dataset_name, dataset_config_name
                    in zip(data_args.dataset_name, data_args.dataset_config_name)]

            max_target_lengths = [AutoTask.get(dataset_name, dataset_config_name).get_max_target_length(
                tokenizer=tokenizer, default_max_length=data_args.max_target_length, )
                for dataset_name, dataset_config_name in zip(data_args.dataset_name, data_args.dataset_config_name)]

            for i, train_dataset in enumerate(train_datasets):
                # print('train_dataset>>>', train_dataset, train_dataset['extra_fields'][0])
                if model_args.shared_attn is True:
                    train_datasets[i] = train_datasets[i].map(
                        functools.partial(
                            preprocess_function, max_target_length=max_target_lengths[i], task_id=i),
                        batched=True,
                        num_proc=data_args.preprocessing_num_workers,
                        # if train_dataset != "superglue-record" else column_names+["answers"],
                        remove_columns=column_names,
                        load_from_cache_file=not data_args.overwrite_cache,
                    )
                else:
                    train_datasets[i] = train_datasets[i].map(
                        functools.partial(preprocess_function,
                                          max_target_length=max_target_lengths[i]),
                        batched=True,
                        num_proc=data_args.preprocessing_num_workers,
                        # if train_dataset != "superglue-record" else column_names+["answers"],
                        remove_columns=column_names,
                        load_from_cache_file=not data_args.overwrite_cache,
                    )
            # print('train_datasets', train_datasets)
            train_dataset = concatenate_datasets(train_datasets)

        else:
            with open(args.save_dir + '/pseudo_label_dataset_' + str(seed_id) + '.pickle', 'rb') as handle:
                train_dataset = pickle.load(handle)
    # print('train_dataset>>>>>', train_dataset)
    # self-train add one train file for few-shot
    if fewshot:
        with open(args.svae_dir + '/dataset_train_' + str(seed_id) + '.pickle', 'rb') as handle:
            train_dataset_fewshot = pickle.load(handle)

    print('data_args.eval_dataset_name', data_args.eval_dataset_name)
    if training_args.do_eval:
        # if ar:
        #     with open(sys.argv[3] + '/dataset_eval_' + str(seed_id) + '.pickle', 'rb') as handle:
        #         eval_datasets = pickle.load(handle)
        # else:
        if data_args.validation_files is not None:
            eval_datasets = {eval_dataset: AutoTask.get(eval_dataset, eval_dataset_config,
                                                        seed=data_args.data_seed).get(
                split="validation_source",
                split_validation_test=training_args.split_validation_test,
                add_prefix=False if adapter_args.train_task_adapters else True,
                n_obs=data_args.max_val_samples, lang=data_args.lang_name, file_name=validation_file)
                for eval_dataset, eval_dataset_config, validation_file in zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name, data_args.validation_files)}
        else:
            eval_datasets = {eval_dataset: AutoTask.get(eval_dataset, eval_dataset_config,
                                                        seed=data_args.data_seed).get(
                split="validation_source",
                split_validation_test=training_args.split_validation_test,
                add_prefix=False if adapter_args.train_task_adapters else True,
                n_obs=data_args.max_val_samples, lang=data_args.lang_name, file_name=data_args.validation_file)
                for eval_dataset, eval_dataset_config in zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name)}


        max_target_lengths = [AutoTask.get(dataset_name, dataset_config_name).get_max_target_length(
            tokenizer=tokenizer, default_max_length=data_args.max_target_length)
            for dataset_name, dataset_config_name in zip(data_args.eval_dataset_name, data_args.eval_dataset_config_name)]

        for k, name in enumerate(eval_datasets):
            if model_args.shared_attn is True:
                eval_datasets[name] = eval_datasets[name].map(
                    functools.partial(
                        preprocess_function, max_target_length=max_target_lengths[k], task_id=k),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    # if name != "superglue-record" else column_names+["answers"],
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )
            else:
                eval_datasets[name] = eval_datasets[name].map(
                    functools.partial(preprocess_function,
                                      max_target_length=max_target_lengths[k]),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    # if name != "superglue-record" else column_names+["answers"],
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )

    if training_args.do_test:
        if data_args.test_files is not None:
            test_datasets = {test_dataset: AutoTask.get(test_dataset, test_dataset_config,
                                                        seed=data_args.data_seed).get(
                split="test_source",
                split_validation_test=training_args.split_validation_test,
                add_prefix=False if adapter_args.train_task_adapters else True,
                n_obs=None, lang=data_args.lang_name, file_name=test_file)
                for test_dataset, test_dataset_config, test_file in zip(data_args.test_dataset_name, data_args.test_dataset_config_name, data_args.test_files)}
        else:
            test_datasets = {test_dataset: AutoTask.get(test_dataset, test_dataset_config,
                                                        seed=data_args.data_seed).get(
                split="test_source",
                split_validation_test=training_args.split_validation_test,
                add_prefix=False if adapter_args.train_task_adapters else True,
                n_obs=None, lang=data_args.lang_name, file_name=data_args.test_file)
                for test_dataset, test_dataset_config in zip(data_args.test_dataset_name, data_args.test_dataset_config_name)}
        max_target_lengths = [AutoTask.get(dataset_name, dataset_config_name).get_max_target_length(
            tokenizer=tokenizer, default_max_length=data_args.max_target_length)
            for dataset_name, dataset_config_name in zip(data_args.test_dataset_name, data_args.test_dataset_config_name)]
        for k, name in enumerate(test_datasets):
            if model_args.shared_attn is True:
                test_datasets[name] = test_datasets[name].map(
                    functools.partial(
                        preprocess_function, max_target_length=max_target_lengths[k], task_id=k),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )
            else:
                test_datasets[name] = test_datasets[name].map(
                    functools.partial(preprocess_function,
                                      max_target_length=max_target_lengths[k]),
                    batched=True,
                    num_proc=data_args.preprocessing_num_workers,
                    remove_columns=column_names,
                    load_from_cache_file=not data_args.overwrite_cache,
                )

    # Data collator
    label_pad_token_id = - \
        100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    if data_args.pad_to_max_length:
        data_collator = default_data_collator
    else:
        data_collator = TaskDataCollatorForSeq2Seq(
            tokenizer,
            label_pad_token_id=label_pad_token_id,
            pad_to_multiple_of=8 if training_args.fp16 else None,
        )

    # Metric, we assume we have only one training task.
    eval_metrics = [AutoTask.get(dataset_name, dataset_config_name).metric
                    for dataset_name, dataset_config_name in zip(data_args.dataset_name, data_args.dataset_config_name)][0]

    # Extracts the extra information needed to evaluate on each dataset.
    # These information are only used in the compute_metrics.
    # We will assume that the test/eval dataloader does not change the order of
    # the data.
    data_info = {"eval": eval_datasets[data_args.eval_dataset_name[0]]['extra_fields']if training_args.do_eval else None,
                 "test": test_datasets[data_args.test_dataset_name[0]]['extra_fields'] if training_args.do_test else None,
                 "train": train_dataset['extra_fields'] if training_args.do_train else None}

    def compute_metrics(eval_preds):
        preds, labels, data_info = eval_preds
        post_processor = AutoPostProcessor.get(data_args.dataset_name[0], tokenizer,
                                               data_args.ignore_pad_token_for_loss)
        decoded_preds, decoded_labels = post_processor.process(
            preds, labels, data_info)
        result = {}
        for metric in eval_metrics:
            result.update(metric(decoded_preds, decoded_labels))
        return result

    # print('trainer >>> train_dataset>>>', train_dataset)
    # print('model_args.attn_learning_rate', model_args.attn_learning_rate)

    if model_args.attn_learning_rate is not None:
        # Initialize a customized optimizer to set a different learning rate for the attention module.
        all_parameters = set(model.parameters())
        attn_params = []
        for name, param in model.named_parameters():
            if name == "encoder.attn_W_up" or name == "encoder.attn_W_down" or name == "encoder.layer_norm":
                attn_params += list(param)
        attn_params = set(attn_params)
        attn_params = set(attn_params)
        non_attn_params = all_parameters - attn_params
        non_attn_params = list(non_attn_params)
        attn_params = list(attn_params)

        optim = AdamW([
            {'params': non_attn_params},
            {'params': attn_params, 'lr': model_args.attn_learning_rate},
        ], lr=training_args.learning_rate,)
        scheduler = get_linear_schedule_with_warmup(
            optim, num_warmup_steps=training_args.warmup_steps, num_training_steps=len(
                train_dataset) * training_args.num_train_epochs // (training_args.gradient_accumulation_steps * training_args.per_device_train_batch_size)
        )

        # Initialize our Trainer
        # print('trainer >>> train_dataset>>>', train_dataset)
        trainer = Seq2SeqTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset if training_args.do_train else None,
            eval_dataset=list(eval_datasets.values())[
                0] if training_args.do_eval else None,
            data_info=data_info,
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics if training_args.predict_with_generate else None,
            evaluation_metrics=TASK_TO_METRICS[data_args.dataset_name[0]],
            shared=model_args.shared_attn,
            optimizers=(optim, scheduler)
        )

    else:
        # print('trainer >>> train_dataset>>>', train_dataset)
        trainer = Seq2SeqTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset if training_args.do_train else None,
            eval_dataset=list(eval_datasets.values())[
                0] if training_args.do_eval else None,
            data_info=data_info,
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics if training_args.predict_with_generate else None,
            evaluation_metrics=TASK_TO_METRICS[data_args.dataset_name[0]],
            shared=model_args.shared_attn)

    # Saves training config.
    if trainer.is_world_process_zero():
        os.makedirs(training_args.output_dir, exist_ok=True)
        save_training_config(args.json_file, training_args.output_dir)

    # Training
    # print('trainer >>> train_dataset>>>', train_dataset)
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint

        if training_args.compute_time:
            torch.cuda.synchronize()  # wait for move to complete
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()

        train_result = trainer.train(resume_from_checkpoint=checkpoint)

        if training_args.compute_time:
            end.record()
            torch.cuda.synchronize()  # wait for all_reduce to complete
            total_time = start.elapsed_time(end)/(1000*60)
            performance_metrics.update({"total_time in minutes ": total_time})

        if model_args.save_prefix_only:
            for name, param in trainer.model.named_parameters():
                if model_args.attn_prefix_tuning is False and ("prefix_shared" in name or "prefix" in name):
                    shared_params = param
                    torch.save(shared_params, os.path.join(
                        training_args.output_dir, "prefix_embeddings.pt"))
                elif model_args.attn_prefix_tuning is True and name == "prefix_shared":
                    shared_params = param
                    if model_args.shared_attn is True:
                        for i in range(config.num_target):
                            torch.save(shared_params[i], os.path.join(
                                training_args.output_dir, "prefix_embeddings_{}.pt".format(i)))
                    else:
                        torch.save(shared_params, os.path.join(
                            training_args.output_dir, "prefix_embeddings.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.attn_Wa.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        training_args.output_dir, "attn_Wa_weights.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.attn_W_down.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        training_args.output_dir, "attn_W_down.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.attn_W_up.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        training_args.output_dir, "attn_W_up.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.layer_norm.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        training_args.output_dir, "layer_norm_weight.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.layer_norm.bias" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        training_args.output_dir, "layer_norm_bias.pt"))
        else:
            trainer.save_model()  # Saves the tokenizer too for easy upload

        train_metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(
                train_dataset)
        )
        train_metrics["train_samples"] = min(
            max_train_samples, len(train_dataset))
        trainer.log_metrics("train", train_metrics)
        trainer.save_metrics("train", train_metrics)

        if not model_args.save_prefix_only:
            trainer.save_state()

    if torch.cuda.is_available() and training_args.compute_memory:
        peak_memory = (torch.cuda.max_memory_allocated() / 1024 ** 2)/1000
        print(
            "Memory utilization",
            peak_memory,
            "GB"
        )
        performance_metrics.update({"peak_memory": peak_memory})
    if training_args.compute_memory or training_args.compute_time:
        trainer.save_metrics("performance", performance_metrics)

    if fewshot:
        if training_args.do_train:
            print('HERE~')
            trainer = Seq2SeqTrainer(
                model=trainer.model,
                args=training_args,
                train_dataset=train_dataset_fewshot if training_args.do_train else None,
                eval_dataset=list(eval_datasets.values())[
                    0] if training_args.do_eval else None,
                data_info=data_info,
                tokenizer=tokenizer,
                data_collator=data_collator,
                compute_metrics=compute_metrics if training_args.predict_with_generate else None,
                evaluation_metrics=TASK_TO_METRICS[data_args.dataset_name[0]],
                shared=model_args.shared_attn)


            train_result = trainer.train()

            # if model_args.save_prefix_only:
            #     for name, param in trainer.model.named_parameters():
            #         if model_args.attn_prefix_tuning is False and ("prefix_shared" in name or "prefix" in name):
            #             shared_params = param
            #             torch.save(shared_params, os.path.join(
            #                 training_args.output_dir, "prefix_embeddings.pt"))
            #         elif model_args.attn_prefix_tuning is True and name == "prefix_shared":
            #             shared_params = param
            #             if model_args.shared_attn is True:
            #                 for i in range(config.num_target):
            #                     torch.save(shared_params[i], os.path.join(
            #                         training_args.output_dir, "prefix_embeddings_{}.pt".format(i)))
            #             else:
            #                 torch.save(shared_params, os.path.join(
            #                     training_args.output_dir, "prefix_embeddings.pt"))
            #         if model_args.attn_prefix_tuning is True and "encoder.attn_Wa.weight" == name:
            #             attn_weights_params = param
            #             torch.save(attn_weights_params, os.path.join(
            #                 training_args.output_dir, "attn_Wa_weights.pt"))
            #         if model_args.attn_prefix_tuning is True and "encoder.attn_W_down.weight" == name:
            #             attn_weights_params = param
            #             torch.save(attn_weights_params, os.path.join(
            #                 training_args.output_dir, "attn_W_down.pt"))
            #         if model_args.attn_prefix_tuning is True and "encoder.attn_W_up.weight" == name:
            #             attn_weights_params = param
            #             torch.save(attn_weights_params, os.path.join(
            #                 training_args.output_dir, "attn_W_up.pt"))
            #         if model_args.attn_prefix_tuning is True and "encoder.layer_norm.weight" == name:
            #             attn_weights_params = param
            #             torch.save(attn_weights_params, os.path.join(
            #                 training_args.output_dir, "layer_norm_weight.pt"))
            #         if model_args.attn_prefix_tuning is True and "encoder.layer_norm.bias" == name:
            #             attn_weights_params = param
            #             torch.save(attn_weights_params, os.path.join(
            #                 training_args.output_dir, "layer_norm_bias.pt"))
            # else:
            #     trainer.save_model()  # Saves the tokenizer too for easy upload

            train_metrics = train_result.metrics
            max_train_samples = (
                data_args.max_train_samples if data_args.max_train_samples is not None else len(
                    train_dataset)
            )
            train_metrics["train_samples"] = min(
                max_train_samples, len(train_dataset))
            trainer.log_metrics("train", train_metrics)
            trainer.save_metrics("train", train_metrics)

            if not model_args.save_prefix_only:
                trainer.save_state()

        if torch.cuda.is_available() and training_args.compute_memory:
            peak_memory = (torch.cuda.max_memory_allocated() / 1024 ** 2) / 1000
            print(
                "Memory utilization",
                peak_memory,
                "GB"
            )
            performance_metrics.update({"peak_memory": peak_memory})
        if training_args.compute_memory or training_args.compute_time:
            trainer.save_metrics("performance", performance_metrics)

    # Evaluation
    if model_args.shared_attn is True and model_args.ignore_target is False:
        learned_embeddings = trainer.model.encoder.prefix_emb.clone().detach()
    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        if model_args.shared_attn is True:
            for idx, (task, eval_dataset) in enumerate(eval_datasets.items()):
                metrics = trainer.evaluate(eval_dataset=eval_dataset,
                                           max_length=data_args.val_max_target_length, num_beams=data_args.num_beams,
                                           )
                trainer.log_metrics("eval", metrics)
                trainer.save_metrics("eval", metrics)
        else:
            for task, eval_dataset in eval_datasets.items():
                metrics = trainer.evaluate(eval_dataset=eval_dataset,
                                           max_length=data_args.val_max_target_length, num_beams=data_args.num_beams,
                                           )
                trainer.log_metrics("eval", metrics)
                trainer.save_metrics("eval", metrics)
        results['eval'] = metrics

    # prediction!!!
    if training_args.do_eval_predict:
        logger.info("*** Predict ***")
        if model_args.shared_attn is True:
            for idx, (task, eval_dataset) in enumerate(eval_datasets.items()):
                metrics = trainer.evaluate(eval_dataset=eval_dataset,
                                           max_length=data_args.val_max_target_length, num_beams=data_args.num_beams,
                                           )
                trainer.log_metrics("eval", metrics)
                trainer.save_metrics("eval", metrics)
        else:
            # for task, eval_dataset in eval_datasets.items():
            print(list(eval_datasets.values())[0]['labels'])
            prediction = trainer.predict(list(eval_datasets.values())[0])
            print('prediction', prediction)
                # metrics = trainer.evaluate(eval_dataset=eval_dataset,
                #                            max_length=data_args.val_max_target_length, num_beams=data_args.num_beams,
                # #                            )
                # trainer.log_metrics("eval", metrics)
                # trainer.save_metrics("eval", metrics)

    # remove checkpoint for the save_prefix_only setting to avoid overly saving models.
    if model_args.save_prefix_only:
        checkpoints = glob.glob(os.path.join(
            training_args.output_dir, "checkpoint-*"))
        for checkpoint_dir in checkpoints:
            # save models
            if not os.path.exists(os.path.join(checkpoint_dir, "pytorch_model.bin")):
                continue
            checkpoint_model = torch.load(os.path.join(
                os.path.join(checkpoint_dir, "pytorch_model.bin")))
            new_dir = "{}_prompt_only".format(checkpoint_dir)
            if not os.path.exists(new_dir):
                os.mkdir(new_dir)
            for name, param in checkpoint_model.items():
                if model_args.attn_prefix_tuning is False and ("prefix_shared" in name or "prefix" in name):
                    shared_params = param
                    torch.save(shared_params, os.path.join(
                        training_args.output_dir, "prefix_embeddings.pt"))
                elif model_args.attn_prefix_tuning is True and name == "prefix_shared":
                    shared_params = param
                    if model_args.shared_attn is True:
                        for i in range(config.num_target):
                            torch.save(shared_params[i], os.path.join(
                                new_dir, "prefix_embeddings_" + str(seed_id) + "{}.pt".format(i)))
                    else:
                        torch.save(shared_params, os.path.join(
                            new_dir, "prefix_embeddings" + str(seed_id) +" .pt"))
                if model_args.attn_prefix_tuning is True and "encoder.attn_Wa.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        new_dir, "attn_Wa_weights.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.attn_W_down.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        new_dir, "attn_W_down.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.attn_W_up.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        new_dir, "attn_W_up.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.layer_norm.weight" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        new_dir, "layer_norm_weight.pt"))
                if model_args.attn_prefix_tuning is True and "encoder.layer_norm.bias" == name:
                    attn_weights_params = param
                    torch.save(attn_weights_params, os.path.join(
                        new_dir, "layer_norm_bias.pt"))
            # after saving prompts, we will remove unnecessary checkpoint dir.
            try:
                shutil.rmtree(checkpoint_dir)
            except OSError as e:
                print("Error: %s : %s" % (checkpoint_dir, e.strerror))

    # Test
    if training_args.do_test:
        logger.info("*** Test ***")
        if model_args.shared_attn is True:
            for idx, (task, test_dataset) in enumerate(test_datasets.items()):
                trainer.model.encoder.prefix_emb[0].data = learned_embeddings[idx]
                metrics = trainer.evaluate(eval_dataset=test_dataset,
                                           max_length=data_args.test_max_target_length, num_beams=data_args.num_beams,
                                           metric_key_prefix="test"
                                           )
                trainer.log_metrics("test", metrics)
                trainer.save_metrics("test", metrics)

        else:
            for task, test_dataset in test_datasets.items():
                metrics = trainer.evaluate(eval_dataset=test_dataset,
                                           max_length=data_args.test_max_target_length, num_beams=data_args.num_beams,
                                           metric_key_prefix="test"
                                           )
                trainer.log_metrics("test", metrics)
                trainer.save_metrics("test", metrics)
        results['test'] = metrics

    return results

def _mp_fn(index):
    main()


if __name__ == "__main__":
    parser = argparse.ArgumentParser("")
    parser.add_argument("--dataset_target", type=str, default=None)
    parser.add_argument("--dataset_source", type=str, default=None)
    parser.add_argument("--save_dir", type=str)
    parser.add_argument("--json_file", type=str)
    parser.add_argument("--fewshot", action="store_true")
    parser.add_argument("--naked", action="store_true")
    parser.add_argument("--random_prefix", action="store_true")
    parser.add_argument("--fewshot_num", type=int)
    parser.add_argument("--attn_method", type=str)
    parser.add_argument("--target_prompt_embedding_path", type=str, default=None)
    parser.add_argument("--attn_path_sub", type=str, default=None)
    parser.add_argument("--layer_norm_dir", type=str, default=None)
    parser.add_argument("--checkpoint_dir", type=str, default=None)
    parser.add_argument("--resume_from_checkpoint", type=str, default=None)
    parser.add_argument("--multiGpu", action="store_true")
    parser.add_argument("--selfTrain", action="store_true")
    parser.add_argument("--model_name", type=str, default=None)
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--test_only", action="store_true")
    parser.add_argument("--prompt_embedding_path_prefix", type=str, default=None)
    parser.add_argument("--prompt_embedding_path_model", type=str, default='t5_base')



    args = parser.parse_args()

    if args.selfTrain:  # self_train
        with open(args.save_dir + '/results.txt', 'w') as f:
            for i in range(20):
                main(i, output_dir=args.save_dir)
                results = main(i, output_dir=args.save_dir, fewshot=32)

                f.write(str(results) + '\n')
                f.flush()
                print('results --->', results)

    elif args.multiGpu: # ppt train source with multiple gpu
        with open(args.save_dir + '/results.txt', 'w') as f:
            results = main(4, output_dir=args.save_dir, args=args)
            f.write(str(results) + '\n')
            f.flush()
            print('results --->', results)

    elif args.test_only: # ppt train source with multiple gpu
        print('args.save_dir', args.save_dir)
        with open(args.save_dir + '/results.txt', 'w') as f:
            results = main(3, output_dir=args.save_dir, args=args)
            f.write(str(results) + '\n')
            f.flush()
            print('results --->', results)

    else:
        with open(args.save_dir + '/results.txt', 'w') as f:
            for i in range(5):
                results = main(i)
                f.write(str(results) + '\n')
                f.flush()
                print('results --->', results)
