import argparse
import torch
import torch.nn as nn
from config import Config
import os, sys, math
from MetaPreTrainingDataset import get_single_h5py_nlp_data, _data_wrapper, read_conll, TASK_NAME_LIST
import t5_model
import numpy as np
import random
import utils
from tqdm import tqdm
import deepspeed
import re
from seqeval.metrics import f1_score as seq_f1_score
from sklearn.metrics import f1_score as sen_f1_score
from fast_bleu import SelfBLEU
import json
import edit_distance
import copy

def _average_all(tensor):
	# We copy because modification happens in-place
	averaged = tensor.detach().clone()
	# We use `all_reduce` because it is better supported than `reduce`
	torch.distributed.all_reduce(averaged, torch.distributed.ReduceOp.SUM)
	return averaged / torch.distributed.get_world_size()

compiler = re.compile(r'<extra_id_0>(.*)<extra_id_1>')
compiler_second = re.compile(r'<extra_id_1>(.*)<extra_id_2>')
def _extract(text):
	match = compiler.search(text)
	if match is None:
		text = text.strip().replace("<extra_id_0>", "").replace("<extra_id_1>", "")
		return text
	else:
		return match.group(1).strip().replace("<extra_id_0>", "").replace("<extra_id_1>", "")

def _extract_document(text):
	match = compiler.search(text)
	if match is None:
		return None
	else:
		return match.group(1).strip().replace("<extra_id_0>", "").replace("<extra_id_1>", "")

def _extract_domain_label(text):
	match = compiler.search(text)
	if match is None:
		return None
	else:
		return match.group(1).replace("<extra_id_0>", "").replace("<extra_id_1>", "").strip()

def _extract_pair(input_seq, output_text):
	text_chunks = []
	if '<extra_id_0>' in input_seq:
		match = compiler.search(output_text)
		if match is None:
			return None
		else:
			text_chunks.append(match.group(1).strip().replace("<extra_id_0>", "").replace("<extra_id_1>", ""))

	if '<extra_id_1>' in input_seq:
		match = compiler_second.search(output_text)
		if match is None:
			return None
		else:
			text_chunks.append(match.group(1).strip().replace("<extra_id_0>", "").replace("<extra_id_1>", ""))

	return text_chunks

def _extract_multiple_span(input_seq, output_text):
	text_chunks = []
	for i in range(100):
		mask = "<extra_id_%d>" % i
		next_mask = "<extra_id_%d>" % (i + 1)
		if mask in input_seq:
			compiler = re.compile(r'%s(.*)%s' % (mask, next_mask))
			match = compiler.search(output_text)
			if match is None:
				return None
			else:
				text_chunks.append(match.group(1).replace(mask, "").replace(next_mask, "").strip())
		else:
			break
	return text_chunks


def get_label_list(label_path):
    label_list = []
    with open(label_path) as out:
        for l in out.readlines():
            label_list.append(l.strip())
    return label_list

def is_match(sub_sentence, generated_entity):
	if sub_sentence == generated_entity:
		return True

	if sub_sentence.startswith(generated_entity) or sub_sentence.endswith(generated_entity):
		return True

	return False

def get_tag_result(g_ner_str, input_sen, label_list):
	g_ner_str = g_ner_str.replace("B- ", "B-")
	g_ner_str = g_ner_str.replace("B_", "B-")

	if g_ner_str == "no entities":
		gen_words = input_sen.split()
		return gen_words, ['O'] * len(gen_words), False
	
	ner_start_pos = []
	g_ner_words = g_ner_str.split()
	for index, w in enumerate(g_ner_words):
		if w.startswith("B-") and w not in input_sen:
			ner_start_pos.append(index)
	ner_start_pos.append(len(g_ner_words))

	gen_words = input_sen.split()
	label = ['O'] * len(gen_words)
	has_error = False
	for s_index, sp in enumerate(ner_start_pos[:-1]):
		entity_label = g_ner_words[sp][2:]

		if 'B-' + entity_label not in label_list: 
			has_error = True
			continue

		entity_words = g_ner_words[sp + 1: ner_start_pos[s_index + 1]]
		entity_word_str = ' '.join(entity_words)
		if entity_word_str not in input_sen:
			has_error = True
			continue

		if len(entity_words) == 0: 
			has_error = True
			continue

		for i, w in enumerate(gen_words):
			if w == entity_words[0] and is_match(' '.join(gen_words[i: i + len(entity_words)]), entity_word_str):
				label[i] = 'B-' + entity_label
				for index in range(i + 1, i + len(entity_words)):
					label[index] = 'I-' + entity_label
				break

	has_error = has_error or all([t == 'O' for t in label])

	return gen_words, label, has_error

def clean_up_t5_tokenizer(tokenizer, sen_np):
    str_sen = ""
    for t in tokenizer.convert_ids_to_tokens(sen_np):
        t = t.replace(chr(9601), ' ')
        str_sen += t
    return str_sen.replace('</s>', '').replace('<pad>', '').strip()

def process_generated_ner_data(output_lines, output_path):
	orginal_num = len(output_lines)
	output_lines = [_extract(v[1]) for v in output_lines]
	generated_sentence = set([v for v in output_lines if len(v.split()) > 3 and len(v.split()) < 100])
	
	weights = {'bigram': (1/2., 1/2.), 'trigram': (1/3., 1/3., 1/3.)}
	bio_words = [v.split() for v in output_lines]
	self_bleu = SelfBLEU(bio_words, weights)
	score = self_bleu.get_score()
	self_b3 = score['trigram']
	self_b3_value = 100 * sum(self_b3) / len(self_b3) 
	
	print("Self BLEU %.2f" % self_b3_value)
	print("Distincted Sentences Count %d (total %d)" % (len(generated_sentence), orginal_num))

	if output_path is not None:
		with open(output_path, 'w') as out:
			for sen in generated_sentence:
				words = sen.split()
				labels = ['O'] * len(words)
				for g, l in zip(words, labels):
					out.write("%s %s\n" % (g, l))
				out.write("\n")

def process_generated_pair_data_v2(output_lines, output_path, to_json=False):
	correct = 0
	output_sens = {}
	# for index, enc_input, gen in output_lines:
	# 	text_chunk = _extract(gen)
	# 	if text_chunk is not None:
	# 		correct += 1
	# 		if index not in output_sens:
	# 			output_sens[index] = {}
	# 		items = enc_input.split('|*|')
	# 		if '<extra_id_0>' in items[0]:
	# 			if 'question' not in output_sens[index]:
	# 				output_sens[index]['question'] = []
	# 			output_sens[index]['question'].append(text_chunk)
	# 		else:
	# 			if 'article' not in output_sens[index]:
	# 				output_sens[index]['article'] = []
	# 			output_sens[index]['article'].append(text_chunk)
	for index, enc_input, gen in output_lines:
		raw_enc_input = enc_input
		text_chunks = _extract_multiple_span(enc_input, gen)
		if text_chunks is not None:
			correct += 1
			for i, text in enumerate(text_chunks):
				enc_input = enc_input.replace("<extra_id_%d>" % i, text)
		if index not in output_sens:
	 		output_sens[index] = {}
		items = enc_input.split('|*|')
		question = items[0].replace('Questions:', '').strip()
		article = items[2].replace('Article:', '').strip()
		if 'question' not in output_sens[index]:
			output_sens[index]['question'] = []
		output_sens[index]['question'].append(question)
		if 'article' not in output_sens[index]:
			output_sens[index]['article'] = []
		output_sens[index]['article'].append(article)

	if output_path is not None:
		instance_count = 0
		with open(output_path, 'w') as out:
			for index in output_sens:
				question_list = set(output_sens[index]['question'])
				article_list = set(output_sens[index]['article'])
				for q in question_list:
					for a in article_list:
						for l in ['True', 'False']:
							instance_count += 1
							json_item = {"question": q, "passage": a, "label": l, "idx": index}
							out.write(json.dumps(json_item) + '\n')
		print("Save Instances %d" % instance_count)
				

def process_generated_pair_data(output_lines, output_path, to_json=False):
	correct = 0
	enc_input_list = []
	for index, enc_input, gen in output_lines:
		raw_enc_input = enc_input
		text_chunks = _extract_multiple_span(enc_input, gen)
		if text_chunks is not None:
			correct += 1
			for i, text in enumerate(text_chunks):
				enc_input = enc_input.replace("<extra_id_%d>" % i, text)

			if '<extra_id_' not in enc_input:
				enc_input_list.append((index, enc_input))

	enc_input_list = list(set(enc_input_list))
	print("Correct %.2f" % (100 * correct / len(output_lines)))
	print("Distincted Sentences Count %d (total %d)" % (len(enc_input_list), len(output_lines)))

	generated_data_list = []
	for index, enc_input in enc_input_list:
		items = enc_input.split('|*|')
		assert len(items) == 3, items
		assert 'Questions:' in items[0]
		question = items[0].replace('Questions:', '').strip()
		assert 'Label:' in items[1]
		label = items[1].replace('Label:', '').strip()
		assert 'Article:' in items[2]
		article = items[2].replace('Article:', '').strip()
		generated_data_list.append((index, question, article, label))

	if output_path is not None:
		if not to_json:
			with open(output_path, 'w') as out:
				for (_, a, b, c) in generated_data_list:
					out.write("%s\t%s\t%s\n" % (a, b, c))
		else:
			with open(output_path, 'w') as out:
				for (index, question, article, label) in generated_data_list:
					json_item = {"question": question, "passage": article, "label": label, "idx": index}
					out.write(json.dumps(json_item) + '\n')

def process_rte_generated_pair_data(output_lines, output_path, to_json=False):
	correct = 0
	enc_input_list = []
	for index, enc_input, gen, _ in output_lines:
		raw_enc_input = enc_input
		text_chunks = _extract_multiple_span(enc_input, gen)
		if text_chunks is not None:
			correct += 1
			for i, text in enumerate(text_chunks):
				enc_input = enc_input.replace("<extra_id_%d>" % i, text)

			if '<extra_id_' not in enc_input:
				enc_input_list.append((index, enc_input))

	enc_input_list = list(set(enc_input_list))
	print("Correct %.2f" % (100 * correct / len(output_lines)))
	print("Distincted Sentences Count %d (total %d)" % (len(enc_input_list), len(output_lines)))

	generated_data_list = []
	for index, enc_input in enc_input_list:
		items = enc_input.split('|*|')
		assert len(items) == 3, items
		assert 'Hypothesis:' in items[0]
		hypothesis = items[0].replace('Hypothesis:', '').strip()
		assert 'Label:' in items[1]
		label = items[1].replace('Label:', '').strip()
		assert 'Premise:' in items[2]
		premise = items[2].replace('Premise:', '').strip()
		generated_data_list.append((index, hypothesis, premise, label))

	if output_path is not None:
		if not to_json:
			with open(output_path, 'w') as out:
				for (_, a, b, c) in generated_data_list:
					out.write("%s\t%s\t%s\n" % (a, b, c))
		else:
			with open(output_path, 'w') as out:
				for index, hypothesis, premise, label in generated_data_list:
					if label == 'not entailment':
						label = "not_entailment"
					json_item = {"hypothesis": hypothesis, "premise": premise, "label": label, "idx": index}
					out.write(json.dumps(json_item) + '\n')



def process_copa_generated_pair_data(output_lines, output_path, to_json=False):
	enc_input_list = []
	gt_y_set = set()
	for index, enc_input, gen, gt_y in output_lines:
		gt_y_set.add(gt_y)
		text_chunks = _extract_multiple_span(enc_input, gen)
		if text_chunks is not None:
			for i, text in enumerate(text_chunks):
				enc_input = enc_input.replace("<extra_id_%d>" % i, text)

			if '<extra_id_' not in enc_input:
				enc_input_list.append(enc_input)

	enc_input_list = list(set(enc_input_list))
	print("Distincted Sentences Count %d (total %d)" % (len(enc_input_list), len(output_lines)))

	key_list = ['Premise:', 'Question:', 'Choice1:', 'Choice2:', 'Answer:']
	generated_data_list = []
	for enc_input in enc_input_list:
		items = enc_input.split('|*|')
		assert len(items) == 5, items

		value_list = []
		for key_index, key in enumerate(key_list):
			assert key in items[key_index]
			value = items[key_index].replace(key, '').strip()
			value_list.append(value)

		if value_list[2] == value_list[3]:
			continue

		generated_data_list.append(value_list)

	print("remaining length %d" % len(generated_data_list))

	data_index = {}
	for premise, question, choice1, choice2, ans in generated_data_list:
		key = "%s-%s" % (choice1, choice2)
		if key not in data_index:
			data_index[key] = len(data_index)

	print("Total Group Count %d" % len(data_index))

	if output_path is not None:
		final_saved_count = 0
		with open(output_path, 'w') as out:
			for index, (premise, question, choice1, choice2, ans) in enumerate(generated_data_list):
				if premise in gt_y_set: continue
				key = "%s-%s" % (choice1, choice2)
				question = 'cause' if 'cause' in question else 'effect'
				json_item = {'premise': premise, 'question': question, 'idx': data_index[key]}
				json_item['label'] = 0 if ans == 'Choice1' else 1
				json_item['choice1'] = choice1
				json_item['choice2'] = choice2
				out.write(json.dumps(json_item) + '\n')
				final_saved_count += 1
		print("Final count %d" % final_saved_count)

def process_wic_generated_pair_data(output_lines, output_path, to_json=False):
	enc_input_list = []
	for index, enc_input, gen, _ in output_lines:
		raw_enc_input = enc_input
		text_chunks = _extract_multiple_span(enc_input, gen)
		if text_chunks is not None:
			for i, text in enumerate(text_chunks):
				enc_input = enc_input.replace("<extra_id_%d>" % i, text)

			if '<extra_id_' not in enc_input:
				enc_input_list.append((index, enc_input))

	enc_input_list = list(set(enc_input_list))
	print("Distincted Sentences Count %d (total %d)" % (len(enc_input_list), len(output_lines)))

	generated_data_list = []
	for index, enc_input in enc_input_list:
		items = enc_input.split('|*|')
		assert len(items) == 4, items

		assert 'Word:' in items[0]
		word = items[0].replace('Word:', '').strip()

		assert 'Sentence a:' in items[1]
		sen1 = items[1].replace('Sentence a:', '').strip()
		match = compiler.search(sen1)

		assert 'Sentence b:' in items[2]
		sen2 = items[2].replace('Sentence b:', '').strip()
		assert 'Sense match:' in items[3]
		label = items[3].replace('Sense match:', '').strip()

		if not len(word.strip().split()) == 1:
			continue

		sen1_keywords = [m for m in re.finditer(r'\*\*(.*)\*\*', sen1)]
		if not len(sen1_keywords) == 1:
			continue
		sen1_keyword = sen1_keywords[0]
		start1 = sen1_keyword.span()[0]
		end1 = start1 + len(sen1_keyword.group(1))
		sen1 = sen1.replace("**", '')

		sen2_keywords = [m for m in re.finditer(r'\*\*(.*)\*\*', sen2)]
		if not len(sen2_keywords) == 1:
			continue
		sen2_keyword = sen2_keywords[0]
		start2 = sen2_keyword.span()[0]
		end2 = start2 + len(sen2_keyword.group(1))
		sen2 = sen2.replace("**", '')

		if sen1 == sen2:
			continue

		words_sen1 = sen1.split()
		words_sen2 = sen2.split()

		sm = edit_distance.SequenceMatcher(a=words_sen1, b=words_sen2)
		if sm.distance() <= 2:
			continue

		generated_data_list.append((word, sen1, sen2, start1, end1, start2, end2, label))

	print("Filtered Sentences Count %d (total %d)" % (len(generated_data_list), len(output_lines)))

	label_mapping = {}
	for item in generated_data_list:
		sen1, sen2 = item[1], item[2]
		key1 = sen1 + '\t' + sen2
		key2 = sen2 + '\t' + sen1
		if key1 not in label_mapping:
			label_mapping[key1] = []
		if key2 not in label_mapping:
			label_mapping[key2] = []
		label_mapping[key1].append(item[-1])
		label_mapping[key2].append(item[-1])

	final_generated_data_list = []
	for item in generated_data_list:
		sen1, sen2 = item[1], item[2]
		key1 = sen1 + '\t' + sen2
		key2 = sen2 + '\t' + sen1

		if len(set(label_mapping[key1])) == 1 and len(set(label_mapping[key2])) == 1:
			final_generated_data_list.append(item)

	print("remaining length %d" % len(final_generated_data_list))

	if output_path is not None:
		with open(output_path, 'w') as out:
			for index, (word, sen1, sen2, start1, end1, start2, end2, label) in enumerate(final_generated_data_list):
				json_item = {"word": word, "sentence1": sen1, "sentence2": sen2, "start1": start1, "end1": end1, "start2": start2, "end2": end2, "label": label == 'True', "idx": index}
				out.write(json.dumps(json_item) + '\n')
		
def process_wsc_generated_pair_data(output_lines, output_path, to_json=False):
	correct = 0
	enc_input_list = []
	for index, enc_input, gen, input_seq in output_lines:
		raw_enc_input = enc_input
		text_chunks = _extract_multiple_span(enc_input, gen)
		if text_chunks is not None:
			correct += 1
			for i, text in enumerate(text_chunks):
				enc_input = enc_input.replace("<extra_id_%d>" % i, text)

			if '<extra_id_' not in enc_input:
				enc_input_list.append((index, enc_input))

	enc_input_list = list(set(enc_input_list))
	print("Distincted Sentences Count %d (total %d)" % (len(enc_input_list), len(output_lines)))
	

	generated_data_list = []
	key_set = set()
	for index, enc_input in enc_input_list:
		items = enc_input.split('|*|')
		assert len(items) == 2, items
		assert 'Text:' in items[0]
		text = items[0].replace('Text:', '').strip()
		assert 'Coreference:' in items[1]
		label = items[1].replace('Coreference:', '').strip()
		removed_punc = False

		if text.endswith('.'):
			removed_punc = True
			text = text[:-1]

		words = text.split()
		span1_index, span1_text = -1, ""
		span2_index, span2_text = -1, ""
		text_span_index = 0
		is_correct = True
		for index, w in enumerate(words):
			if w.startswith('**'):
				if text_span_index == 0:
					if span1_index == -1:
						span1_index = index
					else:
						is_correct = False
						break
				else:
					if span2_index == -1:
						span2_index = index
					else:
						is_correct = False
						break

			if w.endswith("**"):
				if text_span_index == 0:
					if span1_index > -1:
						span1_text = ' '.join(words[span1_index: index + 1])
						span1_text = span1_text.replace("**", "")
						text_span_index += 1
						if span1_text.lower() in ['their', 'her', 'it', 'them', 'it,', 'him', 'they', 'he', 'his', 'she'] or span1_text == span2_text:
							is_correct = False
							break
					else:
						is_correct = False
						break
				else:
					if span2_index > -1:
						span2_text = ' '.join(words[span2_index: index + 1])
						span2_text = span2_text.replace("**", "")
						text_span_index += 1
						if span2_text.lower() not in ['their', 'her', 'it', 'them', 'it,', 'him', 'they', 'he', 'his', 'she'] or span1_text == span2_text:
							is_correct = False
							break
					else:
						is_correct = False
						break

		if removed_punc:
			text = text + '.'

		new_text = text.replace("**", "")
		new_tokens = new_text.split()
		if not new_tokens[span2_index] ==  span2_text:
			is_correct = False

		if text_span_index == 2 and is_correct:
			key = "%s-%s-%s-%s" % (new_text, span1_text, span2_text, label)
			if key not in key_set:
				key_set.add(key)
				generated_data_list.append((new_text, span1_index, span1_text, span2_index, span2_text, label))

	print("Final Sentences Count %d (total %d)" % (len(generated_data_list), len(output_lines)))

	if output_path is not None:
		with open(output_path, 'w') as out:
			for index, (new_text, span1_index, span1_text, span2_index, span2_text, label) in enumerate(generated_data_list):
				target = {"span2_index": span2_index, "span2_text": span2_text, "span1_index": span1_index, "span1_text": span1_text}
				json_item = {"text": new_text, "label": bool(label), "target": target, "idx": index}
				out.write(json.dumps(json_item) + '\n')

def process_document_data(output_lines, output_path, to_json=False):
	valid_document_list = []
	for index, enc_input, gen in output_lines:
		document = _extract_document(gen)
		if document is not None and '<unk>' not in document:
			tokens = document.split()
			if len(tokens) > 30:
				valid_document_list.append(document)
	valid_document_list = list(set(valid_document_list))

	print("saved document %d out of %d" % (len(valid_document_list), len(output_lines)))

	if output_path is not None:
		if not to_json:
			with open(output_path, 'w') as out:
				for document in valid_document_list:
					out.write("Empty Question\t%s\tTrue\n" % document)
		else:
			with open(output_path, 'w') as out:
				for doc_index, document in enumerate(valid_document_list):
					json_item = {"question": "Empty Question", "passage": document, "label": "True", "idx": doc_index}
					out.write(json.dumps(json_item) + '\n')

def process_copa_doc_generated_pair_data(output_lines, output_path, to_json=False):
	valid_document_list = []
	for index, enc_input, gen, _ in output_lines:
		document = _extract_document(gen)
		if document is not None and '<unk>' not in document and '<extra_id_' not in document:
			valid_document_list.append(document)
	valid_document_list = list(set(valid_document_list))

	print("saved document %d out of %d" % (len(valid_document_list), len(output_lines)))

	if output_path is not None:
		with open(output_path, 'w') as out:
			for index, premise in enumerate(valid_document_list):
				json_item = {'premise': premise, 'question': 'cause', 'idx': index, 'choice1': 'empty', 'choice2': "empty", 'label': 0}
				out.write(json.dumps(json_item) + '\n')

def process_rec_document_generated_pair_data(output_lines, output_path, to_json=False):
	valid_document_list = []
	for index, enc_input, gen, input_seq in output_lines:
		document = _extract_document(gen)
		if document is not None and '<unk>' not in document and '<extra_id_' not in document:
			tokens = document.split()
			if len(tokens) > 20:
				valid_document_list.append(document)
	valid_document_list = list(set(valid_document_list))

	print("saved REC document %d out of %d" % (len(valid_document_list), len(output_lines)))

	if output_path is not None:
		with open(output_path, 'w') as out:
			for doc_index, document in enumerate(valid_document_list):
				json_item = {"source": "Daily mail", "passage": {"text": document, 'entities': []}}
				out.write(json.dumps(json_item) + '\n')

def process_rec_entities_generated_pair_data(output_lines, output_path, to_json=False):
	enc_input_list = []
	for index, enc_input, gen, input_seq in output_lines:
		text_chunks = _extract_multiple_span(enc_input, gen)
		if text_chunks is not None:
			for i, text in enumerate(text_chunks):
				enc_input = enc_input.replace("<extra_id_%d>" % i, text)

			if '<extra_id_' not in enc_input:
				enc_input_list.append(enc_input)
		
	enc_input_list = list(set(enc_input_list))
	print("saved ReC entities %d out of %d" % (len(enc_input_list), len(output_lines)))

	key_list = ['Entities:', 'Document:']
	generated_data_list = []
	for enc_input in enc_input_list:
		items = enc_input.split('|*|')

		assert len(items) == 2, items

		value_list = []
		for key_index, key in enumerate(key_list):
			assert key in items[key_index]
			value = items[key_index].replace(key, '').strip()
			value_list.append(value)

		generated_data_list.append(value_list)

	final_generated_data_list = []
	for entities, document in generated_data_list:
		entity_list = [e.strip() for e in entities.split(',')]

		final_entity_list = []
		for entity in entity_list:
			entity_spans = [m for m in re.finditer(r'\s%s\s' % entity, document)]
			for span in entity_spans:
				start, end = span.span()
				final_entity_list.append({'start': start + 1, "end": end - 2})

		if len(final_entity_list) > 5:
			final_generated_data_list.append((document, final_entity_list))

	print("saved REC document and entities %d out of %d" % (len(final_generated_data_list), len(output_lines)))

	if output_path is not None:
		with open(output_path, 'w') as out:
			for doc_index, (document, final_entity_list) in enumerate(final_generated_data_list):
				doc_json = {"text": document, "entities": final_entity_list}
				json_item = {"source": "Daily mail", "passage": doc_json, "qas": [{'query': 'empty', 'answers': [{'text': 'empty'}]}]}
				out.write(json.dumps(json_item) + '\n')

def process_rec_queries_generated_pair_data(output_lines, output_path, to_json=False):
	enc_input_list = []
	for index, enc_input, gen, input_seq in output_lines:
		text_chunks = _extract_multiple_span(enc_input, gen)
		if text_chunks is not None:
			for i, text in enumerate(text_chunks):
				enc_input = enc_input.replace("<extra_id_%d>" % i, text)

			if '<extra_id_' not in enc_input:
				enc_input_list.append(enc_input)
		
	enc_input_list = list(set(enc_input_list))
	print("saved ReC entities %d out of %d" % (len(enc_input_list), len(output_lines)))

	key_list = ['Entities:', 'Query:', 'Document:']
	generated_data_list = []
	for enc_input in enc_input_list:
		items = enc_input.split('|*|')

		assert len(items) == 3, items

		value_list = []
		for key_index, key in enumerate(key_list):
			assert key in items[key_index]
			value = items[key_index].replace(key, '').strip()
			value_list.append(value)

		generated_data_list.append(value_list)

	final_generated_data_list = []
	for entities, query, document in generated_data_list:
		entity_list = [e.strip() for e in entities.split(',')]
		final_entity_list = []
		for entity in entity_list:
			entity_spans = [m for m in re.finditer(r'\s%s\s' % entity, document)]
			for span in entity_spans:
				start, end = span.span()
				final_entity_list.append({'start': start + 1, "end": end - 2})

		for entity in entity_list:
			if ' ' + entity + ' ' in query:
				copied_query = copy.deepcopy(query)
				copied_query = copied_query.replace(' ' + entity + ' ', " @placeholder ")
				entity_spans = [m for m in re.finditer(r'\s%s\s' % entity, document)]
				answers = []
				for span in entity_spans:
					start, end = span.span()
					answers.append({'start': start + 1, "end": end - 2, 'text': entity})

				final_generated_data_list.append((document, final_entity_list, copied_query, answers))

	print("saved REC document, entities and query %d out of %d" % (len(final_generated_data_list), len(output_lines)))

	if output_path is not None:
		query_idx = 0
		doc_idx = 0
		with open(output_path, 'w') as out:
			for doc_index, (document, final_entity_list, query, answers) in enumerate(final_generated_data_list):
				qas = [{"query": query, "answers": answers, "idx": query_idx}]
				json_item = {"source": "Daily mail", "passage": {"text": document, "entities": final_entity_list}, "idx": doc_idx, "qas": qas}
				doc_idx += 1
				query_idx += 1
				out.write(json.dumps(json_item) + '\n')


def process_rte_document_data(output_lines, output_path, to_json=False):
	valid_document_list = []
	for index, enc_input, gen, input_seq in output_lines:
		document = _extract_document(gen)
		if document is not None and '<unk>' not in document and '<extra_id_' not in document:
			tokens = document.split()
			if len(tokens) > 20:
				valid_document_list.append((document, input_seq.count('Text:') > 1))
	valid_document_list = list(set(valid_document_list))

	print("saved RTE document %d out of %d" % (len(valid_document_list), len(output_lines)))

	if output_path is not None:
		if not to_json:
			with open(output_path, 'w') as out:
				for document in valid_document_list:
					out.write("Empty Hypothesis\t%s\tnot_entailment\n" % document)
		else:
			with open(output_path, 'w') as out:
				for doc_index, (document, has_example) in enumerate(valid_document_list):
					json_item = {"hypothesis": "Empty Hypothesis", "premise": document, "label": "not_entailment", "idx": doc_index}
					out.write(json.dumps(json_item) + '\n')

def process_cb_document_data(output_lines, output_path, to_json=False):
	valid_document_list = []
	for index, enc_input, gen, input_seq in output_lines:
		document = _extract_document(gen)
		if document is not None and '<unk>' not in document and '<extra_id_' not in document:
			tokens = document.split()
			if len(tokens) > 20:
				valid_document_list.append((document, input_seq.count('Text:') > 1))
	valid_document_list = list(set(valid_document_list))

	print("saved CB document %d out of %d" % (len(valid_document_list), len(output_lines)))

	if output_path is not None:
		if not to_json:
			with open(output_path, 'w') as out:
				for document in valid_document_list:
					out.write("Empty Hypothesis\t%s\tnot_entailment\n" % document)
		else:
			with open(output_path, 'w') as out:
				for doc_index, (document, has_example) in enumerate(valid_document_list):
					json_item = {"hypothesis": "Empty Hypothesis", "premise": document, "label": "entailment", "idx": doc_index}
					out.write(json.dumps(json_item) + '\n')

def process_multi_rc_document_data(output_lines, output_path, to_json=False):
	valid_document_list = []
	for index, enc_input, gen, input_seq in output_lines:
		document = _extract_document(gen)
		if document is not None and '<unk>' not in document and '<extra_id_' not in document:
			tokens = document.split()
			if len(tokens) > 20:
				valid_document_list.append(document)
	valid_document_list = list(set(valid_document_list))

	print("saved MultiRC document %d out of %d" % (len(valid_document_list), len(output_lines)))

	if output_path is not None:
		with open(output_path, 'w') as out:
			for doc_index, document in enumerate(valid_document_list):
				passage_json_item = {"text": document, "questions": [{'question': 'empty'}]}
				json_item = {"passage": passage_json_item, "idx": doc_index}
				out.write(json.dumps(json_item) + '\n')


def process_multi_rc_question_data(output_lines, output_path, to_json=False):
	enc_input_list = []
	for index, enc_input, gen, input_seq in output_lines:
		text_chunks = _extract_multiple_span(enc_input, gen)
		if text_chunks is not None:
			for i, text in enumerate(text_chunks):
				enc_input = enc_input.replace("<extra_id_%d>" % i, text)

			if '<extra_id_' not in enc_input:
				enc_input_list.append(enc_input)
		
	enc_input_list = list(set(enc_input_list))
	print("saved MultiRC questions %d out of %d" % (len(enc_input_list), len(output_lines)))

	key_list = ['Question:', 'Document:']
	generated_data_list = []
	for enc_input in enc_input_list:
		items = enc_input.split('|*|')

		assert len(items) == 2, items

		value_list = []
		for key_index, key in enumerate(key_list):
			assert key in items[key_index]
			value = items[key_index].replace(key, '').strip()
			value_list.append(value)

		generated_data_list.append(value_list)

	document2questions = {}
	for question, doc in generated_data_list:
		if doc not in document2questions:
			document2questions[doc] = set()
		document2questions[doc].add(question)

	if output_path is not None:
		with open(output_path, 'w') as out:
			for doc_index, document in enumerate(document2questions):
				passage_json_item = {"text": document, "questions": []}
				for question in document2questions[document]:
					passage_json_item['questions'].append({'question': question, 'answers': [{'text': 'empty', 'label': 0}]})
				json_item = {"passage": passage_json_item, "idx": doc_index}
				out.write(json.dumps(json_item) + '\n')

def process_multi_rc_answer_data(output_lines, output_path, to_json=False):
	enc_input_list = []
	for index, enc_input, gen, input_seq in output_lines:
		text_chunks = _extract_multiple_span(enc_input, gen)
		if text_chunks is not None:
			for i, text in enumerate(text_chunks):
				enc_input = enc_input.replace("<extra_id_%d>" % i, text)

			if '<extra_id_' not in enc_input:
				enc_input_list.append(enc_input)
		
	enc_input_list = list(set(enc_input_list))
	print("saved MultiRC questions %d out of %d" % (len(enc_input_list), len(output_lines)))

	key_list = ['Label:', 'Question:', 'Answer:', 'Document:']
	generated_data_list = []
	for enc_input in enc_input_list:
		items = enc_input.split('|*|')

		assert len(items) == 4, items

		value_list = []
		for key_index, key in enumerate(key_list):
			assert key in items[key_index]
			value = items[key_index].replace(key, '').strip()
			value_list.append(value)

		generated_data_list.append(value_list)

	document2questions = {}
	for label, question, answer, doc in generated_data_list:
		if doc not in document2questions:
			document2questions[doc] = {}
		if question not in document2questions:
			document2questions[doc][question] = []
		document2questions[doc][question].append((answer, label))

	if output_path is not None:
		ans_idx = 0
		question_idx = 0
		with open(output_path, 'w') as out:
			for doc_index, document in enumerate(document2questions):
				passage_json_item = {"text": document, "questions": [], "idx": doc_index}
				for question in document2questions[document]:
					answer_list = []
					for answer, label in document2questions[document][question]:
						answer_list.append({'label': 0 if label == 'False' else 1, 'text': answer, 'idx': ans_idx})
						ans_idx += 1
					passage_json_item['questions'].append({'question': question, 'answers': answer_list, 'idx': question_idx})
					question_idx += 1
				json_item = {"passage": passage_json_item, "idx": doc_index}
				out.write(json.dumps(json_item) + '\n')


def process_domain_pair_data(output_lines, output_path, to_json=False):
	correct = 0
	ans_correct = 0
	same_domain_list = []
	
	for enc_input, gt_y, label, _ in output_lines:
		if label is None:
			continue

		if label.startswith('Yes'):
			label = 'Yes'
		elif label.startswith('No'):
			label = 'No'
		if label in ['Yes', 'No']:
			correct += 1
		if label == 'No':
			ans_correct += 1
		if label is not None:
			if label == 'Yes':
				items = enc_input.split('|***|')
				same_domain_list.append(items[0].strip())
				same_domain_list.append(items[1].strip())

	print("output format correct ratio %.2f" % (100 * correct / len(output_lines)))
	print("ans correct ratio %.2f" % (100 * ans_correct / len(output_lines)))

	sen_count_dict = {}
	for sen in same_domain_list:
		if sen not in sen_count_dict:
			sen_count_dict[sen] = 0
		sen_count_dict[sen] += 1

	selected_sen_list = []
	for sen in sen_count_dict:
		if sen_count_dict[sen] == 1:
			items = sen.split('|*|')
			label, question, article = items[0], items[1], items[2]
			label = label.replace('Label 1:', '').replace('Label 2:', '').strip()
			question = question.replace('Questions 1:', '').replace('Questions 2:', '').strip()
			article = article.replace('Article 1:', '').replace('Article 2:', '').strip()
			selected_sen_list.append((question, article, label))

	if output_path is not None:
		print("saved %d instances" % len(selected_sen_list))
		if not to_json:
			with open(output_path, 'w') as out:
				for (a, b, c) in selected_sen_list:
					out.write("%s\t%s\t%s\n" % (a, b, c))
		else:
			with open(output_path, 'w') as out:
				for index, (question, article, label) in enumerate(selected_sen_list):
					json_item = {"question": question, "passage": article, "label": label, "idx": index}
					out.write(json.dumps(json_item) + '\n')

	return 100 * ans_correct / len(output_lines)

def process_consistency_data(output_lines, output_path, to_json=False):
	selected_sen_list = [] 
	for (question, article), gt_y, label, _ in output_lines:
		if gt_y == label:
			selected_sen_list.append((question, article, label))

	if output_path is not None:
		print("saved %d instances" % len(selected_sen_list))
		if not to_json:
			with open(output_path, 'w') as out:
				for (a, b, c) in selected_sen_list:
					out.write("%s\t%s\t%s\n" % (a, b, c))
		else:
			with open(output_path, 'w') as out:
				for index, (question, article, label) in enumerate(selected_sen_list):
					json_item = {"question": question, "passage": article, "label": label, "idx": index}
					out.write(json.dumps(json_item) + '\n')

	return 1

def gen_data_loss(_C, eval_data, model, device, tokenizer, output_path=None, is_root=True):
	local_output = []
	model.eval()
	eval_iter = iter(eval_data)

	loss_list = []
	with torch.no_grad():

		if is_root:
			pbar = tqdm(eval_data)
		else:
			pbar = eval_data

		for batch in pbar:
			
			for n in batch:
				if n in ['gt_x', 'gt_y', 'data_index']: continue
				batch[n] = batch[n].to(device)

			batch_size = batch['encoder_input_ids'].size(0)

			outputs = dist_model(
			    input_ids=batch['encoder_input_ids'], 
			    attention_mask=batch['encoder_mask'], 
			    labels=batch['decoder_input_ids'],
			    task_ids=batch['task_ids'],
				task_type_ids=batch['task_type_ids'],
				prefix_ids=batch['prefix_ids'],
			)
			loss = outputs.loss

			loss_list.append(loss.item())

	final_loss = sum(loss_list) / len(loss_list)

	if is_root:
		print("EVAL LOSS %.2f" % final_loss)

	return -1 * final_loss



def gen_data(_C, eval_data, model, device, tokenizer, output_path=None, is_root=True):
	local_output = []
	model.eval()
	eval_iter = iter(eval_data)

	output_lines = []
	with torch.no_grad():

		if is_root:
			pbar = tqdm(eval_data)
		else:
			pbar = eval_data

		for batch in pbar:
			
			for n in batch:
				if n in ['gt_x', 'gt_y', 'data_index']: continue
				batch[n] = batch[n].to(device)

			batch_size = batch['encoder_input_ids'].size(0)

			outputs = model.generate(
				input_ids=batch['encoder_input_ids'], 
				attention_mask=batch['encoder_mask'], 
				task_ids=batch['task_ids'],
				task_type_ids=batch['task_type_ids'],
				prefix_ids=batch['prefix_ids'],
				max_length=_C.max_length,
				min_length=_C.min_length,
				eos_token_id=tokenizer.eos_token_id,
				num_return_sequences=_C.sample_num, 
				do_sample=True,
				top_p=_C.top_p,
				top_k=0,
				early_stopping=True
			)

			outputs = outputs.view(batch_size, _C.sample_num, -1)

			batch_lines = []
			for i in range(batch_size):
				encoder_seq = batch['gt_x'][i]
				data_index = batch['data_index'][i]
				for j in range(_C.sample_num):
					output_seq = tokenizer.decode(outputs[i][j]).replace("<pad>", "").replace("</s>", "")
					input_seq = tokenizer.decode(batch['encoder_input_ids'][i])
					batch_lines.append((data_index, encoder_seq, output_seq, batch['gt_y'][i]))

			all_batch_lines = [None for _ in range(torch.distributed.get_world_size())]
			torch.distributed.all_gather_object(all_batch_lines, batch_lines)
			for rank_batch_line in all_batch_lines:
				output_lines += rank_batch_line

	score = 1.0
	if is_root:
		if _C.running_task == 'document_generation':
			process_document_data(output_lines, output_path, to_json=_C.output_as_json)
		elif _C.running_task == 'rte_document_generation':
			process_rte_document_data(output_lines, output_path, to_json=_C.output_as_json)
		elif _C.running_task == 'cb_document_generation':
			process_cb_document_data(output_lines, output_path, to_json=_C.output_as_json)
		elif _C.running_task == 'multi_rc_document_generation':
			process_multi_rc_document_data(output_lines, output_path, to_json=_C.output_as_json)
		elif _C.running_task == 'multi_rc_question_generation':
			process_multi_rc_question_data(output_lines, output_path, to_json=_C.output_as_json)
		elif _C.running_task == "multi_rc_answer_generation":
			process_multi_rc_answer_data(output_lines, output_path, to_json=_C.output_as_json)
		elif _C.running_task == 'hypothesis_generation' or _C.running_task == 'cb_hypothesis_generation':
			process_rte_generated_pair_data(output_lines, output_path, to_json=_C.output_as_json)
		elif _C.running_task == 'wsc_generation':
			process_wsc_generated_pair_data(output_lines, output_path, to_json=_C.output_as_json)
		elif _C.running_task == 'wic_generation':
			process_wic_generated_pair_data(output_lines, output_path, to_json=_C.output_as_json)
		elif _C.running_task == 'copa_generation':
			process_copa_generated_pair_data(output_lines, output_path, to_json=_C.output_as_json)
		elif _C.running_task == 'copa_doc_generation':
			process_copa_doc_generated_pair_data(output_lines, output_path, to_json=_C.output_as_json)
		elif _C.running_task == 'copa_opt_generation':
			process_copa_generated_pair_data(output_lines, output_path, to_json=_C.output_as_json)
		elif _C.running_task == "rec_document_generation":
			process_rec_document_generated_pair_data(output_lines, output_path, to_json=_C.output_as_json)
		elif _C.running_task == "rec_entities_generation":
			process_rec_entities_generated_pair_data(output_lines, output_path, to_json=_C.output_as_json)
		elif _C.running_task == "rec_query_generation":
			process_rec_queries_generated_pair_data(output_lines, output_path, to_json=_C.output_as_json)
		else:
			process_generated_pair_data(output_lines, output_path, to_json=_C.output_as_json)
		
	return score

def gen_for_nlu(_C, eval_data, model, device, tokenizer, output_path=None, is_root=True):
	local_output = []
	model.eval()
	eval_iter = iter(eval_data)

	label_list = get_label_list(_C.label_path)

	output_lines = []
	with torch.no_grad():

		if is_root:
			pbar = tqdm(eval_data)
		else:
			pbar = eval_data

		for batch in pbar:
			
			for n in batch:
				if n in ['gt_x', 'gt_y', 'data_index']: continue
				batch[n] = batch[n].to(device)

			outputs = model.generate(
				input_ids=batch['encoder_input_ids'], 
				attention_mask=batch['encoder_mask'], 
				task_ids=batch['task_ids'],
				task_type_ids=batch['task_type_ids'],
				prefix_ids=batch['prefix_ids'],
				max_length=256,
				min_length=10,
				eos_token_id=tokenizer.eos_token_id,
				num_return_sequences=1, 
				num_beams=1,
				early_stopping=True,
			)

			batch_lines = []
			prompt_id = batch['prefix_ids'].cpu().numpy().tolist()
			for i in range(batch['decoder_input_ids'].size(0)):
				gt_x, gt_y, data_index, pid = batch['gt_x'][i], batch['gt_y'][i], batch['data_index'][i], prompt_id[i]
				output_seq = _extract(clean_up_t5_tokenizer(tokenizer, outputs[i]))
				batch_lines.append((data_index, pid, gt_x, gt_y, output_seq))

			all_batch_lines = [None for _ in range(torch.distributed.get_world_size())]
			torch.distributed.all_gather_object(all_batch_lines, batch_lines)
			for rank_batch_line in all_batch_lines:
				output_lines += rank_batch_line

	gen_labels, gt_labels, input_sentences = {}, {}, {}
	total_instances, correct_instances = 0, 0
	for (data_index, pid, gt_x, gt_y, output_seq) in output_lines:
		_, label, has_error = get_tag_result(output_seq, gt_x, label_list)
		_, gt_label, _ = get_tag_result(gt_y, gt_x, label_list)
		if data_index not in gen_labels:
			gen_labels[data_index] = []
		gen_labels[data_index].append((label, pid))
		gt_labels[data_index] = gt_label
		input_sentences[data_index] = gt_x
		total_instances += 1
		correct_instances += 1 if not has_error else 0

	def most_frequent(List):
		return max(set(List), key = List.count)
	
	eval_gen_labels, eval_gt_labels, eval_sentences = [], [], []
	for data_index in gt_labels:
		all_gen_seq = gen_labels[data_index]
		final_tags = []
		for i in range(len(all_gen_seq[0][0])):
			candidates = []
			for j in range(len(all_gen_seq)):
				candidates.append(all_gen_seq[j][0][i])
			final_tags.append(most_frequent(candidates))
		output_set = set(['-'.join([t for t in seq[0]]) for seq in all_gen_seq])
		if (not _C.enable_filtering_error) or (_C.enable_filtering_error and len(output_set) == 1):
			eval_gen_labels.append(final_tags)
			eval_gt_labels.append(gt_labels[data_index])
			eval_sentences.append(input_sentences[data_index])

	new_F = seq_f1_score(eval_gt_labels, eval_gen_labels) * 100
	
	if is_root:
		print("Correct Ratio %.2f" % (100 * correct_instances / total_instances))
		print("saved Instances %d" % len(eval_gen_labels))
		print("Instances %d ==> F1 %.2f" % (len(output_lines), new_F))

		if output_path is not None:
			with open(output_path, 'w') as out:
				for sen, tag in zip(eval_sentences, eval_gen_labels):
					words = sen.split()
					for g, l in zip(words, tag):
						out.write("%s %s\n" % (g, l))
					out.write("\n")

		if _C.enable_tri_training:
			model_count = _C.prefix_set_number + len(_C.tri_training_additional_model_output_path)
			sen2index = {sen: index for (index, sen) in input_sentences.items()}
			for path_index, path in enumerate(_C.tri_training_additional_model_output_path):
				token_docs, tag_docs = read_conll(path)
				for token, tag in zip(token_docs, tag_docs):
					sen = ' '.join(token)
					if sen in sen2index:
						data_index = sen2index[sen]
						gen_labels[data_index].append((tag, _C.prefix_set_number + path_index))

			for i in range(model_count):
				instance_count = 0
				tri_output_path = _C.tri_path + "prompt_%d" % i
				with open(tri_output_path, 'w') as out:
					for data_index in gt_labels:
						all_gen_seq = gen_labels[data_index]
						label_list = []
						for (label, pid) in all_gen_seq:
							if pid == i: continue
							label_list.append(' '.join(label))
						if len(set(label_list)) == 1:
							selected_label = label_list[0].split()
							seleccted_sen = input_sentences[data_index].split()
							if not all([t == 'O' for t in selected_label]):
								for g, l in zip(seleccted_sen, selected_label):
									out.write("%s %s\n" % (g, l))
								out.write("\n")
							instance_count += 1
				print("prompt %d instance %d" % (i, instance_count))

	return new_F

def get_labels(label_list, gen_label):
	if gen_label in label_list:
		return gen_label
	for l in label_list:
		if l in gen_label:
			return l
	return None

def gen_for_sen_cls_nlu(_C, eval_data, model, device, tokenizer, output_path=None, is_root=True):
	local_output = []
	model.eval()
	eval_iter = iter(eval_data)

	output_lines = []
	with torch.no_grad():

		if is_root:
			pbar = tqdm(eval_data)
		else:
			pbar = eval_data

		for batch in pbar:
			
			for n in batch:
				if n in ['gt_x', 'gt_y', 'data_index']: continue
				batch[n] = batch[n].to(device)

			outputs = model.generate(
				input_ids=batch['encoder_input_ids'], 
				attention_mask=batch['encoder_mask'], 
				task_ids=batch['task_ids'],
				task_type_ids=batch['task_type_ids'],
				prefix_ids=batch['prefix_ids'],
				max_length=15,
				min_length=1,
				eos_token_id=tokenizer.eos_token_id,
				num_return_sequences=1, 
				num_beams=1,
				early_stopping=True,
				prefix_allowed_tokens_fn=_next_step_candidate if _C.enable_label_constrained_decode else None,
			)

			batch_lines = []
			for i in range(batch['decoder_input_ids'].size(0)):
				gt_x, gt_y, data_index = batch['gt_x'][i], batch['gt_y'][i], batch['data_index'][i]
				output_seq = _extract(clean_up_t5_tokenizer(tokenizer, outputs[i]))
				batch_lines.append((gt_x, gt_y, output_seq, data_index))

			all_batch_lines = [None for _ in range(torch.distributed.get_world_size())]
			torch.distributed.all_gather_object(all_batch_lines, batch_lines)
			for rank_batch_line in all_batch_lines:
				output_lines += rank_batch_line

	if _C.running_task == 'domain_pair':
		new_F = process_domain_pair_data(output_lines, output_path, to_json=_C.output_as_json)
	elif _C.running_task == 'consistency_filtering':
		new_F = process_consistency_data(output_lines, output_path, to_json=_C.output_as_json)
	elif _C.running_task == 'copa_nlu':
		correct = 0
		match = 0
		total = 0
		for gt_x, gt_y, out, _ in output_lines:
			if out in ['Choice1', 'Choice2']:
				match += 1
			if out == gt_y:
				correct += 1
			total += 1
		new_F = 100 * correct / total
		new_match_acc = 100 * match / total
		print("match acc %.2f" % new_match_acc)
		print("acc %.2f" % new_F)
	elif _C.running_task == 'copa_nlu_filtering' or _C.running_task == 'copa_nlu_tagging':
		idx_to_data = {}
		idx_to_ans = {}
		for gt_x, gt_y, out, _ in output_lines:
			(data_instance, idx) = gt_x
			if idx not in idx_to_data:
				idx_to_data[idx] = data_instance
			if idx not in idx_to_ans:
				idx_to_ans[idx] = []
			c1, c2 = data_instance[1], data_instance[2]
			if out == 'Choice1':
				idx_to_ans[idx].append(c1)
			else:
				idx_to_ans[idx].append(c2)
			if _C.running_task.endswith('filtering'):
				idx_to_ans[idx].append(gt_y)

		selected_instance = []
		for idx in idx_to_ans:
			if len(set(idx_to_ans[idx])) == 1:
				ans = list(set(idx_to_ans[idx]))[0]
				data_instance = idx_to_data[idx]
				c1, c2 = data_instance[1], data_instance[2]
				if ans == c1:
					data_instance[-2] = 0
				else:
					data_instance[-2] = 1
				selected_instance.append(data_instance)

		print("Selected instances %d" % len(selected_instance))

		if output_path is not None:
			with open(output_path, 'w') as out:
				for index, (premise, choice1, choice2, question, tag, idx) in enumerate(selected_instance):
					json_item = {'premise': premise, 'question': question, 'idx': index, 'choice1': choice1, 'choice2': choice2, 'label': tag}
					out.write(json.dumps(json_item) + '\n')
		new_F = 1.0
	elif _C.running_task == 'boolq_nlu_tagging' or _C.running_task == 'boolq_nlu_filtering':
		label_list = get_label_list(_C.label_path)
		final_data = []
		for (gt_x, gt_y, output_seq, data_index) in output_lines:
			question, article = gt_x
			gen_label = get_labels(label_list, output_seq)
			if gen_label is not None:
				if _C.running_task.endswith('filtering'):
					if not gen_label == gt_y: 
						continue

				final_data.append((question, article, gen_label))

		if output_path is not None:
			with open(output_path, 'w') as out:
				for index, (question, article, gen_label) in enumerate(final_data):
					json_item = {'question': question, 'passage': article, 'label': gen_label == 'True', 'idx': index}
					out.write(json.dumps(json_item) + '\n')
		new_F = 1.0
	elif _C.running_task == 'rte_nlu_tagging' or _C.running_task == 'rte_nlu_filtering':
		label_list = get_label_list(_C.label_path)
		final_data = []
		for (gt_x, gt_y, output_seq, data_index) in output_lines:
			hypothesis, premise = gt_x
			gen_label = get_labels(label_list, output_seq)
			if gen_label is not None:
				if _C.running_task.endswith('filtering'):
					if not gen_label == gt_y: 
						continue

				final_data.append((hypothesis, premise, gen_label))

		if output_path is not None:
			with open(output_path, 'w') as out:
				for index, (hypothesis, premise, gen_label) in enumerate(final_data):
					gen_label = 'not_entailment' if gen_label == 'not entailment' else 'entailment'
					json_item = {'hypothesis': hypothesis, 'premise': premise, 'label': gen_label, 'idx': index}
					out.write(json.dumps(json_item) + '\n')
		new_F = 1.0
	else:
		label_list = get_label_list(_C.label_path)
		gen_labels, gt_labels = {}, {}
		total_instances, correct_instances = 0, 0
		for (gt_x, gt_y, output_seq, data_index) in output_lines:
			total_instances += 1
			gen_label = get_labels(label_list, output_seq)
			if data_index not in gen_labels:
				gen_labels[data_index] = []
				gt_labels[data_index] = None
			if gen_label is not None:
				correct_instances += 1
				gen_labels[data_index].append((gt_x, label_list.index(gen_label)))
			else:
				gen_labels[data_index].append((gt_x, random.choice([i in range(len(label_list))])))

			gt_labels[data_index] = label_list.index(gt_y)

		def most_frequent(List):
			return max(set(List), key = List.count)

		final_gt, final_gen, final_x = [], [], []
		for data_index in gt_labels:
			gt_x = gen_labels[data_index][0][0]
			pred_list = [x[1] for x in gen_labels[data_index]]
			final_gt.append(gt_labels[data_index])
			final_gen.append(most_frequent(pred_list))
			final_x.append(gt_x)

		new_F = sen_f1_score(final_gt, final_gen, average='micro') * 100
		
		if is_root:
			print("Correct Ratio %.2f" % (100 * correct_instances / total_instances))
			print("Instances %d ==> F1 %.2f" % (len(output_lines), new_F))
			if output_path is not None:
				with open(output_path, 'w') as out:
					for input_x, gen_y, gt_y in zip(final_x, final_gen, final_gt):
						if _C.enable_consistency_filtering and (not gen_y == gt_y): continue
						if _C.enable_pair_sentence_classification:
							text1, text2 = input_x
							gen_label = label_list[gen_y]
							out.write('%s\t%s\t%s\n' % (text1, text2, gen_label))


	return new_F

def flipDA_filtering(_C, eval_data, model, device, tokenizer, output_path=None, is_root=True):
	model.eval()
	eval_iter = iter(eval_data)
	output_data_dict = {}
	output_data_list = []
	with torch.no_grad():
		if is_root:
			pbar = tqdm(eval_data)
		else:
			pbar = eval_data

		for batch in pbar:
			
			for n in batch:
				if n in ['gt_x', 'gt_y', 'data_index']: continue
				batch[n] = batch[n].to(device)

			outputs = model(
			    input_ids=batch['encoder_input_ids'], 
			    attention_mask=batch['encoder_mask'], 
			    labels=batch['decoder_input_ids'],
			    task_ids=batch['task_ids'],
				task_type_ids=batch['task_type_ids'],
				prefix_ids=batch['prefix_ids'],
			)
			loss = outputs.loss
			loss = loss.cpu().numpy().tolist()

			for gt_x, gt_y, loss_value in zip(batch['gt_x'], batch['gt_y'], loss):
				(question, article, data_index) = gt_x
				tag = gt_y
				if data_index not in output_data_dict:
					output_data_dict[data_index] = {}
				if tag not in output_data_dict[data_index]:
					output_data_dict[data_index][tag] = []
				output_data_dict[data_index][tag].append((question, article, loss_value))
				output_data_list.append((question, article, tag, loss_value))

		selected_instance = []
		if _C.running_task == 'full_rank':
			output_data_list = sorted(output_data_list, key=lambda x: x[-1])
			for (question, article, tag, _) in output_data_list[:_C.full_ranking_top_k]:
				selected_instance.append((question, article, tag))
		else:
			for data_index in output_data_dict:
				for tag in output_data_dict[data_index]:
					instance_list = output_data_dict[data_index][tag]
					instance_list = sorted(instance_list, key=lambda x: x[-1])
					question, article, _ = instance_list[0]
					selected_instance.append((question, article, tag))

		if output_path is not None:
			print("save %d instances" % len(selected_instance))
			with open(output_path, 'w') as out:
				for question, article, tag in selected_instance:
					out.write('%s\t%s\t%s\n' % (question, article, tag))

	return 1


parser = argparse.ArgumentParser("Train a MT5 for Machine Translation")
parser.add_argument(
    "--config", required=True, help="Path to a config file with all configuration parameters."
)
parser.add_argument(
    "--config-override",
    default=[],
    nargs="*",
    help="A sequence of key-value pairs specifying certain config arguments (with dict-like "
    "nesting) using a dot operator. The actual config will be updated and recorded in "
    "the serialization directory.",
)
parser.add_argument(
    "--serialization-dir",
    default=None,
    help="Path to a (non-existent) directory for serializing checkpoints and tensorboard logs.",
)
parser.add_argument(
    "--start-from-checkpoint",
    default=None,
    help="Path to load checkpoint and continue training [only supported for module_training].",
)
parser.add_argument(
    "--output-path",
    default=None,
    help="Path to save output captions",
)
parser.add_argument('--local_rank', type=int, default=-1,
                    help='local rank passed from distributed launcher')
group = parser.add_mutually_exclusive_group()
group.add_argument('--train', action='store_true')
group.add_argument('--validation', action='store_true')
group.add_argument('--test', action='store_true')
parser = deepspeed.add_config_arguments(parser)

if __name__ == "__main__":
	_A = parser.parse_args()
	_C = Config(_A.config, _A.config_override)

	np.random.seed(_C.random_seed)
	random.seed(_C.random_seed)
	torch.manual_seed(_C.random_seed)
	torch.cuda.manual_seed_all(_C.random_seed)
	torch.backends.cudnn.benchmark = False
	torch.backends.cudnn.deterministic = True

	os.environ["TOKENIZERS_PARALLELISM"] = "false"
	os.environ["NCCL_DEBUG"] = "WARN"
	local_rank = _A.local_rank
	torch.cuda.set_device(local_rank)
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

	if _A.deepspeed:
		deepspeed.init_distributed()

	is_root = (not _A.deepspeed) or torch.distributed.get_rank() == 0

	if _C.load_from_pretrained:
		if _C.enable_pretrain_task_embeddings:
			_C.task_embed_count = len(TASK_NAME_LIST)
	else:
		if _C.enable_new_task_embeddings:
			_C.task_embed_count = _C.prefix_set_number if _C.prefix_set_number > 0 else 1

	if _C.enable_full_finetune:
		tokenizer, model = t5_model.get_full_finetune_t5_model(_C)
	elif _C.enable_full_pretrain:
		tokenizer, model = t5_model.get_full_pretrain_t5_model(_C)
	else:
		tokenizer, model = t5_model.get_t5_model(_C)

	if _A.deepspeed:
		val_batch_size = _C.val_batch_size // (torch.distributed.get_world_size() * _C.gradient_accumulation_steps)
	else:
		val_batch_size = _C.val_batch_size
	dev_loader = get_single_h5py_nlp_data(_C, _C.dev_path, _C.train_path, "validation", val_batch_size, tokenizer, _C.max_length, shuffle=True, distributed=_A.deepspeed, is_root=is_root, is_train=False)
	if _C.enable_nlu:
		test_loader = get_single_h5py_nlp_data(_C, _C.test_path, _C.train_path, "test", val_batch_size, tokenizer, _C.max_length, shuffle=True, distributed=_A.deepspeed, is_root=is_root, is_train=False)

	if _A.deepspeed:
		train_batch_size = _C.batch_size // (torch.distributed.get_world_size() * _C.gradient_accumulation_steps)
	else:
		train_batch_size = _C.batch_size
	train_loader = get_single_h5py_nlp_data(_C, _C.train_path, _C.train_path, "train", train_batch_size, tokenizer, _C.max_length, shuffle=True, distributed=_A.deepspeed, is_root=is_root, is_train=True)

	ds_config = {
		"train_batch_size": _C.batch_size,
		"gradient_accumulation_steps": _C.gradient_accumulation_steps,
		"steps_per_print": 100,
		"fp16": {
		  "enabled": False
		},
	}

	if _C.enable_adam_opt:
		optimizer = utils.build_optimizer(_C, model)
	elif _C.enable_full_finetune:
		optimizer = utils.build_adam_optimizer(_C, model)
	else:
		optimizer = utils.build_t5_optimizer(_C, model)

	dist_model, _, _, _ = deepspeed.initialize(args=_A, model=model, model_parameters=[p for p in model.parameters() if p.requires_grad], config=ds_config, optimizer=optimizer)
	if _A.start_from_checkpoint is not None:
		dist_model.load_checkpoint(_A.start_from_checkpoint, load_module_strict=False, load_module_only=True)
	if _C.enable_new_task_embeddings and _C.load_from_pretrained:
		dist_model.module.update_task_embedding(_C.prefix_set_number if _C.prefix_set_number > 0 else 1)

	if is_root: 
		total_parameter_count = 0
		trainable_parameter_count = 0
		for p in model.parameters():
			total_parameter_count += p.numel()
			if p.requires_grad:
				trainable_parameter_count += p.numel()
		print('Total Parameter Count %d' % total_parameter_count)
		print('Trainable Parameter Count %d' % trainable_parameter_count)

		print(_C)
		for arg in vars(_A):
			print("{:<20}: {}".format(arg, getattr(_A, arg)))

	if _A.validation or _A.test:
		if _C.enable_nlu:
			if _C.enable_sentence_classification or _C.enable_pair_sentence_classification:
				if _C.enable_flip_filtering:
					_score = flipDA_filtering(_C, dev_loader if _A.validation else test_loader, model, device, tokenizer, output_path=_A.output_path, is_root=is_root)
				else:
					_score = gen_for_sen_cls_nlu(_C, dev_loader if _A.validation else test_loader, model, device, tokenizer, output_path=_A.output_path, is_root=is_root)
			else:
				_score = gen_for_nlu(_C, dev_loader if _A.validation else test_loader, model, device, tokenizer, output_path=_A.output_path, is_root=is_root)
		else:
			gen_data(_C, dev_loader, model, device, tokenizer, output_path=_A.output_path, is_root=is_root)

	if _A.train:
		train_iter = iter(train_loader)

		if _C.num_training_steps == 0:
			length_train_data = len(train_iter) // (_C.batch_size // torch.distributed.get_world_size())
			_C.num_training_steps = int(length_train_data * _C.max_epoch / _C.gradient_accumulation_steps)
		epoch_num = math.ceil(_C.num_training_steps / _C.checkpoint_every_step)

		os.makedirs(_A.serialization_dir, exist_ok=True)
		_C.dump(os.path.join(_A.serialization_dir, "config.yml"))

		eval_every = _C.checkpoint_every_step * _C.gradient_accumulation_steps
		total_step = 0
		lowest_loss = -1e10
		best_test_performance = 0

		for epoch in range(epoch_num):
			run_step = eval_every if total_step + eval_every < _C.num_training_steps * _C.gradient_accumulation_steps else  _C.num_training_steps * _C.gradient_accumulation_steps - total_step
			dist_model.train()

			if is_root:
				print('EPOCH %d / %d' % (epoch + 1, epoch_num))
				pbar = tqdm(total=math.ceil(run_step / _C.gradient_accumulation_steps), file=sys.stdout)

			for step in range(run_step):
				try:
					batch = next(train_iter)
				except:
					train_iter = iter(train_loader)
					batch = next(train_iter)

				for n in batch:
					if n in ['gt_x', 'gt_y', 'data_index']: continue
					batch[n] = batch[n].to(dist_model.local_rank)
				total_step += 1

				outputs = dist_model(
				    input_ids=batch['encoder_input_ids'], 
				    attention_mask=batch['encoder_mask'], 
				    labels=batch['decoder_input_ids'],
				    task_ids=batch['task_ids'],
					task_type_ids=batch['task_type_ids'],
					prefix_ids=batch['prefix_ids'],
				)
				loss = outputs.loss
				dist_model.backward(loss)
				dist_model.step()

				ave_loss = _average_all(loss).item()

				if is_root:
					pbar.set_description("loss %.2f" % (ave_loss * _C.gradient_accumulation_steps))
					pbar.update(1)
					pbar.refresh()
			
			if is_root:
				pbar.close()

			if _C.enable_nlu:
				if _C.enable_sentence_classification or _C.enable_pair_sentence_classification:
					_score = gen_for_sen_cls_nlu(_C, dev_loader, model, device, tokenizer, output_path=_A.output_path, is_root=is_root)
				else:
					_score = gen_for_nlu(_C, dev_loader, model, device, tokenizer, output_path=_A.output_path, is_root=is_root)
			else:
				if _C.eval_by_loss:
					_score = gen_data_loss(_C, dev_loader, model, device, tokenizer, output_path=_A.output_path, is_root=is_root)
				else:
					_score = gen_data(_C, dev_loader, model, device, tokenizer, output_path=_A.output_path, is_root=is_root)
			
			if _C.save_model_each_epoch:
				if _score >= lowest_loss:
					lowest_loss = _score
					dist_model.save_checkpoint(_A.serialization_dir, "model_epoch_%d" % (epoch + 1))

					if _C.enable_nlu:
						if _C.enable_sentence_classification or _C.enable_pair_sentence_classification:
							_score = gen_for_sen_cls_nlu(_C, test_loader, model, device, tokenizer, output_path=_A.output_path, is_root=is_root)
						else:
							_score = gen_for_nlu(_C, test_loader, model, device, tokenizer, output_path=_A.output_path, is_root=is_root)
						best_test_performance = _score

				if _C.enable_nlu:
					print("Best Test Perforamnce %.2f" % best_test_performance)


			
				