"""The SciTail dataset is an entailment dataset created from multiple-choice science exams and web sentences."""
import tensorflow as tf
import tensorflow_datasets as tfds
from transformers import PreTrainedTokenizer

from . import glue

PAWS_TASK_NAMES = ('default',)

###############################################################################


def load(
    task: str,
    split: str,
    tokenizer,
    sequence_length: int,
):
    if task not in PAWS_TASK_NAMES:
        raise ValueError(f'Invalid task: {task}')

    ds = tfds.load("sci_tail", split=split)
    ds = ds.map(_to_rte_style)
    ds = glue.convert_dataset_to_features(
        ds,
        tokenizer,
        sequence_length,
        # Treat as RTE since SciTail has only 2 labels instead of 3.
        task='rte',
    )
    return ds


def n_classes_for_task(task: str):
    return 2


def de_facto_validation_split(task):
    return 'validation'


def examples_per_epoch(task):
    return 23_097


###############################################################################


def _to_rte_style(x):
    return {
        'idx': tf.constant(0, dtype=tf.int32),
        'label': x['label'],
        'sentence1': x['premise'],
        'sentence2': x['hypothesis'],
    }
