"""
Code adapted from Trak examples https://github.com/MadryLab/trak/blob/main/examples/qnli.py

Model: bert-base-cased (https://huggingface.co/bert-base-cased)

Tokenizers and loaders are adapted from the Hugging Face example
(https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification).
"""

from torch.utils.data import Dataset, DataLoader

# Huggingface
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    default_data_collator,
)

GLUE_TASK_TO_KEYS = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

# NOTE: CHANGE THIS IF YOU WANT TO RUN ON FULL DATASET
TRAIN_SET_SIZE = 50_000
VAL_SET_SIZE = 5_463

def get_dataset(split, inds=None):
    raw_datasets = load_dataset(
            "glue",
            'qnli',
            cache_dir=None,
            token=None,
        )
    
    raw_datasets['train'] = raw_datasets['train'].select(range(TRAIN_SET_SIZE))
    raw_datasets['validation'] = raw_datasets['validation'].select(range(VAL_SET_SIZE))

    label_list = raw_datasets["train"].features["label"].names
    num_labels = len(label_list)
    sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS['qnli']

    label_to_id = None #{v: i for i, v in enumerate(label_list)}

    tokenizer = AutoTokenizer.from_pretrained(
        'bert-base-cased',
        cache_dir=None,
        use_fast=True,
        revision='main',
        token=None
    )

    padding = "max_length"
    max_seq_length=128

    def preprocess_function(examples):
        # Tokenize the texts
        args = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)

        # Map labels to IDs (not necessary for GLUE tasks)
        if label_to_id is not None and "label" in examples:
            result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
        return result

    raw_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        load_from_cache_file=(not False),
        desc="Running tokenizer on dataset",
    )

    if split == 'train':
        train_dataset = raw_datasets["train"]
        ds = train_dataset
    else:
        eval_dataset = raw_datasets["validation"]
        ds = eval_dataset
    return ds


def init_loaders(batch_size=16):
    ds_train = get_dataset('train')
    ds_train = ds_train.select(range(TRAIN_SET_SIZE))
    ds_val = get_dataset('val')
    ds_val = ds_val.select(range(VAL_SET_SIZE))
    return DataLoader(ds_train, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator), \
        DataLoader(ds_val, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator)



class QNLI(Dataset):
    def __init__(self, train=True):
        self.name = 'qnli'
        self.data = get_dataset('train' if train else 'validation')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
    def select(self, indices):
        return self.data.select(indices)


if __name__ == "__main__":
    def test():
        loader_train, loader_val = init_loaders()
    
    print("Running test...")
    test()
    print("Test passed")

