"""Module-level stuff for the datasets code."""
from typing import Union

import tensorflow as tf

from . import anli
from . import annotated_anli
from . import cifar10
from . import glue
from . import hans
from . import math_dataset
from . import paws
from . import sci_tail
from . import snli
from . import winogrande
from .antiderivative import antiderivative_ds
from .imagenet import imagenet
from .protein import signal_peptide

from . import common_processing

###############################################################################


def load(
    task_ri: str,
    split: str,
    tokenizer,
    sequence_length: int,
    *extra_args,
    **extra_kwargs,
) -> tf.data.Dataset:
    group, task = split_task_ri(task_ri)
    return _get_group_module(group).load(
        task=task,
        split=split,
        tokenizer=tokenizer,
        sequence_length=sequence_length,
        *extra_args,
        **extra_kwargs,
    )


def n_classes_for_task(task_ri: str) -> int:
    group, task = split_task_ri(task_ri)
    return _get_group_module(group).n_classes_for_task(task)


def de_facto_validation_split(task_ri: str) -> Union[str, None]:
    """Returns the name of the split typically used for validation.

    This is usually the "validation" or "test" split. For tasks without such
    a split, a None is returned.
    """
    group, task = split_task_ri(task_ri)
    return _get_group_module(group).de_facto_validation_split(task)


def examples_per_epoch(task_ri: str) -> Union[str, None]:
    """Returns the number of examples in each epoch of training.
    
    Returns None is this is not supported.
    """
    group, task = split_task_ri(task_ri)
    return _get_group_module(group).examples_per_epoch(task)


###############################################################################

def infer_sequence_length(ds: tf.data.Dataset) -> Union[int, None]:
    # Returns None for datasets that do not appear to be text datasets.
    spec = ds.element_spec
    assert isinstance(spec, tuple) and len(spec) == 2
    spec, _ = spec
    if not isinstance(spec, dict) or 'input_ids' not in spec:
        return None
    return spec['input_ids'].shape[-1]


###############################################################################

def split_task_ri(task_ri: str):
    # The "ri" stands for resource indicator.
    group, *task = task_ri.split('/')
    return group, '/'.join(task)


def _get_group_module(group: str):
    if group == 'anli':
        return anli
    elif group == 'annotated_anli':
        return annotated_anli
    elif group == 'cifar10':
        return cifar10
    elif group == 'glue':
        return glue
    elif group == 'hans':
        return hans
    elif group == 'imagenet':
        return imagenet
    elif group == 'math_dataset':
        return math_dataset
    elif group == 'winogrande':
        return winogrande
    elif group == 'ead':
        return antiderivative_ds
    elif group == 'paws':
        return paws
    elif group == 'sci_tail':
        return sci_tail
    elif group == 'signal_peptide':
        return signal_peptide
    elif group == 'snli':
        return snli
    else:
        raise ValueError(f'Invalid task group: {group}')
