from logging import error
from datasets import load_dataset
import transformers
from random import sample
import random
import torch
import json
from tqdm import tqdm
from nltk.translate.bleu_score import sentence_bleu
import pandas as pd
import re


'''
data format
{text_a, text_b, label:None or 0_1, }
'''
DATASET_HUGGINGFACE = {
    'cnndm': ['cnn_dailymail', '3.0.0', 'train'],
    'mnli': ['multi_nli', 'default', 'train'],
    'squad': ['squad', 'plain_text', 'train'],
    'squad_v2': ['squad_v2', 'squad_v2', 'train'],
    'paws': ['paws', 'labeled_final', 'train'],
    'vitaminc': ['tals/vitaminc', 'v1.0', 'train'],
    'xsum': ['xsum', 'default', 'train'],
    'stsb': ['glue', 'stsb', 'train'],
    'sick': ['sick', 'default', 'train'],
    'race': ['race', 'all', 'train'],
    'race_val': ['race', 'all', 'validation'],
    'anli_r1': ['anli', 'plain_text', 'train_r1'],
    'anli_r2': ['anli', 'plain_text', 'train_r2'],
    'anli_r3': ['anli', 'plain_text', 'train_r3'],
    'snli': ['snli', 'plain_text', 'train'],
    'wikihow': ['wikihow', 'all', 'train'],
    'mrpc': ['glue', 'mrpc', 'train'],
    'msmarco': ['ms_marco', 'v2.1', 'train'],
    'mrpc_val': ['glue', 'mrpc', 'validation'],
    'paws_val': ['paws', 'labeled_final', 'validation'],
    'paws_unlabeled': ['paws', 'unlabeled_final', 'train'],
    'qqp': ['glue', 'qqp', 'train'],
    'qqp_val': ['glue', 'qqp', 'validation'],
    'squad_v2_new': ['squad_v2', 'squad_v2', 'train'],
    'adversarial_qa': ['adversarial_qa', 'adversarialQA', 'train'],
    'drop': ['drop', 'train'],
    'duorc_self': ['duorc', 'SelfRC', 'train'],
    'duorc_paraphrase': ['duorc', 'ParaphraseRC', 'train'],
    'quoref': ['quoref', 'train'],
    'hotpot_qa_distractor': ['hotpot_qa', 'distractor', 'train'],
    'hotpot_qa_fullwiki': ['hotpot_qa', 'fullwiki', 'train'],
    'ropes': ['ropes', 'train'],
    'boolq': ['boolq', 'train'],
    'eraser_multi_rc': ['eraser_multi_rc', 'train'],
    'quail': ['quail', 'train'],
    'sciq': ['sciq', 'train'],
    'strategy_qa': ['metaeval/strategy-qa', 'train'],
    'gap': ['gap', 'train'],
}

DATASET_CONFIG = {
    # 'cnndm': {'task': 'summarization', 'text_a': 'article', 'text_b': 'highlights', 'label': None, 'huggingface': True},
    # 'mnli': {'task': 'nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label', 'huggingface': True},
    # 'nli_fever': {'task': 'fact_checking', 'text_a': 'context', 'text_b': 'query', 'label': 'label','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/nli_fever/train_fitems.jsonl' },
    # 'doc_nli': {'task': 'bin_nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/DocNLI_dataset/train.json' },
    # 'squad': {'task': 'extractive_qa', 'text_a': 'context', 'text_b': ['question', 'answers'], 'label': None, 'huggingface': True},
    # 'squad_v2': {'task': 'qa', 'text_a': 'context', 'text_b': ['question', 'answers'], 'label': None, 'huggingface': True},
    # 'paws': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label', 'huggingface': True},
    # 'vitaminc': {'task': 'fact_checking', 'text_a': 'evidence', 'text_b': 'claim', 'label': 'label', 'huggingface': True},
    # 'xsum': {'task': 'summarization', 'text_a': 'document', 'text_b': 'summary', 'label': None, 'huggingface': True, 'cliff_path': 'data/model_generated_data/cliff_summ/xsum_train.jsonl'},
    # 'stsb': {'task': 'sts', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label', 'huggingface': True},
    # 'sick': {'task': 'sts', 'text_a': 'sentence_A', 'text_b': 'sentence_B', 'label': 'relatedness_score', 'huggingface': True},
    # 'race': {'task': 'qa', 'text_a': 'article', 'text_b': ['question', 'options'], 'label': 'answer', 'huggingface': True},
    # 'race_val': {'task': 'qa', 'text_a': 'article', 'text_b': ['question', 'options'], 'label': 'answer', 'huggingface': True},
    # 'anli_r1': {'task': 'nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label', 'huggingface': True},
    # 'anli_r2': {'task': 'nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label', 'huggingface': True},
    # 'anli_r3': {'task': 'nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label', 'huggingface': True},
    # 'snli': {'task': 'nli', 'text_a': 'premise', 'text_b': 'hypothesis', 'label': 'label', 'huggingface': True},
    # 'wikihow': {'task': 'summarization', 'text_a': 'text', 'text_b': 'headline', 'label': None, 'huggingface': False, 'using_hf_api': True, 'data_dir': 'data/wikihow_raw'},
    # 'mrpc': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label','huggingface': True},
    # 'mrpc_val': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label','huggingface': True},
    # 'paws_val': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label', 'huggingface': True},
    # 'paws_unlabeled': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': 'label', 'huggingface': True},
    # 'msmarco': {'task': 'ir', 'text_a': 'passages', 'text_b': ['query', 'answers'], 'label': None,'huggingface': True},
    # 'paws_qqp': {'task': 'paraphrase', 'text_a': 'sentence1', 'text_b': 'sentence2', 'label': None,'huggingface': False, 'using_hf_api': False, 'using_pandas': True, 'data_path':'paws_qqp/output/train.tsv' },
    # 'wiki103': {'task': 'paraphrase', 'text_a': 'original_sent', 'text_b': 'paraphrase', 'label': None,'huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json': True, 'data_path':'data/model_generated_data/backtranslation/wiki103_single_sent_backtranslation.json'},
    # 'qqp': {'task': 'paraphrase', 'text_a':'question1', 'text_b':'question2', 'label': 'label', 'huggingface': True},
    # 'qqp_val': {'task': 'paraphrase', 'text_a':'question1', 'text_b':'question2', 'label': 'label', 'huggingface': True},
    # 'wmt17xxx': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': True, 'data_path':'data/wmt/wmt17/2017-da.csv' },
    # 'wmt15': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/eval/wmt15_eval.jsonl' },
    # 'wmt16': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/eval/wmt16_eval.jsonl' },
    # 'wmt17': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/eval/wmt17_eval.jsonl' },
    # 'wmt18': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/eval/wmt18_eval.jsonl' },
    # 'wmt19': {'task': 'wmt', 'text_a': 'ref', 'text_b': 'mt', 'label': 'score','huggingface': False, 'using_hf_api': False, 'using_pandas': False, 'using_json':True, 'data_path':'data/eval/wmt19_eval.jsonl' },
    # 'squad_v2_new': {'task': 'qa', 'huggingface': True},
    # 'adversarial_qa': {'task': 'qa', 'huggingface': True},
    # 'drop': {'task': 'qa', 'huggingface': True},
    # 'duorc_self': {'task': 'qa', 'huggingface': True},
    # 'duorc_paraphrase': {'task': 'qa', 'huggingface': True},
    # 'quoref': {'task': 'qa', 'huggingface': True},
    # 'hotpot_qa_distractor': {'task': 'qa', 'huggingface': True},
    # 'hotpot_qa_fullwiki': {'task': 'qa', 'huggingface': True},
    # 'newsqa': {'task': 'qa',  'using_json': True, 'raw_json': True, 'data_path': 'data/newsqa_raw/combined-newsqa-data-v1.json'},
    # 'ropes': {'task': 'qa', 'huggingface': True},
    # 'boolq': {'task': 'qa', 'huggingface': True},
    # 'eraser_multi_rc': {'task': 'qa', 'huggingface': True},
    # 'quail': {'task': 'qa', 'huggingface': True},
    # 'sciq': {'task': 'qa', 'huggingface': True},
    'strategy_qa': {'task': 'qa', 'huggingface': True},
    # 'gap': {'task': 'coreference', 'huggingface': True},
}


class QA2D():
    def __init__(self, batch_size=32, device='cuda', verbose=True) -> None:
        from transformers import BartTokenizer, BartForConditionalGeneration
        self.tokenizer = BartTokenizer.from_pretrained("MarkS/bart-base-qa2d")
        self.model = BartForConditionalGeneration.from_pretrained("MarkS/bart-base-qa2d").to(device)
        self.batch_size = batch_size
        self.device=device
        self.verbose = verbose

    def generate(self, questions: list, answers: list):
        assert len(questions) == len(answers)
        qa_list = []
        for q, a in zip(questions, answers):
            qa_list.append(f"question: {q} answer: {a}")
        output = []
        for qa_pairs in tqdm(
            self.chunks(qa_list, self.batch_size),
            desc="QA to Declarative",
            total=int(len(qa_list)/self.batch_size), 
            disable=(not self.verbose)
        ):
            input_text = qa_pairs
            input_token = self.tokenizer(
                input_text, return_tensors='pt', padding=True, truncation=True).to(self.device)
            dec_sents = self.model.generate(
                input_token.input_ids, max_length=512)
            result = self.tokenizer.batch_decode(
                dec_sents, skip_special_tokens=True)
            output.extend(result)

        return output

    def chunks(self, lst, n):
        """Yield successive n-sized chunks from lst."""
        for i in range(0, len(lst), n):
            yield lst[i:i + n]


class QAnswering():
    """
    To answer not-answerable questions
    """

    def __init__(self, batch_size=32, device='cuda') -> None:
        from transformers import T5Tokenizer, T5ForConditionalGeneration
        self.tokenizer = T5Tokenizer.from_pretrained(
            "valhalla/t5-base-qa-qg-hl")
        self.model = T5ForConditionalGeneration.from_pretrained(
            "valhalla/t5-base-qa-qg-hl").to(device)
        self.batch_size = batch_size
        self.device = device

    def generate(self, questions: list, contexts: list):
        assert len(questions) == len(contexts)
        answers = []
        for qs, cs in tqdm(zip(self.chunks(questions, self.batch_size), self.chunks(contexts, self.batch_size)), desc="Generating Answers for not answerable", total=int(len(questions)/self.batch_size)):
            qc_pairs = []
            assert len(qs) == len(cs)
            for one_q, one_c in zip(qs, cs):
                qc_pairs.append(f"""question: {one_q} context: {one_c}""")
            input_ids = self.tokenizer(
                qc_pairs, padding=True, truncation=True, return_tensors='pt').to(self.device).input_ids
            outputs = self.model.generate(input_ids, max_length=512)
            answers.extend(self.tokenizer.batch_decode(
                outputs, skip_special_tokens=True))

        return answers

    def chunks(self, lst, n):
        """Yield successive n-sized chunks from lst."""
        for i in range(0, len(lst), n):
            yield lst[i:i + n]


class MLMGeneratorWithPairedData():
    def __init__(self, corpra: list, device='cuda', batch_size=8, mask_percent=0.25) -> None:
        self.device = device
        self.tokenizer = transformers.DistilBertTokenizer.from_pretrained(
            "distilbert-base-uncased")
        self.model = transformers.DistilBertForMaskedLM.from_pretrained(
            "distilbert-base-uncased").to(self.device)
        self.mask_percent = mask_percent
        self.batch_size = batch_size

        self.dataset = corpra  # text needs to be noised

    def chunks(self, lst, n):
        """Yield successive n-sized chunks from lst."""
        for i in range(0, len(lst), n):
            yield lst[i:i + n]

    def generate(self):
        sents_output = []
        for examples in tqdm(self.chunks(self.dataset, self.batch_size), total=int(len(self.dataset)/self.batch_size), desc="MLM Generating"):
            sents_to_be_noised = [each for each in examples]
            sents_noised = self.mlm_infiller(sents_to_be_noised)

            sents_output.extend(sents_noised)

        return sents_output

    def mlm_infiller(self, batch):
        """
        input a batch of sentences, list
        """
        masked_batch = []
        masked_batch_ids = []
        for each_sent in batch:
            sent_tokens = self.tokenizer.tokenize(each_sent)
            sent_token_ids = self.tokenizer(each_sent)['input_ids']
            mask_list = sample(list(range(len(sent_tokens))), int(
                self.mask_percent * len(sent_tokens)))
            sent_tokens = [
                each if i not in mask_list else self.tokenizer.mask_token for i, each in enumerate(sent_tokens)]
            masked_batch_ids.append(
                [each if i-1 not in mask_list else self.tokenizer.mask_token_id for i, each in enumerate(sent_token_ids)])
            masked_batch.append(' '.join(sent_tokens))

        inputs = self.tokenizer(
            masked_batch, padding=True, truncation=True, return_tensors="pt").to(self.device)
        with torch.no_grad():
            logits = self.model(**inputs).logits

        infill_tokens = []
        for i in range(len(masked_batch)):
            mask_token_index = (inputs.input_ids == self.tokenizer.mask_token_id)[
                i].nonzero(as_tuple=True)[0]

            predicted_token_id = logits[i, mask_token_index].argmax(axis=-1)
            infill_tokens.append(predicted_token_id)

        infilled_sent = []
        for masked_sent_ids, infill_token in zip(masked_batch_ids, infill_tokens):
            for infill_one_token in infill_token:
                for i, each_id in enumerate(masked_sent_ids):
                    if each_id == self.tokenizer.mask_token_id:
                        masked_sent_ids[i] = infill_one_token
                        break
                # masked_sent_ids = masked_sent_ids.replace(self.tokenizer.mask_token_id, infill_one_token, 1)
            infilled_sent.append(self.tokenizer.decode(
                masked_sent_ids, skip_special_tokens=True))

        return infilled_sent


class ExtractiveSummarizationGenerator():
    def __init__(self) -> None:
        pass

    def generate(self, texts):
        '''
        texts: list of string
        '''
        from summa.summarizer import summarize

        summaries = []
        for text in tqdm(texts, desc="Extracting Summary"):
            for prop in range(1, 20):
                summ = summarize(text, ratio=prop/20.)
                if len(summ) > 0:
                    break
            summaries.append(summ)

        return summaries


class DataGenerator():
    def __init__(self, dataset_names) -> None:
        self.dataset_names = dataset_names
        self.datasets = dict()
        self.t5_qa = None
        self.t5_tokenizer = None

        self.load_dataset_from_huggingface()

    def load_dataset_from_huggingface(self):
        for each_dataset in self.dataset_names:
            if DATASET_CONFIG[each_dataset].get('huggingface'):
                self.datasets[each_dataset] = load_dataset(
                    *DATASET_HUGGINGFACE[each_dataset][:-1])[DATASET_HUGGINGFACE[each_dataset][-1]]
            elif DATASET_CONFIG[each_dataset].get('using_hf_api'):
                self.datasets[each_dataset] = load_dataset(
                    *DATASET_HUGGINGFACE[each_dataset][:-1], data_dir=DATASET_CONFIG[each_dataset]['data_dir'])[DATASET_HUGGINGFACE[each_dataset][-1]]
            elif DATASET_CONFIG[each_dataset].get('using_pandas'):
                if DATASET_CONFIG[each_dataset]['data_path'].split('.')[-1] == 'tsv':
                    self.datasets[each_dataset] = pd.read_csv(
                        DATASET_CONFIG[each_dataset]['data_path'], sep='\t')
                elif DATASET_CONFIG[each_dataset]['data_path'].split('.')[-1] == 'csv':
                    self.datasets[each_dataset] = pd.read_csv(
                        DATASET_CONFIG[each_dataset]['data_path'])
            elif DATASET_CONFIG[each_dataset].get('using_json'):
                self.datasets[each_dataset] = []
                if DATASET_CONFIG[each_dataset].get('raw_json'):
                    with open(DATASET_CONFIG[each_dataset]['data_path'], 'r', encoding='utf8') as f:
                        self.datasets[each_dataset] = json.load(f)
                else:
                    try:
                        json_file = json.load(
                            open(DATASET_CONFIG[each_dataset]['data_path'], 'r', encoding='utf8'))
                        for example in json_file:
                            self.datasets[each_dataset].append(example)
                    except:
                        with open(DATASET_CONFIG[each_dataset]['data_path'], 'r', encoding='utf8') as f:
                            for example in f:
                                self.datasets[each_dataset].append(
                                    json.loads(example))
            else:
                error('unable to locate raw dataset...')

    def process_squad(self):
        from rake_nltk import Rake
        r = Rake()
        topk = 5
        threshold = 0.6

        output = []
        label = -1
        for example in tqdm(self.datasets['squad'], desc=f'Constructing squad'):
            text_a = example[DATASET_CONFIG['squad']['text_a']]
            question = example[DATASET_CONFIG['squad']['text_b'][0]]
            answer = example[DATASET_CONFIG['squad']
                             ['text_b'][1]]['text']  # a list
            text_b = [question+' '+answer_ele for answer_ele in answer]
            text_c = []

            r.extract_keywords_from_text(text_a)
            keywords_in_context = r.get_ranked_phrases()[:topk]
            for each_keyword in keywords_in_context:
                # then it is an incorrect answer
                if sentence_bleu([answer_ele.lower().split() for answer_ele in answer], each_keyword.split(), weights=(0.33, 0.33, 0.33)) < threshold:
                    text_c.append(question+' '+each_keyword)

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_squad_v2(self):
        # first collect answerable items
        not_answerable_contexts = []
        not_answerable_questions = []
        not_answerable_answers = []

        answerable_contexts = []
        answerable_questions = []
        answerable_answers = []

        qa_generator = QAnswering(batch_size=32, device='cuda')
        qa2d_generator = QA2D(batch_size=32, device='cuda')

        for example in tqdm(self.datasets['squad_v2'], desc=f'Collecting (not)answerable examples'):
            if len(example['answers']['text']) == 0:
                not_answerable_contexts.append(example['context'])
                not_answerable_questions.append(example['question'])
            else:
                answerable_contexts.append(example['context'])
                answerable_questions.append(example['question'])
                answerable_answers.append(example['answers']['text'][0])

        not_answerable_answers = qa_generator.generate(
            not_answerable_questions, not_answerable_contexts)
        answerable_declarative_sents = qa2d_generator.generate(
            answerable_questions, answerable_answers)
        not_answerable_declarative_sents = qa2d_generator.generate(
            not_answerable_questions, not_answerable_answers)

        output = []
        for i, dec_sent in enumerate(answerable_declarative_sents):
            output.append({
                'text_a': answerable_contexts[i],
                'text_b': [dec_sent],
                'text_c': [],
                'label': 1
            })

        for i, dec_sent in enumerate(not_answerable_declarative_sents):
            output.append({
                'text_a': not_answerable_contexts[i],
                'text_b': [dec_sent],
                'text_c': [],
                'label': 0
            })

        return output

    def process_race(self):
        qa2d_generator = QA2D(batch_size=32, device='cuda')
        option_dict = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
        output = []

        correct_context = []
        correct_question = []
        correct_answer = []

        wrong_context = []
        wrong_question = []
        wrong_answer = []

        for example in tqdm(self.datasets['race'], desc=f'Constructing race'):
            text_a = example[DATASET_CONFIG['race']['text_a']]
            # text_b = [] # pos
            # text_c = [] # neg
            label = -1
            question = example[DATASET_CONFIG['race']['text_b'][0]]
            if "_" in question:
                answer_id = option_dict[example[DATASET_CONFIG['race']['label']]]
                for i, options in enumerate(example[DATASET_CONFIG['race']['text_b'][1]]):
                    if i == answer_id:
                        # text_b.append(' '.join(question.replace("_", " "+options+" ").split()))
                        output.append({
                            'text_a': text_a,
                            'text_b': [' '.join(question.replace("_", " "+options+" ").split())],
                            'text_c': [],
                            'label': 1
                        })
                    else:
                        # text_c.append(' '.join(question.replace("_", " "+options+" ").split()))
                        output.append({
                            'text_a': text_a,
                            'text_b': [' '.join(question.replace("_", " "+options+" ").split())],
                            'text_c': [],
                            'label': 0
                        })
            else:
                answer_id = option_dict[example[DATASET_CONFIG['race']['label']]]
                for i, options in enumerate(example[DATASET_CONFIG['race']['text_b'][1]]):
                    if i == answer_id:
                        ## text_b.append(question+" "+options+" ")
                        # correct_context.append(text_a)
                        # correct_question.append(question)
                        # correct_answer.append(options)
                        output.append({
                                'text_a': text_a,
                                'text_b': [question],
                                'text_c': [options],
                                'label': 1
                            })
                    else:
                        ## text_c.append(question+" "+options+" ")
                        # wrong_context.append(text_a)
                        # wrong_question.append(question)
                        # wrong_answer.append(options)
                        output.append({
                                'text_a': text_a,
                                'text_b': [question],
                                'text_c': [options],
                                'label': 0
                            })

            
        # correct_declarative = qa2d_generator.generate(
        #     correct_question, correct_answer)
        # wrong_declarative = qa2d_generator.generate(
        #     wrong_question, wrong_answer)
        # assert len(correct_context) == len(correct_declarative)
        # assert len(wrong_context) == len(wrong_declarative)
        # for context, dec in zip(correct_context, correct_declarative):
        #     output.append({
        #         'text_a': context,
        #         'text_b': [dec],
        #         'text_c': [],
        #         'label': 1
        #     })

        # for context, dec in zip(wrong_context, wrong_declarative):
        #     output.append({
        #         'text_a': context,
        #         'text_b': [dec],
        #         'text_c': [],
        #         'label': 0
        #     })

        return output

    def process_race_val(self):
        qa2d_generator = QA2D(batch_size=32, device='cuda')
        option_dict = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
        output = []

        correct_context = []
        correct_question = []
        correct_answer = []

        wrong_context = []
        wrong_question = []
        wrong_answer = []

        for example in tqdm(self.datasets['race_val'], desc=f'Constructing race_val'):
            text_a = example[DATASET_CONFIG['race_val']['text_a']]
            # text_b = [] # pos
            # text_c = [] # neg
            label = -1
            question = example[DATASET_CONFIG['race_val']['text_b'][0]]
            if "_" in question:
                answer_id = option_dict[example[DATASET_CONFIG['race_val']['label']]]
                for i, options in enumerate(example[DATASET_CONFIG['race_val']['text_b'][1]]):
                    if i == answer_id:
                        # text_b.append(' '.join(question.replace("_", " "+options+" ").split()))
                        output.append({
                            'text_a': text_a,
                            'text_b': [' '.join(question.replace("_", " "+options+" ").split())],
                            'text_c': [],
                            'label': 1
                        })
                    else:
                        # text_c.append(' '.join(question.replace("_", " "+options+" ").split()))
                        output.append({
                            'text_a': text_a,
                            'text_b': [' '.join(question.replace("_", " "+options+" ").split())],
                            'text_c': [],
                            'label': 0
                        })
            else:
                answer_id = option_dict[example[DATASET_CONFIG['race_val']['label']]]
                for i, options in enumerate(example[DATASET_CONFIG['race_val']['text_b'][1]]):
                    if i == answer_id:
                        # text_b.append(question+" "+options+" ")
                        correct_context.append(text_a)
                        correct_question.append(question)
                        correct_answer.append(options)
                    else:
                        # text_c.append(question+" "+options+" ")
                        wrong_context.append(text_a)
                        wrong_question.append(question)
                        wrong_answer.append(options)

            # output.append({
            #         'text_a': text_a,
            #         'text_b': text_b,
            #         'text_c': text_c,
            #         'label': label
            #     })
        correct_declarative = qa2d_generator.generate(
            correct_question, correct_answer)
        wrong_declarative = qa2d_generator.generate(
            wrong_question, wrong_answer)
        assert len(correct_context) == len(correct_declarative)
        assert len(wrong_context) == len(wrong_declarative)
        for context, dec in zip(correct_context, correct_declarative):
            output.append({
                'text_a': context,
                'text_b': [dec],
                'text_c': [],
                'label': 1
            })

        for context, dec in zip(wrong_context, wrong_declarative):
            output.append({
                'text_a': context,
                'text_b': [dec],
                'text_c': [],
                'label': 0
            })

        return output
        # option_dict = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
        # output = []
        # for example in tqdm(self.datasets['race_val'], desc=f'Constructing race_val'):
        #     text_a = example[DATASET_CONFIG['race_val']['text_a']]
        #     text_b = [] # pos
        #     text_c = [] # neg
        #     label = -1
        #     question = example[DATASET_CONFIG['race_val']['text_b'][0]]
        #     if "_" in question:
        #         answer_id = option_dict[example[DATASET_CONFIG['race_val']['label']]]
        #         for i, options in enumerate(example[DATASET_CONFIG['race_val']['text_b'][1]]):
        #             if i == answer_id:
        #                 text_b.append(' '.join(question.replace("_", " "+options+" ").split()))
        #             else:
        #                 text_c.append(' '.join(question.replace("_", " "+options+" ").split()))
        #     else:
        #         answer_id = option_dict[example[DATASET_CONFIG['race_val']['label']]]
        #         for i, options in enumerate(example[DATASET_CONFIG['race_val']['text_b'][1]]):
        #             if i == answer_id:
        #                 text_b.append(question+" "+options+" ")
        #             else:
        #                 text_c.append(question+" "+options+" ")

        #     output.append({
        #             'text_a': text_a,
        #             'text_b': text_b,
        #             'text_c': text_c,
        #             'label': label
        #         })

        # return output
    def process_race_test(self):
        option_dict = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
        output = []
        for example in tqdm(self.datasets['race_test'], desc=f'Constructing race_test'):
            text_a = example[DATASET_CONFIG['race_test']['text_a']]
            text_b = []  # pos
            text_c = []  # neg
            label = -1
            question = example[DATASET_CONFIG['race_test']['text_b'][0]]
            if "_" in question:
                answer_id = option_dict[example[DATASET_CONFIG['race_test']['label']]]
                for i, options in enumerate(example[DATASET_CONFIG['race_test']['text_b'][1]]):
                    if i == answer_id:
                        text_b.append(' '.join(question.replace(
                            "_", " "+options+" ").split()))
                    else:
                        text_c.append(' '.join(question.replace(
                            "_", " "+options+" ").split()))
            else:
                answer_id = option_dict[example[DATASET_CONFIG['race_test']['label']]]
                for i, options in enumerate(example[DATASET_CONFIG['race_test']['text_b'][1]]):
                    if i == answer_id:
                        text_b.append(question+" "+options+" ")
                    else:
                        text_c.append(question+" "+options+" ")

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_xsum(self):
        '''
        text_a: raw_text
        text_b: raw_summary + ***extractive summ*** removed
        text_c: cliff xsum + DistillBERT from raw_text_b + ***DistillBERT from extractive summ text_b***
        '''
        # interpretation of fairseq-generate output: https://github.com/facebookresearch/fairseq/issues/3000
        output = []

        # try: ### load cliff xsum
        #     self.cliff_xsum
        # except:
        #     self.cliff_xsum = []
        #     with open(DATASET_CONFIG['xsum']['cliff_path'], 'r') as f:
        #         for i, content in enumerate(f):
        #             self.cliff_xsum.append(json.loads(content))

        gold_summary = [example[DATASET_CONFIG['xsum']['text_b']]
                        for example in self.datasets['xsum']]
        ext_summarizer = ExtractiveSummarizationGenerator()
        extracted_summ = ext_summarizer.generate(
            [example[DATASET_CONFIG['xsum']['text_a']] for example in self.datasets['xsum']])

        mlm_hallucinator = MLMGeneratorWithPairedData(
            corpra=gold_summary, device='cuda:0', batch_size=64, mask_percent=0.25)
        gold_summary_hallucinated = mlm_hallucinator.generate()

        mlm_hallucinator = MLMGeneratorWithPairedData(
            corpra=extracted_summ, device='cuda:0', batch_size=64, mask_percent=0.25)
        extracted_summ_hallucinated = mlm_hallucinator.generate()

        assert len(self.datasets['xsum']) == len(gold_summary_hallucinated) and len(
            self.datasets['xsum']) == len(extracted_summ_hallucinated)

        for i, example in tqdm(enumerate(self.datasets['xsum']), desc="Constructing xsum", total=len(self.datasets['xsum'])):
            text_a = example[DATASET_CONFIG['xsum']['text_a']]
            text_b = [gold_summary[i], extracted_summ[i]]
            text_c = [gold_summary_hallucinated[i],
                      extracted_summ_hallucinated[i]]
            label = -1

            # for each_can in self.cliff_xsum[i]['candidates']:
            #     if each_can[-1] < 0.21: # the threshold from cliff paper
            #         text_c.append(each_can[0])

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_cnndm(self):
        '''
        text_a: raw_text
        text_b: raw_summary + ***extractive summ*** removed
        text_c: DistillBERT from raw_text_b + ***DistillBERT from extractive summ text_b***
        '''
        # interpretation of fairseq-generate output: https://github.com/facebookresearch/fairseq/issues/3000
        output = []

        gold_summary = [example[DATASET_CONFIG['cnndm']['text_b']]
                        for example in self.datasets['cnndm']]
        ext_summarizer = ExtractiveSummarizationGenerator()
        extracted_summ = ext_summarizer.generate(
            [example[DATASET_CONFIG['cnndm']['text_a']] for example in self.datasets['cnndm']])

        mlm_hallucinator = MLMGeneratorWithPairedData(
            corpra=gold_summary, device='cuda:0', batch_size=64, mask_percent=0.25)
        gold_summary_hallucinated = mlm_hallucinator.generate()

        mlm_hallucinator = MLMGeneratorWithPairedData(
            corpra=extracted_summ, device='cuda:0', batch_size=64, mask_percent=0.25)
        extracted_summ_hallucinated = mlm_hallucinator.generate()

        assert len(self.datasets['cnndm']) == len(gold_summary_hallucinated) and len(
            self.datasets['cnndm']) == len(extracted_summ_hallucinated)

        for i, example in tqdm(enumerate(self.datasets['cnndm']), desc="Constructing cnndm", total=len(self.datasets['cnndm'])):
            text_a = example[DATASET_CONFIG['cnndm']['text_a']]
            text_b = [gold_summary[i], extracted_summ[i]]
            text_c = [gold_summary_hallucinated[i],
                      extracted_summ_hallucinated[i]]
            label = -1

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_wikihow(self):
        '''
        text_a: raw_text
        text_b: raw_summary + ***extractive summ*** removed
        text_c: DistillBERT from raw_text_b + ***DistillBERT from extractive summ text_b***
        '''
        # interpretation of fairseq-generate output: https://github.com/facebookresearch/fairseq/issues/3000
        output = []

        gold_summary = [example[DATASET_CONFIG['wikihow']['text_b']]
                        for example in self.datasets['wikihow']]
        ext_summarizer = ExtractiveSummarizationGenerator()
        extracted_summ = ext_summarizer.generate(
            [example[DATASET_CONFIG['wikihow']['text_a']] for example in self.datasets['wikihow']])

        mlm_hallucinator = MLMGeneratorWithPairedData(
            corpra=gold_summary, device='cuda:0', batch_size=64, mask_percent=0.25)
        gold_summary_hallucinated = mlm_hallucinator.generate()

        mlm_hallucinator = MLMGeneratorWithPairedData(
            corpra=extracted_summ, device='cuda:0', batch_size=64, mask_percent=0.25)
        extracted_summ_hallucinated = mlm_hallucinator.generate()

        assert len(self.datasets['wikihow']) == len(gold_summary_hallucinated) and len(
            self.datasets['wikihow']) == len(extracted_summ_hallucinated)

        for i, example in tqdm(enumerate(self.datasets['wikihow']), desc="Constructing wikihow", total=len(self.datasets['wikihow'])):
            text_a = example[DATASET_CONFIG['wikihow']['text_a']]
            text_b = [gold_summary[i], extracted_summ[i]]
            text_c = [gold_summary_hallucinated[i],
                      extracted_summ_hallucinated[i]]
            label = -1

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_wiki103(self):
        output = []

        paraphrases = [example[DATASET_CONFIG['wiki103']['text_b']]
                       for example in self.datasets['wiki103']]
        mlm_hallucinator = MLMGeneratorWithPairedData(
            corpra=paraphrases, device='cuda:3', batch_size=64, mask_percent=0.25)
        paraphrase_hallucinated = mlm_hallucinator.generate()

        assert len(self.datasets['wiki103']) == len(paraphrase_hallucinated)

        for i, example in tqdm(enumerate(self.datasets['wiki103']), desc=f'Constructing wiki103'):
            output.append({
                'text_a': example[DATASET_CONFIG['wiki103']['text_a']],
                'text_b': [example[DATASET_CONFIG['wiki103']['text_b']]],
                'text_c': [],
                'label': 1
            })
            output.append({
                'text_a': example[DATASET_CONFIG['wiki103']['text_a']],
                'text_b': [paraphrase_hallucinated[i]],
                'text_c': [],
                'label': 0
            })

        return output

    def process_mnli(self):
        output = []
        for example in tqdm(self.datasets['mnli'], desc=f'Constructing mnli'):
            text_a = example[DATASET_CONFIG['mnli']['text_a']]
            text_b = [example[DATASET_CONFIG['mnli']['text_b']]]
            text_c = []
            label = example[DATASET_CONFIG['mnli']['label']]

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_nli_fever(self):
        output = []
        for example in tqdm(self.datasets['nli_fever'], desc=f'Constructing nli_fever'):
            text_a = example[DATASET_CONFIG['nli_fever']['text_a']]
            text_b = [example[DATASET_CONFIG['nli_fever']['text_b']]]
            text_c = []
            raw_label = example[DATASET_CONFIG['nli_fever']['label']]
            if raw_label == 'SUPPORTS':  # convert to nli style label
                label = 0
            elif raw_label == 'REFUTES':
                label = 2
            else:
                label = 1

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_doc_nli(self):
        output = []
        for example in tqdm(self.datasets['doc_nli'], desc=f'Constructing doc_nli'):
            text_a = example[DATASET_CONFIG['doc_nli']['text_a']]
            text_b = [example[DATASET_CONFIG['doc_nli']['text_b']]]
            text_c = []
            raw_label = example[DATASET_CONFIG['doc_nli']['label']]
            if raw_label == 'entailment':  # convert to paraphrase style label
                label = 1
            else:
                label = 0

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_anli_r1(self):
        output = []
        for example in tqdm(self.datasets['anli_r1'], desc=f'Constructing anli_r1'):
            text_a = example[DATASET_CONFIG['anli_r1']['text_a']]
            text_b = [example[DATASET_CONFIG['anli_r1']['text_b']]]
            text_c = []
            label = example[DATASET_CONFIG['anli_r1']['label']]

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_anli_r2(self):
        output = []
        for example in tqdm(self.datasets['anli_r2'], desc=f'Constructing anli_r2'):
            text_a = example[DATASET_CONFIG['anli_r2']['text_a']]
            text_b = [example[DATASET_CONFIG['anli_r2']['text_b']]]
            text_c = []
            label = example[DATASET_CONFIG['anli_r2']['label']]

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_anli_r3(self):
        output = []
        for example in tqdm(self.datasets['anli_r3'], desc=f'Constructing anli_r3'):
            text_a = example[DATASET_CONFIG['anli_r3']['text_a']]
            text_b = [example[DATASET_CONFIG['anli_r3']['text_b']]]
            text_c = []
            label = example[DATASET_CONFIG['anli_r3']['label']]

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_snli(self):
        output = []
        for example in tqdm(self.datasets['snli'], desc=f'Constructing snli'):
            text_a = example[DATASET_CONFIG['snli']['text_a']]
            text_b = [example[DATASET_CONFIG['snli']['text_b']]]
            text_c = []
            label = example[DATASET_CONFIG['snli']['label']]

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_paws(self):
        output = []
        for example in tqdm(self.datasets['paws'], desc=f'Constructing paws'):
            text_a = example[DATASET_CONFIG['paws']['text_a']]
            text_b = [example[DATASET_CONFIG['paws']['text_b']]]
            text_c = []
            label = example[DATASET_CONFIG['paws']['label']]

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_vitaminc(self):
        output = []
        for example in tqdm(self.datasets['vitaminc'], desc=f'Constructing vitaminc'):
            text_a = example[DATASET_CONFIG['vitaminc']['text_a']]
            text_b = [example[DATASET_CONFIG['vitaminc']['text_b']]]
            text_c = []
            raw_label = example[DATASET_CONFIG['vitaminc']['label']]
            if raw_label == 'SUPPORTS':  # convert to nli style label
                label = 0
            elif raw_label == 'REFUTES':
                label = 2
            else:
                label = 1

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_stsb(self):
        output = []
        for example in tqdm(self.datasets['stsb'], desc=f'Constructing stsb'):
            text_a = example[DATASET_CONFIG['stsb']['text_a']]
            text_b = [example[DATASET_CONFIG['stsb']['text_b']]]
            text_c = []
            label = example[DATASET_CONFIG['stsb']['label']] / 5.0
            # label = 1 if label > 4.0 else 0

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_sick(self):
        output = []
        for example in tqdm(self.datasets['sick'], desc=f'Constructing sick'):
            text_a = example[DATASET_CONFIG['sick']['text_a']]
            text_b = [example[DATASET_CONFIG['sick']['text_b']]]
            text_c = []
            label = example[DATASET_CONFIG['sick']['label']] / 5.0
            # label = 1 if label > 4.0 else 0

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_mrpc(self):
        output = []
        for example in tqdm(self.datasets['mrpc'], desc=f'Constructing mrpc'):
            text_a = example[DATASET_CONFIG['mrpc']['text_a']]
            text_b = [example[DATASET_CONFIG['mrpc']['text_b']]]
            text_c = []
            label = example[DATASET_CONFIG['mrpc']['label']]

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_mrpc_val(self):
        output = []
        for example in tqdm(self.datasets['mrpc_val'], desc=f'Constructing mrpc_val'):
            text_a = example[DATASET_CONFIG['mrpc_val']['text_a']]
            text_b = [example[DATASET_CONFIG['mrpc_val']['text_b']]]
            text_c = []
            label = example[DATASET_CONFIG['mrpc_val']['label']]

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_paws_val(self):
        output = []
        for example in tqdm(self.datasets['paws_val'], desc=f'Constructing paws_val'):
            text_a = example[DATASET_CONFIG['paws_val']['text_a']]
            text_b = [example[DATASET_CONFIG['paws_val']['text_b']]]
            text_c = []
            label = example[DATASET_CONFIG['paws_val']['label']]

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_paws_unlabeled(self):
        output = []
        for example in tqdm(self.datasets['paws_unlabeled'], desc=f'Constructing paws_unlabeled'):
            text_a = example[DATASET_CONFIG['paws_unlabeled']['text_a']]
            text_b = [example[DATASET_CONFIG['paws_unlabeled']['text_b']]]
            text_c = []
            label = example[DATASET_CONFIG['paws_unlabeled']['label']]

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_qqp(self):
        output = []
        for example in tqdm(self.datasets['qqp'], desc=f'Constructing qqp'):
            text_a = example[DATASET_CONFIG['qqp']['text_a']]
            text_b = [example[DATASET_CONFIG['qqp']['text_b']]]
            text_c = []
            label = example[DATASET_CONFIG['qqp']['label']]

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_qqp_val(self):
        output = []
        for example in tqdm(self.datasets['qqp_val'], desc=f'Constructing qqp_val'):
            text_a = example[DATASET_CONFIG['qqp_val']['text_a']]
            text_b = [example[DATASET_CONFIG['qqp_val']['text_b']]]
            text_c = []
            label = example[DATASET_CONFIG['qqp_val']['label']]

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_msmarco(self):
        qa2d_generator = QA2D(batch_size=32, device='cuda')
        output = []
        correct_contexts = []
        correct_questions = []
        correct_answers = []

        wrong_contexts = []
        wrong_questions = []
        wrong_answers = []

        filtered_examples = []
        questions = []
        answers = []
        declaratives = []

        for example in tqdm(self.datasets['msmarco'], desc=f'Collecting msmarco'):
            if sum(example['passages']['is_selected']) > 0:  # has answer
                questions.append(example['query'])
                answers.append(example['answers'][0] if len(
                    example['wellFormedAnswers']) == 0 else example['wellFormedAnswers'][0])
                filtered_examples.append(example)
        #         for i, is_selected in enumerate(example['passages']['is_selected']):
        #             # text_a = example['passages']['passage_text'][i]
        #             # text_b = [example['query'] + ' ' + each_answer for each_answer in example['answers']] if len(example['wellFormedAnswers']) == 0 else [example['query'] + ' ' + each_answer for each_answer in example['wellFormedAnswers']]
        #             # text_c = []

        #             if is_selected == 1:
        #                 correct_contexts.append(example['passages']['passage_text'][i])
        #                 correct_questions.append(example['query'])
        #                 correct_answers.append(example['answers'][0] if len(example['wellFormedAnswers']) == 0 else example['wellFormedAnswers'][0])
        #             else:
        #                 wrong_contexts.append(example['passages']['passage_text'][i])
        #                 wrong_questions.append(example['query'])
        #                 wrong_answers.append(example['answers'][0] if len(example['wellFormedAnswers']) == 0 else example['wellFormedAnswers'][0])

        # correct_dec = qa2d_generator.generate(correct_questions, correct_answers)
        # wrong_dec = qa2d_generator.generate(wrong_questions, wrong_answers)
        # assert len(correct_contexts) == len(correct_dec)
        # assert len(wrong_contexts) == len(wrong_dec)
        # for context, dec in zip(correct_contexts, correct_dec):
        #     output.append({
        #                     'text_a': context,
        #                     'text_b': [dec],
        #                     'text_c': [],
        #                     'label': 1
        #                 })
        # for context, dec in zip(wrong_contexts, wrong_dec):
        #     output.append({
        #                     'text_a': context,
        #                     'text_b': [dec],
        #                     'text_c': [],
        #                     'label': 0
        #                 })

        # declaratives = qa2d_generator.generate(
        #     questions=questions, answers=answers)
        # assert len(filtered_examples) == len(declaratives)
        # for example, one_dec in zip(filtered_examples, declaratives):
        for example in filtered_examples:
            for i, is_selected in enumerate(example['passages']['is_selected']):
                if is_selected == 1:
                    output.append({
                        'text_a': example['passages']['passage_text'][i],
                        'text_b': [example['query']],
                        'text_c': [],
                        'label': 1
                    }
                    )
                else:
                    output.append({
                        'text_a': example['passages']['passage_text'][i],
                        'text_b': [example['query']],
                        'text_c': [],
                        'label': 0
                    }
                    )
        return output

    def process_paws_qqp(self):
        output = []

        for i in range(len(self.datasets['paws_qqp'])):
            text_a = self.datasets['paws_qqp'].iloc[i]['sentence1'][2:-1]
            text_b = [self.datasets['paws_qqp'].iloc[i]['sentence2'][2:-1]]
            text_c = []
            label = self.datasets['paws_qqp'].iloc[i]['label']

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': int(label)
            })

        return output

    def process_wmt15(self):
        output = []

        for example in self.datasets['wmt15']:
            text_a = example['reference']
            text_b = [example['candidate']]
            text_c = []
            label = example['score']

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_wmt16(self):
        output = []

        for example in self.datasets['wmt16']:
            text_a = example['reference']
            text_b = [example['candidate']]
            text_c = []
            label = example['score']

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_wmt17(self):
        # output = []

        # for i in range(len(self.datasets['wmt17'])):
        #     if self.datasets['wmt17'].iloc[i]['lp'][-2:] != 'en':
        #         continue
        #     text_a = self.datasets['wmt17'].iloc[i]['ref']
        #     text_b = [self.datasets['wmt17'].iloc[i]['mt']]
        #     text_c = []
        #     label = self.datasets['wmt17'].iloc[i]['score']

        #     output.append({
        #                 'text_a': text_a,
        #                 'text_b': text_b,
        #                 'text_c': text_c,
        #                 'label': label
        #             })

        # return output

        output = []

        for example in self.datasets['wmt17']:
            text_a = example['reference']
            text_b = [example['candidate']]
            text_c = []
            label = example['score']

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_wmt18(self):
        output = []

        for example in self.datasets['wmt18']:
            text_a = example['reference']
            text_b = [example['candidate']]
            text_c = []
            label = example['score']

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_wmt19(self):
        output = []

        for example in self.datasets['wmt19']:
            text_a = example['reference']
            text_b = [example['candidate']]
            text_c = []
            label = example['score']

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output
    
    def process_boolq(self):
        output = []

        for example in self.datasets['boolq']:
            text_a = example['passage']
            text_b = [example['question']]
            text_c = ["Yes." if example['answer'] else "No."]
            label = 1

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

            text_a = example['passage']
            text_b = [example['question']]
            text_c = ["Yes." if not example['answer'] else "No."]
            label = 0

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output
    
    def process_eraser_multi_rc(self):
        output = []

        for example in self.datasets['eraser_multi_rc']:
            text_a = example['passage']
            text_b = [example['query_and_answer'].replace("|", "")]
            text_c = []
            label = int(example['label'])

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output
    
    def process_quail(self):
        output = []

        for example in self.datasets['quail']:
            for i, ans in enumerate(example['answers']):
                text_a = example['context']
                text_b = [example['question']]
                text_c = [ans]
                label = 1 if i == example['correct_answer_id'] else 0

                output.append({
                    'text_a': text_a,
                    'text_b': text_b,
                    'text_c': text_c,
                    'label': label
                })

        return output
    
    def process_sciq(self):
        output = []

        for example in self.datasets['sciq']:
            text_a = example['support']

            output.append({
                'text_a': text_a,
                'text_b': [example['question']],
                'text_c': [example['distractor1']],
                'label': 0
            })
            output.append({
                'text_a': text_a,
                'text_b': [example['question']],
                'text_c': [example['distractor2']],
                'label': 0
            })
            output.append({
                'text_a': text_a,
                'text_b': [example['question']],
                'text_c': [example['distractor3']],
                'label': 0
            })
            output.append({
                'text_a': text_a,
                'text_b': [example['question']],
                'text_c': [example['correct_answer']],
                'label': 1
            })

        return output
    
    def process_strategy_qa(self):
        output = []

        for example in self.datasets['strategy_qa']:
            text_a = ' '.join(example['facts'])
            text_b = [example['question']]
            text_c = ["Yes." if example['answer'] else "No."]
            label = 1

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

            text_a = ' '.join(example['facts'])
            text_b = [example['question']]
            text_c = ["Yes." if not example['answer'] else "No."]
            label = 0

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def process_gap(self):
        output = []

        for example in self.datasets['gap']:
            text_a = example['Text']
            text_b = [example['Text'][:example['Pronoun-offset']]+example['A']+example['Text'][(example['Pronoun-offset']+len(example['Pronoun'])):]]
            text_c = []
            label = 1 if example['A-coref'] else 0

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

            text_a = example['Text']
            text_b = [example['Text'][:example['Pronoun-offset']]+example['B']+example['Text'][(example['Pronoun-offset']+len(example['Pronoun'])):]]
            text_c = []
            label = 1 if example['B-coref'] else 0

            output.append({
                'text_a': text_a,
                'text_b': text_b,
                'text_c': text_c,
                'label': label
            })

        return output

    def init_qa_t5(self):
        from transformers import T5Tokenizer, T5ForConditionalGeneration
        if self.t5_qa is None:
            self.t5_tokenizer = T5Tokenizer.from_pretrained(
                "t5-base", model_max_length=800)
            self.t5_qa = T5ForConditionalGeneration.from_pretrained("t5-base")
            self.t5_qa.to('cuda:1')
            self.t5_qa.eval()

    @staticmethod
    def mask_answer(context, answers):
        answers = sorted(answers, key=len, reverse=True)
        for answer in answers:
            pattern = f'(?<![\w\\-\u2013]){re.escape(answer)}(?![\w\\-\u2013])'
            context = re.sub(pattern, '', context, flags=re.IGNORECASE)
        return context

    def generate_fake_answer(self, context, question, answers):
        self.init_qa_t5()

        context_no_answer = self.mask_answer(context, answers)

        input_ids = self.t5_tokenizer(
            f'question: {question} context: {context_no_answer}',
            return_tensors="pt",
            truncation='only_first'
        ).input_ids.to(self.t5_qa.device)

        outputs = self.t5_qa.generate(
            input_ids,
            max_new_tokens=40,
            remove_invalid_values=True
        )

        return self.t5_tokenizer.decode(outputs[0], skip_special_tokens=True)

    def negative_sample_qa(self, samples, negative_sample_no_ans_only=True):
        outputs = []
        for context, question, answers in samples:
            if answers:
                outputs.append({
                    'text_a': context,
                    'text_b': [question],
                    'text_c': answers,
                    'label': 1
                })
            if not answers or not negative_sample_no_ans_only:
                fake_answer = self.generate_fake_answer(
                    context, question, answers)
                outputs.append({
                    'text_a': context,
                    'text_b': [question],
                    'text_c': [fake_answer],
                    'label': 0
                })

        return outputs

    def process_squad_v2_new(self):
        samples = (
            (sample['context'], sample['question'], sample['answers']['text'])
            for sample in tqdm(self.datasets['squad_v2_new'], desc=f'squad_v2_new')
        )
        return self.negative_sample_qa(samples)

    def process_adversarial_qa(self):
        samples = (
            (sample['context'], sample['question'], sample['answers']['text'])
            for sample in tqdm(self.datasets['adversarial_qa'], desc=f'adversarial_qa')
        )
        return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)

    def process_drop(self):
        samples = (
            (sample['passage'], sample['question'],
             sample['answers_spans']['spans'])
            for sample in tqdm(self.datasets['drop'], desc=f'drop')
        )
        return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)

    def process_duorc_self(self):
        samples = (
            (sample['plot'], sample['question'],
             sample['answers'])
            for sample in tqdm(self.datasets['duorc_self'], desc=f'duorc_self')
        )
        return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)

    def process_duorc_paraphrase(self):
        samples = (
            (sample['plot'], sample['question'],
             sample['answers'])
            for sample in tqdm(self.datasets['duorc_paraphrase'], desc=f'duorc_paraphrase')
        )
        return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)

    def process_quoref(self):
        samples = (
            (sample['context'], sample['question'], sample['answers']['text'])
            for sample in tqdm(self.datasets['quoref'], desc=f'quoref')
        )
        return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)

    @staticmethod
    def prepare_hotpot_qa_samples(dateset):
        for sample in dateset:
            question = sample['question']
            answer = sample['answer']
            supporting_docs = set(sample['supporting_facts']['title'])
            irrelevant_docs = []
            context_paragraphs = []
            for title, setences in zip(sample['context']['title'], sample['context']['sentences']):
                doc = ''.join(setences)
                if title in supporting_docs:
                    context_paragraphs.append(doc)
                else:
                    irrelevant_docs.append(doc)
            # Add some irrelevant documents
            if irrelevant_docs and len(context_paragraphs) < 4:
                context_paragraphs.append(random.choice(irrelevant_docs))
            random.shuffle(context_paragraphs)
            yield '\n'.join(context_paragraphs), question, [answer]

    def process_hotpot_qa_distractor(self):
        samples = self.prepare_hotpot_qa_samples(
            tqdm(self.datasets['hotpot_qa_distractor'],
                 desc=f'hotpot_qa_distractor')
        )
        return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)

    def process_hotpot_qa_fullwiki(self):
        samples = self.prepare_hotpot_qa_samples(
            tqdm(self.datasets['hotpot_qa_fullwiki'],
                 desc=f'hotpot_qa_fullwiki')
        )
        return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)

    def process_newsqa(self):
        def get_samples(dataset):
            for story in tqdm(dataset['data'], desc='newsqa'):
                if story['type'] != 'train':
                    continue
                context = story['text']
                for question in story['questions']:
                    if question.get('isQuestionBad', 0.) > 0.2:
                        continue
                    answers = []
                    if 's' in question['consensus']:
                        start = question['consensus']['s']
                        end = question['consensus']['e']
                        answers.append(context[start:end].strip())
                    yield context, question['q'], answers
        samples = get_samples(self.datasets['newsqa'])
        return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)

    def process_ropes(self):
        samples = (
            (
                sample['situation'] + ' ' + sample['background'],
                sample['question'], sample['answers']['text']
            )
            for sample in tqdm(self.datasets['ropes'], desc=f'ropes')
        )
        return self.negative_sample_qa(samples, negative_sample_no_ans_only=False)

    def generate(self):
        for each_dataset in self.datasets:
            with open(f'./data/training/{each_dataset}.json', 'w', encoding='utf8') as outfile:
                outfile.write("")
        for each_dataset in self.datasets:
            outputs = eval(f'self.process_{each_dataset}()')

            for each_output in outputs:
                dict_write_to_file = {
                    'task': DATASET_CONFIG[each_dataset]['task'],
                    'text_a': each_output['text_a'],  # string
                    # list of positive examples
                    'text_b': each_output['text_b'],
                    # list of negative examples
                    'text_c': each_output['text_c'],
                    # original label, if -1 only has positive pairs and negative pairs
                    'orig_label': each_output['label']
                }
                # if len(outputs) % 1000 == 0:
                with open(f'./data/training/{each_dataset}.json', 'a', encoding='utf8') as outfile:
                    # for each in outputs:
                    json.dump(dict_write_to_file, outfile, ensure_ascii=False)
                    outfile.write('\n')


if __name__ == "__main__":
    random.seed(42)
    gen = DataGenerator(list(DATASET_CONFIG.keys()))
    # gen = DataGenerator(['sick'])
    gen.generate()
