import os
import sys
import json
import torch
import random
import logging
import dataclasses
import numpy as np
import pickle as pkl
import seaborn as sns
import matplotlib.pyplot as plt

from copy import deepcopy
from typing import Dict, Optional
from dataclasses import dataclass, field
from collections import defaultdict
from torch.utils.data.dataset import Dataset

from transformers import (
    AutoConfig,
    AutoTokenizer,
    EvalPrediction,
    BertTokenizer,
    DistilBertTokenizer,
)
from transformers import (
    HfArgumentParser,
    set_seed,
)
from transformers import AutoModelForSequenceClassification

from torch.utils.data.dataloader import DataLoader
from transformers.data.data_collator import DataCollator, DefaultDataCollator

from lang_exps.data.dataset.data import CLDataset
from lang_exps.data.processors.data import output_modes, tasks_num_labels, data_dir
from lang_exps.data.metrics import task_metrics
from lang_exps.data.util import never_split
from lang_exps.data.dataset.data import CLDataTrainingArguments as DataTrainingArguments
from lang_exps.data.metrics import compute_metrics as cl_compute_metrics

from lang_exps.models.distilbert import DistilBertForSequenceClassification
from lang_exps.models.bert import BertForSequenceClassification
from lang_exps.models.roberta import RobertaForSequenceClassification

from lang_exps.trainer.exemplars import ExemplarHandler
from lang_exps.trainer.trainer_replay import CLTrainingArguments
from lang_exps.trainer.trainer_replay import TrainerReplay as Trainer

from lang_exps.common.util import prepare_global_logging
from tqdm.std import tqdm

logger = logging.getLogger(__name__)


@dataclass
class ModelArguments:

    model_name_or_path: str = field(
        metadata={
            "help": "Path to pretrained model or model identifier from huggingface.co/models"
        }
    )
    config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained config name or path if not the same as model_name"
        },
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained tokenizer name or path if not the same as model_name"
        },
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": "Where do you want to store the pretrained models downloaded from s3"
        },
    )
    model_type: str = field(
        default="vanilla",
        metadata={"help": "Choices: ['vanilla', 'adapter', 'moa', 'distilbert']"},
    )
    task_seq: str = field(
        default="seq1", metadata={"help": "sequence of tasks from predefined sequences"}
    )
    output_hidden_states: bool = field(
        default=False, metadata={"help": "Whether output hidden states."}
    )
    model_type2: str = field(
        default="none",
        metadata={
            "help": "Choices: ['linear', 'rndm_pooler_only', 'init_pooler_only', 'rndm_init', 'xyz_init']"
        },
    )
    pooler_type: str = field(
        default="first_token",
        metadata={"help": "Choices: ['first_token', 'average', 'max']"},
    )
    dropout_prob: float = field(
        default=0.1, metadata={"help": "Hidden and attention dropout probability."}
    )
    analysis_split: str = field(
        default="validation",
        metadata={"help": "Dataset split to use for analysis like loss contours!"},
    )
    loss_contour: bool = field(
        default=False, metadata={"help": "Get loss contours data."}
    )
    lmi: bool = field(
        default=False, metadata={"help": "Get linear model interpolation data."}
    )
    evaluate_sharpness: bool = field(
        default=False, metadata={"help": "Evaluate sharpness"}
    )
    p_dim: int = field(default=0, metadata={"help": "Project to p-dimensions"})
    lmi_start_task_idx: int = field(
        default=0, metadata={"help": "w(start_idx) -> w(start_idx+1),..."}
    )
    loss_contour_start_task_idx: int = field(
        default=0, metadata={"help": "w(start_idx), w(start_idx+1), w(start_idx+2)"}
    )
    analysis: str = field(
        default="contour", metadata={"help": "Choices: 'contour', 'lmi', 'sharpness'"}
    )
    output_file: str = field(
        default="./out.json", metadata={"help": "Output file where data will be dumped"}
    )
    analysis_start_task_idx: int = field(
        default=0,
        metadata={
            "help": "For contour: w(start_idx), w(start_idx+1), w(start_idx+2). lmi: w(start_idx) -> w(start_idx+1),..."
        },
    )
    visualize_weights: bool = field(
        default=False, metadata={"help": "Visualize network weights."}
    )


task_sequences = {
    "seq41": [
        "event",
        "boolq",
        "argument",
        "yelp",
        "dmarker",
        "sst-2",
        "qqp",
        "yahooqa",
        "qnli",
        "rocstory",
        "mnli",
        "scitail",
        "cola",
        "drelation",
        "emotion",
    ],
    "seq42": [
        "cola",
        "qqp",
        "mnli",
        "qnli",
        "emotion",
        "sst-2",
        "boolq",
        "event",
        "argument",
        "scitail",
        "rocstory",
        "yelp",
        "drelation",
        "yahooqa",
        "dmarker",
    ],
    "seq43": [
        "scitail",
        "boolq",
        "sst-2",
        "argument",
        "dmarker",
        "yahooqa",
        "qnli",
        "rocstory",
        "drelation",
        "emotion",
        "event",
        "mnli",
        "qqp",
        "cola",
        "yelp",
    ],
    "seq44": [
        "dmarker",
        "qnli",
        "cola",
        "yahooqa",
        "argument",
        "scitail",
        "drelation",
        "emotion",
        "event",
        "rocstory",
        "qqp",
        "yelp",
        "mnli",
        "boolq",
        "sst-2",
    ],
    "seq45": [
        "emotion",
        "sst-2",
        "rocstory",
        "yahooqa",
        "argument",
        "mnli",
        "cola",
        "dmarker",
        "qqp",
        "qnli",
        "event",
        "drelation",
        "scitail",
        "yelp",
        "boolq",
    ],
    "seq51": ["yelp", "agnews", "dbpedia", "amzn", "yahooqa"],
    "seq52": ["dbpedia", "yahooqa", "agnews", "amzn", "yelp"],
    "seq53": ["yelp", "yahooqa", "amzn", "dbpedia", "agnews"],
    "seq54": ["agnews", "yelp", "amzn", "yahooqa", "dbpedia"],
    "seq55": ["yahooqa", "yelp", "dbpedia", "agnews", "amzn"],
    "seq71": ["yahooqa1", "yahooqa2", "yahooqa3", "yahooqa4", "yahooqa5"],
    "seq72": ["yahooqa3", "yahooqa5", "yahooqa2", "yahooqa4", "yahooqa1"],
    "seq73": ["yahooqa1", "yahooqa5", "yahooqa4", "yahooqa3", "yahooqa2"],
    "seq74": ["yahooqa2", "yahooqa1", "yahooqa4", "yahooqa5", "yahooqa3"],
    "seq75": ["yahooqa5", "yahooqa1", "yahooqa3", "yahooqa2", "yahooqa4"],
}


def process_args():

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, CLTrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    base_model_type = "large" if "large" in model_args.model_name_or_path else "base"
    model_args.use_adapter_fn = (
        True
        if (model_args.model_type == "moa" or model_args.model_type == "adapter")
        else False
    )

    if training_args.sparse_replay == "no":
        training_args.replay_rate = 0.0

    model_dir = os.path.join(
        model_args.model_type,
        data_args.task_name,
        base_model_type
        + "_{}_{}_{}_{}_{}_{}".format(
            training_args.num_train_epochs,
            training_args.per_gpu_train_batch_size,
            training_args.learning_rate,
            int(100 * training_args.write_rate),
            int(100 * training_args.replay_rate),
            training_args.seed,
        ),
    )
    model_dir = model_dir + "_" + str(data_args.task_type)

    if training_args.max_examples_per_class != -1:
        model_dir = model_dir + "_" + str(training_args.max_examples_per_class)

    if training_args.supconloss:
        model_dir = model_dir + "_supcon"

    if training_args.replay_tiny:
        model_dir = model_dir + "_tiny"

    if training_args.enable_mtl_baseline:
        model_dir = model_dir + "_smtl_baseline"

    if training_args.freeze_embeddings:
        model_dir = model_dir + "_frznemb"

    if model_args.model_type2 != "none":
        model_dir = model_dir + "_" + model_args.model_type2

    if training_args.write_strategy != "random":
        model_dir = model_dir + "_" + training_args.write_strategy

    if training_args.skip_replay:
        model_dir = model_dir + "_skipped_replay"

    if training_args.replay_only_fraction == "yes":
        model_dir = model_dir + "_replay1epoch"

    if model_args.pooler_type != "first_token":
        model_dir = model_dir + "_" + model_args.pooler_type

    if training_args.ewc:
        model_dir = (
            model_dir
            + "_ewc_"
            + str(training_args.lmbda)
            + "_"
            + str(training_args.gamma)
        )

    if training_args.enable_l2:
        model_dir = model_dir + "_l2_" + str(training_args.l2_weight)

    if training_args.optimizer != "adam":
        model_dir = model_dir + f"_{training_args.optimizer}_{training_args.rho}"

    training_args.output_dir = os.path.join(training_args.output_dir, model_dir)

    return data_args, model_args, training_args


def setup_logging(
    training_args,
    tb_logname="train",
    stdout_file_name="stdout.log",
    stderr_file_name="stderr.log",
):

    # Setup logging dir for tensorboard logs
    training_args.logging_dir = os.path.join(
        training_args.output_dir, "log", tb_logname
    )

    # Setup logging
    if not training_args.debug:
        prepare_global_logging(
            args=training_args,
            stdout_file_name=stdout_file_name,
            stderr_file_name=stderr_file_name,
        )
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    # logger.info("Training/evaluation parameters %s", training_args)

    return training_args


def get_label_map(task_seq):

    if task_seq not in task_sequences:
        raise KeyError(task_seq)

    label_map = {}
    label_idx = 0

    task_sequence = task_sequences[task_seq]
    for task_name in sorted(task_sequence):
        task_name = task_name.lower()
        num_labels = tasks_num_labels[task_name]

        label_map[task_name] = {}
        for idx in range(num_labels):
            label_map[task_name][idx] = label_idx
            label_idx += 1

    total_num_labels = label_idx
    return label_map, total_num_labels


def get_datasets(data_args, training_args, tokenizer):

    train_datasets = {}
    val_datasets = {}
    test_datasets = {}

    use_predefined_subsampling = (
        False if data_args.disable_predefined_subsampling else True
    )

    for task_name in task_sequences[data_args.task_name]:
        # Get datasets
        task_data_args = deepcopy(data_args)
        task_data_args.data_dir = os.path.join(data_args.data_dir, data_dir[task_name])
        task_data_args.task_name = task_name
        train_datasets[task_name] = (
            CLDataset(
                task_data_args,
                tokenizer=tokenizer,
                local_rank=training_args.local_rank,
                train=True,
                use_predefined_subsampling=use_predefined_subsampling,
            )
            if (training_args.do_train or training_args.do_replay_train)
            else None
        )
        val_datasets[task_name] = (
            CLDataset(
                task_data_args,
                tokenizer=tokenizer,
                local_rank=training_args.local_rank,
                validation=True,
                use_predefined_subsampling=use_predefined_subsampling,
            )
            if training_args.do_val
            else None
        )
        test_datasets[task_name] = (
            CLDataset(
                task_data_args,
                tokenizer=tokenizer,
                local_rank=training_args.local_rank,
                evaluate=True,
                use_predefined_subsampling=use_predefined_subsampling,
            )
            if training_args.do_eval
            else None
        )

    return train_datasets, val_datasets, test_datasets


def train_evaluate_task(
    task_name,
    model,
    training_args,
    model_args,
    train_dataset,
    val_dataset,
    test_dataset,
    train=False,
    replay_train=False,
    exemplarHandler=None,
    est_fisher_info=None,
    est_mean_prev_task=None,
    task_idx=0,
    k=32,
):

    if task_name is None:
        output_mode = "classification"
        task_metric = "acc"
    else:
        output_mode = output_modes[task_name]
        task_metric = task_metrics[task_name]

    def compute_metrics(p: EvalPrediction) -> Dict:
        if output_mode == "classification":
            preds = np.argmax(p.predictions, axis=1)
        elif output_mode == "regression":
            preds = np.squeeze(p.predictions)
        return cl_compute_metrics(task_name, preds, p.label_ids, p.guids)

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        task=task_name,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        test_dataset=test_dataset,
        compute_metrics=compute_metrics,
        adapter_fn=model_args.use_adapter_fn,
        exemplarHandler=exemplarHandler,
        task_metric=task_metric,
        est_fisher_info=est_fisher_info,
        est_mean_prev_task=est_mean_prev_task,
        task_idx=task_idx,
    )

    if train:
        logger.info("Training for task : {}".format(task_name))
        trainer.train()
        return trainer
    else:
        logger.info("Evaluating for task : {}".format(task_name))
        result = trainer.predict(test_dataset=test_dataset)
        return result.metrics


def get_flatten_params(model_params, param_name=None):

    params = []
    for param in model_params:
        if param_name is not None and param_name not in param:
            continue
        params.append(model_params[param].flatten())

    all_params = np.hstack(params)

    return all_params


def visualize_weights():

    data_args, model_args, training_args = process_args()

    # Set seed
    set_seed(training_args.seed)

    model_name_or_path = model_args.model_name_or_path
    # Initialize model config
    config = AutoConfig.from_pretrained(
        model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

    # Initialize model either for training (pre-trained distilbert/bert/roberta) or evaluation
    if model_args.model_type == "distilbert":
        model = DistilBertForSequenceClassification.from_pretrained(
            model_name_or_path,
            from_tf=False,
            config=config,
            cache_dir=model_args.cache_dir,
        )
    elif model_args.model_type == "roberta":
        model = RobertaForSequenceClassification.from_pretrained(
            model_name_or_path,
            from_tf=False,
            config=config,
            cache_dir=model_args.cache_dir,
        )
    else:
        model = BertForSequenceClassification.from_pretrained(
            model_name_or_path,
            from_tf=False,
            config=config,
            cache_dir=model_args.cache_dir,
        )

    if model_args.model_type == "distilbert" and model_args.model_type2 == "rndm_init":
        print("Initializing with random weights!")
        model.rndm_init_weights()
    elif model_args.model_type == "distilbert" and model_args.model_type2 == "xyz_init":
        print("Initializing with xyz random weights!")
        model.xyz_init_weights()

    model_params = {n: p.data.cpu().numpy() for n, p in model.named_parameters()}

    for param_name in model_params:
        param = model_params[param_name]
        # all_params = get_flatten_params(model_params=model_params, param_name='11')
        flatten_params = param.flatten()

        pos_params = flatten_params[flatten_params > 0]
        neg_params = flatten_params[flatten_params < 0]

        if (
            "layer_norm" in param_name
            or "LayerNorm" in param_name
            or "pre_classifier" in param_name
            or "bias" in param_name
        ):
            continue

        # print(f"{param_name} : (pos|neg)no. of params {len(pos_params)}|{len(neg_params)}")
        print(
            f"{param_name} : (pos|neg)mean {np.round(np.mean(pos_params), 3)}|{np.round(np.mean(neg_params), 3)} || (pos|neg)std {np.round(np.std(pos_params), 3)}|{np.round(np.std(neg_params), 3)} || (pos|neg)min {np.round(np.min(pos_params), 3)}|{np.round(np.max(neg_params), 3)} | (pos|neg)max {np.round(np.max(pos_params), 3)}|{np.round(np.min(neg_params), 3)} "
        )


def main():

    data_args, model_args, training_args = process_args()

    if model_args.visualize_weights:
        visualize_weights()
        sys.exit()

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    training_args = setup_logging(training_args)

    logger.info("Training/evaluation parameters %s", training_args)
    logger.info("Model parameters %s", model_args)
    logger.info("Data parameters %s", data_args)

    # Set seed
    set_seed(training_args.seed)

    if training_args.do_train:
        model_name_or_path = model_args.model_name_or_path
    else:
        model_name_or_path = training_args.output_dir

    if model_args.model_type == "roberta":
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name
            if model_args.tokenizer_name
            else model_name_or_path,
            cache_dir=model_args.cache_dir,
        )
    else:
        # tokenizer = BertTokenizer.from_pretrained("vocab/vocab.txt", cache_dir=model_args.cache_dir, never_split=never_split)
        tokenizer = BertTokenizer.from_pretrained(
            model_args.tokenizer_name
            if model_args.tokenizer_name
            else "bert-base-uncased",
            cache_dir=model_args.cache_dir,
            never_split=never_split,
        )

    # Get training and evaluation datasets
    train_datasets, val_datasets, test_datasets = get_datasets(
        data_args=data_args, training_args=training_args, tokenizer=tokenizer
    )

    n_tasks = len(task_sequences[data_args.task_name])
    tasks = task_sequences[data_args.task_name]
    num_labels = [tasks_num_labels[task] for task in tasks]

    # Initialize model config
    config = AutoConfig.from_pretrained(
        model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

    if training_args.do_train:
        config.update(
            config_dict={
                "n_labels": num_labels,
                "finetuning_task": tasks,
                "output_hidden_states": model_args.output_hidden_states,
            }
        )
        logger.info("Updated model config with CL-setup configurations %s", str(config))

    # Initialize model either for training (pre-trained distilbert/bert/roberta) or evaluation
    if model_args.model_type == "distilbert":
        model = DistilBertForSequenceClassification.from_pretrained(
            model_name_or_path,
            from_tf=False,
            config=config,
            cache_dir=model_args.cache_dir,
        )
    elif model_args.model_type == "roberta":
        model = RobertaForSequenceClassification.from_pretrained(
            model_name_or_path,
            from_tf=False,
            config=config,
            cache_dir=model_args.cache_dir,
        )
    else:
        model = BertForSequenceClassification.from_pretrained(
            model_name_or_path,
            from_tf=False,
            config=config,
            cache_dir=model_args.cache_dir,
        )

    if model_args.model_type2 == "linear":
        logger.info("Freezing whole BERT model except ``linear'' classifier!")
        for n, p in model.named_parameters():
            if "classifier" not in n:
                p.requires_grad = False

    elif model_args.model_type2 == "rndm_pooler_only":
        logger.info("Freezing whole BERT model except ``pooler'' layer!")
        logger.info("Random initialization of the ``pooler'' layer!")
        for n, p in model.named_parameters():
            if "classifier" not in n and "pooler" not in n:
                p.requires_grad = False

        # Randomly initializing the pooler layer
        model.bert.pooler.dense.weight.data.normal_(
            mean=0.0, std=model.bert.config.initializer_range
        )
        model.bert.pooler.dense.bias.data.zero_()
        for p in model.bert.pooler.parameters():
            p.requires_grad = True

    elif model_args.model_type2 == "init_pooler_only":
        logger.info("Freezing whole BERT model except ``pooler'' layer!")
        logger.info("Pre-trained initialization of the ``pooler'' layer!")
        for n, p in model.named_parameters():
            if "classifier" not in n and "pooler" not in n:
                p.requires_grad = False

    if training_args.freeze_embeddings:
        logger.info("Word/Position/Token-type embeddings freezed!")
        for n, p in model.named_parameters():
            if (
                "embeddings" in n
                or "word_embeddings" in n
                or "position_embeddings" in n
                or "token_type_embeddings" in n
            ):
                p.requires_grad = False

    if (
        model_args.model_type == "distilbert"
        and model_args.model_type2 == "rndm_init"
        and training_args.do_train
    ):
        logger.info("Initializing with random weights!")
        model.rndm_init_weights()
    elif (
        model_args.model_type == "distilbert"
        and model_args.model_type2 == "xyz_init"
        and training_args.do_train
    ):
        logger.info("Initializing with xyz random weights!")
        model.xyz_init_weights()

    init_model = deepcopy(model)

    results = {}
    val_results = {}
    if training_args.do_eval:
        output_eval_dir = os.path.join(training_args.output_dir, "eval_results")
        if not os.path.exists(output_eval_dir):
            os.makedirs(output_eval_dir, exist_ok=True)

    if training_args.do_train:
        init_handle = "{}-after-{}".format(data_args.task_name, "init")
        init_checkpoint_dir = os.path.join(training_args.output_dir, init_handle)
        os.makedirs(init_checkpoint_dir, exist_ok=True)
        model.save_pretrained(init_checkpoint_dir)

    # Initialize exemplar handler to read/write exemplars to the memory for replay
    exemplarHandler = ExemplarHandler()

    est_fisher_info = None
    est_mean_prev_task = None

    for idx, task_name in enumerate(task_sequences[data_args.task_name]):

        task_handle = "{}-after-{}".format(data_args.task_name, task_name)
        task_checkpoint_dir = os.path.join(training_args.output_dir, task_handle)

        # Training
        if training_args.do_train:

            if training_args.enable_mtl_baseline:
                logger.info("Setting the model to the init-model (pre-trained BERT)")
                model = deepcopy(init_model)

            trainer = train_evaluate_task(
                task_name=task_name,
                model=model,
                training_args=training_args,
                model_args=model_args,
                train_dataset=train_datasets[task_name],
                val_dataset=val_datasets[task_name],
                test_dataset=test_datasets[task_name],
                train=True,
                exemplarHandler=exemplarHandler,
                est_fisher_info=est_fisher_info,
                est_mean_prev_task=est_mean_prev_task,
                task_idx=idx,
            )
            # if trainer.best_model is not None:
            model = trainer.update_to_best_model()

            if training_args.ewc:
                est_mean_prev_task, est_fisher_info = trainer.estimate_fisher(
                    model=model
                )

            exemplarHandler = trainer.exemplarHandler
            exemplarHandler.save_to_file(cached_exemplar_dir=training_args.output_dir)

            trainer.save_model(output_dir=task_checkpoint_dir, save_bertmodel=False)
            config.save_pretrained(save_directory=task_checkpoint_dir)
            if trainer.is_world_master():
                tokenizer.save_pretrained(training_args.output_dir)

            json.dump(
                trainer.all_logs,
                open(os.path.join(task_checkpoint_dir, "all_logs.json"), "w"),
            )

            if idx == (n_tasks - 1):
                trainer.save_model(
                    output_dir=training_args.output_dir, save_bertmodel=False
                )
                config.save_pretrained(save_directory=training_args.output_dir)

        # Evaluation - validation split
        if training_args.do_val:

            for task_idx in range(idx + 1):
                eval_task_name = tasks[task_idx]
                result = train_evaluate_task(
                    task_name=eval_task_name,
                    model=model,
                    training_args=training_args,
                    model_args=model_args,
                    train_dataset=None,
                    val_dataset=None,
                    test_dataset=val_datasets[eval_task_name],
                    train=False,
                )

                if task_handle not in val_results:
                    val_results[task_handle] = {}
                val_results[task_handle][eval_task_name] = result

                output_eval_file = os.path.join(
                    output_eval_dir, f"val_results_{eval_task_name}.txt"
                )
                with open(output_eval_file, "a+") as writer:
                    logger.info(
                        "***** Val results for {} task : {} *****".format(
                            eval_task_name, task_handle
                        )
                    )
                    writer.write(
                        "***** Val results for {} task : {} *****\n".format(
                            eval_task_name, task_handle
                        )
                    )
                    for key, value in result.items():
                        if "gates" in key:
                            continue
                        logger.info("  %s = %s", key, value)
                        writer.write("%s = %s\n" % (key, value))

                writer.close()

                # Keep overwriting results with updated ones
                outf = open(os.path.join(output_eval_dir, "val_results.json"), "w")
                outf.write(json.dumps(val_results, indent=4))
                outf.close()

        # Evaluation - Test set
        if training_args.do_eval:

            for task_idx in range(idx + 1):
                eval_task_name = tasks[task_idx]
                result = train_evaluate_task(
                    task_name=eval_task_name,
                    model=model,
                    training_args=training_args,
                    model_args=model_args,
                    train_dataset=None,
                    val_dataset=None,
                    test_dataset=test_datasets[eval_task_name],
                    train=False,
                )

                if task_handle not in results:
                    results[task_handle] = {}
                results[task_handle][eval_task_name] = result

                output_eval_file = os.path.join(
                    output_eval_dir, f"eval_results_{eval_task_name}.txt"
                )
                with open(output_eval_file, "a+") as writer:
                    logger.info(
                        "***** Eval results for {} task : {} *****".format(
                            eval_task_name, task_handle
                        )
                    )
                    writer.write(
                        "***** Eval results for {} task : {} *****\n".format(
                            eval_task_name, task_handle
                        )
                    )
                    for key, value in result.items():
                        if "gates" in key:
                            continue
                        logger.info("  %s = %s", key, value)
                        writer.write("%s = %s\n" % (key, value))

                writer.close()

                # Keep overwriting results with updated ones
                outf = open(os.path.join(output_eval_dir, "results.json"), "w")
                outf.write(json.dumps(results, indent=4))
                outf.close()

    if training_args.do_val:
        outf = open(os.path.join(output_eval_dir, "val_results.json"), "w")
        outf.write(json.dumps(val_results, indent=4))
        outf.close()

    if training_args.do_eval:
        outf = open(os.path.join(output_eval_dir, "results.json"), "w")
        outf.write(json.dumps(results, indent=4))
        outf.close()


if __name__ == "__main__":
    main()
