import random
from torch.utils.data import Dataset, ConcatDataset, DataLoader
import numpy as np
import torch
from NLPTasks_wo_SuperGLUE import *
from tqdm import tqdm
import itertools
import os
import h5py
import math
import copy
from pathlib import Path
import re
import json
import nltk


def nltk_line_tokenizer(line):
    return nltk.word_tokenize(line)

class MetaH5Construction:

	def __init__(self, single_task, split, tokenizer):
		self.key_list = single_task.get_valid_keys()
		self.key2ids = {}
		for key in self.key_list:
			self.key2ids[key] = tokenizer(key.lower().capitalize() + ": ", return_tensors="np")['input_ids'][0, :-1].tolist()

		self.split_data_list = []
		for query_datapoint in tqdm(single_task.get_split(split)):
			data_dict = {}
			for key in self.key_list:
				value = single_task.get_value_from_key(query_datapoint, key)
				data_dict[key] = tokenizer(value, return_tensors="np")['input_ids'][0, :-1].tolist()
			self.split_data_list.append(data_dict)

def get_random_span(seq, l, n):
	indices = range(len(seq) - (l - 1) * n)
	result = []
	offset = 0
	for i in sorted(random.sample(indices, n)):
		i += offset
		result.append(i)
		offset += l - 1
	return result

def get_chunk_type(tag_name):
    tag_class = tag_name.split('-')[0]
    tag_type = tag_name.split('-')[-1]
    return tag_class, tag_type

def get_chunks(seq):
    default = "O"
    chunks = []

    chunk_type, chunk_start = None, None
    for i, tok in enumerate(seq):
        if tok == default and chunk_type is not None:
            chunk = (chunk_type, chunk_start, i)
            chunks.append(chunk)
            chunk_type, chunk_start = None, None

        elif tok != default:
            tok_chunk_class, tok_chunk_type = get_chunk_type(tok)
            if chunk_type is None:
                chunk_type, chunk_start = tok_chunk_type, i
            elif tok_chunk_type != chunk_type or tok_chunk_class == "B":
                chunk = (chunk_type, chunk_start, i)
                chunks.append(chunk)
                chunk_type, chunk_start = tok_chunk_type, i
        else:
            pass

    if chunk_type is not None:
        chunk = (chunk_type, chunk_start, len(seq))
        chunks.append(chunk)
    return chunks

def read_conll(file_path):
    file_path = Path(file_path)

    raw_text = file_path.read_text().strip()
    raw_docs = re.split(r'\n\t?\n', raw_text)

    data_list = []
    for doc in raw_docs:
        tokens = []
        tags = []
        for line in doc.split('\n'):
            items = line.split()
            if len(items) == 2:
                token, tag = items
                tokens.append(token)
                tags.append(tag)
        data_list.append((tokens, tags))

    return data_list

def read_cls_sen_data(_path):
	data_list = []
	with open(_path) as out:
		for l in out:
			l = l.strip()
			items = l.split('\t')
			if len(items) < 2: continue
			data_list.append((items[0], items[1]))
	return data_list

def read_pair_data(_path):
	data_list = []
	with open(_path) as out:
		for l in out:
			l = l.strip()
			items = l.split('\t')
			if len(items) < 3: continue
			data_list.append((items[0], items[1], items[2]))
	return data_list

class DAInContextDataset(Dataset):

	SKIP_ATTRIBUTES = ['gt_x', 'gt_y']

	def __init__(self, config, data_path, train_path, tokenizer, is_training=False, is_root=True):
		self.tokenizer = tokenizer
		self.is_training = is_training
		self.config = config
		self.sep_token_id = 0

		self.data_list = []
		data_instance_list = self.get_data_set(data_path)

		if self.config.enable_tri_training and self.config.enable_prompt_ensemble:
			self.train_data_index_prompt_map = {i: [] for i in range(self.config.prefix_set_number)}
			for data_index in range(len(data_instance_list)):
				for prompt_id in range(self.config.prefix_set_number):
					if random.random() < 0.8:
						self.train_data_index_prompt_map[prompt_id].append(data_index)

		for index, data_instance in enumerate(data_instance_list):
			if self.is_training:
				for _ in range(self.config.oversample):
					if self.config.enable_prompt_ensemble:
						for prompt_id in range(self.config.prefix_set_number):
							if self.config.enable_tri_training and self.config.enable_prompt_ensemble:
								if index in self.train_data_index_prompt_map[prompt_id]:
									self.data_list.append((index, prompt_id, data_instance))
							else:
								self.data_list.append((index, prompt_id, data_instance))
					else:
						self.data_list.append((index, None, data_instance))
			else:
				for _ in range(self.config.eval_data_replication):
					if self.config.enable_prompt_ensemble:
						for prompt_id in range(self.config.prefix_set_number):
							self.data_list.append((index, prompt_id, data_instance))
					else:
						self.data_list.append((index, None, data_instance))

		# self.nlu_ids = self.tokenizer("[NLU] ", return_tensors="np")['input_ids'][0, :-1].tolist()
		# self.nlg_ids = self.tokenizer("[NLG] ", return_tensors="np")['input_ids'][0, :-1].tolist()

		self.train_data_list = self.get_data_set(train_path)
		self.train_data_index = [i for i in range(len(self.train_data_list))]

		if is_training:
			assert len(self.config.training_da_mode) > 0
			self.da_mode = self.config.training_da_mode
		else:
			assert len(self.config.eval_da_mode) > 0
			self.da_mode = self.config.eval_da_mode

		self.mode_func = {
			"tag": self.gen_from_tag_sequence,
			"nlu": self.gen_from_nlu,
		}

		if is_training:
			assert self.config.prefix_set_number == 0 or len(config.lm_gen_train_path_list) == 0 or len(config.lm_gen_train_path_list) == self.config.prefix_set_number
			for pid, d_path in enumerate(config.lm_gen_train_path_list):
				for data_instance in self.get_data_set(d_path, filtering=True):
					index = len(self.data_list)
					self.data_list.append((index, 0 if self.config.prefix_set_number == 0 else pid, data_instance))

		if is_root:
			print("Data Size %d" % len(self.data_list))

	def __len__(self):
		return len(self.data_list)

	def __getitem__(self, idx):
		mode = random.choice(self.da_mode)
		data_generator = self.mode_func[mode]

		(index, prompt_id, data_instance) = self.data_list[idx]
		x_ids, y_ids, gt_x, gt_y = data_generator(data_instance, mask_token=True)
		input_ids = x_ids
		total_length = len(input_ids)
		if self.config.in_context_instance_count > 0:
			selected_instances = []
			selected_gt_list = []
			for d_index in random.sample(self.train_data_index, k=self.config.in_context_instance_count):
				train_data_instance = self.train_data_list[d_index]
				if self.is_identical(train_data_instance, data_instance): continue
				if total_length <= self.config.max_length - 1:
					x_ids, _, context_gt_x, _ = data_generator(train_data_instance, add_seperator=True, is_train_instance=True)
					if len(x_ids) + total_length <= self.config.max_length - 1:
						selected_instances.append(x_ids)
						selected_gt_list.append(context_gt_x)
						total_length += len(x_ids)
				else:
					break
			if len(selected_instances) > 0:
				selected_instance_num = random.choice([i for i in range(1, len(selected_instances) + 1)])
				for instance in selected_instances[:selected_instance_num]:
					input_ids = instance + input_ids

		input_ids.append(self.tokenizer.eos_token_id)

		# if mode == "nlu":
		# 	input_ids = self.nlu_ids + input_ids
		# else:
		# 	input_ids = self.nlg_ids + input_ids

		input_np = np.array(input_ids).astype(np.int64)
		output_np = np.array(y_ids).astype(np.int64)

		if prompt_id is None:
			if self.config.prefix_set_number > 1:
				prompt_id = random.choice([i for i in range(self.config.prefix_set_number)])
			else:
				prompt_id = 0

		return input_np, output_np, gt_x, gt_y, prompt_id, index, 0 if mode == "nlu" else 1

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False):
		raise NotImplementedError

	def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
		raise NotImplementedError

	def get_data_set(self, path, filtering=False):
		raise NotImplementedError

	def is_identical(self, instance_a, instance_b):
		raise NotImplementedError


class SeqLabelInContext(DAInContextDataset):

	def get_data_set(self, path, filtering=False):
		data_list = read_conll(path)
		if filtering:
			new_data_list = []
			for (token, tag) in data_list:
				if len(token) < 50:
					new_data_list.append((token, tag))
			data_list = new_data_list
		return data_list

	def is_identical(self, instance_a, instance_b):
		text_a = ' '.join(instance_a[0])
		text_b = ' '.join(instance_b[0])
		return text_a == text_b

	def add_annotation(self, tokens, chunk_info):
		entity_label = "B-%s" % chunk_info[0]
		
		if chunk_info[1] + 1 == chunk_info[2]:
			tokens[chunk_info[1]] = "%s %s &&" % (entity_label, tokens[chunk_info[1]])
		else:
			tokens[chunk_info[1]] = "%s %s" % (entity_label, tokens[chunk_info[1]])
			tokens[chunk_info[2] - 1] = "%s &&" % (tokens[chunk_info[2] - 1])
		
		return tokens

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False):
		token, tag = data_instance
		label_list = []
		chunks = get_chunks(tag)
		copied_token = copy.deepcopy(token)
		for v in chunks:
			entity_label = "B-%s" % v[0]
			if self.config.nlg_with_annotation:
				copied_token = self.add_annotation(copied_token, v)
			label_list.append(entity_label)
		
		random.shuffle(label_list)
		input_x = " and ".join(label_list)
		input_y = ' '.join(copied_token)

		gt_x, gt_y = input_x, input_y

		if mask_token:
			input_x = "Entity_Labels: %s Sentence: <extra_id_0>" % input_x
			input_y = "<extra_id_0> %s <extra_id_1>" % input_y
		else:
			input_x = "Entity_Labels: %s Sentence: %s" % (input_x, input_y)
			input_y = ""

		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
		token, tag = data_instance
		label_list = []
		chunks = get_chunks(tag)

		input_x = ' '.join(token)
		entities = []
		for v in chunks:
			entity_label = "B-%s" % v[0]
			entity_mention = ' '.join(token[v[1]: v[2]])
			entities.append(entity_label + ' ' + entity_mention)
		input_y = 'no entities' if len(entities) == 0 else ' '.join(entities)

		gt_x, gt_y = input_x, input_y

		if mask_token:
			input_x = "Entity_Labels: <extra_id_0> Sentence: %s" % input_x
			input_y = "<extra_id_0> %s <extra_id_1>" % input_y
		else:
			input_x = "Entity_Labels: %s Sentence: %s" % (input_y, input_x)
			input_y = ""

		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

class SenCLSInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False):
		token, tag = data_instance
		input_x = tag
		input_y = token

		gt_x, gt_y = input_x, input_y

		if mask_token:
			input_x = "Label: %s Sentence: <extra_id_0>" % input_x
			input_y = "<extra_id_0> %s <extra_id_1>" % input_y
		else:
			input_x = "Label: %s Sentence: %s" % (input_x, input_y)
			input_y = ""

		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
		token, tag = data_instance
		input_x = token
		input_y = tag

		gt_x, gt_y = input_x, input_y

		if mask_token:
			input_x = "Label: <extra_id_0> Sentence: %s" % input_x
			input_y = "<extra_id_0> %s <extra_id_1>" % input_y
		else:
			input_x = "Label: %s Sentence: %s" % (input_y, input_x)
			input_y = ""

		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		return read_cls_sen_data(path)

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0]

class DocumentInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False):
		question, article, tag = data_instance
		
		if mask_token:
			input_x = "Context: <extra_id_0>"
			input_y = "<extra_id_0> %s <extra_id_1>" % article
		else:
			input_x = "Context: %s" % article
			input_y = tag

		gt_y = input_y
		gt_x = input_x
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
		question, article, tag = data_instance
		input_x = "Questions: %s Article: %s" % (question, article)
		input_y = tag

		gt_x, gt_y = (question, article), tag

		if mask_token:
			input_x = "Label: <extra_id_0> %s" % input_x
			input_y = "<extra_id_0> %s <extra_id_1>" % input_y
		else:
			input_x = "Label: %s %s" % (input_y, input_x)
			input_y = ""

		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		return read_pair_data(path)

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

class RTEDocumentInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		hypothesis, premise, label, _ = data_instance
		
		if mask_token:
			input_x = "Premise: <extra_id_0>"
			input_y = "<extra_id_0> %s <extra_id_1>" % premise
		else:
			input_x = "Premise: %s" % premise
			input_y = label

		gt_y = input_y
		gt_x = input_x
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				data_list.append((items['hypothesis'], items['premise'], str(items['label']), items['idx']))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

class BoolQDocumentInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		hypothesis, premise, label, _ = data_instance
		
		if mask_token:
			input_x = "Article: <extra_id_0>"
			input_y = "<extra_id_0> %s <extra_id_1>" % premise
		else:
			input_x = "Article: %s" % premise
			input_y = label

		gt_y = input_y
		gt_x = input_x
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				data_list.append((items['hypothesis'], items['premise'], str(items['label']), items['idx']))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

class CBDocumentInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		hypothesis, premise, label, _ = data_instance
		
		if mask_token:
			input_x = "Document: <extra_id_0>"
			input_y = "<extra_id_0> %s <extra_id_1>" % premise
		else:
			input_x = "Document: %s" % premise
			input_y = label

		gt_y = input_y
		gt_x = input_x
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				data_list.append((items['hypothesis'], items['premise'], str(items['label']), items['idx']))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

class MultiRCDocumentInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		document = data_instance
		
		if mask_token:
			input_x = "Document: <extra_id_0>"
			input_y = "<extra_id_0> %s <extra_id_1>" % document
		else:
			input_x = "Document: %s" % document
			input_y = ""

		gt_y = input_y
		gt_x = input_x
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				data_list.append(items['passage']['text'])
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0]

class MultiRCQuestionInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		document, question = data_instance
		
		if mask_token:
			input_x = "Question: <extra_id_0> Document: %s" % document
			input_y = "<extra_id_0> %s <extra_id_1>" % question
			gt_x = "Question: <extra_id_0> |*| Document: %s" % document
			gt_y = question
		else:
			input_x = "Question: %s Document: %s" % (question, document)
			input_y = ""
			gt_y = ""
			gt_x = ""

		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				doc = items['passage']['text']
				for q in items['passage']['questions']:
					data_list.append((doc, q['question']))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

class MultiRCAnswerInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		document, question, answer, label = data_instance

		if self.is_training or is_train_instance: 
			label = 'True' if label == 1 else 'False'
		else:
			label = 'True' if random.random() < 0.5 else 'False'
		
		if mask_token:
			input_x = "Label: %s Question: %s Answer: <extra_id_0> Document: %s" % (label, question, document)
			input_y = "<extra_id_0> %s <extra_id_1>" % answer
			gt_x = "Label: %s |*| Question: %s |*| Answer: <extra_id_0> |*| Document: %s" % (label, question, document)
			gt_y = question
		else:
			input_x = "Question: %s Document: %s" % (question, document)
			input_y = ""
			gt_y = ""
			gt_x = ""

		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				doc = items['passage']['text']
				for q in items['passage']['questions']:
					for ans in q['answers']:
						data_list.append((doc, q['question'], ans['text'], ans['label']))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1] and instance_a[2] == instance_b[2]

class WICGenerationInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		word, sen1, sen2, label = data_instance

		label_list = ['True', 'False']
		if self.is_training or is_train_instance:
			new_tag = str(label)
		else:
			new_tag = random.choice(label_list)
		
		if mask_token:
			input_x = "Word: <extra_id_0> Sentence a: <extra_id_1> Sentence b: <extra_id_2> Sense match: %s" % new_tag
			input_y = "<extra_id_0> %s <extra_id_1> %s <extra_id_2> %s <extra_id_3>" % (word, sen1, sen2)
			gt_x = "Word: <extra_id_0> |*| Sentence a: <extra_id_1> |*| Sentence b: <extra_id_2> |*| Sense match: %s" % new_tag
		else:
			input_x = "Word: %s Sentence a: %s Sentence b: %s Sense match: %s" % (word, sen1, sen2, new_tag)
			input_y = "%s" % label
			gt_x = input_x

		gt_y = input_y
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				
				sen1 = items['sentence1']
				sen1 = sen1[:items['start1']] + ' **' + sen1[items['start1']: items['end1']] + '** ' + sen1[items['end1']:]

				sen2 = items['sentence2']
				sen2 = sen2[:items['start2']] + ' **' + sen2[items['start2']: items['end2']] + '** ' + sen2[items['end2']:]

				data_list.append((items['word'], sen1, sen2, items['label']))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

class WSCGenerationInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		text, label = data_instance
		
		if mask_token:
			input_x = "Text: <extra_id_0> Coreference: %s" % label
			input_y = "<extra_id_0> %s <extra_id_1>" % text
			gt_x = "Text: <extra_id_0> |*| Coreference: %s" % label
		else:
			input_x = "Text: %s Coreference: %s" % (text, label)
			input_y = "%s" % label
			gt_x = input_x

		gt_y = input_y
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				tokens = items['text'].split()

				span2 = tokens[items['target']['span2_index']]
				span2_text = items['target']['span2_text']
				span2 = '**' + span2
				tokens[items['target']['span2_index']] = span2

				span1 = tokens[items['target']['span1_index']]
				span1_text = items['target']['span1_text']
				span1 = '**' + span1
				tokens[items['target']['span1_index']] = span1

				final_text = ' '.join(tokens)

				final_text = final_text.replace('**' + span2_text, '**' + span2_text + '** ')
				final_text = final_text.replace('**' + span1_text, '**' + span1_text + '** ')
				final_text = ' '.join(final_text.strip().split())

				data_list.append((final_text, items['label']))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

class HypothesisGenInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		hypothesis, premise, label, _ = data_instance
		label = ' '.join(label.split('_'))
		
		if mask_token:
			if self.is_training and is_train_instance:
				new_tag = label
			else:
				new_tag = 'not entailment' if random.random() < 0.5 else 'entailment'

			input_x = "Label: %s Hypothesis: <extra_id_0> Premise: %s" % (new_tag, premise)
			input_y = "<extra_id_0> %s <extra_id_1>" % hypothesis
			gt_x = "Hypothesis: <extra_id_0> |*| Label: %s |*| Premise: %s" % (new_tag, premise)
		else:
			input_x = "Hypothesis: %s Premise: %s" % (hypothesis, premise)
			input_y = label
			input_x = "Label: %s %s" % (label, input_x)
			gt_x = ""

		gt_y = input_y
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				data_list.append((items['hypothesis'], items['premise'], str(items['label']), items['idx']))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

class CopaDocumentInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		premise, question, correct_choice, wrong_choice = data_instance
		
		if mask_token:
			input_x = "Text: <extra_id_0>"
			input_y = "<extra_id_0> %s <extra_id_1>" % premise
		else:
			input_x = "Text: %s" % premise
			input_y = question

		gt_x = premise
		gt_y = input_y
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				correct_choice = 'choice1' if items['label'] == 0 else 'choice2'
				wrong_choice = 'choice2' if items['label'] == 0 else 'choice1'
				data_list.append((items['premise'], items['question'], items[correct_choice], items[wrong_choice]))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

class CopaChoicesGenInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		premise, question, correct_choice, wrong_choice = data_instance

		if self.is_training or is_train_instance:
			new_tag = question
		else:
			new_tag = 'effect' if random.random() < 0.5 else 'cause'

		question_dict = {
			"effect": "What is the effect of this ?",
			"cause": "What is the cause of this ?",
		}
		
		if mask_token:
			input_x = "Premise: %s Question: %s Choice1: <extra_id_0> Choice2: <extra_id_1>" % (premise, question_dict[new_tag])
			input_y = "<extra_id_0> %s <extra_id_1> %s <extra_id_2>" % (correct_choice, wrong_choice)
			gt_x = "Premise: %s |*| Question: %s |*| Choice1: <extra_id_0> |*| Choice2: <extra_id_1>" % (premise, new_tag)
		else:
			input_x = "Premise: %s Question: %s Choice1: %s Choice2: %s" % (premise, question_dict[question], correct_choice, wrong_choice)
			input_y = "empty"
			gt_x = input_x

		gt_y = input_y
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				correct_choice = 'choice1' if items['label'] == 0 else 'choice2'
				wrong_choice = 'choice2' if items['label'] == 0 else 'choice1'
				data_list.append((items['premise'], items['question'], items[correct_choice], items[wrong_choice]))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

class CopaGenInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		premise, question, c1, c2, label = data_instance
		question_dict = {
			"effect": "What is the effect of this ?",
			"cause": "What is the cause of this ?",
		}

		new_tag = question

		if random.random() < 0.5:
			new_c1, new_c2 = c1, c2
			new_label = label
		else:
			new_c1, new_c2 = c2, c1
			new_label = 'Choice1' if label == 'Choice2' else 'Choice2'
		
		if mask_token:
			input_x = "Premise: <extra_id_0> Question: %s Choice1: %s Choice2: %s Answer: %s" % (question_dict[new_tag], new_c1, new_c2, new_label)
			input_y = "<extra_id_0> %s <extra_id_1>" % premise
			gt_x = "Premise: <extra_id_0> |*| Question: %s |*| Choice1: %s |*| Choice2: %s |*| Answer: %s" % (question_dict[new_tag], new_c1, new_c2, new_label)
		else:
			input_x = "Premise: %s Question: %s Choice1: %s Choice2: %s Answer: %s" % (premise, question_dict[new_tag], new_c1, new_c2, new_label)
			input_y = "empty"
			gt_x = input_x

		gt_y = premise
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				label = 'Choice1' if items['label'] == 0 else 'Choice2'
				data_list.append((items['premise'], items['question'], items['choice1'], items['choice2'], label))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

class COPANLUIncontext(DAInContextDataset):
    def get_data_set(self, path, filtering=False):
        data_list = []
        with open(path) as out:
            for l in out:
                items = json.loads(l)
                label = 'Choice2' if int(items['label']) == 1 else 'Choice1'
                data_list.append((items['premise'], items['choice1'], items['choice2'], items['question'], label, items['idx']))
                if self.is_training:
                    label = 'Choice1' if int(items['label']) == 1 else 'Choice2'
                    data_list.append((items['premise'], items['choice2'], items['choice1'], items['question'], label, items['idx']))

        return data_list

    def is_identical(self, instance_a, instance_b):
        return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1] \
               and instance_a[2] == instance_b[2] and instance_a[3] == instance_b[3]

    def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
        premise, choice1, choice2, question, tag, idx = data_instance
        question_dict = {
			"effect": "What is the effect of this ?",
			"cause": "What is the cause of this ?",
		}

        input_x = "Choice1: %s Choice2: %s Premise: %s Question: %s" % (choice1, choice2, premise, question_dict[question])
        input_y = tag

        gt_x, gt_y = (data_instance, idx), input_y

        if mask_token:
            input_x = "Answer: <extra_id_0> %s" % input_x
            input_y = "<extra_id_0> %s <extra_id_1>" % input_y
        else:
            input_x = "Answer: %s %s" % (input_y, input_x)
            input_y = ""

        y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
        x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

        if add_seperator:
            x_ids.append(self.sep_token_id)

        return x_ids, y_ids, gt_x, gt_y

class COPANLUFilteringIncontext(DAInContextDataset):
    def get_data_set(self, path, filtering=False):
        data_list = []
        with open(path) as out:
            for l in out:
                items = json.loads(l)
                label = 'Choice2' if int(items['label']) == 1 else 'Choice1'
                data_list.append((items['premise'], items['choice1'], items['choice2'], items['question'], label, items['idx']))

                label = 'Choice1' if int(items['label']) == 1 else 'Choice2'
                data_list.append((items['premise'], items['choice2'], items['choice1'], items['question'], label, items['idx']))

        return data_list

    def is_identical(self, instance_a, instance_b):
        return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1] \
               and instance_a[2] == instance_b[2] and instance_a[3] == instance_b[3]

    def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
        premise, choice1, choice2, question, tag, idx = data_instance
        question_dict = {
			"effect": "What is the effect of this ?",
			"cause": "What is the cause of this ?",
		}

        input_x = "Choice1: %s Choice2: %s Premise: %s Question: %s" % (choice1, choice2, premise, question_dict[question])
        input_y = tag

        gt_x, gt_y = (data_instance, idx), input_y

        if mask_token:
            input_x = "Answer: <extra_id_0> %s" % input_x
            input_y = "<extra_id_0> %s <extra_id_1>" % input_y
        else:
            input_x = "Answer: %s %s" % (input_y, input_x)
            input_y = ""

        y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
        x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

        if add_seperator:
            x_ids.append(self.sep_token_id)

        return x_ids, y_ids, gt_x, gt_y

class CBHypothesisGenInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		hypothesis, premise, label, _ = data_instance
		label = ' '.join(label.split('_'))
		label_list = ['neutral', 'entailment', 'contradiction']
		
		if mask_token:
			if self.is_training or is_train_instance:
				new_tag = label
			else:
				new_tag = random.choice(label_list)

			input_x = "Label: %s Hypothesis: <extra_id_0> Premise: %s" % (new_tag, premise)
			input_y = "<extra_id_0> %s <extra_id_1>" % hypothesis
			gt_x = "Hypothesis: <extra_id_0> |*| Label: %s |*| Premise: %s" % (new_tag, premise)
		else:
			input_x = "Hypothesis: %s Premise: %s" % (hypothesis, premise)
			input_y = label
			input_x = "Label: %s %s" % (label, input_x)
			gt_x = ""

		gt_y = input_y
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
		hypothesis, premise, tag, _ = data_instance
		input_x = "Premise: %s Hypothesis: %s" % (premise, hypothesis)
		input_y = tag

		gt_x, gt_y = (hypothesis, premise), tag

		if mask_token:
			input_x = "Label: <extra_id_0> %s" % input_x
			input_y = "<extra_id_0> %s <extra_id_1>" % input_y
		else:
			input_x = "Label: %s %s" % (input_y, input_x)
			input_y = ""

		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				data_list.append((items['hypothesis'], items['premise'], str(items['label']), items['idx']))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

class ReCDocumentInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		document = data_instance
		
		if mask_token:
			input_x = "Document: <extra_id_0>"
			input_y = "<extra_id_0> %s <extra_id_1>" % document
		else:
			input_x = "Document: %s" % document
			input_y = ""

		gt_y = input_y
		gt_x = input_x
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				data_list.append(items['passage']['text'])
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0]


class ReCEntitiesInContext(DAInContextDataset):
	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				article = items['passage']['text']
				entity_list = []
				for entity in items['passage']['entities']:
					entity_text = article[entity['start']: entity['end'] + 1]
					entity_list.append(entity_text)
				entity_list = list(set(entity_list))
				data_list.append((article, entity_list))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0]

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		document, entities = data_instance
		
		entities_str = " , ".join(entities)
		if mask_token:
			input_x = "Entities: <extra_id_0> Document: %s" % document
			input_y = "<extra_id_0> %s <extra_id_1>" % entities_str
			gt_x = "Entities: <extra_id_0> |*| Document: %s" % document
		else:
			input_x = "Entities: %s Document: %s" % (entities_str, document)
			input_y = ""
			gt_x = input_x

		gt_y = input_y
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

class ReCQueryInContext(DAInContextDataset):
	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				article = items['passage']['text']
				entity_list = []
				for entity in items['passage']['entities']:
					entity_text = article[entity['start']: entity['end'] + 1]
					entity_list.append(entity_text)
				entity_list = list(set(entity_list))
				for qas in items['qas']:
					query = qas['query']
					for ans in qas['answers']:
						copied_query = copy.deepcopy(query)
						copied_query = copied_query.replace('@placeholder', ans['text'])
						data_list.append((article, entity_list, copied_query))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0]

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
		document, entities, query = data_instance
		
		entities_str = " , ".join(entities)
		if mask_token:
			input_x = "Query: <extra_id_0> Document: %s" % document
			input_y = "<extra_id_0> %s <extra_id_1>" % query
			gt_x = "Entities: %s |*| Query: <extra_id_0> |*| Document: %s" % (entities_str, document)
		else:
			input_x = "Query: %s Document: %s" % (query, document)
			input_y = ""
			gt_x = input_x

		gt_y = input_y
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y


class DomainPairInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False):
		pass

	def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
		question_a, article_a, tag_a, question_b, article_b, tag_b, domain_label  = data_instance
		t1 = "Label 1: %s Questions 1: %s Article 1: %s" % (tag_a, question_a, article_a)
		t2 = "Label 2: %s Questions 2: %s Article 2: %s" % (tag_b, question_b, article_b)

		input_x = "%s %s Same_Domain: <extra_id_0>" % (t1, t2)
		input_y = "<extra_id_0> %s <extra_id_1>" % domain_label

		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		gt1 = "Label 1: %s |*| Questions 1: %s |*| Article 1: %s" % (tag_a, question_a, article_a)
		gt2 = "Label 2: %s |*| Questions 2: %s |*| Article 2: %s" % (tag_b, question_b, article_b)
		gt_x = "%s |***| %s" % (gt1, gt2)
		gt_y = domain_label

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				l = l.strip()
				items = l.split('\t')
				if len(items) < 7: continue
				data_list.append((items[0], items[1], items[2], items[3], items[4], items[5], items[6]))
		return data_list

	def is_identical(self, instance_a, instance_b):
		for i in range(7):
			if not instance_a[i] == instance_b[i]:
				return False
		return True

class BoolQInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False):
		question, article, tag = data_instance
		
		if mask_token:
			if seed <= 0.33:
				input_x = "Questions: %s Label: %s Article: <extra_id_0>" % (question, tag)
				input_y = "<extra_id_0> %s <extra_id_1>" % article
				gt_x = "Questions: %s |*| Label: %s |*| Article: <extra_id_0>" % (question, tag)
			elif seed <= 0.67:
				input_x = "Questions: <extra_id_0> Label: %s Article: %s" % (tag, article)
				input_y = "<extra_id_0> %s <extra_id_1>" % question
				gt_x = "Questions: <extra_id_0> |*| Label: %s |*| Article: %s" % (tag, article)
			else:
				input_x = "Questions: <extra_id_0> Label: %s Article: <extra_id_1>" % tag
				input_y = "<extra_id_0> %s <extra_id_1> %s <extra_id_2>" % (question, article)
				gt_x = "Questions: <extra_id_0> |*| Label: %s |*| Article: <extra_id_1>" % tag
				
			if self.is_training:
				new_tag = tag
			else:
				new_tag = 'True' if random.random() < 0.5 else 'False'

			# if random.random() < 0.5:
			# 	input_x = "Questions: %s Answers: %s Context: <extra_id_0>" % (question, new_tag)
			# 	input_y = "<extra_id_0> %s <extra_id_1>" % article
			# 	gt_x = "Questions: %s |*| Label: %s |*| Article: <extra_id_0>" % (question, new_tag)
			# else:
			# 	input_x = "Questions: <extra_id_0> Answers: %s Context: %s" % (new_tag, article)
			# 	input_y = "<extra_id_0> %s <extra_id_1>" % question
			# 	gt_x = "Questions: <extra_id_0> |*| Label: %s |*| Article: %s" % (new_tag, article)

			question_tokens = nltk_line_tokenizer(question)
			span_length = math.ceil((len(question_tokens) * 0.3))
			start_point = get_random_span(question_tokens, span_length, 1)[0]
			input_masked_question_str = ' '.join(question_tokens[:start_point]) + ' <extra_id_0> ' + ' '.join(question_tokens[start_point + span_length:])
			masked_question_text = ' '.join(question_tokens[start_point: start_point + span_length])

			article_tokens = nltk_line_tokenizer(article)
			span_length = math.ceil((len(article_tokens) * 0.3) / 3)
			start_points = get_random_span(article_tokens, span_length, 3)
			input_masked_article_str, masked_article_text = "", ""
			last_start_point = 0
			for index, start_point in enumerate(start_points):
				input_masked_article_str += ' '.join(article_tokens[last_start_point:start_point]) + ' <extra_id_%d> ' % (index + 1)
				masked_article_text += '<extra_id_%d> %s ' % (index + 1, ' '.join(article_tokens[start_point: start_point + span_length]))
				last_start_point = start_point + span_length
			masked_article_text += "<extra_id_%d>" % (len(start_points) + 1)
			input_masked_article_str += ' '.join(article_tokens[last_start_point:])

			input_x = "Questions: %s Answers: %s Context: %s" % (input_masked_question_str, new_tag, input_masked_article_str)
			input_y = "<extra_id_0> %s %s" % (masked_question_text, masked_article_text)
			gt_x = "Questions: %s |*| Label: %s |*| Article: %s" % (input_masked_question_str, new_tag, input_masked_article_str)
			
		else:
			input_x = "Questions: %s Article: %s" % (question, article)
			input_y = tag
			input_x = "Label: %s %s" % (tag, input_x)

		gt_y = input_y
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
		question, article, tag = data_instance
		input_x = "Question: %s Article: %s" % (question, article)
		input_y = tag

		gt_x, gt_y = (question, article), tag

		if mask_token:
			input_x = "Answer: <extra_id_0> %s" % input_x
			input_y = "<extra_id_0> %s <extra_id_1>" % input_y
		else:
			input_x = "Answer: %s %s" % (input_y, input_x)
			input_y = ""

		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				data_list.append((items['question'], items['passage'], str(items['label'])))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

class FlipDAInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False):
		question, article, tag, data_index = data_instance
		
		if mask_token:
			# if seed <= 0.33:
			# 	input_x = "Questions: %s Label: %s Article: <extra_id_0>" % (question, tag)
			# 	input_y = "<extra_id_0> %s <extra_id_1>" % article
			# 	gt_x = "Questions: %s |*| Label: %s |*| Article: <extra_id_0>" % (question, tag)
			# elif seed <= 0.67:
			# 	input_x = "Questions: <extra_id_0> Label: %s Article: %s" % (tag, article)
			# 	input_y = "<extra_id_0> %s <extra_id_1>" % question
			# 	gt_x = "Questions: <extra_id_0> |*| Label: %s |*| Article: %s" % (tag, article)
			# else:
			# 	input_x = "Questions: <extra_id_0> Label: %s Article: <extra_id_1>" % tag
			# 	input_y = "<extra_id_0> %s <extra_id_1> %s <extra_id_2>" % (question, article)
			# 	gt_x = "Questions: <extra_id_0> |*| Label: %s |*| Article: <extra_id_1>" % tag
			if self.is_training:
				new_tag = tag
			else:
				new_tag = 'True' if random.random() < 0.5 else 'False'

			input_x = "Questions: <extra_id_0> Label: %s Article: %s" % (new_tag, article)
			input_y = "<extra_id_0> %s <extra_id_1>" % question
			gt_x = "Questions: <extra_id_0> |*| Label: %s |*| Article: %s" % (new_tag, article)


			
		else:
			input_x = "Questions: %s Article: %s" % (question, article)
			input_y = tag
			input_x = "Label: %s %s" % (tag, input_x)

		gt_y = input_y
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
		question, article, tag, data_index = data_instance
		input_x = "Questions: %s Article: %s" % (question, article)
		input_y = tag

		gt_x, gt_y = (question, article, data_index), tag

		if mask_token:
			input_x = "Label: <extra_id_0> %s" % input_x
			input_y = "<extra_id_0> %s <extra_id_1>" % input_y
		else:
			input_x = "Label: %s %s" % (input_y, input_x)
			input_y = ""

		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		data_list = []
		with open(path) as out:
			for l in out:
				items = json.loads(l)
				data_list.append((items['question'], items['passage'], str(items['label']), items['idx']))
		return data_list

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

class QuestionGenInContext(DAInContextDataset):

	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False):
		question, article, tag = data_instance
		
		if mask_token:
			if self.is_training:
				new_tag = tag
			else:
				new_tag = 'True' if random.random() < 0.5 else 'False'

			input_x = "Questions: <extra_id_0> Label: %s Article: %s" % (new_tag, article)
			input_y = "<extra_id_0> %s <extra_id_1>" % question
			gt_x = "Questions: <extra_id_0> |*| Label: %s |*| Article: %s" % (new_tag, article)
		else:
			input_x = "Questions: %s Article: %s" % (question, article)
			input_y = tag
			input_x = "Label: %s %s" % (tag, input_x)

		gt_y = input_y
		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False):
		question, article, tag = data_instance
		input_x = "Questions: %s Article: %s" % (question, article)
		input_y = tag

		gt_x, gt_y = (question, article, data_index), tag

		if mask_token:
			input_x = "Label: <extra_id_0> %s" % input_x
			input_y = "<extra_id_0> %s <extra_id_1>" % input_y
		else:
			input_x = "Label: %s %s" % (input_y, input_x)
			input_y = ""

		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

		if add_seperator:
			x_ids.append(self.sep_token_id)

		return x_ids, y_ids, gt_x, gt_y

	def get_data_set(self, path, filtering=False):
		return read_pair_data(path)

	def is_identical(self, instance_a, instance_b):
		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

# class RTEInContext(DAInContextDataset):
# 	def get_data_set(self, path, filtering=False):
# 		return read_pair_data(path)

# 	def is_identical(self, instance_a, instance_b):
# 		return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

# 	def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
# 		premise, hypothesis, tag = data_instance
# 		input_x = "Hypothesis: %s Premise: %s" % (hypothesis, premise)
# 		input_y = tag

# 		gt_x, gt_y = input_x, input_y

# 		if mask_token:
# 			input_x = "Answer: <extra_id_0> %s" % input_x
# 			input_y = "<extra_id_0> %s <extra_id_1>" % input_y
# 		else:
# 			input_x = "Answer: %s %s" % (input_y, input_x)
# 			input_y = ""

# 		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
# 		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

# 		if add_seperator:
# 			x_ids.append(self.sep_token_id)

# 		return x_ids, y_ids, gt_x, gt_y

# 	def gen_from_tag_sequence(self, data_instance, mask_token=False, add_seperator=False):
# 		premise, hypothesis, tag = data_instance
# 		input_x = "Hypothesis: %s Premise: %s" % (hypothesis, premise)
# 		input_y = tag

# 		gt_x, gt_y = input_x, input_y

# 		if mask_token:
# 			seed = random.random()
# 			if seed <= 0.33:
# 				input_x = "Hypothesis: %s Label: %s Premise: <extra_id_0>" % (hypothesis, tag)
# 				input_y = "<extra_id_0> %s <extra_id_1>" % premise
# 			elif seed <= 0.67:
# 				input_x = "Hypothesis: <extra_id_0> Label: %s Premise: %s" % (tag, premise)
# 				input_y = "<extra_id_0> %s <extra_id_1>" % hypothesis
# 			else:
# 				input_x = "Hypothesis: <extra_id_0> Label: %s Premise: <extra_id_1>" % tag
# 				input_y = "<extra_id_0> %s <extra_id_1> %s <extra_id_2>" % (hypothesis, premise)
			
# 		else:
# 			input_x = "Label: %s %s" % (input_y, input_x)
# 			input_y = ""

# 		y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
# 		x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

# 		if add_seperator:
# 			x_ids.append(self.sep_token_id)

# 		return x_ids, y_ids, gt_x, gt_y

class RTEInContext(DAInContextDataset):
    def get_data_set(self, path, filtering=False):
        data_list = []
        with open(path) as out:
            for l in out:
                items = json.loads(l)
                label = ' '.join(items['label'].split('_'))
                data_list.append((items['premise'], items['hypothesis'], label))
        return data_list

    def is_identical(self, instance_a, instance_b):
        return instance_a[0] == instance_b[0] and instance_a[1] == instance_b[1]

    def gen_from_nlu(self, data_instance, mask_token=False, add_seperator=False, is_train_instance=False):
        premise, hypothesis, tag = data_instance
        input_x = "Hypothesis: %s Premise: %s" % (hypothesis, premise)
        input_y = tag

        gt_x, gt_y = (hypothesis, premise), input_y

        if mask_token:
            input_x = "Answer: <extra_id_0> %s" % input_x
            input_y = "<extra_id_0> %s <extra_id_1>" % input_y
        else:
            input_x = "Answer: %s %s" % (input_y, input_x)
            input_y = ""

        y_ids = self.tokenizer(input_y, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()
        x_ids = self.tokenizer(input_x, return_tensors="np")['input_ids'][0, :self.config.max_length].tolist()

        if add_seperator:
            x_ids.append(self.sep_token_id)

        return x_ids, y_ids, gt_x, gt_y


class MetaH5NLPTaskDynamic(Dataset):

	def __init__(self, config, task, task_index, path_, split, tokenizer, max_length, is_train=True):
		self.split = split
		self.eos_token_id = tokenizer.eos_token_id
		self.sep_token_id = 0
		self.max_length = max_length
		self.use_t5_span = config.use_t5_span
		self.path_ = path_
		self.config = config
		self.is_train = is_train
		self.task = task
		self.tokenizer = tokenizer
		self._task_index = task_index
		self.nlu_only = self.config.nlu_only
		self.nlg_only = self.config.nlg_only

		self.key_to_index = None
		self.key_replacement_dict = None
		if len(config.key_replacement_path) > 0:
			self.key_replacement_dict = {}
			with open(config.key_replacement_path) as out:
				for idx, line in enumerate(out):
					if idx == 0: continue
					line = line.strip()
					items = line.split(',')
					key = self.get_query_key(items[0])
					if key not in self.key_replacement_dict:
						self.key_replacement_dict[key] = []
					for key_replacement in items:
						key_replacement = key_replacement.strip()
						if len(key_replacement) > 0:
							key_ids = tokenizer(key_replacement.lower().capitalize() + ": ", return_tensors="np")['input_ids'][0, :-1].tolist()
							self.key_replacement_dict[key].append(key_ids)


		assert not (self.nlg_only and self.nlu_only), "nlg_only and nlu_only are exclusive"

		with h5py.File(self.path_, 'r') as _hdf5:
			self.orginal_len = _hdf5['%s_instance_count' % split][0]
			self.data_len = min(self.config.meta_task_max_sample, self.orginal_len)
			self.train_data_index = [i for i in range(_hdf5['train_instance_count'][0])]
			self.nlu_key_list = self.task.get_nlu_keys()
			self.key_list = [key[7:] for key in _hdf5.keys() if key.startswith('keyids_') and key[7:] not in self.nlu_key_list]
			self.key_list.append('nlu_key')
			self.nlu_visible_keys = [key for key in self.key_list if not key == "nlu_key"]
			# self.nlu_key_list = [(k.lower(), self.key_to_index[k.lower()]) for k in self.nlu_key_list]
			# self.key_list = [(k.lower(), self.key_to_index[k.lower()]) for k in self.key_list]
			# self.key_list.append(('nlu_key', None))

			self.key_combinations = [[]]
			for i in range(1, len(self.key_list)):
				self.key_combinations += list(itertools.combinations(self.key_list, i))
			self.switch_prob = 1.0 / (2.0 - 2.0 * len(self.nlu_key_list) / len(self.key_combinations))
			
			self.nlg_only_key_combinations = []
			for key_conbination in self.key_combinations:
				if all([key in key_conbination for key in self.nlu_visible_keys]):
					continue
				self.nlg_only_key_combinations.append(key_conbination)

		self.mask_warmup_steps = config.mask_warmup_epochs * self.data_len if self.is_train else -1
		self.current_steps = 0

	def __len__(self):
		return self.data_len

	def _get_t5_span(self, original_sequence, span_start_index):
		if len(original_sequence) <= 10:
			input_seq = [span_start_index]
			output_seq = [span_start_index] + original_sequence
			return input_seq, output_seq, span_start_index + 1
		else:
			if len(original_sequence) <= 30 or span_start_index == 32097:
				span_count = 1
			elif len(original_sequence) <= 60 or span_start_index == 32098:
				span_count = 2
			else:
				span_count = 3

			span_length = math.ceil((len(original_sequence) * 0.15) / span_count)
			start_points = get_random_span(original_sequence, span_length, span_count)
			input_seq = []
			output_seq = []
			prev_start_point = 0
			for start_point in start_points:
				input_seq += original_sequence[prev_start_point:start_point] + [span_start_index]
				output_seq += [span_start_index] + original_sequence[start_point: start_point + span_length]
				span_start_index -= 1
				prev_start_point = start_point + span_length
			input_seq += original_sequence[prev_start_point:]
			return input_seq, output_seq, span_start_index

	def _get_instance_t5_span_representation(self, data_index, used_nlu_key):
		show_attributes = random.choice(self.key_combinations)

		input_ids = []
		output_ids = []
		span_index = 32099

		for key in self.key_list:
			if key == "nlu_key":
				real_key = used_nlu_key
			else:
				real_key = key

			input_ids += self.data_hdf5['keyids_%s' % real_key][0].tolist()
			value_ids = self.data_hdf5['%s_key_%s_values' % (self.split, real_key)][data_index].tolist()
			if key in show_attributes and span_index > 32096:
				input_ids += value_ids
			else:
				value_input_seq, value_output_seq, span_index = self._get_t5_span(value_ids, span_index)
				input_ids += value_input_seq
				output_ids += value_output_seq

		output_ids.append(span_index)
		output_ids.append(self.eos_token_id)

		return input_ids, output_ids, show_attributes, False

	def get_query_key(self, key):
		return ''.join(key.strip().lower().split())

	def curriculum_learning_mask(self, special_token_id, value_ids):
		if self.mask_warmup_steps <= 0 or self.current_steps >= self.mask_warmup_steps or len(value_ids) <= 10:
			return [special_token_id], [special_token_id] + value_ids
		else:
			mask_ratio = 0.15 + 0.85 * (self.current_steps / self.mask_warmup_steps)
			span_length = int(len(value_ids) * mask_ratio)
			start_point = get_random_span(value_ids, span_length, 1)[0]
			input_sub_ids = value_ids[:start_point] + [special_token_id] + value_ids[start_point + span_length:]
			output_sub_ids = [special_token_id] + value_ids[start_point: start_point + span_length]
			return input_sub_ids, output_sub_ids

	def _get_instance_representation(self, data_index, used_nlu_key, show_attributes=None, add_seperator=False, use_train_data=False, key_mapping=None):
		enable_nlu = False
		if show_attributes is None:
			if self.nlg_only:
				show_attributes = random.choice(self.nlg_only_key_combinations)
			else:
				if (not self.nlu_only) and random.random() < self.switch_prob:
					show_attributes = random.choice(self.key_combinations)
					enable_nlu = all([key in show_attributes for key in self.nlu_visible_keys])
				else:
					enable_nlu = True
					show_attributes = self.nlu_visible_keys

		input_ids = []
		output_ids = []
		miss_span_count = 0
		for key in self.key_list:
			if key == "nlu_key":
				real_key = used_nlu_key
			else:
				real_key = key

			if key_mapping is None:
				input_ids += self.data_hdf5['keyids_%s' % real_key][0].tolist()
			else:
				query_key = self.get_query_key(real_key)
				input_ids += key_mapping[query_key]
				
			value_ids = self.data_hdf5['%s_key_%s_values' % (self.split if not use_train_data else 'train', real_key)][data_index].tolist()
			value_ids = value_ids[:self.config.meta_task_max_value_length]
			if key in show_attributes:
				input_ids += value_ids
			else:
				special_token_id = 32099 - miss_span_count
				input_sub_ids, output_sub_ids = self.curriculum_learning_mask(special_token_id, value_ids)
				input_ids += input_sub_ids
				output_ids += output_sub_ids
				miss_span_count += 1

		if len(output_ids) > 0:
			output_ids.append(32099 - miss_span_count)
			output_ids.append(self.eos_token_id)

		if add_seperator:
			input_ids.append(self.sep_token_id)

		return input_ids, output_ids, show_attributes, enable_nlu

	def __getitem__(self, index):
		if self.data_len >= self.orginal_len:
			index = index % self.orginal_len
		else:
			scale = self.orginal_len // self.data_len
			min_v, max_v = index * scale, (index + 1) * scale - 1
			index = min(random.randint(min_v, max_v), self.data_len)

		if not hasattr(self, 'data_hdf5'):
			self.data_hdf5 = h5py.File(self.path_, 'r')

		if len(self.nlu_key_list) > 1:
			show_nlu_key = random.sample(self.nlu_key_list, k=1)[0]
		else:
			show_nlu_key = self.nlu_key_list[0]

		random.shuffle(self.key_list)

		key_mapping = None
		if self.key_replacement_dict is not None:
			key_mapping = {}
			for key in self.key_list + [show_nlu_key]:
				if key == 'nlu_key': continue
				query_key = self.get_query_key(key)
				key_mapping[query_key] = random.choice(self.key_replacement_dict[query_key])

		if self.use_t5_span:
			input_ids, output_ids, used_kv, enable_nlu = self._get_instance_t5_span_representation(index, show_nlu_key, add_seperator=True)
		else:
			input_ids, output_ids, used_kv, enable_nlu = self._get_instance_representation(index, show_nlu_key, add_seperator=True, key_mapping=key_mapping)

		saved_instances = []
		total_length = len(input_ids)

		for d_index in random.sample(self.train_data_index, k=16):
			if self.split == "train" and d_index == index: continue
			if total_length <= self.max_length - 1:
				full_example_ids, _, _, _ = self._get_instance_representation(d_index, show_nlu_key, show_attributes=self.key_list, add_seperator=True, use_train_data=True, key_mapping=key_mapping)
				if len(full_example_ids) + total_length <= self.max_length - 1:
					saved_instances.append(full_example_ids)
					total_length += len(full_example_ids)
			else:
				break

		if len(saved_instances) > 0:
			low_bound = 1
			# if not all key-value pairs are masked, it is possible not to include full example
			if len(used_kv) > 0: low_bound = 0
			selected_instance_num = random.choice([i for i in range(low_bound, len(saved_instances) + 1)])

			candidates_list = saved_instances[:selected_instance_num] + [input_ids]
			if self.config.shuffle_example:
				random.shuffle(candidates_list)

			final_input_ids = []
			for candidate in candidates_list:
				final_input_ids += candidate
			input_ids = final_input_ids
		
		input_ids[-1] = self.eos_token_id

		input_np = np.array(input_ids).astype(np.int64)[:self.max_length]
		output_np = np.array(output_ids).astype(np.int64)[:self.max_length]

		if self.config.enable_new_task_embeddings:
			_task_index = 0
		else:
			_task_index = self._task_index

		return input_np, output_np, _task_index, 0 if enable_nlu else 1


class MetaH5NLPTask(Dataset):

	def __init__(self, task_name, split, is_root=True):
		self.split = split
		self.folder = "meta_task_h5/"
		self.path_ = os.path.join(self.folder, task_name + ".h5")
		with h5py.File(self.path_, 'r') as _hdf5:
			self.data_len = _hdf5['%s_instance_count' % split][0]

	def __len__(self):
		return self.data_len
		
	def __getitem__(self, index):
		if not hasattr(self, 'data_hdf5'):
			self.data_hdf5 = h5py.File(self.path_, 'r')

		input_np = self.data_hdf5['%s_instances_input' % self.split][index]
		output_np = self.data_hdf5['%s_instances_output' % self.split][index]

		input_np = input_np.astype(np.int64)
		output_np = output_np.astype(np.int64)

		return input_np, output_np



class MetaNLPTask(Dataset):

	def __init__(self, single_task, split, tokenizer, max_length=512, pre_load_training=False):
		self.split = split
		self.single_task = single_task
		self.tokenizer = tokenizer
		self.max_length = max_length
		self.key_list = single_task.get_valid_keys()
		self.key_combinations = [[]]
		for i in range(1, len(self.key_list) - 1):
			self.key_combinations += list(itertools.combinations(self.key_list, i))
		self.pre_load_training = pre_load_training
		self.split_data_list = single_task.get_split(split)
		self.data_index = [i for i in range(len(single_task.train_data))]
			
		if pre_load_training:
			if single_task.pre_load_train_data is not None:
				self.train_data = single_task.pre_load_train_data
			else:
				print("prepare training instances")
				self.train_data = []
				for data_item in tqdm(single_task.train_data):
					example_input_ids, _, _ = self._get_instance_representation(data_item, show_attributes=self.key_list)
					self.train_data.append(example_input_ids)
				single_task.pre_load_train_data = self.train_data

			print("prepare meta instances")
			new_split_data_list = []
			for index in tqdm(range(len(self.split_data_list))):
				input_ids, output_ids = self.generate_meta_instance(index)
				new_split_data_list.append((input_ids, output_ids))
			self.split_data_list = new_split_data_list
		else:
			self.train_data = single_task.train_data
	

	def __len__(self):
		return len(self.split_data_list)

	def __getitem__(self, index):
		if self.pre_load_training:
			return self.split_data_list[index]
		else:
			return self.generate_meta_instance(index)

	def _get_instance_representation(self, query_datapoint, show_attributes=None, add_seperator=True):
		if show_attributes is None:
			show_attributes = random.choice(self.key_combinations)

		input_str = ""
		output_str = ""
		miss_span_count = 0
		for key in self.key_list:
			if key in show_attributes:
				input_str += key.lower().capitalize() + ": " + self.single_task.get_value_from_key(query_datapoint, key) + " "
			else:
				input_str += key.lower().capitalize() + ": <extra_id_%d> " % miss_span_count
				output_str += "<extra_id_%d> %s " % (miss_span_count, self.single_task.get_value_from_key(query_datapoint, key))
				miss_span_count += 1
		if add_seperator:
			input_str += " <AND> "

		input_ids = self.tokenizer(input_str, return_tensors="np")['input_ids'][0, :self.max_length - 1].tolist()
		if len(output_str) > 0:
			output_ids = self.tokenizer(output_str, return_tensors="np")['input_ids'][0, :self.max_length - 1].tolist()
		else:
			output_ids = None

		if input_ids[-1] == self.tokenizer.eos_token_id:
			input_ids = input_ids[:-1]

		return input_ids, output_ids, show_attributes

	def generate_meta_instance(self, index):
		input_ids, output_ids, used_kv = self._get_instance_representation(self.split_data_list[index], add_seperator=False)

		total_length = len(input_ids)
		saved_instances = []
		instance_used = 0
		if total_length < self.max_length - 1:
			for d_index in random.sample(self.data_index, k=32):
				if self.split == "train" and d_index == index: continue
				if self.pre_load_training:
					example_input_ids =  self.train_data[d_index]
				else:
					example_input_ids, _, _ = self._get_instance_representation(self.train_data[d_index], show_attributes=self.key_list)
				if len(example_input_ids) + total_length < self.max_length - 1:
					saved_instances.append(example_input_ids)
					total_length += len(example_input_ids)
					instance_used += 1
				elif instance_used > 0:
					break
			if instance_used > 0:
				low_bound = 1
				# if not all key-value pairs are masked, it is possible not to include full example
				if len(used_kv) > 0: low_bound = 0
				selected_instance_num = random.choice([i for i in range(low_bound, instance_used + 1)])
				for instance in saved_instances[:selected_instance_num]:
					input_ids = instance + input_ids
		input_ids.append(self.tokenizer.eos_token_id)
		return np.array(input_ids, dtype=np.int64), np.array(output_ids, dtype=np.int64)

def process_tensor(tensor_list, last_dim, output_mask=False):
    tensor_len = [d.shape[0] for d in tensor_list]
    tensor_max_lenth = max(tensor_len)
    d_type = tensor_list[0].dtype
    if last_dim > 0:
        tensor_np = np.zeros((len(tensor_list), tensor_max_lenth, last_dim), dtype=d_type)
    else:
        tensor_np = np.zeros((len(tensor_list), tensor_max_lenth), dtype=d_type)
    mask_np = np.zeros((len(tensor_list), tensor_max_lenth), dtype=np.float32)
    for i, (d, l) in enumerate(zip(tensor_list, tensor_len)):
        if l > 0:
            tensor_np[i, :l] = d
            mask_np[i, :l] = 1
    if output_mask:
        return torch.from_numpy(tensor_np), torch.from_numpy(mask_np)
    else:
        return torch.from_numpy(tensor_np)

def _data_wrapper(dataset):
    encoder_input_ids, encoder_mask = process_tensor([d[0] for d in dataset], 0, output_mask=True)
    decoder_input_ids, decoder_mask = process_tensor([d[1] for d in dataset], 0, output_mask=True)
    decoder_input_ids[decoder_mask == 0] = -100
    gt_y, gt_x, data_index = None, None, None
    task_index = torch.tensor([0 for d in dataset]).long()
    task_type_index = torch.tensor([0 for d in dataset]).long()
    prefix_ids = torch.tensor([0 for d in dataset]).long()
    
    if len(dataset[0]) == 7:
    	data_index = [d[5] for d in dataset]
    	prefix_ids = torch.tensor([d[4] for d in dataset]).long()
    	task_index = torch.tensor([d[4] for d in dataset]).long()
    	task_type_index = torch.tensor([d[6] for d in dataset]).long()
    	gt_y = [d[3] for d in dataset]
    	gt_x = [d[2] for d in dataset]   	
    elif len(dataset[0]) == 6:
    	data_index = [d[5] for d in dataset]
    	prefix_ids = torch.tensor([d[4] for d in dataset]).long()
    	gt_y = [d[3] for d in dataset]
    	gt_x = [d[2] for d in dataset]
    elif len(dataset[0]) == 4:
    	task_index = torch.tensor([d[2] for d in dataset]).long()
    	task_type_index = torch.tensor([d[3] for d in dataset]).long()
    elif len(dataset[0]) == 3:
    	prefix_ids = torch.tensor([d[2] for d in dataset]).long()
    	

    return {"encoder_input_ids": encoder_input_ids, "encoder_mask": encoder_mask, "decoder_input_ids": decoder_input_ids, "task_ids": task_index, "task_type_ids": task_type_index, "prefix_ids": prefix_ids, "gt_x": gt_x, "gt_y": gt_y, "data_index": data_index}


def get_meta_task():
	task_list = []
	for task_name in TASK_NAME_TO_CLS:
		cls_method = TASK_NAME_TO_CLS[task_name]
		task_list.append(cls_method(task_name))
	return task_list

def get_meta_nlp_data(tokenizer, task_list, split, batch_size, max_length=512, shuffle=False, pre_load_training=False, distributed=False, is_root=True, is_train=True):
	meta_task_list = []
	updated_single_task = []
	for task_ in task_list:
		meta_task_list.append(MetaNLPTask(task_, split, tokenizer, max_length=max_length, pre_load_training=pre_load_training, is_train=is_train))
		updated_single_task.append(meta_task_list[-1].single_task)

	combined_dataset = ConcatDataset(meta_task_list)

	if is_root:
		print("%s Data Size %d" % (split, len(combined_dataset)))

	if distributed:
		dist_sampler = torch.utils.data.distributed.DistributedSampler(combined_dataset, shuffle=shuffle)
		dist_loader = DataLoader(combined_dataset, batch_size=batch_size, num_workers=4, collate_fn=_data_wrapper, sampler=dist_sampler)
		return dist_loader, updated_single_task
	else:
		data_loader = DataLoader(combined_dataset, batch_size=batch_size, num_workers=4, collate_fn=_data_wrapper, shuffle=shuffle)
		return data_loader, updated_single_task

def get_h5py_nlp_data(config, split, batch_size, tokenizer, max_length, use_t5_span=False, mask_warmup_steps=-1, shuffle=False, distributed=False, is_root=True, is_train=True):
	meta_task_list = []
	
	for task_name in TASK_NAME_TO_CLS:
		task_index = TASK_NAME_LIST.index(task_name)
		path_ = os.path.join("meta_task_h5_wo_SuperBLUE/", task_name + ".h5")
		task = TASK_NAME_TO_CLS[task_name](task_name)
		meta_task_list.append(MetaH5NLPTaskDynamic(config, task, task_index, path_, split, tokenizer, max_length, is_train=is_train))

	combined_dataset = ConcatDataset(meta_task_list)

	if is_root:
		print("Task Number %d" % len(TASK_NAME_TO_CLS))
		print("%s Data Size %d" % (split, len(combined_dataset)))

	if distributed:
		dist_sampler = torch.utils.data.distributed.DistributedSampler(combined_dataset, shuffle=shuffle)
		dist_loader = DataLoader(combined_dataset, pin_memory=True, batch_size=batch_size, num_workers=8, collate_fn=_data_wrapper, sampler=dist_sampler, drop_last=True)
		return dist_loader
	else:
		data_loader = DataLoader(combined_dataset, pin_memory=True, batch_size=batch_size, num_workers=8, collate_fn=_data_wrapper, shuffle=shuffle, drop_last=True)
		return data_loader

def get_h5py_nlp_dataset(config, split, tokenizer, max_length, use_t5_span=False, mask_warmup_steps=-1, is_root=True, is_train=True):
	meta_task_list = []
	
	for task_name in TASK_NAME_TO_CLS:
		task_index = TASK_NAME_LIST.index(task_name)
		path_ = os.path.join("meta_task_h5_wo_SuperBLUE/", task_name + ".h5")
		task = TASK_NAME_TO_CLS[task_name](task_name)
		meta_task_list.append(MetaH5NLPTaskDynamic(config, task, task_index, path_, split, tokenizer, max_length, is_train=is_train))

	combined_dataset = ConcatDataset(meta_task_list)

	if is_root:
		print("Task Number %d" % len(TASK_NAME_TO_CLS))
		print("%s Data Size %d" % (split, len(combined_dataset)))

	return combined_dataset

	
def get_single_h5py_nlp_data(config, path, train_path, split, batch_size, tokenizer, max_length, shuffle=False, distributed=False, is_root=True, is_train=True):
	if config.enable_pair_sentence_classification:
		if config.running_task == 'question_generation':
			combined_dataset = QuestionGenInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == 'document_generation':
			combined_dataset = DocumentInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "domain_pair":
			combined_dataset = DomainPairInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "rte_document_generation":
			combined_dataset = RTEDocumentInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "bool_document_generation":
			combined_dataset = BoolQDocumentInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "cb_document_generation":
			combined_dataset = CBDocumentInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "multi_rc_document_generation":
			combined_dataset = MultiRCDocumentInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "multi_rc_question_generation":
			combined_dataset = MultiRCQuestionInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "multi_rc_answer_generation":
			combined_dataset = MultiRCAnswerInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "hypothesis_generation":
			combined_dataset = HypothesisGenInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "cb_hypothesis_generation" or config.running_task == 'cb_nlu':
			combined_dataset = CBHypothesisGenInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "wsc_generation":
			combined_dataset = WSCGenerationInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "wic_generation":
			combined_dataset = WICGenerationInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "copa_generation":
			combined_dataset = CopaGenInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "copa_doc_generation":
			combined_dataset = CopaDocumentInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "copa_opt_generation":
			combined_dataset = CopaChoicesGenInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "copa_nlu":
			combined_dataset = COPANLUIncontext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "copa_nlu_filtering" or config.running_task == "copa_nlu_tagging":
			combined_dataset = COPANLUFilteringIncontext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "rec_document_generation":
			combined_dataset = ReCDocumentInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "rec_entities_generation":
			combined_dataset = ReCEntitiesInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		elif config.running_task == "rec_query_generation":
			combined_dataset = ReCQueryInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
		else:
			if config.enable_flip_filtering:
				combined_dataset = FlipDAInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
			else:
				combined_dataset = RTEInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)

	elif config.enable_sentence_classification:
		combined_dataset = SenCLSInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)
	else:
		combined_dataset = SeqLabelInContext(config, path, train_path, tokenizer, is_training=is_train, is_root=is_root)

	if is_root:
		print("%s Data Size %d" % (split, len(combined_dataset)))

	if distributed:
		dist_sampler = torch.utils.data.distributed.DistributedSampler(combined_dataset, shuffle=shuffle)
		dist_loader = DataLoader(combined_dataset, pin_memory=True, batch_size=batch_size, num_workers=8, collate_fn=_data_wrapper, sampler=dist_sampler)
		return dist_loader
	else:
		data_loader = DataLoader(combined_dataset, pin_memory=True, batch_size=batch_size, num_workers=8, collate_fn=_data_wrapper, shuffle=shuffle)
		return data_loader



