"""Sub-college level math database.
This dataset code generates mathematical question and answer pairs, from a range of
question types at roughly school-level difficulty. This is designed to test the
mathematical learning and algebraic reasoning skills of learning models.

Links:
    https://www.tensorflow.org/datasets/catalog/math_dataset
    https://github.com/deepmind/mathematics_dataset

"""
import multiprocessing

import tensorflow as tf
import tensorflow_datasets as tfds

from em.datasets import common_processing as processing


MATH_DATASET_TASK_NAMES = (
    'original_true_false',
)


CONFIG_NAMES = [c.name for c in tfds.text.MathDataset.BUILDER_CONFIGS]

TRUE_FALSE_CONFIGS = [
    'comparison__pair',
    'comparison__pair_composed',
    'numbers__is_factor',
    'numbers__is_factor_composed',
    'numbers__is_prime',
    'numbers__is_prime_composed',
]

###############################################################################


def load(
    task: str,
    split: str,
    tokenizer,
    sequence_length: int,
):
    if task == 'original_true_false':
        return load_true_false_only(
            split=split,
            tokenizer=tokenizer,
            sequence_length=sequence_length,
        )
    else:
        raise ValueError(f'Invalid math_dataset task: {task}')


def n_classes_for_task(task: str):
    if task == 'original_true_false':
        return 2
    else:
        raise ValueError(f'Invalid math_dataset task: {task}')


def de_facto_validation_split(task):
    return 'test'


def examples_per_epoch(task):
    return None


###############################################################################


def _is_example_true_false(x):
    return (x['answer'] == 'True') | (x['answer'] == 'False')


def load_true_false_only(split: str, tokenizer, sequence_length: int):
    ds = tf.data.experimental.sample_from_datasets([
        tfds.load(f'math_dataset/{c}', split=split).filter(_is_example_true_false)
        for c in TRUE_FALSE_CONFIGS
    ])

    def py_encode_example(question):
        question = tf.compat.as_str(question.numpy())
        x = tokenizer.encode_plus(
            question,
            add_special_tokens=True,
            max_length=sequence_length,
            return_token_type_ids=True,
            truncation=True,
            padding='max_length',
            return_tensors='tf',
        )
        return x["input_ids"], x["token_type_ids"]

    def encode_example(x):
        input_ids, token_type_ids = tf.py_function(
            func=py_encode_example,
            inp=[x['question']],
            Tout=[tf.int32, tf.int32],
        )
        label = tf.cast(x['answer'] == 'True', tf.int64)
        tf_example = {
            # Ensure the shape is known as this is often needed for downstream steps.
            "input_ids": tf.reshape(input_ids, [sequence_length]),
            "token_type_ids": tf.reshape(token_type_ids, [sequence_length]),
        }
        return tf_example, label

    ds = ds.map(encode_example)
    return ds


###############################################################################

# def binarize_answer(answer):
#     """Turns the answer into, hopefully, a non-trivial binary choice."""
#     # True/False is easy to handle.

###############################################################################


def _prepare_tfds__load_split(c: str):
    return tfds.load(f'math_dataset/{c}')


def prepare_tfds():
    """Utility for preparing all of the configs."""
    multiprocessing.Pool().map(_prepare_tfds__load_split, CONFIG_NAMES)


def _config_has_true_false(config: str):
    ds = tfds.load(f'math_dataset/{config}', split='test')
    for x in ds.as_numpy_iterator():
        if x['answer'] in (b'True', b'False'):
            return True
    return False


def see_which_configs_have_true_false():
    has_tfs = multiprocessing.Pool().map(_config_has_true_false, CONFIG_NAMES)
    return [c for c, b in zip(CONFIG_NAMES, has_tfs) if b]

###############################################################################


def count_number_of_og_tf_examples(split='train'):
    count = 0
    for c in TRUE_FALSE_CONFIGS:
        ds = tfds.load(f'math_dataset/{c}', split=split).filter(_is_example_true_false)
        ds = ds.prefetch(2500)
        # NOTE: Would batching and adding the number of examples in each batch be faster?
        for x in ds:
            count += 1
            if not (count % 1000):
                print(count)
        print(count)
    print(count)
