import json
from collections import defaultdict as ddict
from tqdm import tqdm
import datasets
import random

from src.utils.data import dataset_info, data_keys


class DatasetBuilder():
    def __init__(self, dataset_name, args_split, split, io_mode, rationale_src, tokenizer):
        self.args_split = args_split
        self.split = split
        self.ftr_token_replacement = False
        if io_mode == 'IreplacedR-O':
            io_mode = 'IR-O'
            self.ftr_token_replacement = True

        self.io_mode = io_mode
        self.rationale_src = rationale_src
        self.tokenizer = tokenizer

        self.dataset = self.load_raw_dataset(dataset_name)
        self.indices = range(len(self.dataset))
        self.classes = dataset_info[dataset_name]['classes']
        self.delimiters = dataset_info[dataset_name]['delimiters']
        self.generated_rationales = None

        self.dataset_dict = ddict(list)
    
    def shuffle_rationale(self, dataset):
        raise NotImplementedError # It should be implemented by the sub-class.
    
    def shuffle_rationale(self, dataset, key):
        dataset = list(dataset) # It is for dataset loaded from HuggingFace.
        keys = [instance[key] for instance in dataset]
        random.shuffle(keys)
        for instance, new_key in zip(dataset, keys):
            instance[key] = new_key
        assert self.io_mode == 'IshuffledR-O'
        self.io_mode = 'IR-O'
        return dataset

    def load_raw_dataset(self, dataset_name):
        dataset = datasets.load_dataset(dataset_info[dataset_name]['hf'])[self.split]
        if self.io_mode == 'IshuffledR-O':
            dataset = self.shuffle_rationale(dataset)
            self.io_mode = 'IR-O'
        return dataset

    def load_generated_rationales(self, dataset_name):
        filename = dataset_info[dataset_name]['rationale_file'][f'{self.args_split}_{self.rationale_src}']
        rationales = []
        with open(filename, 'r') as fin:
            for line in fin:
                data = json.loads(line)
                rationales.append(data)
        return rationales

    def process_instance(self, instance, item_idx):
        example, rationale, label, target_seq = None, None, None, None
        return example, rationale, label, target_seq

    def update_dataset_dict(self, item_idx, example, rationale, label, target_seq):
        self.dataset_dict['item_idx'].append(item_idx)

        example_ids = self.tokenizer(text=example, add_special_tokens=True).input_ids
        if self.io_mode == 'IR-O':
            assert rationale is not None
            example_rationale_ids = example_ids + self.tokenizer(text='explain: {}'.format(rationale), add_special_tokens=True).input_ids
            token_type = [1] * len(example_ids) + [0] * (len(example_rationale_ids) - len(example_ids))
            if self.ftr_token_replacement:
                assert token_type[len(example_ids) - 1] == 1
                for i in range(len(example_ids), len(example_rationale_ids)):
                    assert token_type[i] == 0
                    example_rationale_ids[i] = random.randint(0, len(self.tokenizer) - 1)
            example_ids = example_rationale_ids
        else:
            token_type = [1] * len(example_ids)
        self.dataset_dict['example'].append(example_ids)
        self.dataset_dict['token_type'].append(token_type)

        rationale_ids = None if rationale is None else self.tokenizer(text=rationale, add_special_tokens=True).input_ids
        self.dataset_dict['rationale'].append(rationale_ids)
        
        self.dataset_dict['label'].append(label)

        if self.io_mode in ['I-OR', 'I-RO']:
            target_seq_ids = self.tokenizer(text=target_seq, add_special_tokens=True).input_ids
        else:
            target_seq_ids = [self.tokenizer(text=x, add_special_tokens=True).input_ids for x in target_seq]
        self.dataset_dict['target_seq'].append(target_seq_ids)

    def process_instances(self):
        for item_idx in tqdm(self.indices, total=len(self.indices), desc=f'Building {self.args_split} dataset'):
            instance = self.dataset[item_idx]
            example, rationale, label, target_seq = self.process_instance(instance, item_idx)
            self.update_dataset_dict(item_idx, example, rationale, label, target_seq)
        return self.dataset_dict
    
    def get_example(self, example, rationale):
        assert isinstance(example, str)
        if self.io_mode in ['IR-O']:
            # example = f'{example} explanation: {rationale}'
            pass
        elif self.io_mode in ['I-OR', 'I-RO']:
            example = f'explain {example}'
        return example


class eSNLIBuilder(DatasetBuilder):
    def __init__(self, dataset_name, args_split, split, io_mode, rationale_src, tokenizer):
        super().__init__(dataset_name, args_split, split, io_mode, rationale_src, tokenizer)

    def process_instance(self, instance, item_idx):
        hypothesis = instance['hypothesis']
        premise = instance['premise']
        label = self.classes[instance['label']]
        rationale = instance['explanation_1']
        target_seq = None

        example = self.get_example(f'{self.delimiters["hypothesis"]} {hypothesis} {self.delimiters["premise"]} {premise}', rationale)
        if self.io_mode == 'I-OR':
            target_seq = f'{label} explanation: {instance[f"explanation_1"]}'
        elif self.io_mode == 'I-RO':
            target_seq = f'{instance[f"explanation_1"]} label: {label}'
        else:
            label = self.classes.index(label)
            target_seq = self.classes

        return example, rationale, label, target_seq


class ECQABuilder(DatasetBuilder):
    def __init__(self, dataset_name, args_split, split, io_mode, rationale_src, tokenizer):
        super().__init__(dataset_name, args_split, split, io_mode, rationale_src, tokenizer)
        self.dataset_name = dataset_name
        if self.rationale_src not in [None, 'gold']:
            self.generated_rationales = self.load_generated_rationales(dataset_name)

    def process_instance(self, instance, item_idx):
        question = instance['q_text']
        choices = [instance[f'q_op{choice_index}'] for choice_index in range(1, len(self.classes)+1)]
        label = instance['q_ans']

        if self.generated_rationales is not None:
            raise NotImplementedError # Mismatch between instances in self.dataset and self.generated_rationales
            instance_ = self.generated_rationales[item_idx]
            assert instance_['id'] == instance['q_no']
            rationale = instance_['explanation'][instance_['answer']]
        else:
            rationale = instance['taskA_pos']

        if self.dataset_name == 'ecqa_unk':
            rationale = rationale.split()
            banned_words = label.lower().split()
            for index, r_word in enumerate(rationale):
                if r_word.lower() in banned_words:
                    rationale[index] = '<extra_id_0>'
            rationale = ' '.join(rationale)

        target_seq = None
        choices_text = ' '.join([f'({item}) {choice}' for item, choice in zip(self.classes, choices)])

        example = self.get_example(f'{self.delimiters["question"]} {question} {self.delimiters["choices"]} {choices_text}', rationale)
        if self.io_mode == 'I-OR':
            target_seq = f'{label} explanation: {rationale}'
        elif self.io_mode == 'I-RO':
            target_seq = f'{rationale} label: {label}'
        else:
            label = choices.index(label)
            target_seq = choices
        
        return example, rationale, label, target_seq


class OpenBookQABuilder(DatasetBuilder):
    def __init__(self, dataset_name, args_split, split, io_mode, rationale_src, tokenizer):
        super().__init__(dataset_name, args_split, split, io_mode, rationale_src, tokenizer)
        if self.rationale_src not in [None, 'gold']:
            self.generated_rationales = self.load_generated_rationales(dataset_name)
    
    def load_raw_dataset(self, dataset_name):
        dataset = datasets.load_dataset(dataset_name, 'additional')[self.split]
        if self.io_mode == 'IshuffledR-O':
            dataset = self.shuffle_rationale(dataset, 'fact1')
        return dataset

    def process_instance(self, instance, item_idx):
        question = instance['question_stem']
        choices = instance['choices']['text']
        assert instance['choices']['label'] == ['A', 'B', 'C', 'D']
        label = choices[instance['choices']['label'].index(instance['answerKey'])]
        target_seq = None
        choices_text = ' '.join([f'({item}) {choice}' for item, choice in zip(instance['choices']['label'], instance['choices']['text'])])

        if self.generated_rationales is not None:
            instance_ = self.generated_rationales[item_idx]
            assert instance_['id'] == instance['id']
            rationale = instance_['explanation'][instance_['answer']]
        else:
            rationale = instance['fact1']

        example = self.get_example(f'{self.delimiters["question"]} {question} {self.delimiters["choices"]} {choices_text}', rationale)
        if self.io_mode == 'I-OR':
            target_seq = f'{label} explanation: {rationale}'
        elif self.io_mode == 'I-RO':
            target_seq = f'{rationale} label: {label}'
        else:
            label = choices.index(label)
            target_seq = choices
        
        return example, rationale, label, target_seq


class StrategyQABuilder(DatasetBuilder):
    def __init__(self, dataset_name, args_split, split, io_mode, rationale_src, tokenizer):
        super().__init__(dataset_name, args_split, split, io_mode, rationale_src, tokenizer)
        if self.rationale_src not in [None, 'gold']:
            self.generated_rationales = self.load_generated_rationales(dataset_name)

    def load_raw_dataset(self, dataset_name):
        filename = dataset_info[dataset_name]['file'][self.split]
        with open(filename, 'r', encoding='utf-8') as fin:
            dataset = json.load(fin)
        if self.io_mode == 'IshuffledR-O':
            dataset = self.shuffle_rationale(dataset, 'facts')
        return dataset

    def process_instance(self, instance, item_idx):
        question = instance['question']
        label = instance['answer']
        target_seq = None

        if self.generated_rationales is not None:
            instance_ = self.generated_rationales[item_idx]
            assert instance_['context'].strip() == instance['question'].strip()
            rationale = instance_['explanation'][instance_['answer']]
        else:
            rationale = ' '.join(instance['facts'])

        example = self.get_example(question, rationale)
        if self.io_mode == 'I-OR':
            target_seq = f'{label} explanation: {rationale}'
        elif self.io_mode == 'I-RO':
            target_seq = f'{rationale} label: {label}'
        else:
            label = self.classes.index(label)
            target_seq = self.classes
            
        return example, rationale, label, target_seq

class CREAKBuilder(DatasetBuilder):
    def __init__(self, dataset_name, args_split, split, io_mode, rationale_src, tokenizer):
        super().__init__(dataset_name, args_split, split, io_mode, rationale_src, tokenizer)

    def load_raw_dataset(self, dataset_name):
        filename = dataset_info[dataset_name]['file'][self.split]
        with open(filename) as fin:
            return [json.loads(line) for line in fin]

    def process_instance(self, instance, item_idx):
        sentence = instance['sentence']
        label = instance['label']
        rationale = instance['explanation']
        target_seq = None

        example = self.get_example(sentence, rationale)
        if self.io_mode == 'I-OR':
            target_seq = f'{label} explanation: {rationale}'
        elif self.io_mode == 'I-RO':
            target_seq = f'{rationale} label: {label}'
        else:
            label = self.classes.index(label)
            target_seq = self.classes
            
        return example, rationale, label, target_seq

class QuaRTzBuilder(DatasetBuilder):
    def __init__(self, dataset_name, args_split, split, io_mode, tokenizer):
        super().__init__(dataset_name, args_split, split, io_mode, tokenizer)

    def load_raw_dataset(self, dataset_name):
        filename = dataset_info[dataset_name]['file'][self.split]
        with open(filename) as fin:
            dataset = [json.loads(line) for line in fin]
        if self.io_mode == 'IshuffledR-O':
            dataset = self.shuffle_rationale(dataset, 'para')
        return dataset
    
    def process_instance(self, instance):
        question = instance['question']['stem']
        assert len(instance['question']['choices']) == 2
        assert (instance['question']['choices'][0]['label'], instance['question']['choices'][1]['label']) == ('A', 'B')
        choices = [instance['question']['choices'][0]['text'], instance['question']['choices'][1]['text']]
        label = choices[('A', 'B').index(instance['answerKey'])]
        rationale = instance['para']
        target_seq = None
        choices_text = ' '.join([f'({item}) {choice}' for item, choice in zip(('A', 'B'), choices)])

        example = self.get_example(f'{self.delimiters["question"]} {question} {self.delimiters["choices"]} {choices_text}', rationale)
        if self.io_mode == 'I-OR':
            target_seq = f'{label} explanation: {rationale}'
        elif self.io_mode == 'I-RO':
            target_seq = f'{rationale} label: {label}'
        else:
            label = choices.index(label)
            target_seq = choices
        
        return example, rationale, label, target_seq

class AQuA_RATBuilder(DatasetBuilder):
    def __init__(self, dataset_name, args_split, split, io_mode, tokenizer):
        super().__init__(dataset_name, args_split, split, io_mode, tokenizer)
        self.dataset_name = dataset_name
    
    def load_raw_dataset(self, dataset_name):
        dataset = datasets.load_dataset('aqua_rat', 'raw')[self.split]
        if self.io_mode == 'IshuffledR-O':
            dataset = self.shuffle_rationale(dataset, 'rationale')
        return dataset
    
    def unk_process(self, rationale, label):
        for index in range(len(rationale) - 1, -1, -1):
            if rationale[index] in ('A', 'B', 'C', 'D', 'E'):
                rationale = rationale[:index]
                break
        rationale = rationale.replace(label[2:], ' ')
        return rationale
    
    def process_instance(self, instance):
        question = instance['question']
        choices = instance['options']
        for item, choice in zip(('A)', 'B)', 'C)', 'D)', 'E)'), choices):
            assert choice[:2] == item
        label = choices[('A', 'B', 'C', 'D', 'E').index(instance['correct'])]
        rationale = instance['rationale']
        if self.dataset_name == 'aqua_rat_unk':
            rationale = self.unk_process(rationale, label)
        target_seq = None
        choices_text = ' '.join(choices)

        example = self.get_example(f'{self.delimiters["question"]} {question} {self.delimiters["choices"]} {choices_text}', rationale)
        if self.io_mode == 'I-OR':
            target_seq = f'{label} explanation: {rationale}'
        elif self.io_mode == 'I-RO':
            target_seq = f'{rationale} label: {label}'
        else:
            label = choices.index(label)
            target_seq = choices
        
        return example, rationale, label, target_seq

class QASCBuilder(DatasetBuilder):
    def __init__(self, dataset_name, args_split, split, io_mode, rationale_src, tokenizer):
        super().__init__(dataset_name, args_split, split, io_mode, rationale_src, tokenizer)
        self.dataset_name = dataset_name
        if self.rationale_src not in [None, 'gold']:
            self.generated_rationales = self.load_generated_rationales(dataset_name)

        self.collected_rationale = {}
        for split in ('train', 'dev', 'test'):
            with open(dataset_info[dataset_name]['rationale_file'][split]) as fin:
                for instance in json.load(fin):
                    rationales = {}
                    for index, choice in enumerate(instance['question']['choices']):
                        if index:
                            if choice['chains']:
                                fact1, fact2 = choice['chains'][0][0]['text'], choice['chains'][0][1]['text']
                            else:
                                fact1 = fact2 = ''
                        else:
                            fact1, fact2 = instance['fact1'], instance['fact2']
                        rationales[self.eliminate_period(choice['text'])] = ' . '.join([self.eliminate_period(fact1), self.eliminate_period(fact2)])
                    self.collected_rationale[instance['id']] = rationales

    def eliminate_period(self, word : str):
        if word.endswith('.'):
            word = word[: -1]
        return word

    def process_instance(self, instance, item_idx):
        question = instance['question']

        choices = instance['choices']['text']
        for index, choice in enumerate(choices):
            choices[index] = self.eliminate_period(choice)
        assert instance['choices']['label'] == ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
        choices_text = ' '.join([f'({item}) {choice}' for item, choice in zip(instance['choices']['label'], choices)])
        
        label = choices[instance['choices']['label'].index(instance['answerKey'])]
        target_seq = None

        if self.generated_rationales is not None:
            instance_ = self.generated_rationales[item_idx]
            assert instance_['id'] == instance['id']
            rationale = instance_['explanation'][instance_['answer']]
        else:
            if self.dataset_name == 'eqasc':
                rationale = '</s>'.join([self.collected_rationale[instance['id']][choice] for choice in choices])
            else:
                rationale = self.collected_rationale[instance['id']][label]
                if self.dataset_name == 'qasc_unk':
                    rationale = rationale.split()
                    banned_words = label.lower().split()
                    for index, r_word in enumerate(rationale):
                        if r_word.lower() in banned_words:
                            rationale[index] = '<extra_id_0>'
                    rationale = ' '.join(rationale)

        example = self.get_example(f'{self.delimiters["question"]} {question} {self.delimiters["choices"]} {choices_text}', rationale)
        if self.io_mode == 'I-OR':
            target_seq = f'{label} explanation: {rationale}'
        elif self.io_mode == 'I-RO':
            target_seq = f'{rationale} label: {label}'
        else:
            label = choices.index(label)
            target_seq = choices
        
        return example, rationale, label, target_seq

class ComVEBuilder(DatasetBuilder):
    def __init__(self, dataset_name, args_split, split, io_mode, tokenizer):
        super().__init__(dataset_name, args_split, split, io_mode, tokenizer)
        self.dataset_name = dataset_name

    def load_raw_dataset(self, dataset_name):
        filename = dataset_info[dataset_name]['file'][self.split]
        with open(filename, 'r', encoding='utf-8') as fin:
            dataset = json.load(fin)
        return dataset

    def process_instance(self, instance):
        sent0, sent1 = instance['sent0'], instance['sent1']
        label = instance['label']
        rationale = instance['explanation']
        target_seq = None

        if self.dataset_name == 'comve_unk':
            sets = {
                'sent0': set(sent0.lower().replace('.', '').replace(',', '').split()),
                'sent1': set(sent1.lower().replace('.', '').replace(',', '').split()),
            }
            intersection = sets['sent0'] & sets['sent1']
            rationale = rationale.lower()
            for word in sets[label]:
                if word not in intersection:
                    rationale = rationale.replace(word, '<extra_id_0>')

        example = self.get_example(' </s> '.join(['sent0: {}'.format(sent0), 'sent1: {}'.format(sent1)]), rationale)
        if self.io_mode == 'I-OR':
            target_seq = f'{label} explanation: {rationale}'
        elif self.io_mode == 'I-RO':
            target_seq = f'{rationale} label: {label}'
        else:
            label = self.classes.index(label)
            target_seq = self.classes

        return example, rationale, label, target_seq

class WinoWhyBuilder(DatasetBuilder):
    def __init__(self, dataset_name, args_split, split, io_mode, tokenizer):
        super().__init__(dataset_name, args_split, split, io_mode, tokenizer)
    
    def load_raw_dataset(self, dataset_name):
        filename = dataset_info[dataset_name]['file'][self.split]
        with open(filename, 'r', encoding='utf-8') as fin:
            dataset = json.load(fin)
        return dataset
    
    def process_instance(self, instance):
        text = 'wsc: {} *{}* {}'.format(instance['text']['txt1'], instance['text']['pron'], instance['text']['txt2'])
        choices = instance['answers']
        label = choices[('A', 'B').index(instance['correctAnswer'].replace('.', ''))]
        rationale = instance['explanation']
        target_seq = None

        example = self.get_example(text, rationale)
        if self.io_mode == 'I-OR':
            target_seq = f'{label} explanation: {rationale}'
        elif self.io_mode == 'I-RO':
            target_seq = f'{rationale} label: {label}'
        else:
            label = choices.index(label)
            target_seq = choices
        
        return example, rationale, label, target_seq

class MNLIBuilder(DatasetBuilder):
    def __init__(self, dataset_name, args_split, split, io_mode, rationale_src, tokenizer):
        super().__init__(dataset_name, args_split, split, io_mode, rationale_src, tokenizer)
        assert io_mode == 'I-O'
    
    def load_raw_dataset(self, dataset_name):
        return datasets.load_dataset('glue', 'mnli')[self.split]

    def process_instance(self, instance, item_idx):
        hypothesis = instance['hypothesis']
        premise = instance['premise']
        label = self.classes[instance['label']]
        rationale = None
        target_seq = None

        example = f'{self.delimiters["hypothesis"]} {hypothesis} {self.delimiters["premise"]} {premise}'
        label = self.classes.index(label)
        target_seq = self.classes

        return example, rationale, label, target_seq

class ANLIBuilder(DatasetBuilder):
    def __init__(self, dataset_name, args_split, split, io_mode, rationale_src, tokenizer):
        super().__init__(dataset_name, args_split, split, io_mode, rationale_src, tokenizer)
        assert io_mode == 'I-O'
    
    def process_instance(self, instance, item_idx):
        hypothesis = instance['hypothesis']
        premise = instance['premise']
        label = self.classes[instance['label']]
        rationale = None
        target_seq = None

        example = f'{self.delimiters["hypothesis"]} {hypothesis} {self.delimiters["premise"]} {premise}'
        label = self.classes.index(label)
        target_seq = self.classes
            
        return example, rationale, label, target_seq

builder_dict = {
    'esnli': eSNLIBuilder,
    'ecqa': ECQABuilder, 'ecqa_unk': ECQABuilder,
    'strategyqa': StrategyQABuilder,
    'creak': CREAKBuilder,
    'quartz': QuaRTzBuilder,
    'aqua_rat': AQuA_RATBuilder, 'aqua_rat_unk': AQuA_RATBuilder,
    'winowhy_0': WinoWhyBuilder, 'winowhy_1': WinoWhyBuilder, 'winowhy_2': WinoWhyBuilder, 'winowhy_3': WinoWhyBuilder, 'winowhy_4': WinoWhyBuilder,
    'qasc': QASCBuilder, 'qasc_unk': QASCBuilder, 'eqasc': QASCBuilder,
    'comve': ComVEBuilder, 'comve_unk': ComVEBuilder,
    'mnli_matched': MNLIBuilder, 'mnli_mismatched': MNLIBuilder,
    'anli_r1': ANLIBuilder, 'anli_r2': ANLIBuilder, 'anli_r3': ANLIBuilder,
    'openbookqa': OpenBookQABuilder,
}