""" Data processors and helpers """

import os
import csv
import json
import torch
import random
import logging
import itertools
import collections
import numpy as np
import pandas as pd
import pickle as pkl

from enum import Enum
from nlp import load_dataset
from sklearn.model_selection import KFold
from typing import List, Optional, Union

from transformers import PreTrainedTokenizer, BertTokenizer
from lang_exps.data.processors.utils import (
    DataProcessor,
    InputExample,
    InputFeatures,
    get_splits,
)


logger = logging.getLogger(__name__)


def convert_examples_to_features(
    examples: List[InputExample],
    tokenizer: PreTrainedTokenizer,
    max_length: Optional[int] = None,
    task=None,
    label_list=None,
    output_mode=None,
):
    """
    Loads a data file into a list of ``InputFeatures``
    Args:
        examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples.
        tokenizer: Instance of a tokenizer that will tokenize the examples
        max_length: Maximum example length. Defaults to the tokenizer's max_len
        task: task
        label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method
        output_mode: String indicating the output mode. Either ``regression`` or ``classification``
    Returns:
        If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset``
        containing the task-specific features. If the input is a list of ``InputExamples``, will return
        a list of task-specific ``InputFeatures`` which can be fed to the model.
    """
    return _convert_examples_to_features(
        examples,
        tokenizer,
        max_length=max_length,
        task=task,
        label_list=label_list,
        output_mode=output_mode,
    )


def _convert_examples_to_features(
    examples: List[InputExample],
    tokenizer: PreTrainedTokenizer,
    max_length: Optional[int] = None,
    task=None,
    label_list=None,
    output_mode=None,
):
    if max_length is None:
        max_length = tokenizer.max_len

    if task is not None:
        processor = data_processors[task]()
        if label_list is None:
            label_list = processor.get_labels()
            logger.info("Using label list %s for task %s" % (label_list, task))
        if output_mode is None:
            output_mode = output_modes[task]
            logger.info("Using output mode %s for task %s" % (output_mode, task))

    label_map = {label: i for i, label in enumerate(label_list)}

    def label_from_example(example: InputExample) -> Union[int, float]:
        if output_mode == "classification":
            return label_map[example.label]
        elif output_mode == "regression":
            return float(example.label)
        raise KeyError(output_mode)

    labels = [label_from_example(example) for example in examples]

    if task in ["boolq"]:
        batch_encoding = tokenizer.batch_encode_plus(
            [(example.text_a, example.text_b) for example in examples],
            max_length=max_length,
            pad_to_max_length=True,
            truncation_strategy="only_first",
        )
    else:
        batch_encoding = tokenizer.batch_encode_plus(
            [(example.text_a, example.text_b) for example in examples],
            max_length=max_length,
            pad_to_max_length=True,
        )

    features = []
    for i in range(len(examples)):
        inputs = {k: batch_encoding[k][i] for k in batch_encoding}

        feature = InputFeatures(**inputs, guid=examples[i].guid, label=labels[i])
        features.append(feature)

    for i, example in enumerate(examples[:5]):
        logger.info("*** Example ***")
        logger.info("guid: %s" % (example.guid))
        logger.info("features: %s" % features[i])

    return features


class OutputMode(Enum):

    classification = "classification"
    regression = "regression"


def get_corpus_filtered_examples(data_dir, datafile, metadatafile, split, corpus=None):

    data = json.load(open(os.path.join(data_dir, split, datafile), encoding="utf-8"))
    metadata = json.load(
        open(os.path.join(data_dir, split, metadatafile), encoding="utf-8")
    )

    metadata_dict = {}
    for metadata_i in metadata:
        metadata_dict[metadata_i["pair-id"]] = metadata_i

    examples = []
    for data_i in data:
        pair_id = data_i["pair-id"]
        if pair_id in metadata_dict:
            if corpus is not None and metadata_dict[pair_id]["corpus"] == corpus:
                examples.append(data_i)
            elif corpus is None:
                examples.append(data_i)

    return examples


def get_label_dist(labels, filename=None):

    num_examples = len(labels)
    counter = collections.Counter(labels)
    if filename is not None:
        filename.write("No. of examples : {}\n".format(num_examples))
        filename.write(
            "Label counts : {}\n".format(
                {label: counter[label] for label in counter.keys()}
            )
        )
        normalized_counter = {
            label: counter[label] / num_examples for label in counter.keys()
        }
        filename.write("Label distribution : {}\n".format(normalized_counter))
    else:
        print("No. of examples : ", num_examples)
        print("Label counts : ", {label: counter[label] for label in counter.keys()})
        normalized_counter = {
            label: counter[label] / num_examples for label in counter.keys()
        }
        print("Label distribution : ", normalized_counter)


#### 15-dataset-NLP

# 1
class ColaProcessor(DataProcessor):
    """Processor for the CoLA data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            None,
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train_{}.tsv".format(version))
                ),
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train", task_type
        )

    def get_dev_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev_{}.tsv".format(version))
                ),
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev", task_type
        )

    def get_test_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "processed", "test.tsv")),
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "test", task_type
        )

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def get_label(self, input):

        return input[1]

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            if task_type == "dataset":
                text_a = dataset_token["cola"] + " " + line[3]
            elif task_type == "task":
                text_a = task_token["cola"] + " " + line[3]
            else:
                text_a = line[3]
            label = line[1]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label)
            )
        return examples


# 2
class BoolQProcessor(DataProcessor):
    """Processor for the Boolean Questions data set."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["guid"].numpy(),
            tensor_dict["passage"].numpy().decode("utf-8"),
            tensor_dict["question"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def process_train_dev(
        self,
        data_dir,
        task_name,
        train_filename="train.jsonl",
        dev_filename="val.jsonl",
        test_filename="test.jsonl",
        n_dev_examples=943,
        n_splits=5,
        random_state=42,
        delimiter="\t",
        quotechar=None,
    ):

        lines = []
        labels = []
        for i, d in enumerate(
            open(os.path.join(data_dir, train_filename), encoding="utf-8")
        ):

            row = json.loads(d)
            label = self.get_label(input=row)
            if label is None:
                continue
            labels.append(label)
            lines.append(row)

        num_examples = len(lines)

        splits = get_splits(
            n_samples=num_examples,
            n_dev_examples=n_dev_examples,
            n_splits_to_select=n_splits,
            random_state=random_state,
        )
        split_indices = {}

        if not os.path.exists(os.path.join(data_dir, "processed")):
            os.makedirs(os.path.join(data_dir, "processed"))

        log_f = open(os.path.join(data_dir, "processed", "data.log"), "a+")
        log_f.write("Actual train label distribution\n")
        self.get_label_dist(labels=labels, filename=log_f)
        log_f.write("-------------------------------\n")

        for idx, split_idx in enumerate(splits):
            train_idx, dev_idx = split_idx[0], split_idx[1]
            split_indices[idx] = split_idx

            train_lines = [lines[i] for i in train_idx]
            dev_lines = [lines[i] for i in dev_idx]
            train_labels = [labels[i] for i in train_idx]
            dev_labels = [labels[i] for i in dev_idx]

            f = open(
                os.path.join(data_dir, "processed", "train_v{}.jsonl".format(idx + 1)),
                "w",
                encoding="utf-8",
            )
            for line in train_lines:
                f.write(json.dumps(line))
                f.write("\n")
            f.close()

            f = open(
                os.path.join(data_dir, "processed", "dev_v{}.jsonl".format(idx + 1)),
                "w",
                encoding="utf-8",
            )
            for line in dev_lines:
                f.write(json.dumps(line))
                f.write("\n")
            f.close()

            log_f.write("Train label distribution: {}\n".format(idx + 1))
            self.get_label_dist(labels=train_labels, filename=log_f)

            log_f.write("Dev label distribution: {}\n".format(idx + 1))
            self.get_label_dist(labels=dev_labels, filename=log_f)
            log_f.write("-------------------------------\n")

        pkl.dump(
            split_indices,
            open(os.path.join(data_dir, "processed", "split_indices.pkl"), "wb"),
        )

        f = open(
            os.path.join(data_dir, "processed", test_filename), "w", encoding="utf-8"
        )

        test_labels = []
        for i, d in enumerate(
            open(os.path.join(data_dir, dev_filename), encoding="utf-8")
        ):
            row = json.loads(d)
            label = self.get_label(input=row)
            if label is None:
                continue
            test_labels.append(label)
            f.write(json.dumps(row))
            f.write("\n")
        f.close()

        log_f.write("Test label distribution\n")
        self.get_label_dist(labels=test_labels, filename=log_f)
        log_f.write("-------------------------------\n")
        log_f.close()

    def get_train_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                [
                    json.loads(d)
                    for d in open(
                        os.path.join(
                            data_dir, "processed", "train_{}.jsonl".format(version)
                        ),
                        encoding="utf-8",
                    )
                ],
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            [
                json.loads(d)
                for d in open(os.path.join(data_dir, "train.jsonl"), encoding="utf-8")
            ],
            "train",
            task_type,
        )

    def get_dev_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                [
                    json.loads(d)
                    for d in open(
                        os.path.join(
                            data_dir, "processed", "dev_{}.jsonl".format(version)
                        ),
                        encoding="utf-8",
                    )
                ],
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            [
                json.loads(d)
                for d in open(os.path.join(data_dir, "val.jsonl"), encoding="utf-8")
            ],
            "dev",
            task_type,
        )

    def get_test_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                [
                    json.loads(d)
                    for d in open(
                        os.path.join(data_dir, "processed", "test.jsonl"),
                        encoding="utf-8",
                    )
                ],
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            [
                json.loads(d)
                for d in open(os.path.join(data_dir, "test.jsonl"), encoding="utf-8")
            ],
            "test",
            task_type,
        )

    def get_labels(self):
        """See base class."""
        return [0, 1]

    def get_label(self, input):

        return input["label"]

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training, dev and test sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, line["idx"])

            if task_type == "dataset":
                text_a = dataset_token["boolq"] + " " + line["passage"]
            elif task_type == "task":
                text_a = task_token["boolq"] + " " + line["passage"]
            else:
                text_a = line["passage"]
            text_b = line["question"]
            label = 1 if line["label"] else 0
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
            )
        return examples


# 3
class Sst2Processor(DataProcessor):
    """Processor for the SST-2 data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["guid"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            None,
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "train_{}.tsv".format(version)
                    )
                ),
                "train",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train_{}.tsv".format(version))
                ),
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train", task_type
        )

    def get_dev_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "dev_{}.tsv".format(version)
                    )
                ),
                "dev",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev_{}.tsv".format(version))
                ),
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev", task_type
        )

    def get_test_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "processed_lll", "test.tsv")),
                "test",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "processed", "test.tsv")),
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "test", task_type
        )

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def get_label(self, input):

        return str(input[1])

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0 and not processed_data:
                continue
            guid = "%s-%s" % (set_type, i)
            if task_type == "dataset":
                text_a = dataset_token["sst-2"] + " " + line[0]
            elif task_type == "task":
                text_a = task_token["sst-2"] + " " + line[0]
            else:
                text_a = line[0]
            label = str(line[1])
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label)
            )
        return examples


# 4
class QqpProcessor(DataProcessor):
    """Processor for the QQP data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["question1"].numpy().decode("utf-8"),
            tensor_dict["question2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "train_{}.tsv".format(version)
                    )
                ),
                "train",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train_{}.tsv".format(version))
                ),
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train", task_type
        )

    def get_dev_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "dev_{}.tsv".format(version)
                    )
                ),
                "dev",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev_{}.tsv".format(version))
                ),
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev", task_type
        )

    def get_test_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "processed_lll", "test.tsv")),
                "test",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "processed", "test.tsv")),
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "test", task_type
        )

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def get_label(self, input):
        try:
            label = input[5]
            return label
        except IndexError:
            return None

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0 and not processed_data:
                continue
            guid = "%s-%s" % (set_type, line[0])
            try:
                if task_type == "dataset":
                    text_a = dataset_token["qqp"] + " " + line[3]
                elif task_type == "task":
                    text_a = task_token["qqp"] + " " + line[3]
                else:
                    text_a = line[3]
                text_b = line[4]
                label = line[5]
            except IndexError:
                continue
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
            )
        return examples


# 5
class YahooQAProcessor(DataProcessor):
    """Processor for the Yahoo question and answer classification data set."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            None,
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "train_{}.tsv".format(version)
                    ),
                    delimiter=",",
                    quotechar='"',
                ),
                "train",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train_{}.tsv".format(version)),
                    delimiter=",",
                    quotechar='"',
                ),
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "train.csv"), delimiter=",", quotechar='"'
            ),
            "train",
            task_type,
        )

    def get_dev_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "dev_{}.tsv".format(version)
                    ),
                    delimiter=",",
                    quotechar='"',
                ),
                "dev",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev_{}.tsv".format(version)),
                    delimiter=",",
                    quotechar='"',
                ),
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "test.csv"), delimiter=",", quotechar='"'
            ),
            "dev",
            task_type,
        )

    def get_test_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed_lll", "test.tsv"),
                    delimiter=",",
                    quotechar='"',
                ),
                "test",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "test.tsv"),
                    delimiter=",",
                    quotechar='"',
                ),
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "test.csv"), delimiter=",", quotechar='"'
            ),
            "dev",
            task_type,
        )

    def get_labels(self):
        """See base class."""
        return ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]

    def get_label(self, input):

        return input[0]

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, row) in enumerate(lines):

            guid = "%s-%s" % (set_type, i)
            if task_type == "dataset":
                text_a = dataset_token["yahooqa"] + " " + " ".join(row[1:])
            elif task_type == "task":
                text_a = task_token["yahooqa"] + " " + " ".join(row[1:])
            else:
                text_a = " ".join(row[1:])
            label = row[0]

            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label)
            )
        return examples


# 6
class YelpProcessor(DataProcessor):
    """Processor for the Yelp data set."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            None,
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "train_{}.tsv".format(version)
                    ),
                    delimiter=",",
                    quotechar='"',
                ),
                "train",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train_{}.tsv".format(version)),
                    delimiter=",",
                    quotechar='"',
                ),
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "train.csv"), delimiter=",", quotechar='"'
            ),
            "train",
            task_type,
        )

    def get_dev_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "dev_{}.tsv".format(version)
                    ),
                    delimiter=",",
                    quotechar='"',
                ),
                "dev",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev_{}.tsv".format(version)),
                    delimiter=",",
                    quotechar='"',
                ),
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "test.csv"), delimiter=",", quotechar='"'
            ),
            "dev",
            task_type,
        )

    def get_test_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed_lll", "test.tsv"),
                    delimiter=",",
                    quotechar='"',
                ),
                "test",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "test.tsv"),
                    delimiter=",",
                    quotechar='"',
                ),
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "test.csv"), delimiter=",", quotechar='"'
            ),
            "dev",
            task_type,
        )

    def get_labels(self):
        return ["1", "2", "3", "4", "5"]

    def get_label(self, input):

        return input[0]

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, row) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)

            if task_type == "dataset":
                text_a = dataset_token["yelp"] + " ".join(row[1:])
            elif task_type == "task":
                text_a = task_token["yelp"] + " ".join(row[1:])
            else:
                text_a = " ".join(row[1:])

            label = row[0]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label)
            )
        return examples


# 7
class EventFactualityProcessor(DataProcessor):
    """Processor for the Factuality data set."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["premise"].numpy().decode("utf-8"),
            tensor_dict["hypothesis"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        examples = get_corpus_filtered_examples(
            data_dir=data_dir,
            datafile="recast_factuality_data.json",
            metadatafile="recast_factuality_metadata.json",
            split="train",
            corpus="Decomp",
        )

        return self._create_examples(examples, "train", task_type)

    def get_dev_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        examples = get_corpus_filtered_examples(
            data_dir=data_dir,
            datafile="recast_factuality_data.json",
            metadatafile="recast_factuality_metadata.json",
            split="dev",
            corpus="Decomp",
        )
        return self._create_examples(examples, "dev", task_type)

    def get_test_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        examples = get_corpus_filtered_examples(
            data_dir=data_dir,
            datafile="recast_factuality_data.json",
            metadatafile="recast_factuality_metadata.json",
            split="test",
            corpus="Decomp",
        )

        return self._create_examples(examples, "test", task_type)

    def get_labels(self):
        """See base class."""
        return ["entailed", "not-entailed"]

    def _create_examples(self, lines, set_type, task_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, line["pair-id"])

            if task_type == "dataset":
                text_a = dataset_token["event"] + " " + line["context"]
            elif task_type == "task":
                text_a = task_token["event"] + " " + line["context"]
            else:
                text_a = line["context"]
            text_b = line["hypothesis"]
            label = line["label"]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
            )
        return examples


# 8
class ArgumentProcessor(DataProcessor):
    """Processor for the Argument Aspect Mining data set."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["topic"].numpy().decode("utf-8"),
            tensor_dict["argument"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        filenames = [
            "abortion.tsv",
            "cloning.tsv",
            "death_penalty.tsv",
            "gun_control.tsv",
            "marijuana_legalization.tsv",
            "minimum_wage.tsv",
            "nuclear_energy.tsv",
            "school_uniforms.tsv",
        ]
        examples = []

        for filename in filenames:
            null_count = 0
            lines = self._read_tsv(os.path.join(data_dir, filename))
            for (i, line) in enumerate(lines):
                if i == 0:
                    continue
                if line[6] == "train":
                    if line[4] == "null":
                        null_count += 1
                    else:
                        examples.append(line)

            print("No. of null sentences for file {}: {}".format(filename, null_count))

        return self._create_examples(examples, "train", task_type)

    def get_dev_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        filenames = [
            "abortion.tsv",
            "cloning.tsv",
            "death_penalty.tsv",
            "gun_control.tsv",
            "marijuana_legalization.tsv",
            "minimum_wage.tsv",
            "nuclear_energy.tsv",
            "school_uniforms.tsv",
        ]
        examples = []

        for filename in filenames:
            null_count = 0
            lines = self._read_tsv(os.path.join(data_dir, filename))
            for (i, line) in enumerate(lines):
                if i == 0:
                    continue
                if line[6] == "val":
                    if line[4] == "null":
                        null_count += 1
                    else:
                        examples.append(line)

            print("No. of null sentences for file {}: {}".format(filename, null_count))

        return self._create_examples(examples, "dev", task_type)

    def get_test_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        filenames = [
            "abortion.tsv",
            "cloning.tsv",
            "death_penalty.tsv",
            "gun_control.tsv",
            "marijuana_legalization.tsv",
            "minimum_wage.tsv",
            "nuclear_energy.tsv",
            "school_uniforms.tsv",
        ]
        examples = []

        for filename in filenames:
            null_count = 0
            lines = self._read_tsv(os.path.join(data_dir, filename))
            for (i, line) in enumerate(lines):
                if i == 0:
                    continue
                if line[6] == "test":
                    if line[4] == "null":
                        null_count += 1
                    else:
                        examples.append(line)

            print("No. of null sentences for file {}: {}".format(filename, null_count))

        return self._create_examples(examples, "test", task_type)

    def get_labels(self):
        """See base class."""
        return ["Argument_for", "NoArgument", "Argument_against"]

    def _create_examples(self, lines, set_type, task_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)

            if task_type == "dataset":
                text_a = dataset_token["argument"] + " " + line[0]
            elif task_type == "task":
                text_a = task_token["argument"] + " " + line[0]
            else:
                text_a = line[0]
            text_b = line[4]
            label = line[5]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
            )

        return examples


# 9
class PDTB2Explicit8Processor(DataProcessor):
    """Processor for the PDTB v2.0 explicit connective classification data set."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["argument1"].numpy().decode("utf-8"),
            tensor_dict["argument2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train.tsv"), delimiter="\t"
                ),
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "processed", "train.csv"), delimiter="\t"
            ),
            "train",
            task_type,
        )

    def get_dev_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev.tsv"), delimiter="\t"
                ),
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "processed", "dev.csv"), delimiter="\t"
            ),
            "dev",
            task_type,
        )

    def get_test_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "test.tsv"), delimiter="\t"
                ),
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.csv"), delimiter="\t"),
            "test",
            task_type,
        )

    def get_labels(self):
        """See base class."""
        return ["and", "but", "because", "if", "when", "also", "while", "as"]

    def get_label(self, input):

        return input[-1]

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, row) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)

            text_a = row[5]
            text_b = row[6]

            if task_type == "dataset":
                text_a = dataset_token["dmarker"] + " " + text_a
            elif task_type == "task":
                text_a = task_token["dmarker"] + " " + text_a

            label = row[7]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
            )

        return examples


# 10
class QnliProcessor(DataProcessor):
    """Processor for the QNLI data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["question"].numpy().decode("utf-8"),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "train_{}.tsv".format(version)
                    )
                ),
                "train",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train_{}.tsv".format(version))
                ),
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train", task_type
        )

    def get_dev_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "dev_{}.tsv".format(version)
                    )
                ),
                "dev",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev_{}.tsv".format(version))
                ),
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev", task_type
        )

    def get_test_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "processed_lll", "test.tsv")),
                "test",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "processed", "test.tsv")),
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "test", task_type
        )

    def get_labels(self):
        """See base class."""
        return ["entailment", "not_entailment"]

    def get_label(self, input):

        return input[-1]

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0 and not processed_data:
                continue
            guid = "%s-%s" % (set_type, line[0])
            if task_type == "dataset":
                text_a = dataset_token["qnli"] + " " + line[1]
            elif task_type == "task":
                text_a = task_token["qnli"] + " " + line[1]
            else:
                text_a = line[1]
            text_b = line[2]
            label = line[-1]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
            )
        return examples


# 11
class RocStoryBinarySentenceOrderingProcessor(DataProcessor):
    """Processor for the rocstory binary sentence order classification data set."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence1"].numpy().decode("utf-8"),
            tensor_dict["sentence2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def process_train_dev(
        self,
        data_dir,
        task_name,
        train_filename="train.txt",
        dev_filename="dev.txt",
        test_filename="test.txt",
        random_state=42,
        delimiter="\t",
        quotechar=None,
    ):

        os.makedirs(os.path.join(data_dir, "processed"), exist_ok=True)
        train_examples = []
        for i, row in enumerate(
            self._read_tsv(
                os.path.join(data_dir, train_filename),
                delimiter=delimiter,
                quotechar=quotechar,
            )
        ):
            train_examples.append(row)

        sampled_indices = np.random.RandomState(random_state).choice(
            len(train_examples), 5000, replace=False
        )

        f = open(
            os.path.join(data_dir, "processed", "train.txt"), "w", encoding="utf-8-sig"
        )
        train_f = csv.writer(f, delimiter=delimiter, quotechar=quotechar)
        for idx in sampled_indices:
            train_f.writerow(train_examples[idx])
        f.close()

        dev_examples = []
        for i, row in enumerate(
            self._read_tsv(
                os.path.join(data_dir, dev_filename),
                delimiter=delimiter,
                quotechar=quotechar,
            )
        ):
            dev_examples.append(row)

        sampled_indices = np.random.RandomState(random_state).choice(
            len(dev_examples), 1200, replace=False
        )

        f = open(
            os.path.join(data_dir, "processed", "dev.txt"), "w", encoding="utf-8-sig"
        )
        dev_f = csv.writer(f, delimiter=delimiter, quotechar=quotechar)
        for idx in sampled_indices:
            dev_f.writerow(dev_examples[idx])
        f.close()

        test_examples = []
        for i, row in enumerate(
            self._read_tsv(
                os.path.join(data_dir, test_filename),
                delimiter=delimiter,
                quotechar=quotechar,
            )
        ):
            test_examples.append(row)

        sampled_indices = np.random.RandomState(random_state).choice(
            len(test_examples), 1200, replace=False
        )

        f = open(
            os.path.join(data_dir, "processed", "test.txt"), "w", encoding="utf-8-sig"
        )
        test_f = csv.writer(f, delimiter=delimiter, quotechar=quotechar)
        for idx in sampled_indices:
            test_f.writerow(test_examples[idx])
        f.close()

    def get_train_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train.txt"), delimiter="\t"
                ),
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.txt"), delimiter="\t"),
            "train",
            task_type,
        )

    def get_dev_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev.txt"), delimiter="\t"
                ),
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.txt"), delimiter="\t"),
            "dev",
            task_type,
        )

    def get_test_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "test.txt"), delimiter="\t"
                ),
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.txt"), delimiter="\t"),
            "test",
            task_type,
        )

    def get_labels(self):
        """See base class."""
        return [True, False]

    def get_label(self, input):

        return input[-1]

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, row) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)

            text_a = row[0]
            text_b = row[1]

            if task_type == "dataset":
                text_a = dataset_token["rocstory"] + " " + text_a
            elif task_type == "task":
                text_a = task_token["rocstory"] + " " + text_a

            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=True)
            )

            # Reverse sentence order
            text_a = row[1]
            text_b = row[0]

            if task_type == "dataset":
                text_a = dataset_token["rocstory"] + " " + text_a
            elif task_type == "task":
                text_a = task_token["rocstory"] + " " + text_a

            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=False)
            )

        return examples


# 12
class MnliProcessor(DataProcessor):
    """Processor for the MultiNLI data set (GLUE version)."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["premise"].numpy().decode("utf-8"),
            tensor_dict["hypothesis"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "train_{}.tsv".format(version)
                    )
                ),
                "train",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train_{}.tsv".format(version))
                ),
                "train",
                task_type,
                processed_data,
            )

        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")),
            "train",
            task_type,
            processed_data,
        )

    def get_dev_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "dev_{}.tsv".format(version)
                    )
                ),
                "dev_matched",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev_{}.tsv".format(version))
                ),
                "dev_matched",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
            "dev_matched",
            task_type,
        )

    def get_test_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "processed_lll", "test_m.tsv")),
                "test_matched",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "processed", "test_m.tsv")),
                "test_matched",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test_matched.tsv")),
            "test_matched",
            task_type,
        )

    def get_labels(self):
        """See base class."""
        return ["contradiction", "entailment", "neutral"]

    def get_label(self, input):

        return input[-1]

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0 and not processed_data:
                continue
            guid = "%s-%s" % (set_type, line[0])

            if task_type == "dataset":
                text_a = dataset_token["mnli"] + " " + line[8]
            elif task_type == "task":
                text_a = task_token["mnli"] + " " + line[8]
            else:
                text_a = line[8]
            text_b = line[9]
            label = line[-1]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
            )
        return examples


class MnliMismatchedProcessor(MnliProcessor):
    """Processor for the MultiNLI Mismatched data set (GLUE version)."""

    def get_dev_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev_{}.tsv".format(version))
                ),
                "dev_mismatched",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
            "dev_mismatched",
            task_type,
        )

    def get_test_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "processed", "test_mm.tsv")),
                "test_mismatched",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")),
            "test_mismatched",
            task_type,
        )


# 13
class SciTAILProcessor(DataProcessor):
    """Processor for the SciTAIL data set."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["premise"].numpy().decode("utf-8"),
            tensor_dict["hypothesis"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train_{}.tsv".format(version))
                ),
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "scitail_1.0_train.tsv")),
            "train",
            task_type,
        )

    def get_dev_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev_{}.tsv".format(version))
                ),
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "scitail_1.0_dev.tsv")),
            "dev",
            task_type,
        )

    def get_test_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "processed", "test.tsv")),
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "scitail_1.0_test.tsv")),
            "test",
            task_type,
        )

    def get_labels(self):
        """See base class."""
        return ["entails", "neutral"]

    def get_label(self, input):

        return input[-1]

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training, dev and test sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)

            if task_type == "dataset":
                text_a = dataset_token["scitail"] + " " + line[0]
            elif task_type == "task":
                text_a = task_token["scitail"] + " " + line[0]
            else:
                text_a = line[0]
            text_b = line[1]
            label = line[-1]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
            )
        return examples


# 14
class PDTB2Level1Processor(DataProcessor):
    """Processor for the PDTB v2.0 implicit discourse relation classification data set."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["argument1"].numpy().decode("utf-8"),
            tensor_dict["argument2"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train.tsv"), delimiter="\t"
                ),
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "processed", "train.tsv"), delimiter="\t"
            ),
            "train",
            task_type,
        )

    def get_dev_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev.tsv"), delimiter="\t"
                ),
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "processed", "dev.tsv"), delimiter="\t"
            ),
            "dev",
            task_type,
        )

    def get_test_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "test.tsv"), delimiter="\t"
                ),
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv"), delimiter="\t"),
            "test",
            task_type,
        )

    def get_labels(self):
        """See base class."""
        return ["Expansion", "Comparison", "Contingency", "Temporal"]

    def get_label(self, input):

        return input[4]

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, row) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)

            if set_type == "train":
                text_a = row[6]
                text_b = row[7]
            else:
                text_a = row[7]
                text_b = row[8]

            if task_type == "dataset":
                text_a = dataset_token["drelation"] + " " + text_a
            elif task_type == "task":
                text_a = task_token["drelation"] + " " + text_a

            label = row[4]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
            )

        return examples


# 15
class EmotionClassificationProcessor(DataProcessor):
    """Processor for the emotion classification data set."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            None,
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train.tsv"), delimiter="\t"
                ),
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv"), delimiter="\t"),
            "train",
            task_type,
        )

    def get_dev_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev.tsv"), delimiter="\t"
                ),
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv"), delimiter="\t"),
            "dev",
            task_type,
        )

    def get_test_examples(
        self, data_dir, task_type=None, version="v1", processed_data=False
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "test.tsv"), delimiter="\t"
                ),
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv"), delimiter="\t"),
            "test",
            task_type,
        )

    def get_labels(self):
        """See base class."""
        return ["joy", "sadness", "anger", "fear", "love", "surprise"]

    def get_label(self, input):

        return input[-1]

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, row) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)

            text_a = row[0]
            text_b = None

            if task_type == "dataset":
                text_a = dataset_token["emotion"] + " " + text_a
            elif task_type == "task":
                text_a = task_token["emotion"] + " " + text_a

            label = row[1]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
            )

        return examples


## 5-dataset-NLP (+ YahooQA[5], Yelp[6] from above processors)
## Datasets from Episodic Memory for LLL


class AGNewsProcessor(DataProcessor):
    """Processor for the AG news classification data set."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "train_{}.tsv".format(version)
                    ),
                    delimiter=",",
                    quotechar='"',
                ),
                "train",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train_{}.tsv".format(version)),
                    delimiter=",",
                    quotechar='"',
                ),
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "train.csv"), delimiter=",", quotechar='"'
            ),
            "train",
            task_type,
        )

    def get_dev_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "dev_{}.tsv".format(version)
                    ),
                    delimiter=",",
                    quotechar='"',
                ),
                "dev",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev_{}.tsv".format(version)),
                    delimiter=",",
                    quotechar='"',
                ),
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "test.csv"), delimiter=",", quotechar='"'
            ),
            "dev",
            task_type,
        )

    def get_test_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed_lll", "test.tsv"),
                    delimiter=",",
                    quotechar='"',
                ),
                "test",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "test.tsv"),
                    delimiter=",",
                    quotechar='"',
                ),
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "test.csv"), delimiter=",", quotechar='"'
            ),
            "dev",
            task_type,
        )

    def get_labels(self):
        """See base class."""
        # ["World", "Sports", "Business", "Sci/Tech"]
        return ["1", "2", "3", "4"]

    def get_label(self, input):

        return input[0]

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, row) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)

            if task_type == "dataset":
                text_a = dataset_token["agnews"] + " ".join(row[1:])
            elif task_type == "task":
                text_a = task_token["agnews"] + " ".join(row[1:])
            else:
                text_a = " ".join(row[1:])

            label = row[0]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label)
            )
        return examples


class DBPediaProcessor(DataProcessor):
    """Processor for the AG news classification data set."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "train_{}.tsv".format(version)
                    ),
                    delimiter=",",
                    quotechar='"',
                ),
                "train",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train_{}.tsv".format(version)),
                    delimiter=",",
                    quotechar='"',
                ),
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "train.csv"), delimiter=",", quotechar='"'
            ),
            "train",
            task_type,
        )

    def get_dev_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "dev_{}.tsv".format(version)
                    ),
                    delimiter=",",
                    quotechar='"',
                ),
                "dev",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev_{}.tsv".format(version)),
                    delimiter=",",
                    quotechar='"',
                ),
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "test.csv"), delimiter=",", quotechar='"'
            ),
            "dev",
            task_type,
        )

    def get_test_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed_lll", "test.tsv"),
                    delimiter=",",
                    quotechar='"',
                ),
                "test",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "test.tsv"),
                    delimiter=",",
                    quotechar='"',
                ),
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "test.csv"), delimiter=",", quotechar='"'
            ),
            "dev",
            task_type,
        )

    def get_labels(self):
        """See base class."""
        return [
            "1",
            "2",
            "3",
            "4",
            "5",
            "6",
            "7",
            "8",
            "9",
            "10",
            "11",
            "12",
            "13",
            "14",
        ]

    def get_label(self, input):

        return input[0]

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, row) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)

            if task_type == "dataset":
                text_a = dataset_token["dbpedia"] + " ".join(row[1:])
            elif task_type == "task":
                text_a = task_token["dbpedia"] + " ".join(row[1:])
            else:
                text_a = " ".join(row[1:])

            label = row[0]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label)
            )
        return examples


class AMZNProcessor(DataProcessor):
    """Processor for the Amazon data set ."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return InputExample(
            tensor_dict["idx"].numpy(),
            tensor_dict["sentence"].numpy().decode("utf-8"),
            None,
            str(tensor_dict["label"].numpy()),
        )

    def get_train_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "train_{}.tsv".format(version)
                    ),
                    delimiter=",",
                    quotechar='"',
                ),
                "train",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train_{}.tsv".format(version)),
                    delimiter=",",
                    quotechar='"',
                ),
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "train.csv"), delimiter=",", quotechar='"'
            ),
            "train",
            task_type,
        )

    def get_dev_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(
                        data_dir, "processed_lll", "dev_{}.tsv".format(version)
                    ),
                    delimiter=",",
                    quotechar='"',
                ),
                "dev",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev_{}.tsv".format(version)),
                    delimiter=",",
                    quotechar='"',
                ),
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "test.csv"), delimiter=",", quotechar='"'
            ),
            "dev",
            task_type,
        )

    def get_test_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=False,
        lll_mode=False,
    ):
        """See base class."""
        if lll_mode:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed_lll", "test.tsv"),
                    delimiter=",",
                    quotechar='"',
                ),
                "test",
                task_type,
                processed_data,
            )
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "test.tsv"),
                    delimiter=",",
                    quotechar='"',
                ),
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "test.csv"), delimiter=",", quotechar='"'
            ),
            "dev",
            task_type,
        )

    def get_labels(self):
        return ["1", "2", "3", "4", "5"]

    def get_label(self, input):

        return input[0]

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, row) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)

            if task_type == "dataset":
                text_a = dataset_token["amzn"] + " ".join(row[1:])
            elif task_type == "task":
                text_a = task_token["amzn"] + " ".join(row[1:])
            else:
                text_a = " ".join(row[1:])

            label = row[0]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label)
            )
        return examples


## Split YahooQA (5 2-way classification tasks)


class SplitYahooQAProcessor(DataProcessor):
    """Processor for the Split Yahoo question and answer classification data set."""

    def __init__(self, filter_labels=None):
        super().__init__()
        self.filter_labels = filter_labels

    def get_train_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=True,
        lll_mode=False,
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "train_{}.tsv".format(version)),
                    delimiter=",",
                    quotechar='"',
                ),
                "train",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "train.csv"), delimiter=",", quotechar='"'
            ),
            "train",
            task_type,
        )

    def get_dev_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=True,
        lll_mode=False,
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "dev_{}.tsv".format(version)),
                    delimiter=",",
                    quotechar='"',
                ),
                "dev",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "test.csv"), delimiter=",", quotechar='"'
            ),
            "dev",
            task_type,
        )

    def get_test_examples(
        self,
        data_dir,
        task_type=None,
        version="v1",
        processed_data=True,
        lll_mode=False,
    ):
        """See base class."""
        if processed_data:
            return self._create_examples(
                self._read_tsv(
                    os.path.join(data_dir, "processed", "test.tsv"),
                    delimiter=",",
                    quotechar='"',
                ),
                "test",
                task_type,
                processed_data,
            )
        return self._create_examples(
            self._read_tsv(
                os.path.join(data_dir, "test.csv"), delimiter=",", quotechar='"'
            ),
            "dev",
            task_type,
        )

    def get_labels(self):
        """See base class."""
        if self.filter_labels is not None:
            return self.filter_labels
        else:
            return ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]

    def get_label(self, input):

        return input[0]

    def _create_examples(self, lines, set_type, task_type, processed_data=False):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, row) in enumerate(lines):

            guid = "%s-%s" % (set_type, i)
            if task_type == "dataset":
                text_a = dataset_token["yahooqa"] + " " + " ".join(row[1:])
            elif task_type == "task":
                text_a = task_token["yahooqa"] + " " + " ".join(row[1:])
            else:
                text_a = " ".join(row[1:])
            label = row[0]

            if self.filter_labels is not None and label not in self.filter_labels:
                continue

            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label)
            )
        return examples


# 1
class Split1YahooQAProcessor(SplitYahooQAProcessor):
    """Processor for the Split Yahoo question and answer classification data set."""

    def __init__(
        self,
    ):
        super().__init__(filter_labels=["1", "2"])


# 2
class Split2YahooQAProcessor(SplitYahooQAProcessor):
    """Processor for the Split Yahoo question and answer classification data set."""

    def __init__(
        self,
    ):
        super().__init__(filter_labels=["3", "4"])


# 3
class Split3YahooQAProcessor(SplitYahooQAProcessor):
    """Processor for the Split Yahoo question and answer classification data set."""

    def __init__(
        self,
    ):
        super().__init__(filter_labels=["5", "6"])


# 4
class Split4YahooQAProcessor(SplitYahooQAProcessor):
    """Processor for the Split Yahoo question and answer classification data set."""

    def __init__(
        self,
    ):
        super().__init__(filter_labels=["7", "8"])


# 5
class Split5YahooQAProcessor(SplitYahooQAProcessor):
    """Processor for the Split Yahoo question and answer classification data set."""

    def __init__(
        self,
    ):
        super().__init__(filter_labels=["9", "10"])


tasks_num_labels = {
    "cola": 2,
    "boolq": 2,
    "sst-2": 2,
    "qqp": 2,
    "yahooqa": 10,
    "yelp": 5,
    "event": 2,
    "argument": 3,
    "dmarker": 8,
    "qnli": 2,
    "rocstory": 2,
    "mnli": 3,
    "scitail": 2,
    "drelation": 4,
    "emotion": 6,
    "agnews": 4,
    "dbpedia": 14,
    "amzn": 5,
    "yahooqa1": 2,
    "yahooqa2": 2,
    "yahooqa3": 2,
    "yahooqa4": 2,
    "yahooqa5": 2,
}

data_processors = {
    "cola": ColaProcessor,
    "boolq": BoolQProcessor,
    "sst-2": Sst2Processor,
    "qqp": QqpProcessor,
    "yahooqa": YahooQAProcessor,
    "yelp": YelpProcessor,
    "event": EventFactualityProcessor,
    "argument": ArgumentProcessor,
    "dmarker": PDTB2Explicit8Processor,
    "qnli": QnliProcessor,
    "rocstory": RocStoryBinarySentenceOrderingProcessor,
    "mnli": MnliProcessor,
    "mnli-mm": MnliMismatchedProcessor,
    "scitail": SciTAILProcessor,
    "drelation": PDTB2Level1Processor,
    "emotion": EmotionClassificationProcessor,
    "agnews": AGNewsProcessor,
    "dbpedia": DBPediaProcessor,
    "amzn": AMZNProcessor,
    "splityahooqa": SplitYahooQAProcessor,
    "yahooqa1": Split1YahooQAProcessor,
    "yahooqa2": Split2YahooQAProcessor,
    "yahooqa3": Split3YahooQAProcessor,
    "yahooqa4": Split4YahooQAProcessor,
    "yahooqa5": Split5YahooQAProcessor,
}

output_modes = {
    "cola": "classification",
    "boolq": "classification",
    "sst-2": "classification",
    "qqp": "classification",
    "yahooqa": "classification",
    "yelp": "classification",
    "event": "classification",
    "argument": "classification",
    "dmarker": "classification",
    "qnli": "classification",
    "rocstory": "classification",
    "mnli": "classification",
    "mnli-mm": "classification",
    "scitail": "classification",
    "drelation": "classification",
    "emotion": "classification",
    "agnews": "classification",
    "dbpedia": "classification",
    "amzn": "classification",
    "splityahooqa": "classification",
    "yahooqa1": "classification",
    "yahooqa2": "classification",
    "yahooqa3": "classification",
    "yahooqa4": "classification",
    "yahooqa5": "classification",
}

data_dir = {
    "cola": "glue_data/CoLA",
    "boolq": "superglue_data/BoolQ",
    "sst-2": "glue_data/SST-2",
    "qqp": "glue_data/QQP",
    "yahooqa": "other_data/yahoo_answers_csv",
    "yelp": "other_data/yelp_review_full_csv",
    "event": "DNC",
    "argument": "other_data/UKP_Sentential_Argument_Mining_Corpus/data/complete",
    "dmarker": "other_data/pdtb2_explicit_8_tsv",
    "qnli": "glue_data/QNLI",
    "rocstory": "other_data/rocstory_binary_sentence_ordering",
    "mnli": "glue_data/MNLI",
    "scitail": "glue_data/SciTAIL",
    "drelation": "other_data/pdtb2_implicit_rel_tsv",
    "emotion": "other_data/emotion",
    "agnews": "other_data/ag_news_csv",
    "dbpedia": "other_data/dbpedia_csv",
    "amzn": "other_data/amazon_review_full_csv",
    "splityahooqa": "other_data/split_yahoo_answers_csv",
    "yahooqa1": "other_data/split_yahoo_answers_csv",
    "yahooqa2": "other_data/split_yahoo_answers_csv",
    "yahooqa3": "other_data/split_yahoo_answers_csv",
    "yahooqa4": "other_data/split_yahoo_answers_csv",
    "yahooqa5": "other_data/split_yahoo_answers_csv",
}

dataset_subsampling = {
    "cola": 1.0,
    "boolq": 1.0,
    "sst-2": 0.15,
    "qqp": 0.03,
    "yahooqa": 0.01,
    "yelp": 0.02,
    "event": 0.3,
    "argument": 0.6,
    "dmarker": 1.0,
    "qnli": 0.1,
    "rocstory": 1.0,
    "mnli": 0.03,
    "mnli-mm": 1.0,
    "scitail": 0.5,
    "drelation": 1.0,
    "emotion": 0.6,
    "agnews": 1.0,
    "dbpedia": 1.0,
    "amzn": 1.0,
    "splityahooqa": 1.0,
    "yahooqa1": 1.0,
    "yahooqa2": 1.0,
    "yahooqa3": 1.0,
    "yahooqa4": 1.0,
    "yahooqa5": 1.0,
}

dataset_token = {
    "cola": "lingacc",
    "boolq": "boolq",
    "sst-2": "sst2",
    "qqp": "qqp",
    "yahooqa": "yahooqa",
    "yelp": "yelp",
    "event": "event",
    "argument": "argument",
    "dmarker": "dmarker",
    "qnli": "qnli",
    "rocstory": "rocstory",
    "mnli": "mnli",
    "mnli-mm": "mnli",
    "scitail": "scitail",
    "drelation": "drelation",
    "emotion": "emotion",
    "agnews": "agnews",
    "dbpedia": "dbpedia",
    "amzn": "amzn",
    "splityahooqa": "yahooqa",
    "yahooqa1": "yahooqa1",
    "yahooqa2": "yahooqa2",
    "yahooqa3": "yahooqa3",
    "yahooqa4": "yahooqa4",
    "yahooqa5": "yahooqa5",
}

task_token = {
    "cola": "lingacc",
    "boolq": "qans",
    "sst-2": "sentanalysis",
    "qqp": "semsimilarity",
    "yahooqa": "yahooqa",
    "yelp": "sentanalysis",
    "event": "factuality",
    "argument": "nlinference",
    "dmarker": "discourse",
    "qnli": "nlinference",
    "rocstory": "discourse",
    "mnli": "nlinference",
    "mnli-mm": "nlinference",
    "scitail": "nlinference",
    "drelation": "discourse",
    "emotion": "emotion",
    "agnews": "agnews",
    "dbpedia": "dbpedia",
    "amzn": "amzn",
    "splityahooqa": "yahooqa",
    "yahooqa1": "yahooqa",
    "yahooqa2": "yahooqa",
    "yahooqa3": "yahooqa",
    "yahooqa4": "yahooqa",
    "yahooqa5": "yahooqa",
}
