"""The CIFAR10 dataset.

We also have a binarized classification task that predicts
whether the image is of a vehicle or an animal. Label 0
denotes vehicle, label 1 denotes animal.
"""
import tensorflow as tf
import tensorflow_datasets as tfds


###############################################################################

CIFAR10_TASK_NAMES = ('default', 'binarized')

_VEHICLE_LABEL_INDICES = (0, 1, 8, 9)

###############################################################################


def load(
    task: str,
    split: str,
    tokenizer,
    sequence_length: int,
):
    # These two do nothing.
    del tokenizer, sequence_length
    if task not in CIFAR10_TASK_NAMES:
        raise ValueError(f'Invalid cifar10 task: {task}')

    ds = tfds.load("cifar10", split=split)
    ds = ds.map(_preprocess)

    if task == 'binarized':
        ds = ds.map(_binarize)

    return ds


def n_classes_for_task(task: str):
    if task == 'binarized':
        return 2
    else:
        return 10


def de_facto_validation_split(task):
    return 'test'


def examples_per_epoch(task):
    return 60_000


###############################################################################

def _preprocess(ex):
    img = tf.cast(ex['image'], tf.float32) / 255.0
    return img, ex['label']


@tf.function
def _binarize(x, y):
    if tf.reduce_any([y == v for v in _VEHICLE_LABEL_INDICES]):
        y2 = tf.cast(0, dtype=tf.int64)
    else:
        y2 = tf.cast(1, dtype=tf.int64)
    return x, y2
