import os
from functools import partial

os.environ['TRANSFORMERS_CACHE'] = 'data/hg_data/transformers'
os.environ['HF_DATASETS_CACHE'] = 'data/hg_data/datasets'

from transformers import AutoTokenizer
from datasets import load_dataset, Dataset
import pandas as pd


class Dataloader:

    def __init__(self, data_dir, max_length):
        self.data_dir = data_dir
        self.max_length = max_length
        self.tokenizer_dict = {}
        self.necessary_items = None
        self.task_name = None
        self.dataset_info = None
        self.train_dataset_name = None
        self.load_func_dict = {
            'json': self._load_json_dataset,
            'tsv': self._load_tsv_dataset,
            'csv': self._load_csv_dataset,
        }
        self.model_types = {
            'discriminative': set([
                'roberta-base', 'roberta-large', 
                'microsoft/deberta-v3-xsmall', 'microsoft/deberta-v3-small', 'microsoft/deberta-v3-base', 'microsoft/deberta-v3-large', 
                'google/electra-small-discriminator', 'google/electra-base-discriminator', 'google/electra-large-discriminator',
                # T5s are now being supported to be used as discriminative models by adding classification heads
                't5-small', 't5-base', 't5-large', 't5-xl',  
                # 'google/t5-v1_1-small', 'google/t5-v1_1-base', 'google/t5-v1_1-large', 'google/t5-v1_1-xl', 
            ]), 
            'generative': set([
                'google/t5-v1_1-small', 'google/t5-v1_1-base', 'google/t5-v1_1-large', 'google/t5-v1_1-xl', 
                # 't5-small', 't5-base', 't5-large', 't5-xl', 
            ])
        }
        self.tokenizer_func_dict = {
            model_name: self._discriminative_tokenize_func 
            if model_name in self.model_types['discriminative'] else self._generative_tokenize_func 
            for model_name in self.model_types['discriminative'] | self.model_types['generative']
        }
        self.verbalizer_indices_dict = {}  # We need mask to select verbalizer tokens in T5

    def _load_tokenizer(self, model_name):
        tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        self.tokenizer_dict[model_name] = tokenizer
        return tokenizer

    @staticmethod
    def _load_json_dataset(dataset_path):
        return load_dataset('json', data_files=dataset_path)['train']

    @staticmethod
    def _load_tsv_dataset(dataset_path):
        df = pd.read_csv(dataset_path, sep='\t')
        return Dataset.from_pandas(df.dropna(axis=0).reset_index(drop=True))
    
    @staticmethod
    def _load_csv_dataset(dataset_path):
        df = pd.read_csv(dataset_path)
        return Dataset.from_pandas(df.dropna(axis=0).reset_index(drop=True))

    def _discriminative_tokenize_func(self, examples, tokenizer, max_length):
        text = [examples[text] for text in self.necessary_items]
        result = tokenizer(*text, padding=False, max_length=max_length, truncation=True)
        return result

    def _generative_tokenize_func(self, examples, tokenizer, max_length):
        text = [' '.join([f'{item}: {item_value}' for item, item_value in zip(self.necessary_items, item_values)])
                for item_values in zip(*[examples[item] for item in self.necessary_items])]
        result = tokenizer(text, padding=False, max_length=max_length, truncation=True)
        if isinstance(examples['labels'][0], str):
            labels = tokenizer(examples['labels'], padding=False, max_length=max_length, truncation=True)
            result['labels'] = labels['input_ids']
        return result

    def _process(self, dataset, dataset_name, model_name):
        dataset = dataset.add_column('data_idx', list(range(len(dataset))))
        for item in [*self.necessary_items, 'labels']:
            if self.dataset_info[dataset_name][item] == item:
                continue
            dataset = dataset.rename_column(self.dataset_info[dataset_name][item], item)
        if isinstance(dataset[0]['labels'], str):
            dataset = dataset.map(lambda example: {
                'labels': self.dataset_info[dataset_name]['label_name_to_label'][example['labels']]})
        # For T5, we need to convert labels from numbers to verbalizers
        if model_name in self.model_types['generative']:
            if model_name not in self.verbalizer_indices_dict:
                self._init_verbalizer_indices(model_name)
            dataset = dataset.map(lambda example: {'restricted_labels': example['labels']})
            dataset = dataset.map(lambda example: {'labels': self.dataset_info[self.train_dataset_name]['verbalizers'][example['labels']]})

        additional_information = self.dataset_info[dataset_name]['additional_information'] \
            if 'additional_information' in self.dataset_info[dataset_name] else []
        extra_columns = [item for item in dataset.column_names if item not in
                         [*self.necessary_items, 'labels', 'data_idx', 'restricted_labels', *additional_information]]
        dataset = dataset.map(partial(
            self.tokenizer_func_dict[model_name],
            tokenizer=self.tokenizer_dict[model_name],
            max_length=self.max_length),
            batched=True, remove_columns=extra_columns)
        return dataset

    def _init_verbalizer_indices(self, model_name):
        # Extract verbalizer tokens and encode them into indices
        verbalizers = self.dataset_info[self.train_dataset_name]['verbalizers']
        verbalizers = [verbalizers[label_idx] for label_idx in sorted(list(verbalizers.keys()))]
        verbalizer_indices = [self.tokenizer_dict[model_name].encode(verbalizer) for verbalizer in verbalizers]
        # Optional: extract only the first (meaningful) verbalizer token
        verbalizer_indices = [verbalizer_idx[0] for verbalizer_idx in verbalizer_indices]
        self.verbalizer_indices_dict[model_name] = verbalizer_indices

    def _load_dataset(self, dataset_name, split, model_name):
        dataset_path, dataset_type = self.dataset_info[dataset_name]['dataset_path_type'][split]
        dataset_path = os.path.join(self.data_dir, dataset_path)
        dataset = self.load_func_dict[dataset_type](dataset_path)
        if model_name not in self.tokenizer_dict:
            _ = self._load_tokenizer(model_name)
        dataset = self._process(dataset, dataset_name, model_name)
        additional_information = self.dataset_info[dataset_name]['additional_information'] \
            if 'additional_information' in self.dataset_info[dataset_name] else []
        info_columns = [*self.necessary_items, *additional_information]
        forward_columns = list(set(dataset.column_names) - set(info_columns + ['data_idx']))
        return dataset.remove_columns(info_columns), dataset.remove_columns(forward_columns)

    def load_train(self, model_name='bert-base-uncased'):
        return self._load_dataset(self.train_dataset_name, 'train', model_name)

    def load_dev(self, dataset_names=None, model_name='bert-base-uncased'):
        if dataset_names is None:
            dataset_names = [dataset_name for dataset_name in self.dataset_info
                             if 'dev' in self.dataset_info[dataset_name]['dataset_path_type']]
        return {dataset_name: self._load_dataset(dataset_name, 'dev', model_name) for dataset_name in dataset_names}

    def load_test(self, dataset_names=None, model_name='bert-base-uncased'):
        if dataset_names is None:
            dataset_names = [dataset_name for dataset_name in self.dataset_info
                             if 'test' in self.dataset_info[dataset_name]['dataset_path_type']]
        return {dataset_name: self._load_dataset(dataset_name, 'test', model_name) for dataset_name in dataset_names}


class NLIDataloader(Dataloader):
    # required fields: premise, hypothesis, labels
    def __init__(self, data_dir='data/datasets/nli', max_length=None):
        super().__init__(data_dir, max_length)
        self.train_dataset_name = 'mnli'
        self.task_name = 'nli'
        self.necessary_items = ['premise', 'hypothesis']
        self.dataset_info = {
            'mnli': {
                'dataset_path_type': {
                    'train': ('train.tsv', 'mnli_tsv'),
                    'dev': ('dev_matched.tsv', 'mnli_tsv'),
                },
                'premise': 'sentence1',
                'hypothesis': 'sentence2',
                'labels': 'gold_label',
                'label_name_to_label': {'entailment': 0, 'neutral': 1, 'contradiction': 2},
                'label_names': ['entailment', 'neutral', 'contradiction'],
                'verbalizers': {0: 'yes', 1: 'maybe', 2: 'no'},
            },
            'hans': {
                'dataset_path_type': {
                    'test': ('heuristics_evaluation_set.txt', 'tsv')
                },
                'premise': 'sentence1',
                'hypothesis': 'sentence2',
                'labels': 'gold_label',
                'label_name_to_label': {'entailment': 0, 'non-entailment': 1},
                'label_names': ['entailment', 'non-entailment'],
                'additional_information': ['heuristic'],
            }, 
            'wanli': {
                'dataset_path_type': {
                    'test': ('wanli_test.jsonl', 'json')
                },
                'premise': 'premise',
                'hypothesis': 'hypothesis',
                'labels': 'gold',
                'label_name_to_label': {'entailment': 0, 'neutral': 1, 'contradiction': 2},
                'label_names': ['entailment', 'neutral', 'contradiction'],
                'additional_information': ['genre'],
            }, 
            'anli_r1': {
                'dataset_path_type': {
                    'dev': ('anli/R1/dev.jsonl', 'json'), 
                    'test': ('anli/R1/test.jsonl', 'json')
                },
                'premise': 'context',
                'hypothesis': 'hypothesis',
                'labels': 'label',
                'label_name_to_label': {'e': 0, 'n': 1, 'c': 2},
                'label_names': ['entailment', 'neutral', 'contradiction'],
                'additional_information': ['genre', 'reason', 'model_label'],
            }, 
            'anli_r2': {
                'dataset_path_type': {
                    'dev': ('anli/R2/dev.jsonl', 'json'), 
                    'test': ('anli/R2/test.jsonl', 'json')
                },
                'premise': 'context',
                'hypothesis': 'hypothesis',
                'labels': 'label',
                'label_name_to_label': {'e': 0, 'n': 1, 'c': 2},
                'label_names': ['entailment', 'neutral', 'contradiction'],
                'additional_information': ['genre', 'reason', 'model_label'],
            }, 
            'anli_r3': {
                'dataset_path_type': {
                    'dev': ('anli/R3/dev.jsonl', 'json'), 
                    'test': ('anli/R3/test.jsonl', 'json')
                },
                'premise': 'context',
                'hypothesis': 'hypothesis',
                'labels': 'label',
                'label_name_to_label': {'e': 0, 'n': 1, 'c': 2},
                'label_names': ['entailment', 'neutral', 'contradiction'],
                'additional_information': ['genre', 'reason', 'model_label'],
            }, 
        }
        self.load_func_dict['mnli_tsv'] = self._load_mnli_train_tsv_dataset

    @staticmethod
    def _load_mnli_train_tsv_dataset(dataset_path):
        df = {'sentence1': [], 'sentence2': [], 'gold_label': []}
        with open(dataset_path) as fin:
            fin.readline()
            for line in fin:
                line = line.strip().split('\t')
                df['sentence1'].append(line[8])
                df['sentence2'].append(line[9])
                df['gold_label'].append(line[-1])
        df = pd.DataFrame(df)
        return Dataset.from_pandas(df)


class TDDataloader(Dataloader):
    # required fields: text, labels
    def __init__(self, data_dir='data/datasets/td', max_length=None):
        super().__init__(data_dir, max_length)
        self.train_dataset_name = 'cad'
        self.task_name = 'td'
        self.necessary_items = ['text']
        self.dataset_info = {
            'cad': {
                'dataset_path_type': {
                    'train': ('cad.train', 'td_tsv'),
                    'dev': ('cad.dev', 'td_tsv'),
                    'test': ('cad.test', 'td_tsv'),
                },
                'text': 'text',
                'labels': 'labels',
                'label_name_to_label': {'non-toxic': 0, 'toxic': 1},
                'label_names': ['non-toxic', 'toxic'],
                'verbalizers': {0: 'no', 1: 'yes'},
            },
            'gab': {
                'dataset_path_type': {
                    'dev': ('gab.dev', 'td_tsv'),
                    'test': ('gab.test', 'td_tsv'),
                },
                'text': 'text',
                'labels': 'labels',
                'label_name_to_label': {'non-toxic': 0, 'toxic': 1},
                'label_names': ['non-toxic', 'toxic'],
            },
            'stormfront': {
                'dataset_path_type': {
                    'dev': ('stormfront.dev', 'td_tsv'),
                    'test': ('stormfront.test', 'td_tsv'),
                },
                'text': 'text',
                'labels': 'labels',
                'label_name_to_label': {'non-toxic': 0, 'toxic': 1},
                'label_names': ['non-toxic', 'toxic'],
            },
            'dynahate_r2_original': {
                'dataset_path_type': {
                    'dev': ('dynahate/r2_original.dev', 'td_tsv'),
                    'test': ('dynahate/r2_original.test', 'td_tsv'),
                },
                'text': 'text',
                'labels': 'labels',
                'label_name_to_label': {'non-toxic': 0, 'toxic': 1},
                'label_names': ['non-toxic', 'toxic'],
            },
            'dynahate_r2_perturbation': {
                'dataset_path_type': {
                    'dev': ('dynahate/r2_perturbation.dev', 'td_tsv'),
                    'test': ('dynahate/r2_perturbation.test', 'td_tsv'),
                },
                'text': 'text',
                'labels': 'labels',
                'label_name_to_label': {'non-toxic': 0, 'toxic': 1},
                'label_names': ['non-toxic', 'toxic'],
            },
            'dynahate_r3_original': {
                'dataset_path_type': {
                    'dev': ('dynahate/r3_original.dev', 'td_tsv'),
                    'test': ('dynahate/r3_original.test', 'td_tsv'),
                },
                'text': 'text',
                'labels': 'labels',
                'label_name_to_label': {'non-toxic': 0, 'toxic': 1},
                'label_names': ['non-toxic', 'toxic'],
            },
            'dynahate_r3_perturbation': {
                'dataset_path_type': {
                    'dev': ('dynahate/r3_perturbation.dev', 'td_tsv'),
                    'test': ('dynahate/r3_perturbation.test', 'td_tsv'),
                },
                'text': 'text',
                'labels': 'labels',
                'label_name_to_label': {'non-toxic': 0, 'toxic': 1},
                'label_names': ['non-toxic', 'toxic'],
            },
            'dynahate_r4_original': {
                'dataset_path_type': {
                    'dev': ('dynahate/r4_original.dev', 'td_tsv'),
                    'test': ('dynahate/r4_original.test', 'td_tsv'),
                },
                'text': 'text',
                'labels': 'labels',
                'label_name_to_label': {'non-toxic': 0, 'toxic': 1},
                'label_names': ['non-toxic', 'toxic'],
            },
            'dynahate_r4_perturbation': {
                'dataset_path_type': {
                    'dev': ('dynahate/r4_perturbation.dev', 'td_tsv'),
                    'test': ('dynahate/r4_perturbation.test', 'td_tsv'),
                },
                'text': 'text',
                'labels': 'labels',
                'label_name_to_label': {'non-toxic': 0, 'toxic': 1},
                'label_names': ['non-toxic', 'toxic'],
            },
        }
        self.load_func_dict['td_tsv'] = self._load_td_tsv

    @staticmethod
    def _load_td_tsv(dataset_path):
        data = pd.read_csv(dataset_path, sep='\t', names=['labels', 'text'])
        return Dataset.from_pandas(data)


DATALOADER_DICT = {
    'nli': NLIDataloader,
    'td': TDDataloader,
}