from collections import OrderedDict
import collections
import abc
import functools
from typing import Callable, List, Mapping
from utils import pad_punctuation
from seq2seq.metrics import metrics
from .utils import round_stsb_target
import datasets
import logging
import numpy as np
import torch
import re
import random

logger = logging.getLogger(__name__)

ppt_question = ' A. yes. B. maybe. C. no.? The correct one is '

class teacherFewShot(object):
    def __init__(self, teachers):
        '''
        teachers: a list of models.
        '''
        pass

class AbstractTask(abc.ABC):
    name = NotImplemented
    config = NotImplemented
    prefix = NotImplemented
    preprocessor: Callable = NotImplemented
    metric = NotImplemented
    metric_names = NotImplemented
    split_map = None
    labels_list = None
    split_to_data_split: Mapping[str, str] = \
        {"train": "train", "validation": "validation", "test": "test"}
    small_datasets_without_all_splits = ["cola", "wnli", "wnli_ppt", "rte", "superglue-cb", "superglue-cb_ppt", "superglue-copa",
                                         "superglue-multirc", "superglue-multirc_ppt", "superglue-wic", "superglue-wic_ppt",
                                         "superglue-wsc.fixed", "superglue-wsc.fixed_ppt", "superglue-rte", "superglue-rte_ppt",
                                         "mrpc", "mrpc_ppt", "stsb", "superglue-boolq", "superglue-boolq_ppt", "xsum", "scitail"]
    large_data_without_all_splits = ["qqp", "qnli", "superglue-record", "sst2", "squad", "snli", "anli",
                                     "qqp_ppt", "qnli_ppt", "superglue-record_ppt", "sst2_ppt", "squad_ppt",
                                     "amazon_polarity", "yelp_polarity", "winogrande", "newsqa", "searchqa",
                                     "triviaqa", "nq", "hotpotqa"]

    def __init__(self, config, seed=42):
        self.config = config
        self.seed = seed
        random.seed(seed)

    def get_max_target_length(self, tokenizer, default_max_length):
        if self.labels_list is not None:
            return max([len(tokenizer.encode(label)) for label in self.labels_list])
        return default_max_length

    def seq2seq_format(self, sources: List[str],
                       targets: List[str],
                       add_prefix: bool = False,
                       prefix: str = None,
                       extra_fields={}):
        src_prefix = self.name if prefix is None else prefix
        sources = [src_prefix]+sources if add_prefix else sources
        return {'source': ' '.join(sources),
                'target': ' '.join(targets),
                'task': self.name,
                'extra_fields': extra_fields}

    def check_n_obs(self, n_obs, total_size):
        if n_obs is not None and n_obs > total_size:
            n_obs = total_size
            logger.warning("n_obs is set to %s", n_obs)
        return n_obs

    def shuffled_indices(self, dataset):
        num_samples = len(dataset)
        generator = torch.Generator()
        generator.manual_seed(self.seed)
        return torch.randperm(num_samples, generator=generator).tolist()

    def subsample(self, dataset, n_obs=None, indices=None):
        """
        Given a dataset returns the subsampled dataset.
        :param n_obs: the number of samples of the subsampled dataset.
        :param indices: indices to select the samples from, if not given, indices are computed
        from by shuffling the given dataset.
        :return: subsampled dataset.
        """
        if n_obs:
            num_samples = len(dataset)
            # print('num_samples', num_samples)
            n_obs = self.check_n_obs(n_obs, num_samples)
            if indices is None or not indices:
                indices = self.shuffled_indices(dataset)

            # # indices = indices[:n_obs]
            # balance
            indices_res = []
            labels = None
            if 'label' in dataset:
                labels = set([d['label'] for d in dataset])

            counts = dict()
            count = 0
            if not labels or len(labels) >= 6:
                indices_res = indices[:n_obs]
            else:
                for l in labels:
                    counts[l] = 16  # 16 for each class


                idx = 0
                # print('counts', counts)
                # print(labels)
                # print('indices', indices)
                while len(indices_res) < n_obs and idx < len(indices):
                    # print(dataset[indices[idx]]['label'], dataset[indices[idx]]['label'] in counts)
                    # print('indices[idx]', indices[idx])
                    if counts[dataset[indices[idx]]['label']] > 0:
                        indices_res.append(indices[idx])
                        counts[dataset[indices[idx]]['label']] -= 1
                    idx += 1
                count += 1
            # print('indices_res', indices_res)
            return dataset.select(indices_res)
        else:
            return dataset


    def load_dataset(self, split: int):
        return datasets.load_dataset(self.name, self.config, split=split, script_version="master")

    def get_split_indices(self, split, dataset, validation_size):
        indices = self.shuffled_indices(dataset)
        validation_size = max(int(0.5*len(indices)), 32)
        if split == "validation":

            return indices[:validation_size]
        else:
            return indices[validation_size:]

    def map_dataset(self, dataset, add_prefix):
        return dataset.map(functools.partial(self.preprocessor, add_prefix=add_prefix),
                           remove_columns=dataset.column_names)

    def get(self, split, add_prefix=True, n_obs=None, split_validation_test=False, lang=None, file_name=None):
        if 'source' in split:
            print('split >>>', split)
            split = split.split('_')[0]

            # if "train" in split:
            #     if lang is not None:
            #         dataset = self.load_dataset(split="train", lang_code=lang)
            #     if file_name is not None:
            #         dataset = datasets.load_dataset(
            #             'csv', data_files=file_name, split="train")
            #     else:
            #         dataset = self.load_dataset(split="train")
            # elif "validation" in split:
            #     mapped_split = self.split_to_data_split["validation"]
            #     if lang is not None:
            #         dataset = self.load_dataset(split=mapped_split, lang_code=lang)
            #     if file_name is not None:
            #         dataset = datasets.load_dataset(
            #             'csv', data_files=file_name, split=mapped_split)
            #     else:
            #         dataset = self.load_dataset(split=mapped_split)
            #     indices = self.get_split_indices(
            #         split='validation', dataset=dataset, validation_size=1000)
            #     dataset = self.subsample(dataset, n_obs, indices)
            #     # # shuffles the data and samples it.
            #     # if n_obs is not None:
            #     #     dataset = self.subsample(dataset, n_obs)
            # else: # test
            #     mapped_split = self.split_to_data_split["test"]
            #     if lang is not None:
            #         dataset = self.load_dataset(split=mapped_split, lang_code=lang)
            #     if file_name is not None:
            #         dataset = datasets.load_dataset(
            #             'csv', data_files=file_name, split=mapped_split)
            #     else:
            #         dataset = self.load_dataset(split=mapped_split)
            #     indices = self.get_split_indices(
            #         split='test', dataset=dataset, validation_size=1000)
            #     dataset = self.subsample(dataset, n_obs, indices)
            # For small datasets (n_samples < 10K) without test set, we divide validation set to
            # half, use one half as test set and one half as validation set.
            if split_validation_test and self.name in self.small_datasets_without_all_splits \
                    and "train" not in split:
                mapped_split = self.split_to_data_split["validation"]
                if lang is not None:
                    dataset = self.load_dataset(split=mapped_split, lang_code=lang)
                if file_name is not None:
                    dataset = datasets.load_dataset(
                        'csv', data_files=file_name, split="train")
                else:
                    dataset = self.load_dataset(split=mapped_split)
                indices = self.get_split_indices(
                    split, dataset, validation_size=len(dataset) // 2)
                dataset = self.subsample(dataset, n_obs, indices)
            # For larger datasets (n_samples > 10K), we divide training set into 1K as
            # validation and the rest as training set, keeping the original validation
            # set as the test set.
            elif split_validation_test and self.name in self.large_data_without_all_splits \
                    and "test" not in split:
                if lang is not None:
                    dataset = self.load_dataset(split="train", lang_code=lang)
                if file_name is not None:
                    dataset = datasets.load_dataset(
                        'csv', data_files=file_name, split="train")
                else:
                    dataset = self.load_dataset(split="train")
                indices = self.get_split_indices(
                    split, dataset, validation_size=1000)
                dataset = self.subsample(dataset, n_obs, indices)
            else:
                mapped_split = self.split_to_data_split[split]
                if lang is not None:
                    dataset = self.load_dataset(split=mapped_split, lang_code=lang)

                if file_name is not None:
                    dataset = datasets.load_dataset(
                        'csv', data_files=file_name, split="train")
                else:
                    dataset = self.load_dataset(split=mapped_split)
                # shuffles the data and samples it.
                if n_obs is not None:
                    dataset = self.subsample(dataset, n_obs)
            return self.map_dataset(dataset, add_prefix)
        else:
            if split != "test":
                if lang is not None:
                    dataset = self.load_dataset(split="train", lang_code=lang)
                if file_name is not None:
                    dataset = datasets.load_dataset(
                        'csv', data_files=file_name, split="train")
                else:
                    if '+' in split:
                        dataset = self.load_dataset(split="train+validation")
                    elif 'real' in split:
                        dataset = self.load_dataset(split="test")
                    else:
                        dataset = self.load_dataset(split="train")
                indices = self.get_split_indices(
                    split, dataset, validation_size=1000)
                dataset = self.subsample(dataset, n_obs, indices)
                # shuffles the data and samples it.
                if n_obs is not None:
                    dataset = self.subsample(dataset, n_obs)
            else: # test
                print('split->', split)
                mapped_split = self.split_to_data_split[split]
                if lang is not None:
                    dataset = self.load_dataset(split=mapped_split, lang_code=lang)

                if file_name is not None:
                    dataset = datasets.load_dataset(
                        'csv', data_files=file_name, split=mapped_split)
                else:
                    dataset = self.load_dataset(split=mapped_split)
                # # shuffles the data and samples it.
                # if n_obs is not None:
                #     dataset = self.subsample(dataset, n_obs)
        res = self.map_dataset(dataset, add_prefix)
        print(split, '--->', res)
        return res


class Squad(AbstractTask):
    name = "squad"
    metric = [metrics.squad]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        try:
            res = datasets.load_dataset(self.name, split=split, script_version="master")
        except:
            res = datasets.load_dataset(self.name, split=split)
        return res

    def preprocessor(self, example, add_prefix):
        answer = pad_punctuation(example['answers']).split("\t")
        question = pad_punctuation(example['question'])
        context = pad_punctuation(example['context'])
        source = ["context:", context, "question:", question, 'The answer is ']
        # target = [answer] if type(answer) == str else answer
        # print('ans>>>', example['answers'])
        target = [answer] if type(answer) == str else example['answers']['text'][0]
        return self.seq2seq_format(source, target, add_prefix)

class Allsix_ppt(AbstractTask):
    name = "allsix"
    metric = [metrics.squad]

    def load_dataset(self, split):
        return datasets.load_dataset(self.name, split=split, script_version="master")

    def preprocessor(self, example, add_prefix):
        answer = pad_punctuation(example['answers']).split("\t")
        question = pad_punctuation(example['question'])
        context = pad_punctuation(example['context'])
        source = ["context:", context + question]
        if example['idx'] % 3 == 0:
            source += 'A. ' + answer + 'B. ' + context.split()[random.randint(0, len(context.split()))] + 'C. ' + context.split()[random.randint(0, len(context.split()))]
            target = ['A']
        elif example['idx'] % 3 == 1:
            source += 'A. ' + context.split()[random.randint(0, len(context.split()))] + 'B. ' + answer + 'C. ' + context.split()[random.randint(0, len(context.split()))]
            target = ['B']
        else:
            source += 'A. ' + context.split()[random.randint(0, len(context.split()))] + 'B. ' + context.split()[random.randint(0, len(context.split()))] + 'C. ' + \
                      answer
            target = ['C']
        return self.seq2seq_format(source, target, add_prefix)

class Squad_ppt(AbstractTask):
    name = "squad"
    metric = [metrics.squad]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        try:
            res = datasets.load_dataset(self.name, split=split, script_version="master")
        except:
            res = datasets.load_dataset(self.name, split=split)
        return res

    def preprocessor(self, example, add_prefix):
        # print('>>>', pad_punctuation(example['answers']['text']))
        answer = example['answers']['text'][0]
        question = pad_punctuation(example['question'])
        context = pad_punctuation(example['context'])
        source = ["context:", context + question, "question:"]
        import re
        idx = int(re.sub('\D', '', example['id']))
        # print('sss', source)
        # print('context', context)
        # print('answer', answer)
        if idx % 3 == 0:
            source.append('A. ' + answer + '. B. ' + context.split()[random.randint(0, len(context.split())-1)].strip() + '. C. ' + context.split()[random.randint(0, len(context.split())-1)].strip())
            target = ['A']
        elif idx % 3 == 1:
            # print('context.split()[random.randint(0, len(context.split()))]', context.split()[random.randint(0, len(context.split()))])
            # print('context.split()[random.randint(0, len(context.split()))]', context.split()[random.randint(0, len(context.split()))])
            source.append('A. ' + context.split()[random.randint(0, len(context.split())-1)].strip() + '. B. ' + answer + '. C. ' + context.split()[random.randint(0, len(context.split())-1)].strip())
            target = ['B']
        else:
            # print('context.split()[random.randint(0, len(context.split()))]',
            #       context.split()[random.randint(0, len(context.split()))])
            # print('context.split()[random.randint(0, len(context.split()))]',
            #       context.split()[random.randint(0, len(context.split()))])
            source.append('A. ' + context.split()[random.randint(0, len(context.split())-1)].strip() + '. B. ' + context.split()[random.randint(0, len(context.split())-1)].strip() + '. C. ' + \
                      answer)
            target = ['C']
        # print('source', source)
        return self.seq2seq_format(source, target, add_prefix)

class SciTail(AbstractTask):
    name = "scitail"
    labels_list = ["0", "1"]
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "test"}

    def load_dataset(self, split):
        return datasets.load_dataset('scitail', "snli_format", split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        label2id = {"entailment": "0", "neutral": "1"}
        src_texts = ["premise:", example['sentence1'],
                     "hypothesis:", example["sentence2"]]
        tgt_texts = [label2id[example["gold_label"]]]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
class MRPC(AbstractTask):
    name = "mrpc"
    labels_list = ["0", "1"]
    metric = [metrics.f1_score_with_invalid, metrics.accuracy]
    metric_names = ["f1", "accuracy"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        return datasets.load_dataset('glue', 'mrpc', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence1:", example['sentence1'],
                     "sentence2:", example["sentence2"]]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class MRPC_ppt(AbstractTask):
    name = "mrpc"
    labels_list = ["0", "1"]
    metric = [metrics.f1_score_with_invalid, metrics.accuracy]
    metric_names = ["f1", "accuracy"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        return datasets.load_dataset('glue', 'mrpc', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        '''
        equivalent -> 1-> yes -> A
        not equivalent -> 0 -> no -> C
        '''
        src_texts = ["sentence1:", example['sentence1'] + '. ' + example["sentence2"],
                     "sentence2:", ppt_question]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class COLA(AbstractTask):
    name = "cola"
    labels_list = ["0", "1"]
    metric = [metrics.matthews_corrcoef]
    metric_names = ["matthews_correlation"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        return datasets.load_dataset('glue', 'cola',
                                     split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence:", example['sentence']]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)


class SST2(AbstractTask):
    name = "sst2"
    labels_list = ["0", "1"]
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        return datasets.load_dataset('glue', 'sst2',
                                     split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence:", example['sentence']]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class SST2_ppt(AbstractTask):
    '''
    0 > nagative> no > C
    1 > positive > yes > A
    '''
    name = "sst2"
    labels_list = ["0", "1"]
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        return datasets.load_dataset('glue', 'sst2',
                                     split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence:", example['sentence'] + ppt_question]
        if int(example['label']) == 0:
            tgt_texts = ['C']
        elif int(example['label']) == 1:
            tgt_texts = ['A']
        else:
            tgt_texts = ['B']
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class YelpPolarity(AbstractTask):
    name = "yelp_polarity"
    labels_list = ["0", "1"]
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]
    split_to_data_split = {"train": "train", "test": "test"}

    def load_dataset(self, split):
        return datasets.load_dataset('yelp_polarity')[split]

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence:", example['text']]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class Amazon_Polarity(AbstractTask):
    name = "amazon_polarity"
    labels_list = ["0", "1"]
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]
    split_to_data_split = {"train": "train", "test": "test"}

    def load_dataset(self, split):
        return datasets.load_dataset('yelp_polarity', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence:", "<title> {0} <context> {1}".format(
            example['title'], example['context'])]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)


class STSB(AbstractTask):
    name = "stsb"
    labels_list = [str(np.round(label, decimals=1))
                   for label in np.arange(0, 5.2, 0.2)]
    metric = [metrics.pearson_corrcoef, metrics.spearman_corrcoef]
    metric_names = ["pearson", "spearmanr"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        return datasets.load_dataset('glue', 'stsb',
                                     split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence1:", example['sentence1'],
                     "sentence2:", example["sentence2"]]
        tgt_texts = [str(round_stsb_target(example['label']))]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)


class QQP(AbstractTask):
    name = "qqp"
    labels_list = ["0", "1"]
    metric = [metrics.f1_score_with_invalid, metrics.accuracy]
    metric_names = ["f1", "accuracy"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        return datasets.load_dataset('glue', 'qqp',
                                     split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["question1:", example['question1'],
                     "question2:", example["question2"]]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class QQP_ppt(AbstractTask):
    '''
    0 > not same > wrong > C
    1 > same > yes> A
    '''
    name = "qqp"
    labels_list = ["0", "1"]
    metric = [metrics.f1_score_with_invalid, metrics.accuracy]
    metric_names = ["f1", "accuracy"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        return datasets.load_dataset('glue', 'qqp',
                                     split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["question1:", example['question1'] + example["question2"] + ppt_question]
        if int(example['label']) == 0:
            tgt_texts = ['C']
        elif int(example['label']) == 1:
            tgt_texts = ['A']
        else:
            tgt_texts = ['B']
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class MNLI(AbstractTask):
    name = "mnli"
    labels_list = ["0", "1", "2"]
    split_to_data_split = {"train": "train",
                           "validation": "validation_mismatched",
                           "test": "validation_matched"}
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]

    def load_dataset(self, split):
        return datasets.load_dataset('glue', 'mnli', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["premise:", example['premise'],
                     "hypothesis:", example["hypothesis"]]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class MNLI_ppt(AbstractTask):
    '''
    0 > yes > A
    1 > maybe > B
    2 > no > C
    '''
    name = "mnli"
    labels_list = ["0", "1", "2"]
    split_to_data_split = {"train": "train",
                           "validation": "validation_mismatched",
                           "test": "validation_matched"}
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]

    def load_dataset(self, split):
        return datasets.load_dataset('glue', 'mnli', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        # src_texts = ["premise:", example['premise'],
        #              "hypothesis:", example["hypothesis"]]
        src_texts = ["premise:", example['premise'] + example["hypothesis"] + ppt_question]
        if int(example['label']) == 0:
            tgt_texts = ['A']
        elif int(example['label']) == 1:
            tgt_texts = ['B']
        else:
            tgt_texts = ['C']
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class SNLI(AbstractTask):
    name = "snli"
    labels_list = ["0", "1", "2"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "test"}
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]

    def load_dataset(self, split):
        return datasets.load_dataset('snli', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["premise:", example['premise'],
                     "hypothesis: ", example["hypothesis"]]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)


class MultiNLI(AbstractTask):
    name = "mnli"
    labels_list = ["0", "1", "2"]
    split_to_data_split = {"train": "train",
                           "validation": "validation_mismatched",
                           "test": "validation_matched"}
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]

    def load_dataset(self, split):
        return datasets.load_dataset('multi_nli', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["premise:", example['premise'],
                     "hypothesis:", example["hypothesis"]]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)


class ANLI(AbstractTask):
    name = "anli"
    labels_list = ["0", "1", "2"]
    split_to_data_split = {"train": "train_r3",
                           "validation": "dev_r3",
                           "test": "test_r3"}
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]

    def load_dataset(self, split):
        split_to_data_split = {"train": "train_r3",
                               "validation": "dev_r3",
                               "test": "test_r3"}
        return datasets.load_dataset('anli', split=split_to_data_split[split] if split in split_to_data_split else split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["premise:", example['premise'],
                     "hypothesis:", example["hypothesis"]]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class QNLI(AbstractTask):
    name = "qnli"
    labels_list = ["0", "1"]
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        return datasets.load_dataset('glue', 'qnli', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["question:", example['question'],
                     "sentence:", example["sentence"]]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class QNLI_ppt(AbstractTask):
    '''
    0 > entailment > yes > A
    1 > not_entailment > no > C
    '''
    name = "qnli"
    labels_list = ["0", "1"]
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        return datasets.load_dataset('glue', 'qnli', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence:", example["sentence"] + example['question'] + ppt_question]
        if int(example['label']) == 0:
            tgt_texts = ['A']
        elif int(example['label']) == 1:
            tgt_texts = ['C']
        else:
            tgt_texts = ['B']
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class RTE(AbstractTask):
    name = "rte"
    labels_list = ["0", "1"]
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        return datasets.load_dataset('glue', 'rte',
                                     split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence1:", example['sentence1'],
                     "sentence2:", example["sentence2"]]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)


class WNLI(AbstractTask):
    name = "wnli"
    labels_list = ["0", "1"]
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        return datasets.load_dataset('glue', 'wnli', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence1:", example['sentence1'],
                     "sentence2:", example["sentence2"]]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class WNLI_ppt(AbstractTask):
    '''
    not_entailment -> 0 -> no -> C
    entailment -> 1 -> yes -> A
    '''
    name = "wnli"
    labels_list = ["0", "1"]
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        return datasets.load_dataset('glue', 'wnli', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence1:", example['sentence1'] + '.' +  example["sentence2"],
                     "sentence2:", ppt_question]
        # if str(example['label']).strip() == '0':
        #     tgt_texts = ['C']
        # else:
        #     tgt_texts = ['A']
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class SuperGLUEBoolQ(AbstractTask):
    name = "superglue-boolq"
    labels_list = ['0', '1']
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'boolq', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["question:", example["question"],
                     "passage:", example["passage"]]
        tgt_texts = [str(example["label"])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class SuperGLUEBoolQ_ppt(AbstractTask):
    name = "superglue-boolq"
    labels_list = ['0', '1']
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'boolq', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        '''
        boolq -> True ->  1 -> 209 -> A -> 71
        boolq -> False ->  0 -> 3 -> 205
        '''
        src_texts = ["question:", example["passage"] + ' .' + example["question"],
                     "passage:", ppt_question]
        tgt_texts = [str(example["label"])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class SuperGLUERTE(AbstractTask):
    name = "superglue-rte"
    labels_list = ['0', '1']
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'rte', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["premise:", example["premise"],
                     "hypothesis:", example["hypothesis"]]
        tgt_texts = [str(example["label"])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class SuperGLUERTE_ppt(AbstractTask):
    name = "superglue-rte"
    labels_list = ['0', '1']
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'rte', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        '''
        0 -> entailemnt -> yes -> A
        1 ->  not entailemnt -> no -> C
        '''
        src_texts = ["premise:", example["premise"] + ' .' + example["hypothesis"],
                     "hypothesis:", ppt_question]
        tgt_texts = [str(example["label"])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)


class SuperGLUECB(AbstractTask):
    name = "superglue-cb"
    labels_list = ['0', '1', '2']
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metric = [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]
    metric_names = ["f1_multiclass", "accuracy"]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'cb', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["premise:", example["premise"],
                     "hypothesis:", example["hypothesis"]]
        tgt_texts = [str(example["label"])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class SuperGLUECB_ppt(AbstractTask):
    name = "superglue-cb"
    labels_list = ['0', '1', '2']
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metric = [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]
    metric_names = ["f1_multiclass", "accuracy"]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'cb', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        '''
        cb -> contradiction ->  1 > 209 > no > C. 205
        cb -> entailment ->  0 > 3 > no > A. 71
        cb -> neutral -> B > 272 > 204
        '''
        src_texts = ["premise:", example["premise"],
                     "hypothesis:", example["hypothesis"],
                     ppt_question]
        tgt_texts = [str(example["label"])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)


class SuperGLUECOPA(AbstractTask):
    name = "superglue-copa"
    labels_list = ['0', '1']
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'copa', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["premise:", example["premise"],
                     "choice1:", example["choice1"],
                     "choice2:", example["choice2"]]
        tgt_texts = [str(example["label"])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)


class SuperGLUEMultiRC(AbstractTask):
    name = "superglue-multirc"
    labels_list = ['0', '1']
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metric = [metrics.multirc_f1_over_all_answers,
              metrics.mean_group_metric(metrics.exact_match)]
    metric_names = ["f1", "em"]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'multirc', split=split, script_version="master")

    def remove_markup(self, text):
        """Removes the HTML markup."""
        text = re.sub('<br>', ' ', text)
        text = re.sub('<(/)?b>', '', text)
        return text

    def preprocessor(self, example, add_prefix=True):
        group = example['idx']['question']
        # T5 applies remove_markup to the joined string, but this should not make
        # any difference as well.
        # https://github.com/google-research/text-to-text-transfer-transformer/blob/a1352e625db7ec114062f99d99b0565b9e45c155/t5/data/preprocessors.py#L797
        src_texts = ["question:", self.remove_markup(example["question"]),
                     "answer:", self.remove_markup(example["answer"]),
                     "paragraph:", self.remove_markup(example["paragraph"])]
        tgt_texts = [str(example["label"])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix, extra_fields={"group": group})

class SuperGLUEMultiRC_ppt(AbstractTask):
    name = "superglue-multirc"
    labels_list = ['0', '1']
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metric = [metrics.multirc_f1_over_all_answers,
              metrics.mean_group_metric(metrics.exact_match)]
    metric_names = ["f1", "em"]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'multirc', split=split, script_version="master")

    def remove_markup(self, text):
        """Removes the HTML markup."""
        text = re.sub('<br>', ' ', text)
        text = re.sub('<(/)?b>', '', text)
        return text

    def preprocessor(self, example, add_prefix=True):
        '''
        multirc -> 0 > False > 3 > no > C  > 205
        multirc -> 1 > True > 209 > yes > A  > 71
        '''
        group = example['idx']['question']
        # T5 applies remove_markup to the joined string, but this should not make
        # any difference as well.
        # https://github.com/google-research/text-to-text-transfer-transformer/blob/a1352e625db7ec114062f99d99b0565b9e45c155/t5/data/preprocessors.py#L797
        src_texts = ["question:", self.remove_markup(example["question"]),
                     "answer:", self.remove_markup(example["answer"] + ppt_question),
                     "paragraph:", self.remove_markup(example["paragraph"])]
        tgt_texts = [str(example["label"])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix, extra_fields={"group": group})

class SuperGLUEWIC(AbstractTask):
    name = "superglue-wic"
    labels_list = ['0', '1']
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'wic', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence1:", example["sentence1"],
                     "sentence2:", example["sentence2"],
                     "word:", example["word"]]
        tgt_texts = [str(example["label"])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class SuperGLUEWIC_ppt(AbstractTask):
    name = "superglue-wic"
    labels_list = ['0', '1']
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'wic', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence1:", example["sentence1"],
                     "sentence2:", example["sentence2"],
                     "word:", example["word"]]
        tgt_texts = [str(example["label"])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class SuperGLUEWSCFixed(AbstractTask):
    # source: https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py
    """Convert WSC examples to text2text format.
     WSC includes a sentence along with 2 'spans': the first denoting a noun and
     the other a pronoun. The 'label' specifies whether or not the pronoun is
     referencing the noun. This preprocessor puts ' * ' around the noun and ' # '
     around the pronoun.
     For example, a typical example from WSC might look like
     {
         'text': 'This is a test sentence .',
         'span1_text': 'test',
         'span1_index': 3,
         'span2_text': 'This',
         'span2_index': 0,
         'label': 0
     }
     This example would be transformed to
     {
         'inputs': 'wsc text: # This # is a * test * sentence .',
         'targets': 'False'
     }
    """
    name = "superglue-wsc.fixed"
    labels_list = ['0', '1']
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'wsc.fixed', split=split, script_version="master")

    def _mark_span(self, text, span_str, span_idx, mark):
        pattern_tmpl = r'^((?:\S+\s){N})(W)'
        pattern = re.sub('N', str(span_idx), pattern_tmpl)
        pattern = re.sub('W', span_str, pattern)
        return re.sub(pattern, r'\1{0} \2 {0}'.format(mark), text)

    def preprocessor(self, example, add_prefix=True):
        # converts text as done in T5.
        text = example['text']
        # text = self._mark_span(
        #     text, example['span1_text'], example['span1_index'], '*')
        # # Compensate for 2 added "words" added in previous step.
        # span2_index = example['span2_index'] + 2 * \
        #     int(example['span1_index'] < example['span2_index'])
        # text = self._mark_span(text, example['span2_text'], span2_index, '#')
        # src_texts = ["text:", text]
        # tgt_texts = [str(example["label"])]
        src_texts = ["text:", text + example['span1_text'] + ' is ' + example['span2_text']]
        tgt_texts = [str(example["label"])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class SuperGLUEWSCFixed_ppt(AbstractTask):
    # source: https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py
    """Convert WSC examples to text2text format.
     WSC includes a sentence along with 2 'spans': the first denoting a noun and
     the other a pronoun. The 'label' specifies whether or not the pronoun is
     referencing the noun. This preprocessor puts ' * ' around the noun and ' # '
     around the pronoun.
     For example, a typical example from WSC might look like
     {
         'text': 'This is a test sentence .',
         'span1_text': 'test',
         'span1_index': 3,
         'span2_text': 'This',
         'span2_index': 0,
         'label': 0
     }
     This example would be transformed to
     {
         'inputs': 'wsc text: # This # is a * test * sentence .',
         'targets': 'False'
     }
    """
    name = "superglue-wsc.fixed"
    labels_list = ['0', '1']
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'wsc.fixed', split=split, script_version="master")

    def _mark_span(self, text, span_str, span_idx, mark):
        pattern_tmpl = r'^((?:\S+\s){N})(W)'
        pattern = re.sub('N', str(span_idx), pattern_tmpl)
        pattern = re.sub('W', span_str, pattern)
        return re.sub(pattern, r'\1{0} \2 {0}'.format(mark), text)

    def preprocessor(self, example, add_prefix=True):
        # converts text as done in T5.
        text = example['text']
        # text = self._mark_span(
        #     text, example['span1_text'], example['span1_index'], '*')
        # # Compensate for 2 added "words" added in previous step.
        # span2_index = example['span2_index'] + 2 * \
        #     int(example['span1_index'] < example['span2_index'])
        # text = self._mark_span(text, example['span2_text'], span2_index, '#')
        src_texts = ["text:", text + example['span1_text'] + ' is ' + example['span2_text'] + ppt_question]
        tgt_texts = [str(example["label"])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)


class SuperGLUERecord(AbstractTask):
    """Convert ReCoRD examples to text2text examples.
    ReCoRD contains a passage, query containing a '@placeholder' string, and a set
    of entities that are the possible values of the placeholder. Each train and
    validation example will have a list of answers, any of which would be
    considered correct.
    For example, a typical example from ReCoRD might look like
    {
      'passsage': 'This is the passage.',
      'query': 'A @placeholder is a bird.',
      'entities': ['penguin', 'potato', 'pigeon'],
      'answers': ['penguin', 'pigeon'],
    }
    which this preprocessor would turn into the following two examples:
    {
      'inputs': 'record query: A @placeholder is a bird. entities: penguin, '
                'potato, pigeon passage: This is the passage.',
      'targets': 'penguin',
    }
    and
    {
      'inputs': 'record query: A @placeholder is a bird. entities: penguin, '
                'potato, pigeon passage: This is the passage.',
      'targets': 'pigeon',
    }
    """
    name = "superglue-record"
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metric = [metrics.squad]
    metric_names = ["squad"]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'record', split=split, script_version="master")

    def preprocessor(self, batch, add_prefix=True):
        new_batch = collections.defaultdict(list)
        keys = batch.keys()
        for values in zip(*batch.values()):
            ex = {k: v for k, v in zip(keys, values)}
            # updates the passage.
            passage = ex['passage']
            passage = re.sub(
                r'(\.|\?|\!|\"|\')\n@highlight\n', r'\1 ', passage)
            passage = re.sub(r'\n@highlight\n', '. ', passage)
            inputs = f"record query: {ex['query']} entities: {', '.join(ex['entities'])} passage: {passage}. The answer is "
            if add_prefix:
                inputs = self.name + " " + inputs
            # duplicates the samples based on  number of answers.
            num_answers = len(ex["answers"])
            num_duplicates = np.maximum(1, num_answers)
            new_batch["source"].extend([inputs] * num_duplicates)
            new_batch["target"].extend(
                ex["answers"] if num_answers > 0 else ["<unk>"])
            new_batch["task"].extend([self.name] * num_duplicates)
            new_batch["extra_fields"].extend(
                [{"answers": ex["answers"]}]*num_duplicates)
        return new_batch

    def map_dataset(self, dataset, add_prefix=True):
        return dataset.map(functools.partial(self.preprocessor, add_prefix=add_prefix),
                           batched=True, remove_columns=dataset.column_names)

class SuperGLUERecord_ppt(AbstractTask):
    """Convert ReCoRD examples to text2text examples.
    ReCoRD contains a passage, query containing a '@placeholder' string, and a set
    of entities that are the possible values of the placeholder. Each train and
    validation example will have a list of answers, any of which would be
    considered correct.
    For example, a typical example from ReCoRD might look like
    {
      'passsage': 'This is the passage.',
      'query': 'A @placeholder is a bird.',
      'entities': ['penguin', 'potato', 'pigeon'],
      'answers': ['penguin', 'pigeon'],
    }
    which this preprocessor would turn into the following two examples:
    {
      'inputs': 'record query: A @placeholder is a bird. entities: penguin, '
                'potato, pigeon passage: This is the passage.',
      'targets': 'penguin',
    }
    and
    {
      'inputs': 'record query: A @placeholder is a bird. entities: penguin, '
                'potato, pigeon passage: This is the passage.',
      'targets': 'pigeon',
    }
    """
    name = "superglue-record"
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metric = [metrics.squad]
    metric_names = ["squad"]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'record', split=split, script_version="master")

    def preprocessor(self, batch, add_prefix=True):
        new_batch = collections.defaultdict(list)
        keys = batch.keys()
        for values in zip(*batch.values()):
            ex = {k: v for k, v in zip(keys, values)}
            # updates the passage.
            passage = ex['passage']
            passage = re.sub(
                r'(\.|\?|\!|\"|\')\n@highlight\n', r'\1 ', passage)
            passage = re.sub(r'\n@highlight\n', '. ', passage)
            inputs = f"record query: {ex['query']} entities: {', '.join(ex['entities'])} passage: {passage}"
            if add_prefix:
                inputs = self.name + " " + inputs
            # duplicates the samples based on  number of answers.
            num_answers = len(ex["answers"])
            input_lst, answer_lst = [], []
            for i, answer in enumerate(ex["answers"]):
                random_idx = random.randint(0,2)
                random_answers = [e for e in ex['entities'] if e not in ex["answers"]]
                random_answer = [random_answers[random.randint(0, len(random_answers)-1)], random_answers[random.randint(0, len(random_answers)-1)]]
                # print('inputs', inputs)
                # print('random_answer', random_answer)
                if random_idx == 0:
                    inputs_new = inputs + ' A. ' + answer + '. B. ' + random_answer[0] + '. C. ' + random_answer[1] + '.'
                    answer_lst.append('A')
                elif random_idx == 1:
                    inputs_new = inputs + 'A. ' + random_answer[0] + '. B ' + answer + '. C. ' + random_answer[1] + '.'
                    answer_lst.append('B')
                else:
                    inputs_new = inputs + ' A. ' + random_answer[1] + '. B. ' + random_answer[0] + '. C. ' + answer + '.'
                    answer_lst.append('C')
                input_lst.append(inputs_new)
            if len(input_lst) == 0:
                input_lst.append(inputs)

            num_duplicates = np.maximum(1, num_answers)
            # new_batch["source"].extend([inputs] * num_duplicates)
            new_batch["source"].extend(input_lst)
            # new_batch["target"].extend(
            #     ex["answers"] if num_answers > 0 else ["<unk>"])
            new_batch["target"].extend(
                answer_lst if len(answer_lst) > 0 else ["<unk>"])

            new_batch["task"].extend([self.name] * num_duplicates)
            new_batch["extra_fields"].extend(
                [{"answers": answer_lst}] * num_duplicates)
            # print('new_batch["extra_fields"]', new_batch["extra_fields"])
        new_batch["extra_fields"] = [{}] * len(new_batch["task"])

        # print('new_batch', new_batch)
        return new_batch

    def map_dataset(self, dataset, add_prefix=True):
        return dataset.map(functools.partial(self.preprocessor, add_prefix=add_prefix),
                           batched=True, remove_columns=dataset.column_names)


class IMDB(AbstractTask):
    name = "imdb"
    labels_list = ['0', '1']
    split_to_data_split = {"train": "train",
                           "validation": "test"}
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]

    def load_dataset(self, split):
        if split == "validation":
            split = "test"
        return datasets.load_dataset('imdb', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence:", example['text']]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)


class WinoGrande(AbstractTask):
    name = "winogrande"
    labels_list = ['0', '1']
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]

    def load_dataset(self, split):
        return datasets.load_dataset('winogrande', "winogrande_xl", split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence:", example["sentence"],
                     "option0:", example["option1"],
                     "option1:", example["option1"]]
        tgt_texts = [str(int(example["answer"]) - 1)]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)


class PAWS(AbstractTask):
    name = "paws"
    labels_list = ["0", "1"]
    metric = [metrics.accuracy]
    metric_names = ["accuracy"]
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "test"}

    def load_dataset(self, split):
        return datasets.load_dataset('paws', 'labeled_final', split=split, script_version="master")

    def preprocessor(self, example, add_prefix=True):
        src_texts = ["sentence1:", example['sentence1'],
                     "sentence2:", example["sentence2"]]
        tgt_texts = [str(example['label'])]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)


TASK_MAPPING = OrderedDict(
    [
        ('squad', Squad),
        ('squad_ppt', Squad_ppt),
        ('mrpc', MRPC),
        ('mrpc_ppt', MRPC_ppt),
        ('cola', COLA),
        ('sst2', SST2),
        ('sst2_ppt', SST2_ppt),
        ('qnli', QNLI),
        ('qnli_ppt', QNLI_ppt),
        ('rte', RTE),
        ('wnli', WNLI),
        ('wnli_ppt', WNLI_ppt),
        ('mnli', MNLI),
        ('mnli_ppt', MNLI_ppt),
        ('qqp', QQP),
        ('qqp_ppt', QQP_ppt),
        ('stsb', STSB),
        ('superglue-boolq', SuperGLUEBoolQ),
        ('superglue-boolq_ppt', SuperGLUEBoolQ_ppt),
        ('superglue-rte', SuperGLUERTE),
        ('superglue-rte_ppt', SuperGLUERTE_ppt),
        ('superglue-cb', SuperGLUECB),
        ('superglue-cb_ppt', SuperGLUECB),
        ('superglue-copa', SuperGLUECOPA),
        ('superglue-multirc', SuperGLUEMultiRC),
        ('superglue-multirc_ppt', SuperGLUEMultiRC_ppt),
        ('superglue-wic', SuperGLUEWIC),
        ('superglue-wic_ppt', SuperGLUEWIC_ppt),
        ('superglue-wsc.fixed', SuperGLUEWSCFixed),
        ('superglue-wsc.fixed_ppt', SuperGLUEWSCFixed_ppt),
        ('superglue-record', SuperGLUERecord),
        ('record_ppt', SuperGLUERecord_ppt),
        ('multi_nli', MultiNLI),
        ('snli', SNLI),
        ('newsqa', Squad),
        ('searchqa', Squad),
        ('triviaqa', Squad),
        ('nq', Squad),
        ('hotpotqa', Squad),
        ('anli', ANLI),
        ("imdb", IMDB),
        ("winogrande", WinoGrande),
        ("scitail", SciTail),
        ('yelp_polarity', YelpPolarity),
        ('amazon_polarity', Amazon_Polarity),
        ('paws', PAWS),
    ]
)


class AutoTask:
    @classmethod
    def get(self, task, config, seed=42):
        if task in TASK_MAPPING:
            return TASK_MAPPING[task](config, seed)
        raise ValueError(
            "Unrecognized task {} for AutoTask Model: {}.\n"
            "Task name should be one of {}.".format(
                ", ".join(c for c in TASK_MAPPING.keys())
            )
        )
