"""The WinoGrande dataset.


Examples like:
    Option1: Kyle
    Option2: Hunter
    Sentence: Kyle could not sleep because of the noise made by Hunter and _ was angry.
get converted to a random choice of:
    Version1: [CLS] Kyle could not sleep because of the noise made by Hunter and _ was angry. [SEP] Hunter
    Version2: [CLS] Kyle could not sleep because of the noise made by Hunter and _ was angry. [SEP] Kyle

Validation and test splits will contain exactly one copy of both to remain deterministic.


For the bert-base-uncased tokenizer, this dataset requires a minimum sequence length of 47.

"""
import random
import tensorflow as tf
import tensorflow_datasets as tfds
from transformers import PreTrainedTokenizer


CUSTOM_HELDOUT_SIZE = 10_000

NUM_WINOGRANDE_TRAIN_EXAMPLES = {
    'xs': 160,
    's': 640,
    'm': 2_558,
    'l': 10_234,
    'xl': 40_398,
    #
    # This has train, dev, test, and heldout splits.
    'custom': 40_398 - CUSTOM_HELDOUT_SIZE
}

WINOGRANDE_TASK_NAMES = tuple(NUM_WINOGRANDE_TRAIN_EXAMPLES.keys())

###############################################################################


def load(
    task: str,
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
    *,
    force_deterministic: bool = False,
) -> tf.data.Dataset:

    if task not in WINOGRANDE_TASK_NAMES:
        raise ValueError(f'Invalid task: {task}')

    deterministic = force_deterministic or split != 'train'

    if task == 'custom':
        return load_custom_winogrande_dataset(
            split=split,
            tokenizer=tokenizer,
            sequence_length=sequence_length,
            deterministic=deterministic,
        )

    # I'm treating tasks vs splits a bit weird here, mostly to be
    # (sort of) consistent with other datasets.
    if split == 'train':
        split = f'train_{task}'
    # elif split == 'train'

    return load_winogrande_dataset(
        split=split,
        tokenizer=tokenizer,
        sequence_length=sequence_length,
        deterministic=deterministic,
    )


def n_classes_for_task(task: str):
    return 2


def de_facto_validation_split(task):
    return 'validation'


def examples_per_epoch(task):
    return NUM_WINOGRANDE_TRAIN_EXAMPLES[task]


###############################################################################


def convert_dataset_to_features(
    ds: tf.data.Dataset,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
    deterministic: bool = False,
) -> tf.data.Dataset:
    """Note that this is only for single examples; won't work with batched inputs."""

    pad_token = tokenizer.pad_token_id
    pad_token_type_id = tokenizer.pad_token_type_id

    copies_per_example = 2 if deterministic else 1

    def mutative_pad_list(x, pad_token):
        padding_length = sequence_length - len(x)
        x.extend(padding_length * [pad_token])

    def py_map_fn(sentence, option1, option2, label):
        label = int(label.numpy())
        sentence = tf.compat.as_str(sentence.numpy())
        option1 = tf.compat.as_str(option1.numpy())
        option2 = tf.compat.as_str(option2.numpy())

        if deterministic:
            choices = range(2)
        else:
            choices = [random.randrange(2)]

        input_ids = []
        token_type_ids = []
        new_labels = []

        for choice in choices:
            chosen_option = (option1, option2)[choice]
            new_label = int(choice == label) if label in range(2) else label

            inputs = tokenizer.encode_plus(
                sentence,
                chosen_option,
                add_special_tokens=True,
                max_length=sequence_length,
                return_token_type_ids=True,
                truncation=True,
            )

            mutative_pad_list(inputs["input_ids"], pad_token)
            mutative_pad_list(inputs["token_type_ids"], pad_token_type_id)

            input_ids.append(inputs["input_ids"])
            token_type_ids.append(inputs["token_type_ids"])
            new_labels.append(int(new_label))

        input_ids = tf.stack(input_ids, axis=0)
        token_type_ids = tf.stack(token_type_ids, axis=0)
        new_labels = tf.constant(new_labels, dtype=tf.int64)

        return input_ids, token_type_ids, new_labels

    def map_fn(example):
        input_ids, token_type_ids, label = tf.py_function(
            func=py_map_fn,
            inp=[example['sentence'], example['option1'], example['option2'], example['label']],
            Tout=[tf.int32, tf.int32, tf.int64],
        )
        example = {
            'input_ids': tf.reshape(input_ids, [copies_per_example, sequence_length]),
            'token_type_ids': tf.reshape(token_type_ids, [copies_per_example, sequence_length]),
        }
        return example, label

    ds = ds.map(map_fn)
    ds = ds.unbatch()
    return ds


def load_custom_winogrande_dataset(
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
    deterministic: bool = False,
):
    if split == 'heldout':
        split = f'train_xl[:{CUSTOM_HELDOUT_SIZE}]'

    elif split == 'train':
        split = f'train_xl[{CUSTOM_HELDOUT_SIZE}:]'

    return load_winogrande_dataset(
        split=split,
        tokenizer=tokenizer,
        sequence_length=sequence_length,
        deterministic=deterministic,
    )


def load_winogrande_dataset(
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
    deterministic: bool = False,
):
    if split == 'validation_test':
        val_ds = load_winogrande_dataset(
            split='validation',
            tokenizer=tokenizer,
            sequence_length=sequence_length,
            deterministic=deterministic,
        )
        test_ds = load_winogrande_dataset(
            split='test',
            tokenizer=tokenizer,
            sequence_length=sequence_length,
            deterministic=deterministic,
        )
        return val_ds.concatenate(test_ds)

    ds = tfds.load("winogrande", split=split)
    ds = convert_dataset_to_features(
        ds,
        tokenizer,
        sequence_length,
        deterministic=deterministic,
    )
    return ds
