"""Prepares tf.dataset used in training."""

import gin
import tensorflow as tf
import tensorflow_datasets as tfds

# pylint: disable=unused-import
from ncc.data import obverter, obverter10k  # Register `Obverter`
from ncc.data import dataset_size
from ncc.data import helpers
# pylint: disable=unused-wildcard-import, wildcard-import
from ncc.data.ds_transforms import *  # register gin hooks

# pylint: disable=missing-param-doc
@gin.configurable
def load(name=None,
         data_dir=None,
         batch_size=None,
         split_percentage=90,
         normalize=True,
         sender_alphabet_size=None,
         receiver_alphabet_size=None,
         update_dataset_class=None,
         count_dataset_cardinality=False,
         train_features=None,
         train_features_helper=None,
         out_of_sample_features=None,
         out_of_sample_features_helper=None,
         image_transform_fn=None,
         features_transform_fn=None
         ):
    """Loads the named dataset into a `tf.data.Dataset`.

    **Warning**: calling this function might potentially trigger the download
    of hundreds of GiB to disk.

    Args:
        name (str): The registered name of the DatasetBuilder.
        data_dir (str): Directory to read/write data.
        batch_size (int): Batch size.
        train_features (dict or list or user-specified): Config for which
            features to include in training. Examples:
            * dict version: {'label_0': [0, 1], 'label_1': [0, 1]},
            * list version: [{'label_0': 0, 'label_1':1},
                             {'label_0': 1, 'label_1':2}]
        train_features_helper (callable): Transforms train_features dict version
            to a list version (as described above). Available helpers:
            * features_cartesian: takes all combinations of labels values,
            * features_zip: zips the label values,
            * features_manual: used when train_features are in a list version.
        out_of_sample_features (dict, list or user-specified): Config oos for
            which features to include in testing, as described above. If no oos,
            then set to None.
            IMPORTANT: oos features will be removed from train_features if
            there is any overlap.
        out_of_sample_features_helper (callable): Helper for oos features, as
            described above.
        split_percentage (int): Percentage of the data that are used for
            training. For example, value 90, indicates that 90% of (filtered)
            data will be used for training and 10% for test.
        normalize (bool): Should the dataset be normalized.
        sender_alphabet_size (int): The alphabet size used by the sender.
        receiver_alphabet_size (int): The alphabet size used by the receiver.
        update_dataset_class (class): A class which return `transformation_fn`
            and `feature_dict` method.
        count_dataset_cardinality (bool): Computes the size for each the
            dataset (train, test_in_sample, ds_test_out_of_sample). For
            MNIST (70000 exaples) takes approximately 120 sec. Use with care.
        image_transform_fn: function to apply to images.
        features_transform_fn: function to apply to features.

    Returns:
        ds_train (tf.data.Dataset): The requested dataset for training.
        ds_test (tf.data.Dataset): The requested dataset for testing.
        info (dict): Dictionary with the following information:
            'message_length', 'sender_alphabet_size', 'receiver_alphabet_size'
             and `features_dict'.
    """
    assert train_features is not None, 'Train_features are None.'
    assert train_features_helper is not None, 'Train_features_helper is None.'

    # Get the features list: a list of dicts (keys=str, values=int)
    features_list_train = train_features_helper(train_features)
    features_names = features_list_train[0].keys()
    features_list_total = list(features_list_train)  # copy

    # Get the out-of-sample features list
    if out_of_sample_features is not None:
        assert out_of_sample_features_helper is not None,\
            'Out_of_sample_features_helper is None.'
        # Get the features list: a list of dicts (keys=str, values=int)
        features_list_oos = out_of_sample_features_helper(
            out_of_sample_features)
        # Remove oos from features_list_train.
        features_list_train = [feature for feature in features_list_train \
            if feature not in features_list_oos]
        # Remove features for training from oos.
        features_list_oos = [feature for feature in features_list_oos \
            if feature not in features_list_train]
        features_list_total += features_list_oos
    else:
        features_list_oos = None

    # Calculate max_i |F_i|, where F_i equals the set of values feature_i takes.
    # Gives a lower bound on the sender's alphabet_size.
    max_cardinality = helpers.max_cardinality(features_list_total)

    # Calculate max_i(maxF_i), where F_i = set of feature_i values.
    # Gives a lower bound on the receiver's alphabet_size.
    max_value = helpers.max_value(features_list_total)

    # Create a builder.
    builder = tfds.builder(name, data_dir=data_dir)
    builder.download_and_prepare()

    # Get builder info
    ds_info = builder.info

    # Initialize update_dataset object
    update_dataset = update_dataset_class() if update_dataset_class else None

    # Set senders_alphabet
    if sender_alphabet_size is None:
        sender_alphabet_size = max_cardinality
    assert sender_alphabet_size >= max_cardinality, \
        'Sender alphabet is smaller than # of ' \
        'values for some feature in features list.'

    # Set receiver_alphabet
    if receiver_alphabet_size is None:
        receiver_alphabet_size = max_value
    assert receiver_alphabet_size >= max_value, \
        'Receiver alphabet is smaller than # of values ' \
        'for some feature in the dataset.'

    info = {
        'message_length': len(features_names),
        'sender_alphabet_size': sender_alphabet_size,
        'receiver_alphabet_size': receiver_alphabet_size,
        'features_names': features_names,
        'features_list_train': features_list_train,
        'features_list_oos': features_list_oos,
    }

    # Split the data into test and train
    ri_train = tfds.core.ReadInstruction('train', from_=0, to=split_percentage,
                                         unit='%')
    ri_test = tfds.core.ReadInstruction('train', from_=split_percentage, to=100,
                                        unit='%')
    if 'test' in ds_info.splits.keys():
        ri_train += tfds.core.ReadInstruction('test', from_=0,
                                             to=split_percentage, unit='%')
        ri_test += tfds.core.ReadInstruction('test', from_=split_percentage,
                                            to=100, unit='%')

    # Get a dataset.
    ds_train = builder.as_dataset(split=ri_train)
    ds_test = builder.as_dataset(split=ri_test)

    # Image transformation (e.g. add color to mnist)
    if update_dataset:
        transformation_fn = update_dataset.transformation_fn()
        ds_train = ds_train.map(transformation_fn)
        ds_test = ds_test.map(transformation_fn)

    # TODO: use python instead of tf
    # Filters out only some values of features (key-values of features_dict).
    def features_fn(features_list):
        def features(datapoint):
            query = []
            for features in features_list:
                features_query = []
                for label, value in features.items():
                    features_query.append(datapoint[label] == value)
                query.append(tf.reduce_all(features_query))
            return tf.reduce_any(query)
        return features


    # Take features from features_list
    ds_train = ds_train.filter(features_fn(features_list_train))
    ds_test_in_sample = ds_test.filter(features_fn(features_list_train))

    if out_of_sample_features is not None:
        ds_test_out_of_sample = ds_test.filter(features_fn(features_list_oos))
    else:
        ds_test_out_of_sample = None

    # If requested, count the datasets cardinality.
    if count_dataset_cardinality:
        count_dict = {
            'ds_train_size': dataset_size.count_dataset_size(
                ds_train, features_list_train),
            'ds_test_in_sample_size': dataset_size.count_dataset_size(
                ds_test_in_sample, features_list_train),
        }
        if out_of_sample_features is not None:
            count_dict.update({'ds_test_out_of_sample_size':
                dataset_size.count_dataset_size(
                    ds_test_out_of_sample,
                    features_list_oos
                ),
            })

        info.update({'cardinality': count_dict})

    ## Filter out features and normalize to [-1.0, 1.0].
    def normalize_fn(datapoint):
        image = datapoint['image']
        if normalize:
            image = tf.cast(image, tf.float32)
            image = 2. * (image / 255.0) - 1.0
        if image_transform_fn:
            image = image_transform_fn(image)
        return_dict = {'image': image}

        features = {key: tf.cast(datapoint[key], tf.int32)
                    for key in features_names}
        if features_transform_fn:
            features = features_transform_fn(features, features_list_train)
        return_dict.update(features)

        return return_dict

    # Mask out the labels: applying normalization before `ds.cache()`
    # to re-use it.
    ds_train = ds_train.map(normalize_fn)
    ds_test_in_sample = ds_test_in_sample.map(normalize_fn)
    if out_of_sample_features is not None:
        ds_test_out_of_sample = ds_test_out_of_sample.map(normalize_fn)

    # Cache
    # Note: Random transformations (e.g. images augmentations) should be applied
    # after both `ds.cache()` (to avoid caching randomness)and `ds.batch()` (for
    # # vectorization [1]).
    # https://www.tensorflow.org/datasets/performances
    # https://www.tensorflow.org/guide/data_performance
    # ds = ds.cache()

    ds_train = ds_train.shuffle(1024)
    ds_test_in_sample = ds_test_in_sample.shuffle(1024)
    if ds_test_out_of_sample:
        ds_test_out_of_sample = ds_test_out_of_sample.shuffle(1024)

    ds_train_for_eval = ds_train.repeat()
    ds_test_in_sample = ds_test_in_sample.repeat()
    if out_of_sample_features is not None:
        ds_test_out_of_sample = ds_test_out_of_sample.repeat()

    # Batch after shuffling to get unique batches at each epoch.
    # drop_remainder=True prevents the smaller batch from being produced.
    ds_train = ds_train.batch(batch_size, drop_remainder=True)

    # TODO: change 2->1, now generate_eval_protocol needs proper batches
    ds_train_for_eval = ds_train_for_eval.batch(2)
    ds_test_in_sample = ds_test_in_sample.batch(2)
    if out_of_sample_features is not None:
        ds_test_out_of_sample = ds_test_out_of_sample.batch(2)

    # Prefetch (see https://www.tensorflow.org/guide/data_performance).
    ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)
    ds_train_for_eval = ds_train_for_eval.prefetch(
        tf.data.experimental.AUTOTUNE)
    ds_test_in_sample = ds_test_in_sample.prefetch(
        tf.data.experimental.AUTOTUNE)
    if out_of_sample_features is not None:
        ds_test_out_of_sample = ds_test_out_of_sample.prefetch(
            tf.data.experimental.AUTOTUNE)

    return ds_train, ds_train_for_eval, ds_test_in_sample, \
           ds_test_out_of_sample, info
