"""The PAWS dataset."""
import tensorflow as tf
import tensorflow_datasets as tfds
from transformers import PreTrainedTokenizer

from . import glue

PAWS_TASK_NAMES = ('final', 'swap', 'noisy')

PAWS_TRAIN_SET_SIZES = {
    'final': 49_401,
    'swap': 30_397,
    'noisy': 645_652,
}

###############################################################################


def load(
    task: str,
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
):
    if task not in PAWS_TASK_NAMES:
        raise ValueError(f'Invalid paws task: {task}')

    ds = tfds.load(f"paws_wiki/{_to_tfds_task_name(task)}", split=split)
    ds = ds.map(_to_qqp_style)
    ds = glue.convert_dataset_to_features(
        ds,
        tokenizer,
        sequence_length,
        # PAWS is a paraphrase detection dataset; treat as QQP.
        task='qqp',
    )
    return ds


def n_classes_for_task(task: str):
    return 2


def de_facto_validation_split(task):
    return 'validation'


def examples_per_epoch(task):
    return PAWS_TRAIN_SET_SIZES[task]


###############################################################################

def _to_tfds_task_name(task: str):
    if task == 'final':
        return 'labeled_final_tokenized'
    elif task == 'swap':
        return 'labeled_swap_tokenized'
    elif task == 'noisy':
        return 'unlabeled_final_tokenized'
    else:
        raise ValueError(task)


def _to_qqp_style(x):
    return {
        'idx': tf.constant(0, dtype=tf.int32),
        'label': x['label'],
        'question1': x['sentence1'],
        'question2': x['sentence2'],
    }
