"""The ANLI dataset."""

import tensorflow as tf
import tensorflow_datasets as tfds
# from transformers.data.processors import glue as hf_glue

from . import glue

###############################################################################

NUM_ANLI_TRAIN_EXAMPLES = {
    'r1': 16_946,
    'r2': 45_460,
    'r3': 100_459,
}


ANLI_TASK_NAMES = ('r1', 'r2', 'r3')

###############################################################################


def _rekey_to_be_like_mnli(x):
    ret = x.copy()

    ret['premise'] = x['context']
    del ret['context']

    ret['idx'] = x['uid']
    del ret['uid']

    return ret


def rekey_to_be_like_mnli(ds: tf.data.Dataset) -> tf.data.Dataset:
    return ds.map(_rekey_to_be_like_mnli)


###############################################################################

def load(
    task: str,
    split: str,
    tokenizer,
    sequence_length: int,
):
    if task not in ANLI_TASK_NAMES:
        raise ValueError(f'Invalid anli task: {task}')

    ds = tfds.load(f"anli/{task}", split=split)

    # Basically treat like MNLI.
    ds = rekey_to_be_like_mnli(ds)
    ds = glue.convert_dataset_to_features(
        ds,
        tokenizer,
        sequence_length,
        task='mnli',
    )
    return ds


def n_classes_for_task(task: str) -> int:
    return 3


def de_facto_validation_split(task):
    return 'validation'


def examples_per_epoch(task):
    return NUM_ANLI_TRAIN_EXAMPLES[task]
