# huggingface datasets
import datasets
import random
from collections import Counter
import os


VALID_KEY_LIST = {
	"app_reviews": ['review', 'star'],
	"aslg_pc12": ['text', 'gloss'],
	"break_data***QDMR": ['question_text', 'decomposition'],
	"break_data***QDMR-high-level": ['question_text', 'decomposition'],
	"e2e_nlg_cleaned": ['meaning_representation', 'human_reference'],
	"gigaword": ["document", "summary"],
	"jeopardy": ['question', 'category', 'answer'],
	"multi_news": ['document', 'summary'],
	"reddit_tifu***long": ['documents', 'tldr', 'title'],
	"samsum": ['dialogue', 'summary'],
	"scitail***snli_format": ['sentence1', 'sentence2', 'gold_label'],
	"search_qa***train_test_val": ['question', 'category', 'answer'],
	"spider": ['question', 'query'],
	"xsum": ["document", "summary"],
	"wiki_split": ['simple_sentence_1', 'simple_sentence_2', 'complex_sentence'],
	"ade_corpus_v2***Ade_corpus_v2_drug_dosage_relation": ['text', 'drug', 'dosage'],
	"ade_corpus_v2***Ade_corpus_v2_drug_ade_relation": ['text', 'drug', 'effect'],
}

NLU_KEY_LIST = {
	"app_reviews": ['star'],
	"aslg_pc12": ['text'],
	"break_data***QDMR": ['decomposition'],
	"break_data***QDMR-high-level": ['decomposition'],
	"e2e_nlg_cleaned": ['human_reference'],
	"gigaword": ["summary"],
	"jeopardy": ['answer'],
	"multi_news": ['summary'],
	"reddit_tifu***long": ['tldr', 'title'],
	"samsum": ['summary'],
	"scitail***snli_format": ['gold_label'],
	"search_qa***train_test_val": ['answer'],
	"spider": ['question'],
	"xsum": ["summary"],
	"wiki_split": ['complex_sentence'],
	"ade_corpus_v2***Ade_corpus_v2_drug_dosage_relation": ['dosage'],
	"ade_corpus_v2***Ade_corpus_v2_drug_ade_relation": ['effect'],
}

TASK_TYPE = {
	"app_reviews": "sentence classification",
	"aslg_pc12": "text generation",
	"break_data***QDMR": "text to data",
	"break_data***QDMR-high-level": "text to data",
	"e2e_nlg_cleaned": "data to text",
	"gigaword": "text summarization",
	"jeopardy": "question and answer",
	"multi_news": "text summarization",
	"reddit_tifu***long": "text summarization",
	"samsum": "text summarization",
	"scitail***snli_format": "entailment",
	"search_qa***train_test_val": "question and answer",
	"spider": "text to data",
	"xsum": "text summarization",
	"wiki_split": "text generation",
	"ade_corpus_v2***Ade_corpus_v2_drug_dosage_relation": "text generation",
	"ade_corpus_v2***Ade_corpus_v2_drug_ade_relation": "text generation",
}

CACHE_DIR = "huggingface_dataset_meta_task"

class DefaultNLPTask:

	def __init__(self, task_name):
		self.task_name = task_name

	def load_data(self):
		os.makedirs(CACHE_DIR, exist_ok=True)
		if '***' in self.task_name:
			names = self.task_name.split('***')
			full_data = datasets.load_dataset(names[0], names[1], cache_dir=CACHE_DIR)
		else:
			full_data = datasets.load_dataset(self.task_name, cache_dir=CACHE_DIR)

		self.train_data = [d for d in self.prune([d for d in full_data[self.process_split_name("train")]], "train") if not self.discard_data(d)]
		raw_val_data = self.prune([d for d in full_data[self.process_split_name("validation")]], "validation")
		raw_val_data = raw_val_data[:int(0.05 * len(raw_val_data))]
		self.val_data = [d for d in raw_val_data if not self.discard_data(d)]
		self.pre_load_train_data = None

	def get_nlu_keys(self):
		if self.task_name in NLU_KEY_LIST:
			return NLU_KEY_LIST[self.task_name]
		return []

	def get_task_type(self):
		if self.task_name in TASK_TYPE:
			return TASK_TYPE[self.task_name]
		return "NLP Task"

	def prune(self, data_list, split):
		return data_list

	def get_split(self, split):
		return self.train_data if split == "train" else self.val_data

	def discard_data(self, query_datapoint):
		return False

	def process_split_name(self, split):
		return split

	def get_valid_keys(self):
		if self.task_name in VALID_KEY_LIST:
			return VALID_KEY_LIST[self.task_name]
		return list(self.train_data[0].keys())

	def get_value_from_key(self, query_datapoint, key):
		return str(query_datapoint[key]).strip()

class TrainSplitNLPTask:

	def __init__(self, task_name):
		self.task_name = task_name

	def load_data(self):
		os.makedirs(CACHE_DIR, exist_ok=True)
		if '***' in self.task_name:
			names = self.task_name.split('***')
			full_data = datasets.load_dataset(names[0], names[1], cache_dir=CACHE_DIR)
		else:
			full_data = datasets.load_dataset(self.task_name, cache_dir=CACHE_DIR)
		full_data = [d for d in self.prune([d for d in full_data[self.process_split_name('train')]], "train") if not self.discard_data(d)]
		random.shuffle(full_data)
		n = int(len(full_data) * 0.98)
		self.train_data = full_data[:n]
		self.val_data = full_data[n:]
		self.pre_load_train_data = None

	def get_nlu_keys(self):
		if self.task_name in NLU_KEY_LIST:
			return NLU_KEY_LIST[self.task_name]
		return []

	def get_task_type(self):
		if self.task_name in TASK_TYPE:
			return TASK_TYPE[self.task_name]
		return "NLP Task"

	def prune(self, data_list, split):
		return data_list
	
	def get_split(self, split):
		return self.train_data if split == "train" else self.val_data

	def discard_data(self, query_datapoint):
		return False

	def process_split_name(self, split):
		return split

	def get_valid_keys(self):
		if self.task_name in VALID_KEY_LIST:
			return VALID_KEY_LIST[self.task_name]
		return list(self.train_data[0].keys())

	def get_value_from_key(self, query_datapoint, key):
		return str(query_datapoint[key]).strip()


class AdeClassification(TrainSplitNLPTask):

	def __init__(self, task_name):
		super(AdeClassification, self).__init__(task_name)
		self.label = {
			0: "Not Related",
			1: "Related",
		}

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'text':
			return str(query_datapoint[key])
		else:
			return self.label[query_datapoint[key]]

class AdversarialQA(DefaultNLPTask):

	def get_valid_keys(self):
		return ['question', 'context', 'answers']

	def get_task_type(self):
		return "question and answer"

	def get_nlu_keys(self):
		return ['answers']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'answers':
			return query_datapoint[key]['text'][0]
		else:
			return str(query_datapoint[key])


class AESLC(DefaultNLPTask):

	def get_valid_keys(self):
		return ['email body', 'subject line']

	def get_task_type(self):
		return "text summarization"

	def get_nlu_keys(self):
		return ['subject line']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'email body':
			return query_datapoint['email_body']
		else:
			return query_datapoint['subject_line']

class AdeClassification(TrainSplitNLPTask):

	def __init__(self, task_name):
		super(AdeClassification, self).__init__(task_name)
		self.label = {
			0: "Not Related",
			1: "Related",
		}

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'text':
			return str(query_datapoint[key])
		else:
			return self.label[query_datapoint[key]]

class AGNewsClassification(DefaultNLPTask):

	def __init__(self, task_name):
		super(AGNewsClassification, self).__init__(task_name)
		self.label = {
            0: "World",
            1: "Sports",
            2: "Business",
            3: "Sci/Tech",
        }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def process_split_name(self, split):
		if split == "validation":
			return "test"
		return split

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'text':
			return str(query_datapoint[key])
		else:
			return self.label[query_datapoint[key]]

class ARCChallengeChoices(DefaultNLPTask):

	def get_valid_keys(self):
		return ['question', 'choices', 'answer']

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'question':
			return str(query_datapoint[key])
		elif key == 'choices':
			choices_string = ""
			for i in range(len(query_datapoint["choices"]["label"])):
				choices_string += " (" + query_datapoint["choices"]["label"][i] + ") " + query_datapoint["choices"]["text"][i]
			return choices_string
		else:
			answer_index = query_datapoint["answerKey"]
			for i in range(len(query_datapoint["choices"]["label"])):
				if query_datapoint["choices"]["label"][i] == answer_index:
					return query_datapoint["choices"]["text"][i]
			raise ValueError("answer key not found")
			
class AmazonPolarity(DefaultNLPTask):

	def __init__(self, task_name):
		super(AmazonPolarity, self).__init__(task_name)
		self.label = {
            0: "negative",
            1: "positive",
        }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def prune(self, data_list, split):
		random.shuffle(data_list)
		n = int(0.1 * len(data_list))
		return data_list[:n]

	def process_split_name(self, split):
		if split == "validation":
			return "test"
		return split

	def get_valid_keys(self):
		return ['title', 'content', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'label':
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key])

class ANLI(DefaultNLPTask):

	def __init__(self, task_name):
		super(ANLI, self).__init__(task_name)
		self.label = {
            0: "entailment",
            1: "neutral",
            2: "contradiction",
        }

	def get_task_type(self):
		return "entailment"

	def get_nlu_keys(self):
		return ['label']

	def process_split_name(self, split_name):
		if split_name == "validation":
			split_name = "dev"
		split_name = split_name + "_r1"
		return split_name

	def get_valid_keys(self):
		return ['premise', 'hypothesis', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'label':
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key])

class AquaRat(DefaultNLPTask):

	def get_valid_keys(self):
		return ['question', 'answer', 'options']

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == "question":
			return str(query_datapoint[key])
		elif key == "answer":
			answer_index = query_datapoint["correct"]
			for option in query_datapoint["options"]:
				if option[0] == answer_index:
					return option[2:]
			raise ValueError("answer key not found")
		else:
			choices_string = ""
			for option in query_datapoint["options"]:
				choices_string += " (" + option[0:2] + " " + option[2:]
			return choices_string

class ART(DefaultNLPTask):

	def __init__(self, task_name):
		super(ART, self).__init__(task_name)
		self.label = {
            1: "hypothesis 1",
            2: "hypothesis 2",
        }

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['observation_1', 'observation_2', 'hypothesis_1', 'hypothesis_2', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'label':
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key])

class Circa(TrainSplitNLPTask):
	def __init__(self, task_name):
		super(Circa, self).__init__(task_name)
		self.label = {
			0: "Yes",
			1: "No",
			2: "In the middle, neither yes nor no",
			3: "Yes, subject to some conditions",
			4: "Other",
		}

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['gold_label']

	def discard_data(self, query_datapoint):
		if query_datapoint['goldstandard2'] == -1:
			return True
		return False

	def process_split_name(self, split):
		return split

	def get_valid_keys(self):
		return ['context', 'question-X', 'answer-Y', 'gold_label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'gold_label':
			return self.label[query_datapoint['goldstandard2']]
		else:
			return str(query_datapoint[key])

class ClimateFever(TrainSplitNLPTask):
	def __init__(self, task_name):
		super(ClimateFever, self).__init__(task_name)
		self.label = {
	        0: "Supports",
	        1: "Refutes",
	        2: "Not enough info",
	        3: "Disputed",
	    }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['claim_label']

	def process_split_name(self, split):
		if split == "train":
			return "test"

	def get_valid_keys(self):
		return ['claim', 'claim_label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'claim_label':
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key])

class CODAH(DefaultNLPTask):
	def __init__(self, task_name):
		super(CODAH, self).__init__(task_name)
		self.label = {
	        0: "(A)", 
	        1: "(B)", 
	        2: "(C)", 
	        3: "(D)"
	    }

	def get_valid_keys(self):
		return ['question_propmt', 'correct_answer', 'candidate_answers']

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['correct_answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == "question_propmt":
			return str(query_datapoint[key])
		elif key == "candidate_answers":
			choices_string = ""
			for idx, candidate in enumerate(query_datapoint["candidate_answers"]):
				choices_string += " " + self.label[idx] + " " + candidate
			return choices_string
		else:
			answer_index = query_datapoint["correct_answer_idx"]
			for idx, candidate in enumerate(query_datapoint["candidate_answers"]):
				if idx == answer_index:
					return candidate
			raise ValueError("answer key not found")

class CommonGen(DefaultNLPTask):
	def get_task_type(self):
		return "data to text"

	def get_nlu_keys(self):
		return ['target']

	def get_valid_keys(self):
		return ['concepts', 'target']

	def get_value_from_key(self, query_datapoint, key):
		if key == "concepts":
			return ', '.join(query_datapoint[key])
		else:
			return str(query_datapoint[key])


class CommonsenseQA(DefaultNLPTask):
	def __init__(self, task_name):
		super(CommonsenseQA, self).__init__(task_name)
		self.label = {
	        0: "(A)", 
	        1: "(B)", 
	        2: "(C)", 
	        3: "(D)"
	    }

	def get_valid_keys(self):
		return ['question', 'choices', 'answer']

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == "question":
			return str(query_datapoint[key])
		elif key == "choices":
			choices_string = ""
			for i in range(len(query_datapoint["choices"]["label"])):
				choices_string += " (" + query_datapoint["choices"]["label"][i] + ") " + query_datapoint["choices"]["text"][i]
			return choices_string
		else:
			answer_index = query_datapoint["answerKey"]
			for i in range(len(query_datapoint["choices"]["label"])):
				if query_datapoint["choices"]["label"][i] == answer_index:
					return query_datapoint["choices"]["text"][i]
			raise ValueError("answer key not found")

class CoS_E(DefaultNLPTask):
	def __init__(self, task_name):
		super(CoS_E, self).__init__(task_name)
		self.label = {
	        0: "(A)", 
	        1: "(B)", 
	        2: "(C)", 
	        3: "(D)",
	        4: "(E)"
	    }

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answer']

	def get_valid_keys(self):
		return ['question', 'choices', 'answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == "choices":
			choices_string = ""
			for idx, candidate in enumerate(query_datapoint["choices"]):
				choices_string += " " + self.label[idx] + " " + candidate
			return choices_string
		else:
			return str(query_datapoint[key])

class CosmosQA(DefaultNLPTask):
	def __init__(self, task_name):
		super(CosmosQA, self).__init__(task_name)
		self.label = {
	        0: "(A)", 
	        1: "(B)", 
	        2: "(C)", 
	        3: "(D)",
	        4: "(E)"
	    }

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answer']

	def get_valid_keys(self):
		return ['question', 'context', 'choices', 'answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == "choices":
			choices_string = ""
			for idx in range(4):
				choices_string += " " + self.label[idx] + " " + query_datapoint["answer" + str(idx)]
			return choices_string
		elif key == "answer":
			answer_key = "answer" + str(query_datapoint["label"])
			return query_datapoint[answer_key]
		else:
			return str(query_datapoint[key])

class DBpedia14(DefaultNLPTask):
	def __init__(self, task_name):
		super(DBpedia14, self).__init__(task_name)
		self.label = {
			0:"Company",
			1:"EducationalInstitution",
			2:"Artist",
			3:"Athlete",
			4:"OfficeHolder",
			5:"MeanOfTransportation",
			6:"Building",
			7:"NaturalPlace",
			8:"Village",
			9:"Animal",
			10:"Plant",
			11:"Album",
			12:"Film",
			13:"WrittenWork",
		}

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def prune(self, data_list, split):
		if split == "validation":
			random.shuffle(data_list)
			n = int(0.05 * len(data_list))
			return data_list[:n]
		return data_list

	def process_split_name(self, split):
		if split == "validation":
			return "test"
		return split

	def get_valid_keys(self):
		return ['content', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'label':
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key])

class DefinitePronounResolution(DefaultNLPTask):

	def process_split_name(self, split):
		if split == "validation":
			return "test"
		return split

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answer']

	def get_valid_keys(self):
		return ['sentence', 'pronoun', 'candidates', 'answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'candidates':
			candidates_str = "(A) %s (B) %s" % (query_datapoint['candidates'][0], query_datapoint['candidates'][1])
			return candidates_str
		elif key == "answer":
			return query_datapoint["candidates"][query_datapoint["label"]]
		else:
			return str(query_datapoint[key])

class Discovery(DefaultNLPTask):

	def __init__(self, task_name):
		super(Discovery, self).__init__(task_name)
		self.labels = [
			"[no-conn]",
			"absolutely,",
			"accordingly",
			"actually,",
			"additionally",
			"admittedly,",
			"afterward",
			"again,",
			"already,",
			"also,",
			"alternately,",
			"alternatively",
			"although,",
			"altogether,",
			"amazingly,",
			"and",
			"anyway,",
			"apparently,",
			"arguably,",
			"as_a_result,",
			"basically,",
			"because_of_that",
			"because_of_this",
			"besides,",
			"but",
			"by_comparison,",
			"by_contrast,",
			"by_doing_this,",
			"by_then",
			"certainly,",
			"clearly,",
			"coincidentally,",
			"collectively,",
			"consequently",
			"conversely",
			"curiously,",
			"currently,",
			"elsewhere,",
			"especially,",
			"essentially,",
			"eventually,",
			"evidently,",
			"finally,",
			"first,",
			"firstly,",
			"for_example",
			"for_instance",
			"fortunately,",
			"frankly,",
			"frequently,",
			"further,",
			"furthermore",
			"generally,",
			"gradually,",
			"happily,",
			"hence,",
			"here,",
			"historically,",
			"honestly,",
			"hopefully,",
			"however",
			"ideally,",
			"immediately,",
			"importantly,",
			"in_contrast,",
			"in_fact,",
			"in_other_words",
			"in_particular,",
			"in_short,",
			"in_sum,",
			"in_the_end,",
			"in_the_meantime,",
			"in_turn,",
			"incidentally,",
			"increasingly,",
			"indeed,",
			"inevitably,",
			"initially,",
			"instead,",
			"interestingly,",
			"ironically,",
			"lastly,",
			"lately,",
			"later,",
			"likewise,",
			"locally,",
			"luckily,",
			"maybe,",
			"meaning,",
			"meantime,",
			"meanwhile,",
			"moreover",
			"mostly,",
			"namely,",
			"nationally,",
			"naturally,",
			"nevertheless",
			"next,",
			"nonetheless",
			"normally,",
			"notably,",
			"now,",
			"obviously,",
			"occasionally,",
			"oddly,",
			"often,",
			"on_the_contrary,",
			"on_the_other_hand",
			"once,",
			"only,",
			"optionally,",
			"or,",
			"originally,",
			"otherwise,",
			"overall,",
			"particularly,",
			"perhaps,",
			"personally,",
			"plus,",
			"preferably,",
			"presently,",
			"presumably,",
			"previously,",
			"probably,",
			"rather,",
			"realistically,",
			"really,",
			"recently,",
			"regardless,",
			"remarkably,",
			"sadly,",
			"second,",
			"secondly,",
			"separately,",
			"seriously,",
			"significantly,",
			"similarly,",
			"simultaneously",
			"slowly,",
			"so,",
			"sometimes,",
			"soon,",
			"specifically,",
			"still,",
			"strangely,",
			"subsequently,",
			"suddenly,",
			"supposedly,",
			"surely,",
			"surprisingly,",
			"technically,",
			"thankfully,",
			"then,",
			"theoretically,",
			"thereafter,",
			"thereby,",
			"therefore",
			"third,",
			"thirdly,",
			"this,",
			"though,",
			"thus,",
			"together,",
			"traditionally,",
			"truly,",
			"truthfully,",
			"typically,",
			"ultimately,",
			"undoubtedly,",
			"unfortunately,",
			"unsurprisingly,",
			"usually,",
			"well,",
			"yet,",
			]

	def prune(self, data_list, split):
		if split == "validation":
			random.shuffle(data_list)
			n = int(0.1 * len(data_list))
			return data_list[:n]
		return data_list

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['sentence 1', 'sentence 2', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "sentence 1":
			return query_datapoint['sentence1']
		elif key == "sentence 2":
			return query_datapoint['sentence2']
		else:
			return self.labels[query_datapoint['label']]

class Dream(DefaultNLPTask):
	def __init__(self, task_name):
		super(Dream, self).__init__(task_name)
		self.label = {
			0: "(A)", 
			1: "(B)", 
			2: "(C)", 
		}

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answer']

	def get_valid_keys(self):
		return ['question', 'dialogue', 'answer', 'choice']

	def get_value_from_key(self, query_datapoint, key):
		if key == "choice":
			choices_string = ""
			for idx in range(3):
				choices_string += " " + self.label[idx] + " " + query_datapoint["choice"][idx]
			return choices_string
		else:
			return str(query_datapoint[key])

class DuoRC(DefaultNLPTask):

	def discard_data(self, query_datapoint):
		if query_datapoint['plot'].startswith("This article's plot summary may be too long or excessively detailed. Please help improve it by removing unnecessary details and making it more concise."):
			return True

		if query_datapoint["no_answer"] == 1:
			return True

		return False

	def get_task_type(self):
		return "question and answer"

	def get_nlu_keys(self):
		return ['answers']

	def get_valid_keys(self):
		return ['question', 'context', 'answers']

	def get_value_from_key(self, query_datapoint, key):
		if key == "context":
			return query_datapoint["plot"].replace("\n", " ")
		elif key == "question":
			return str(query_datapoint['question'])
		else:
			return random.choice(query_datapoint['answers'])

class ELI5(DefaultNLPTask):


	def get_task_type(self):
		return "question and answer"

	def get_nlu_keys(self):
		return ['answers']

	def prune(self, data_list, split):
		if split == "validation":
			random.shuffle(data_list)
			n = int(len(data_list) * 0.1)
			return data_list[:n]
		return data_list

	def process_split_name(self, split):
		return split + "_eli5"

	def get_valid_keys(self):
		return ['title', 'selftext', 'answers']

	def get_value_from_key(self, query_datapoint, key):
		if key == "answers":
			return ' '.join([text.replace("\n", " ").replace("\r", " ").replace("\t", " ") for text in query_datapoint["answers"]["text"]])
		else:
			return query_datapoint[key].replace("\n", " ").replace("\r", " ").replace("\t", " ")

class Emo(DefaultNLPTask):
	def __init__(self, task_name):
		super(Emo, self).__init__(task_name)
		self.label = {
			0:"others",
            1:"happy",
            2:"sad",
            3:"angry",
		}

	def process_split_name(self, split):
		if split == "validation":
			return "test"
		return split

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key]).strip()

class Emotion(DefaultNLPTask):
	def __init__(self, task_name):
		super(Emotion, self).__init__(task_name)
		self.label = ["sadness", "joy", "love", "anger", "fear", "surprise"]

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key]).strip()

class EmpatheticDialogues(DefaultNLPTask):

	def discard_data(self, query_datapoint):
		if "hit:" in query_datapoint["utterance"]:
			return True
		return False

	def prune(self, data_list, split):
		if split == "validation":
			random.shuffle(data_list)
			n = int(0.2 * len(data_list))
			return data_list[:n]
		return data_list

	def get_task_type(self):
		return "Dialogue"

	def get_nlu_keys(self):
		return ['utterance']

	def get_valid_keys(self):
		return ['utterance', 'prompt', 'context']

	def get_value_from_key(self, query_datapoint, key):
		return str(query_datapoint[key]).strip().replace("_comma_", ",").replace("\n", " ").replace("\t", " ").replace("\r", " ")


class FinancialPhrasebank(TrainSplitNLPTask):

	def __init__(self, task_name):
		super(FinancialPhrasebank, self).__init__(task_name)
		self.label = {
	        0:"negative",
	        1:"neutral",
	        2:"positive",
	    }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['sentence', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key]).strip()

class FreebaseQA(DefaultNLPTask):

	def discard_data(self, query_datapoint):
		if "RawQuestion" not in query_datapoint or "Parses" not in query_datapoint:
			return True
		return False

	def get_task_type(self):
		return "question and answer"

	def get_nlu_keys(self):
		return ['answer']

	def get_valid_keys(self):
		return ['question', 'answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == "question":
			return query_datapoint["RawQuestion"]
		else:
			all_answers = []
			for item in query_datapoint["Parses"]["Answers"]:
				for answer_name in item["AnswersName"]:
					for what in answer_name:
						all_answers.append(what)
			all_answers = sorted(list(set(all_answers)))
			return random.choice(all_answers)

class GlueCola(DefaultNLPTask):

	def __init__(self, task_name):
		super(GlueCola, self).__init__(task_name)
		self.label = {
	        0: "unacceptable",
            1: "acceptable",
	    }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['sentence', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key]).strip()

class GlueMNLI(DefaultNLPTask):

	def __init__(self, task_name):
		super(GlueMNLI, self).__init__(task_name)
		self.label = {
	        0: "entailment",
            1: "neutral",
            2: "contradiction",
	    }

	def process_split_name(self, split):
		if split == "validation":
			return split + "_matched"
		return split

	def get_task_type(self):
		return "entailment"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['premise', 'label', 'hypothesis']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key]).strip()

class GlueMRPC(DefaultNLPTask):

	def __init__(self, task_name):
		super(GlueMRPC, self).__init__(task_name)
		self.label = {
	        0: "not_equivalent",
            1: "equivalent",
	    }

	def get_task_type(self):
		return "entailment"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['sentence1', 'label', 'sentence2']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key]).strip()

class GoogleWellformedQuery(DefaultNLPTask):

	def __init__(self, task_name):
		super(GoogleWellformedQuery, self).__init__(task_name)
		self.label = {
	        0: "not well-formed",
	        1: "well-formed"
	    }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def discard_data(self, query_datapoint):
		if query_datapoint['rating'] >= 0.4 and query_datapoint['rating'] <= 0.6:
			return True
		return False

	def get_valid_keys(self):
		return ['label', 'content']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			if query_datapoint['rating'] < 0.4:
				return self.label[0]
			else:
				return self.label[1]
		else:
			return str(query_datapoint[key]).strip()

class HateSpeech18(TrainSplitNLPTask):

	def __init__(self, task_name):
		super(HateSpeech18, self).__init__(task_name)
		self.label = {
	        0: "not Hate",
	        1: "hate",
	    }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def discard_data(self, query_datapoint):
		if query_datapoint['label'] > 1:
			return True
		return False

	def get_valid_keys(self):
		return ['label', 'text']


	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key]).strip()

class HateSpeechOffensive(TrainSplitNLPTask):

	def __init__(self, task_name):
		super(HateSpeechOffensive, self).__init__(task_name)
		self.label = {
            0:"hate speech",
            1:"offensive language",
            2:"neither",
        }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['class']

	def get_valid_keys(self):
		return ['class', 'tweet']


	def get_value_from_key(self, query_datapoint, key):
		if key == "class":
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key]).strip()

def get_majority(lst):
	c = Counter(lst)
	rank = c.most_common()
	if len(rank) == 1:
		return rank[0][0]
	elif rank[0][1] == rank[1][1]:
		return None
	else:
		return rank[0][0]

class HatExplain(DefaultNLPTask):

	def __init__(self, task_name):
		super(HatExplain, self).__init__(task_name)
		self.label = {
            0:"hatespeech",
            1:"normal",
            2:"offensive",
        }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['class']

	def discard_data(self, query_datapoint):
		label = get_majority(query_datapoint["annotators"]["label"])
		if label is None:
			return True
		return False

	def get_valid_keys(self):
		return ['class', 'text']


	def get_value_from_key(self, query_datapoint, key):
		if key == "class":
			label = get_majority(query_datapoint["annotators"]["label"])
			return self.label[label]
		else:
			return " ".join(query_datapoint["post_tokens"])

class HealthFact(DefaultNLPTask):

	def __init__(self, task_name):
		super(HealthFact, self).__init__(task_name)
		self.label = {
            0:"false",
            1:"mixture",
            2:"true",
            3:"unproven",
        }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def discard_data(self, query_datapoint):
		if query_datapoint["label"] < 0:
			return True
		return False

	def get_valid_keys(self):
		return ['claim', 'label']


	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return query_datapoint["claim"].strip().replace("\n", " ").replace("\r", " ").replace("\t", " ")


class HellaSwag(DefaultNLPTask):

	def __init__(self, task_name):
		super(HellaSwag, self).__init__(task_name)
		self.label = {0: "(A)", 1: "(B)", 2: "(C)", 3: "(D)", 4: "(E)"}

	def get_valid_keys(self):
		return ['context', 'candidates', 'answer']

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == "context":
			return query_datapoint["ctx"]
		elif key == "candidates":
			choices_string = ""
			for i in range(len(query_datapoint["endings"])):
				choices_string += " " + self.label[i] + " " + query_datapoint["endings"][i]
			return choices_string
		else:
			answer_index = int(query_datapoint["label"])
			for i in range(len(query_datapoint["endings"])):
				if i == answer_index:
					return query_datapoint["endings"][i]

class HotpotQA(DefaultNLPTask):

	def get_valid_keys(self):
		return ['context', 'question', 'answer']

	def get_task_type(self):
		return "question and answer"

	def get_nlu_keys(self):
		return ['answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == "context":
			titles = query_datapoint["supporting_facts"]["title"]
			context = ""
			for sentences, title in zip(query_datapoint["context"]["sentences"], query_datapoint["context"]["title"]):
				if title in titles:
					context += "".join(sentences) + " "
			return context
		elif key == "question":
			return query_datapoint["question"]
		else:
			return query_datapoint["answer"]


class IMDB(DefaultNLPTask):

	def __init__(self, task_name):
		super(IMDB, self).__init__(task_name)
		self.label = {
	        0: "negative",
	        1: "positive",
	    }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['label', 'text']

	def process_split_name(self, split):
		if split == "validation":
			return "test"
		return split


	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key]).strip()

class Kilt(DefaultNLPTask):

	def get_valid_keys(self):
		return ['input', 'output']

	def get_task_type(self):
		return "text generation"

	def get_nlu_keys(self):
		return ['output']


	def get_value_from_key(self, query_datapoint, key):
		if key == "output":
			return random.choice([item["answer"] for item in query_datapoint["output"]])
		else:
			return str(query_datapoint[key]).strip()



class MathQA(DefaultNLPTask):

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answer']

	def process_line(self, dp):
		options = dp["options"].split(",")
		choices = " (A) " + options[0][4:-1]
		if dp["correct"] == "a":
			answer = options[0][4:-1]
		choices += " (B) " + options[1][5:-1]
		if dp["correct"] == "b":
			answer = options[1][5:-1]
		choices += " (C) " + options[2][5:-1]
		if dp["correct"] == "c":
			answer = options[2][5:-1]
		choices += " (D) " + options[3][5:-1]
		if dp["correct"] == "d":
			answer = options[3][5:-1]
		choices += " (E) " + options[4][5:]
		if dp["correct"] == "e":
			answer = options[4][5:]

		return choices, answer

	def get_valid_keys(self):
		return ['question', 'options', 'answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == "question":
			return query_datapoint["Problem"]
		elif key == "options":
			choices, answer = self.process_line(query_datapoint)
			return choices
		else:
			choices, answer = self.process_line(query_datapoint)
			return answer

class Liar(DefaultNLPTask):

	def __init__(self, task_name):
		super(Liar, self).__init__(task_name)
		self.label = {
            0:"false",
            1:"half-true",
            2:"mostly-true",
            3:"true",
            4:"barely-true",
            5:"pants-fire",
        }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['statement', 'speaker', 'context', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return query_datapoint[key].strip().replace("\n", " ").replace("\r", " ").replace("\t", " ")

class Limit(DefaultNLPTask):

	def discard_data(self, query_datapoint):
		if not query_datapoint["motion"] == "yes":
			return True
		return False

	def process_split_name(self, split):
		if split == "validation":
			return "test"
		return split

	def get_valid_keys(self):
		return ['sentence', 'motion_entities']

	def get_task_type(self):
		return "sequence labelling"

	def get_nlu_keys(self):
		return ['motion_entities']

	def get_value_from_key(self, query_datapoint, key):
		if key == "motion_entities":
			return " & ".join([item["entity"] for item in query_datapoint["motion_entities"]])
		else:
			return query_datapoint[key].strip()

class MCTACO(TrainSplitNLPTask):

	def __init__(self, task_name):
		super(MCTACO, self).__init__(task_name)
		self.label = {
            0: "no",
            1: "yes",
        }

	def process_split_name(self, split):
		if split == "train":
			return "validation"
		else:
			return "test"

	def get_valid_keys(self):
		return ['sentence', 'question', 'answer', 'label']

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']


	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key]).strip()

class MedicalQuestionPairs(TrainSplitNLPTask):

	def __init__(self, task_name):
		super(MedicalQuestionPairs, self).__init__(task_name)
		self.label = {
            0: "Similar",
            1: "Dissimilar",
        }

	def get_valid_keys(self):
		return ['question_1', 'question_2', 'label']

	def get_task_type(self):
		return "entailment"

	def get_nlu_keys(self):
		return ['label']


	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key]).strip()


class Mocha(DefaultNLPTask):

	def discard_data(self, query_datapoint):
		if not query_datapoint["score"] % 1 == 0:
			return True
		return False

	def get_valid_keys(self):
		return ['question', 'context', 'reference', 'candidate', 'score']

	def get_task_type(self):
		return "text regression"

	def get_nlu_keys(self):
		return ['score']

	def get_value_from_key(self, query_datapoint, key):
		if key == "score":
			return str(int(query_datapoint[key]))
		else:
			return query_datapoint[key].strip()

class OneStopEnglish(TrainSplitNLPTask):

	def __init__(self, task_name):
		super(OneStopEnglish, self).__init__(task_name)
		self.label = {
            0:"elementary",
            1:"intermediate",
            2:"advance",
        }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			text = query_datapoint["text"].replace("\n", " ")
			if text.startswith("Intermediate  "): # some bug?
				text = text[14:]
			return text.strip()

class OpenbookQA(DefaultNLPTask):
	def get_choices_and_answer_string(self, datapoint):
		answer_index = datapoint["answerKey"]
		choices_string = ""
		datapoint["choices"]["label"] = ["A", "B", "C", "D"]
		for i in range(len(datapoint["choices"]["label"])):
			if datapoint["choices"]["label"][i] == answer_index:
				answer_string = datapoint["choices"]["text"][i]
			choices_string += " (" + datapoint["choices"]["label"][i] + ") " + datapoint["choices"]["text"][i]
		return choices_string, answer_string

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answer']

	def get_valid_keys(self):
		return ['question', 'choices', 'answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == "question":
			return query_datapoint["question_stem"]
		else:
			choices_string, answer_string = self.get_choices_and_answer_string(query_datapoint)
			if key == "choices":
				return choices_string
			else:
				return answer_string

class PAWS(DefaultNLPTask):

	def __init__(self, task_name):
		super(PAWS, self).__init__(task_name)
		self.label = {
            0: "not duplicate",
            1: "duplicate",
        }

	def get_task_type(self):
		return "entailment"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['question 1', 'question 2', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		elif key == "question 1":
			return query_datapoint["sentence1"]
		else:
			return query_datapoint["sentence2"]

class PIQA(DefaultNLPTask):

	def __init__(self, task_name):
		super(PIQA, self).__init__(task_name)
		self.label = {
            0: "solution 1",
            1: "solution 2"
        }

	def get_valid_keys(self):
		return ['goal', 'solution 1', 'solution 2', 'label']

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['label']


	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		elif key == "goal":
			return query_datapoint["goal"]
		elif key == "solution 1":
			return query_datapoint["sol1"]
		else:
			return query_datapoint["sol2"]

class BioMRC(TrainSplitNLPTask):

	def prune(self, data_list, split):
		if split == "validation":
			random.shuffle(data_list)
			n = int(0.1 * len(data_list))
			return data_list[:n]
		return data_list

	def get_valid_keys(self):
		return ['title', 'abstract', 'answer']

	def get_task_type(self):
		return "reading comprehension"

	def get_nlu_keys(self):
		return ['answer']

class PoemSentiment(DefaultNLPTask):

	def __init__(self, task_name):
		super(PoemSentiment, self).__init__(task_name)
		self.label = {
            0:"negative",
            1:"positive",
            2:"no_impact",
            3:"mixed",
        }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return query_datapoint["verse_text"].strip()

class QASRL(DefaultNLPTask):

	def get_task_type(self):
		return "text generation"

	def get_nlu_keys(self):
		return ['answers']

	def get_valid_keys(self):
		return ['question', 'sentence', 'answers', 'predicate']

	def get_value_from_key(self, query_datapoint, key):
		if key in ['sentence', 'predicate']:
			return query_datapoint[key]
		elif key == "question":
			return ' '.join([x for x in query_datapoint[key] if not x == '_'])
		else:
			return random.choice(query_datapoint["answers"])



class QASC(DefaultNLPTask):
	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answer']

	def get_choices_and_answer_string(self, datapoint):
		answer_index = datapoint["answerKey"]
		choices_string = ""
		for i in range(len(datapoint["choices"]["label"])):
			if datapoint["choices"]["label"][i] == answer_index:
				answer_string = datapoint["choices"]["text"][i]
			choices_string += " (" + datapoint["choices"]["label"][i] + ") " + datapoint["choices"]["text"][i]
		return choices_string, answer_string

	def get_valid_keys(self):
		return ['question', 'choices', 'answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == "question":
			return query_datapoint["question"]
		else:
			choices_string, answer_string = self.get_choices_and_answer_string(query_datapoint)
			if key == "choices":
				return choices_string
			else:
				return answer_string

class QUAIL(DefaultNLPTask):
	def __init__(self, task_name):
		super(QUAIL, self).__init__(task_name)
		self.label = {0: "(A)", 1: "(B)", 2: "(C)", 3: "(D)"}

	def get_choices_and_answer_string(self, datapoint):
		answer_index = datapoint["correct_answer_id"]
		choices_string = ""
		for i, answer in enumerate(datapoint["answers"]):
			if i == answer_index:
				answer_string = datapoint["answers"][i]
			choices_string += self.label[i] + datapoint["answers"][i]
		return choices_string, answer_string

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answer']

	def get_valid_keys(self):
		return ['question', 'choices', 'answer', 'context']

	def get_value_from_key(self, query_datapoint, key):
		if key in ["question", "context"]:
			return query_datapoint[key]
		else:
			choices_string, answer_string = self.get_choices_and_answer_string(query_datapoint)
			if key == "choices":
				return choices_string
			else:
				return answer_string

class QUAREL(DefaultNLPTask):

	def get_answer_string(self, datapoint):
		answer_index = datapoint["answer_index"]
		st1 = datapoint["question"].find("(A)")
		st2 = datapoint["question"].find("(B)")

		if answer_index == 0:
			answer_string = datapoint["question"][st1+4: st2]
		else:
			answer_string = datapoint["question"][st2+4: ]

		if answer_string.endswith("or "):
			answer_string = answer_string[:-3]

		return answer_string

	def get_valid_keys(self):
		return ['question', 'answer']

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answer']

	def get_value_from_key(self, query_datapoint, key):
		if key in ['sentence']:
			return query_datapoint[key]
		else:
			return self.get_answer_string(query_datapoint).strip()

class Quartz(DefaultNLPTask):

	def get_choices_and_answer_string(self, datapoint):
		answer_index = datapoint["answerKey"]
		choices_string = ""
		for i in range(len(datapoint["choices"]["label"])):
			if datapoint["choices"]["label"][i] == answer_index:
				answer_string = datapoint["choices"]["text"][i]
			choices_string += " (" + datapoint["choices"]["label"][i] + ") " + datapoint["choices"]["text"][i]
		return choices_string, answer_string

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answer']

	def get_valid_keys(self):
		return ['question', 'answer', 'knowledge', 'choices']

	def get_value_from_key(self, query_datapoint, key):
		if key in ['sentence']:
			return query_datapoint[key]
		elif key == "knowledge":
			return query_datapoint["para"]
		else:
			choices_string, answer_string = self.get_choices_and_answer_string(query_datapoint)
			if key == "choices":
				return choices_string
			else:
				return answer_string

class Quoref(DefaultNLPTask):

	def get_task_type(self):
		return "question and answer"

	def get_nlu_keys(self):
		return ['answers']

	def get_valid_keys(self):
		return ['question', 'context', 'answers']

	def get_value_from_key(self, query_datapoint, key):
		if key in ['context', 'question']:
			return query_datapoint[key]
		else:
			return "\t".join(query_datapoint["answers"]["text"])

class Race(DefaultNLPTask):

	def __init__(self, task_name):
		super(Race, self).__init__(task_name)
		self.label = {0: "(A)", 1: "(B)", 2: "(C)", 3: "(D)"}

	def get_choices_and_answer_string(self, datapoint):
		answer_index = ord(datapoint["answer"]) - ord("A")
		choices_string = ""
		for i, ans in enumerate(datapoint["options"]):
			if i == answer_index:
				answer_string = ans.replace("\n", " ").replace("\t", " ").replace("\r", " ")
			choices_string += " " + self.label[i] + " " + ans.replace("\n", " ").replace("\t", " ").replace("\r", " ")
		return choices_string, answer_string

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answers']

	def get_valid_keys(self):
		return ['question', 'article', 'answers', 'choices']

	def get_value_from_key(self, query_datapoint, key):
		if key in ['question', 'article']:
			return query_datapoint[key]
		else:
			choices_string, answer_string = self.get_choices_and_answer_string(query_datapoint)
			if key == "choices":
				return choices_string
			else:
				return answer_string

class ROPES(DefaultNLPTask):


	def get_task_type(self):
		return "question and answer"

	def get_nlu_keys(self):
		return ['answers']

	def get_valid_keys(self):
		return ['question', 'situation', 'answers', 'background']

	def get_value_from_key(self, query_datapoint, key):
		if key == "answers":
			return "\t".join(query_datapoint["answers"]["text"])
		else:
			return query_datapoint[key]

class RottenTomatos(DefaultNLPTask):

	def __init__(self, task_name):
		super(RottenTomatos, self).__init__(task_name)
		self.label = {
	        0: "negative",
	        1: "positive",
	    }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return query_datapoint["text"]


class SciCite(DefaultNLPTask):

	def __init__(self, task_name):
		super(SciCite, self).__init__(task_name)
		self.label = {
            0:"method",
            1:"background",
            2:"result",
        }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['sentence', 'section name', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		elif key == "sentence":
			citation_marker = query_datapoint["string"][query_datapoint["citeStart"]: query_datapoint["citeEnd"]]
			sentence = query_datapoint["string"].replace(citation_marker, " [CITATION] ")
			return sentence
		else:
			return query_datapoint["sectionName"]

class SciQ(DefaultNLPTask):

	def __init__(self, task_name):
		super(SciQ, self).__init__(task_name)
		self.label = {0: "(A)", 1: "(B)", 2: "(C)", 3: "(D)"}

	def get_choices_and_answer_string(self, datapoint):
		answer_string = datapoint["correct_answer"]
		all_answers = [datapoint["distractor1"], datapoint["distractor2"], datapoint["distractor3"], answer_string]
		random.shuffle(all_answers)

		choices_string = ""
		for i, ans in enumerate(all_answers):
			choices_string += " " + self.label[i] + " " + ans
		return choices_string, answer_string

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answers']

	def get_valid_keys(self):
		return ['question', 'support', 'choices', 'answers']

	def get_value_from_key(self, query_datapoint, key):
		if key in ['question', 'support']:
			return query_datapoint[key]
		else:
			choices_string, answer_string = self.get_choices_and_answer_string(query_datapoint)
			if key == "choices":
				return choices_string
			else:
				return answer_string

class Sick(DefaultNLPTask):

	def __init__(self, task_name):
		super(Sick, self).__init__(task_name)
		self.label = {
            0: "entailment",
            1: "neutral",
            2: "contradiction",
        }

	def get_task_type(self):
		return "entailment"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['sentence_A', 'sentence_B', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return query_datapoint[key]

class SMSSpam(TrainSplitNLPTask):

	def __init__(self, task_name):
		super(SMSSpam, self).__init__(task_name)
		self.label = {
	        0:"ham",
	        1:"spam",
	    }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['sms', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return query_datapoint[key].strip()

class SocialIQA(DefaultNLPTask):

	def __init__(self, task_name):
		super(SocialIQA, self).__init__(task_name)
		self.label = {0: "(A)", 1: "(B)", 2: "(C)", 3: "(D)"}

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answers']

	def get_choices_and_answer_string(self, datapoint):
		answer_idx = int(datapoint["label"])

		choices_string = ""
		for i, ans in enumerate([datapoint["answerA"], datapoint["answerB"], datapoint["answerC"]]):
			if i == answer_idx-1:
				answer_string = ans
			choices_string += " " + self.label[i] + " " + ans
		return choices_string, answer_string


	def get_valid_keys(self):
		return ['question', 'context', 'choices', 'answers']

	def get_value_from_key(self, query_datapoint, key):
		if key in ['question', 'context']:
			return query_datapoint[key]
		else:
			choices_string, answer_string = self.get_choices_and_answer_string(query_datapoint)
			if key == "choices":
				return choices_string
			else:
				return answer_string

class SQuAD(DefaultNLPTask):

	def get_task_type(self):
		return "reading comprehension"

	def get_nlu_keys(self):
		return ['answer']

	def get_valid_keys(self):
		return ['question', 'context', 'answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == "answer":
			return query_datapoint["answers"]["text"][0]
		else:
			return query_datapoint[key].strip()

class Swag(DefaultNLPTask):

	def __init__(self, task_name):
		super(Swag, self).__init__(task_name)
		self.label = {0: "(A)", 1: "(B)", 2: "(C)", 3: "(D)"}

	def get_choices_and_answer_string(self, datapoint):
		answer_index = datapoint["label"]
		candidates = [datapoint["ending0"], datapoint["ending1"], datapoint["ending2"], datapoint["ending3"]]
		choices_string = ""
		for i, ending in enumerate(candidates):
			if i == answer_index:
				answer_string = ending
			choices_string += " " + self.label[i] + " " + ending
		return choices_string, answer_string

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answers']

	def get_valid_keys(self):
		return ['startphrase', 'choices', 'answers']

	def get_value_from_key(self, query_datapoint, key):
		if key in ['startphrase']:
			return query_datapoint[key]
		else:
			choices_string, answer_string = self.get_choices_and_answer_string(query_datapoint)
			if key == "choices":
				return choices_string
			else:
				return answer_string

class TabFact(DefaultNLPTask):

	def __init__(self, task_name):
		super(TabFact, self).__init__(task_name)
		self.label = {
	        0: "refuted",
	        1: "entailed",
	    }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['statement', 'table_caption', 'table_text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		elif key == "table_text":
			return query_datapoint["table_text"].replace("\n", " [n] ")
		else:
			return query_datapoint[key].strip()

class TREC(TrainSplitNLPTask):

	def __init__(self, task_name):
		super(TREC, self).__init__(task_name)
		self.label = {
            0:"DESC",
            1:"ENTY",
            2:"ABBR",
            3:"HUM",
            4:"NUM",
            5:"LOC",
        }

		self.f_label = {
        	0:"manner",
            1:"cremat",
            2:"animal",
            3:"exp",
            4:"ind",
            5:"gr",
            6:"title",
            7:"def",
            8:"date",
            9:"reason",
            10:"event",
            11:"state",
            12:"desc",
            13:"count",
            14:"other",
            15:"letter",
            16:"religion",
            17:"food",
            18:"country",
            19:"color",
            20:"termeq",
            21:"city",
            22:"body",
            23:"dismed",
            24:"mount",
            25:"money",
            26:"product",
            27:"period",
            28:"substance",
            29:"sport",
            30:"plant",
            31:"techmeth",
            32:"volsize",
            33:"instru",
            34:"abb",
            35:"speed",
            36:"word",
            37:"lang",
            38:"perc",
            39:"code",
            40:"dist",
            41:"temp",
            42:"symbol",
            43:"ord",
            44:"veh",
            45:"weight",
            46:"currency",
        }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label-coarse', 'label-fine']

	def get_valid_keys(self):
		return ['text', 'label-coarse', 'label-fine']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label-coarse":
			return self.label[query_datapoint[key]]
		elif key == "label-fine":
			return self.f_label[query_datapoint[key]]
		else:
			return query_datapoint[key].strip()

class TweetQA(DefaultNLPTask):

	def get_task_type(self):
		return "question and answer"

	def get_nlu_keys(self):
		return ['Answer']

	def get_valid_keys(self):
		return ['Question', 'Tweet', 'Answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == "Answer":
			return '\t'.join(query_datapoint[key])
		else:
			return query_datapoint[key]

class YelpReviewFull(DefaultNLPTask):

	def process_split_name(self, split):
		if split == "validation":
			return "test"
		return split

	def get_task_type(self):
		return "text regression"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "text":
			return query_datapoint["text"].replace("\\n", " ")
		else:
			return str(query_datapoint[key] + 1)

class WebQuestions(TrainSplitNLPTask):

	def get_task_type(self):
		return "question and answer"

	def get_nlu_keys(self):
		return ['answers']

	def get_valid_keys(self):
		return ['question', 'answers']

	def get_value_from_key(self, query_datapoint, key):
		if key == "answers":
			return '\t'.join(query_datapoint[key])
		else:
			return query_datapoint[key]



class WIQA(DefaultNLPTask):

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answers']

	def __init__(self, task_name):
		super(WIQA, self).__init__(task_name)
		self.label = {0: "(A)", 1: "(B)", 2: "(C)", 3: "(D)"}

	def get_choices_and_answer_string(self, datapoint):
		answer_idx = ord(datapoint["answer_label_as_choice"]) - ord("A")

		choices_string = ""
		for i, ans in enumerate(datapoint["choices"]["text"]):
			if i == answer_idx:
				answer_string = ans
			choices_string += " " + self.label[i] + " " + ans
		return choices_string, answer_string

	def get_valid_keys(self):
		return ['question', 'choices', 'answers', 'question_para_step']

	def get_value_from_key(self, query_datapoint, key):
		if key in ['question']:
			return query_datapoint["question_stem"]
		elif key == "question_para_step":
			return ' '.join(query_datapoint["question_para_step"])
		else:
			choices_string, answer_string = self.get_choices_and_answer_string(query_datapoint)
			if key == "choices":
				return choices_string
			else:
				return answer_string

class NumerSense(TrainSplitNLPTask):

	def get_task_type(self):
		return "text generation"

	def get_nlu_keys(self):
		return ['target']

	def get_valid_keys(self):
		return ['sentence', 'target']

	def get_value_from_key(self, query_datapoint, key):
		if key == "sentence":
			return query_datapoint[key].replace("<mask>", "[MASK]").strip()
		else:
			return query_datapoint[key].strip()

class WikiQA(TrainSplitNLPTask):

	def __init__(self, task_name):
		super(WikiQA, self).__init__(task_name)
		self.label = {
	        0: "false",
	        1: "true",
	    }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['question', 'answer', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == "label":
			return self.label[query_datapoint[key]]
		else:
			return query_datapoint[key].strip()


class YelpPolarity(DefaultNLPTask):

	def __init__(self, task_name):
		super(YelpPolarity, self).__init__(task_name)
		self.label = {
            0: "negative",
            1: "positive",
        }

	def prune(self, data_list, split):
		if split == "validation":
			random.shuffle(data_list)
			n = int(len(data_list) * 0.2)
			return data_list[:n]
		return data_list

	def process_split_name(self, split):
		if split == "validation":
			return "test"
		else:
			return split

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'label':
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key])


class Sentiment140(DefaultNLPTask):

	def __init__(self, task_name):
		super(Sentiment140, self).__init__(task_name)
		self.label = {
            0:"Negative",
            2:"Neutral",
            4:"Positive",
        }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['sentiment']

	def process_split_name(self, split):
		if split == "validation":
			return "test"
		else:
			return split

	def get_valid_keys(self):
		return ['text', 'sentiment']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'sentiment':
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key])


class YahooAnswersTopics(DefaultNLPTask):

	def __init__(self, task_name):
		super(YahooAnswersTopics, self).__init__(task_name)
		self.label = {
            0:"Society & Culture",
            1:"Science & Mathematics",
            2:"Health",
            3:"Education & Reference",
            4:"Computers & Internet",
            5:"Sports",
            6:"Business & Finance",
            7:"Entertainment & Music",
            8:"Family & Relationships",
            9:"Politics & Government",
        }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['topic']

	def process_split_name(self, split):
		if split == "validation":
			return "test"
		else:
			return split

	def get_valid_keys(self):
		return ['question_title', 'question_content', 'topic', 'best_answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'topic':
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key])

class WikiBio(DefaultNLPTask):

	def discard_data(self, query_datapoint):
		if not len(query_datapoint["input_text"]["table"]["column_header"]) == len(query_datapoint["input_text"]["table"]["content"]):
			return True
		return False

	def prune(self, data_list, split):
		if split == "validation":
			random.shuffle(data_list)
			n = int(0.01 * len(data_list))
			return data_list[:n]
		return data_list

	def get_task_type(self):
		return "text generation"

	def get_nlu_keys(self):
		return ['target_text']

	def process_split_name(self, split):
		if split == "validation":
			return "val"
		else:
			return split

	def make_input_text(self, datapoint):
		input_text = datapoint["input_text"]["context"].strip() + " and "
		for a, b in zip(datapoint["input_text"]["table"]["column_header"], datapoint["input_text"]["table"]["content"]):
			input_text += "{}: {} [n] ".format(a, b.strip().replace("\n", " "))
		return input_text


	def get_valid_keys(self):
		return ['input_text', 'target_text']


	def get_value_from_key(self, query_datapoint, key):
		if key == 'input_text':
			return self.make_input_text(query_datapoint)
		else:
			return query_datapoint[key]


class WikiQA(DefaultNLPTask):

	def __init__(self, task_name):
		super(WikiQA, self).__init__(task_name)
		self.label = {
            0: "False",
            1: "True",
        }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['question', 'answer', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'label':
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key])

class GlueQNLI(DefaultNLPTask):

	def __init__(self, task_name):
		super(GlueQNLI, self).__init__(task_name)
		self.label = {
            0: "entailment",
            1: "not entailment",
        }

	def get_task_type(self):
		return "entailment"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['question', 'sentence', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'label':
			return self.label[query_datapoint[key]]
		else:
			return str(query_datapoint[key])

class GlueQQP(DefaultNLPTask):

	def __init__(self, task_name):
		super(GlueQQP, self).__init__(task_name)
		self.label = {
            0: "not duplicate",
            1: "duplicate",
        }

	def get_task_type(self):
		return "entailment"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['question 1', 'question 2', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'label':
			return self.label[query_datapoint[key]]
		elif key == "question 1":
			return query_datapoint["question1"]
		else:
			return query_datapoint["question2"]


class GlueSST2(DefaultNLPTask):

	def __init__(self, task_name):
		super(GlueSST2, self).__init__(task_name)
		self.label = {
            0: "negative",
            1: "positive",
        }

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['sentence', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'label':
			return self.label[query_datapoint[key]]
		else:
			return query_datapoint["sentence"]

class GlueWNLI(DefaultNLPTask):

	def __init__(self, task_name):
		super(GlueWNLI, self).__init__(task_name)
		self.label = {
            0: "not entailment",
            1: "entailment"
        }

	def get_task_type(self):
		return "entailment"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['sentence 1', 'sentence 2', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'label':
			return self.label[query_datapoint[key]]
		elif key == "sentence 1":
			return query_datapoint["sentence1"]
		else:
			return query_datapoint["sentence2"]


class WikiSQL(DefaultNLPTask):
	
	def get_task_type(self):
		return "text generation"

	def get_nlu_keys(self):
		return ['sql expression']

	def get_valid_keys(self):
		return ['question', 'sql expression']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'question':
			return query_datapoint["question"]
		else:
			return query_datapoint["sql"]["human_readable"].strip()


class Winogrande(DefaultNLPTask):

	def discard_data(self, query_datapoint):
		return int(query_datapoint['answer']) not in [1, 2]

	def get_task_type(self):
		return "multiple choices"

	def get_nlu_keys(self):
		return ['answer']

	def get_valid_keys(self):
		return ['sentence', 'option 1', 'option 2', 'answer']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'answer':
			if int(query_datapoint['answer']) == 1:
				return query_datapoint["option1"]
			return query_datapoint["option2"]
		elif key == "option 1":
			return query_datapoint["option1"]
		elif key == "option 2": 
			return query_datapoint["option2"]
		else:
			return str(query_datapoint[key])

class Ethos(TrainSplitNLPTask):

	def __init__(self, task_name):
		super(Ethos, self).__init__(task_name)
		self.violence_label = {
			0: "not violent",
			1: "violent",
		}

		self.directed_vs_generalized_label = {
			0:"generalied",
			1:"directed",
		}

		self.label = {
			0: "false",
			1: "true",
		}

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ["directed_vs_generalized", "disability", "gender", "national_origin", "race", "religion", "sexual_orientation"]

	def get_valid_keys(self):
		return ["text", "directed_vs_generalized", "disability", "gender", "national_origin", "race", "religion", "sexual_orientation"]

	def get_value_from_key(self, query_datapoint, key):
		if key == "text":
			return query_datapoint[key]
		elif key == "violence":
			return self.violence_label[query_datapoint[key]]
		elif key == "directed_vs_generalized":
			return self.directed_vs_generalized_label[query_datapoint[key]]
		else:
			return self.label[query_datapoint[key]]



class TweetEvalEmotion(DefaultNLPTask):

	def __init__(self, task_name):
		super(TweetEvalEmotion, self).__init__(task_name)
		self.label = {
			0:"anger",
            1:"joy",
            2:"optimism",
            3:"sadness",
		}

	def discard_data(self, query_datapoint):
		if len(query_datapoint["text"].replace("\n", " ")) == 0:
			return True
		return False

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'text':
			return query_datapoint["text"].replace("\n", " ")
		else:
			return self.label[query_datapoint["label"]]

class TweetEvalHate(DefaultNLPTask):

	def __init__(self, task_name):
		super(TweetEvalHate, self).__init__(task_name)
		self.label = {
            0:"not hate",
            1:"hate",
        }

	def discard_data(self, query_datapoint):
		if len(query_datapoint["text"].replace("\n", " ")) == 0:
			return True
		return False

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'text':
			return query_datapoint["text"].replace("\n", " ")
		else:
			return self.label[query_datapoint["label"]]


class TweetEvalIrony(DefaultNLPTask):

	def __init__(self, task_name):
		super(TweetEvalIrony, self).__init__(task_name)
		self.label = {
            0:"not irony",
            1:"irony",
        }

	def discard_data(self, query_datapoint):
		if len(query_datapoint["text"].replace("\n", " ")) == 0:
			return True
		return False

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'text':
			return query_datapoint["text"].replace("\n", " ")
		else:
			return self.label[query_datapoint["label"]]

class TweetEvalOffensive(DefaultNLPTask):

	def __init__(self, task_name):
		super(TweetEvalOffensive, self).__init__(task_name)
		self.label = {
            0:"not offensive",
            1:"offensive",
        }

	def discard_data(self, query_datapoint):
		if len(query_datapoint["text"].replace("\n", " ")) == 0:
			return True
		return False

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'text':
			return query_datapoint["text"].replace("\n", " ")
		else:
			return self.label[query_datapoint["label"]]


class TweetEvalSentiment(DefaultNLPTask):

	def __init__(self, task_name):
		super(TweetEvalSentiment, self).__init__(task_name)
		self.label = {
            0:"negative",
            1:"neutral",
            2:"positive",
        }

	def discard_data(self, query_datapoint):
		if len(query_datapoint["text"].replace("\n", " ")) == 0:
			return True
		return False

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'text':
			return query_datapoint["text"].replace("\n", " ")
		else:
			return self.label[query_datapoint["label"]]

class TweetEvalStance(DefaultNLPTask):

	def __init__(self, task_name):
		super(TweetEvalStance, self).__init__(task_name)
		self.label = {
            0:"none",
            1:"against",
            2:"favor",
        }

	def discard_data(self, query_datapoint):
		if len(query_datapoint["text"].replace("\n", " ")) == 0:
			return True
		return False

	def get_task_type(self):
		return "sentence classification"

	def get_nlu_keys(self):
		return ['label']

	def get_valid_keys(self):
		return ['text', 'label']

	def get_value_from_key(self, query_datapoint, key):
		if key == 'text':
			return query_datapoint["text"].replace("\n", " ")
		else:
			return self.label[query_datapoint["label"]]

# "amazon_polarity": AmazonPolarity,
# "yahoo_answers_topics": YahooAnswersTopics,
# "kilt_tasks***wow": Kilt,
TASK_NAME_TO_CLS = {
	"ade_corpus_v2***Ade_corpus_v2_drug_dosage_relation": TrainSplitNLPTask,
	"ade_corpus_v2***Ade_corpus_v2_drug_ade_relation": TrainSplitNLPTask,
	"wiki_bio": WikiBio,
	"jeopardy": TrainSplitNLPTask,
	"definite_pronoun_resolution": DefinitePronounResolution,
	"sentiment140": Sentiment140,
	"numer_sense": NumerSense,
	"wikisql": WikiSQL,
	"yelp_review_full": YelpReviewFull,
	"tweet_eval***emotion": TweetEvalEmotion,
	"tweet_eval***hate": TweetEvalHate,
	"tweet_eval***irony": TweetEvalIrony,
	"tweet_eval***offensive": TweetEvalOffensive,
	"tweet_eval***sentiment": TweetEvalSentiment,
	"tweet_eval***stance_abortion":TweetEvalStance,
	"tweet_eval***stance_atheism":TweetEvalStance,
	"tweet_eval***stance_climate":TweetEvalStance,
	"tweet_eval***stance_feminist":TweetEvalStance,
	"ade_corpus_v2***Ade_corpus_v2_classification": AdeClassification,
	"adversarial_qa***adversarialQA": AdversarialQA,
	"aeslc": AESLC,
	"ag_news": AGNewsClassification,
	"ai2_arc***ARC-Challenge": ARCChallengeChoices,
	"liar": Liar,
	"tweet_qa": TweetQA,
	"anli": ANLI,
	"app_reviews": TrainSplitNLPTask,
	"aqua_rat***raw": AquaRat,
	"art": ART,
	"aslg_pc12": TrainSplitNLPTask,
	"biomrc***biomrc_large_B": BioMRC,
	"break_data***QDMR": DefaultNLPTask,
	"break_data***QDMR-high-level": DefaultNLPTask,
	"circa": Circa,
	"climate_fever": ClimateFever,
	"codah***fold_0": CODAH,
	"common_gen": CommonGen,
	"commonsense_qa": CommonsenseQA,
	"cos_e***v1.11": CoS_E,
	"cosmos_qa": CosmosQA,
	"dbpedia_14": DBpedia14,
	"discovery***discovery": Discovery,
	"dream": Dream,
	"duorc***SelfRC": DuoRC,
	"e2e_nlg_cleaned": DefaultNLPTask,
	"eli5":  ELI5,
	"emo": Emo,
	"emotion": Emotion,
	"empathetic_dialogues": EmpatheticDialogues,
	"financial_phrasebank***sentences_allagree": FinancialPhrasebank,
	"freebase_qa": FreebaseQA,
	"gigaword": DefaultNLPTask,
	"glue***cola": GlueCola,
	"glue***mnli": GlueMNLI,
	"glue***mrpc": GlueMRPC,
	"glue***qnli": GlueQNLI,
	"glue***qqp": GlueQQP,
	"glue***sst2": GlueSST2,
	"glue***wnli": GlueWNLI,
	"google_wellformed_query": GoogleWellformedQuery,
	"hate_speech18": HateSpeech18,
	"hate_speech_offensive": HateSpeechOffensive,
	"hatexplain": HatExplain,
	"health_fact": HealthFact,
	"hellaswag": HellaSwag,
	"hotpot_qa***distractor": HotpotQA,
	"imdb": IMDB,
	"kilt_tasks***aidayago2": Kilt,
	"kilt_tasks***fever": Kilt,
	"kilt_tasks***hotpotqa": Kilt,
	"kilt_tasks***nq": Kilt,
	"kilt_tasks***trex": Kilt,
	"kilt_tasks***structured_zeroshot": Kilt,
	"limit": Limit,
	"math_qa": MathQA,
	"mc_taco": MCTACO,
	"medical_questions_pairs": MedicalQuestionPairs,
	"mocha": Mocha,
	"multi_news": DefaultNLPTask,
	"onestop_english": OneStopEnglish,
	"openbookqa***main": OpenbookQA,
	"paws***labeled_final": PAWS,
	"piqa": PIQA,
	"poem_sentiment": PoemSentiment,
	"qa_srl": QASRL,
	"qasc": QASC,
	"quail": QUAIL,
	"quarel": QUAREL,
	"quartz": Quartz,
	"quoref": Quoref,
	"race***middle": Race,
	"race***high": Race,
	"reddit_tifu***long": TrainSplitNLPTask,
	"ropes": ROPES,
	"rotten_tomatoes": RottenTomatos,
	"samsum": DefaultNLPTask,
	"scicite": SciCite,
	"sciq": SciQ,
	"scitail***snli_format": DefaultNLPTask,
	"search_qa***train_test_val": DefaultNLPTask,
	"sick": Sick,
	"sms_spam": SMSSpam,
	"social_i_qa": SocialIQA,
	"spider": DefaultNLPTask,
	"squad": SQuAD,
	"swag***regular": Swag,
	"tab_fact***tab_fact": TabFact,
	"trec": TREC,
	"web_questions": WebQuestions,
	"xsum": DefaultNLPTask,
	"wiqa": WIQA,
	"wiki_qa": WikiQA,
	"yelp_polarity": YelpPolarity,
	"wiki_split": DefaultNLPTask,
}

TASK_NAME_LIST = [
	"ade_corpus_v2***Ade_corpus_v2_drug_dosage_relation",
	"ade_corpus_v2***Ade_corpus_v2_drug_ade_relation",
	"wiki_bio",
	"jeopardy",
	"definite_pronoun_resolution",
	"sentiment140",
	"numer_sense",
	"wikisql",
	"yelp_review_full",
	"tweet_eval***emotion",
	"tweet_eval***hate",
	"tweet_eval***irony",
	"tweet_eval***offensive",
	"tweet_eval***sentiment",
	"tweet_eval***stance_abortion",
	"tweet_eval***stance_atheism",
	"tweet_eval***stance_climate",
	"tweet_eval***stance_feminist",
	"ade_corpus_v2***Ade_corpus_v2_classification",
	"adversarial_qa***adversarialQA",
	"aeslc",
	"ag_news",
	"ai2_arc***ARC-Challenge",
	"liar",
	"tweet_qa",
	"anli",
	"app_reviews",
	"aqua_rat***raw",
	"art",
	"aslg_pc12",
	"biomrc***biomrc_large_B",
	"break_data***QDMR",
	"break_data***QDMR-high-level",
	"circa",
	"climate_fever",
	"codah***fold_0",
	"common_gen",
	"commonsense_qa",
	"cos_e***v1.11",
	"cosmos_qa",
	"dbpedia_14",
	"discovery***discovery",
	"dream",
	"duorc***SelfRC",
	"e2e_nlg_cleaned",
	"eli5",
	"emo",
	"emotion",
	"empathetic_dialogues",
	"financial_phrasebank***sentences_allagree",
	"freebase_qa",
	"gigaword",
	"glue***cola",
	"glue***mnli",
	"glue***mrpc",
	"glue***qnli",
	"glue***qqp",
	"glue***sst2",
	"glue***wnli",
	"google_wellformed_query",
	"hate_speech18",
	"hate_speech_offensive",
	"hatexplain",
	"health_fact",
	"hellaswag",
	"hotpot_qa***distractor",
	"imdb",
	"kilt_tasks***aidayago2",
	"kilt_tasks***fever",
	"kilt_tasks***hotpotqa",
	"kilt_tasks***nq",
	"kilt_tasks***trex",
	"kilt_tasks***structured_zeroshot",
	"limit",
	"math_qa",
	"mc_taco",
	"medical_questions_pairs",
	"mocha",
	"multi_news",
	"onestop_english",
	"openbookqa***main",
	"paws***labeled_final",
	"piqa",
	"poem_sentiment",
	"qa_srl",
	"qasc",
	"quail",
	"quarel",
	"quartz",
	"quoref",
	"race***middle",
	"race***high",
	"reddit_tifu***long",
	"ropes",
	"rotten_tomatoes",
	"samsum",
	"scicite",
	"sciq",
	"scitail***snli_format",
	"search_qa***train_test_val",
	"sick",
	"sms_spam",
	"social_i_qa",
	"spider",
	"squad",
	"swag***regular",
	"tab_fact***tab_fact",
	"trec",
	"web_questions",
	"xsum",
	"wiqa",
	"wiki_qa",
	"yelp_polarity",
	"wiki_split",
]


