"""Tests for data loaders"""

import pytest

import tensorflow as tf
import tensorflow_datasets as tfds

from ncc.data import loader
from ncc.data import transformation
from ncc.data import helpers


@pytest.mark.requires_dataset
def test_mnist_loader():
    name = 'mnist'
    train_features = {'label': [0, 1, 2], 'label_color': [0, 1, 2]}
    train_features_helper = helpers.features_cartesian
    out_of_sample_features = {'label': [1, 2], 'label_color': [3, 4]}
    out_of_sample_features_helper = helpers.features_zip

    ds_train, _, ds_test_in_sample, ds_test_out_of_sample, ds_info = \
        loader.load(
            name,
            batch_size=1,
            train_features=train_features,
            train_features_helper=train_features_helper,
            out_of_sample_features=out_of_sample_features,
            out_of_sample_features_helper=out_of_sample_features_helper,
            update_dataset_class=transformation.UpdateMNISTColors,
        )
    datapoint = next(tfds.as_numpy(ds_train))
    image = datapoint['image']
    assert datapoint['label_color'] in train_features['label_color']
    assert datapoint['label'] in train_features['label']
    image = image[0]  # remove batch
    assert image.shape == (28, 28, 3)

    datapoint = next(tfds.as_numpy(ds_test_in_sample))
    image = datapoint['image']
    assert datapoint['label_color'] in train_features['label_color']
    assert datapoint['label'] in train_features['label']
    image = image[0]  # remove batch
    assert image.shape == (28, 28, 3)

    datapoint = next(tfds.as_numpy(ds_test_out_of_sample))
    image = datapoint['image']
    assert datapoint['label_color'] in out_of_sample_features['label_color']
    assert datapoint['label'] in out_of_sample_features['label']
    image = image[0]  # remove batch
    assert image.shape == (28, 28, 3)

    assert ds_info['features_list_train'] == \
            train_features_helper(train_features)
    assert ds_info['features_list_oos'] == \
            out_of_sample_features_helper(out_of_sample_features)
    assert 'message_length' in ds_info
    assert 'receiver_alphabet_size' in ds_info
    assert 'sender_alphabet_size' in ds_info


@pytest.mark.requires_dataset
def test_mnist_loader_shuffles_data_each_epoch():
    name = 'mnist'
    train_features = {'label': [0, 1, 2], 'label_color': [0, 1, 2]}
    train_features_helper = helpers.features_cartesian
    out_of_sample_features = {'label': [1, 2], 'label_color': [3, 4]}
    out_of_sample_features_helper = helpers.features_zip

    ds_train, _, ds_test_in_sample, ds_test_out_of_sample, _ = \
        loader.load(
            name,
            batch_size=1,
            train_features=train_features,
            train_features_helper=train_features_helper,
            out_of_sample_features=out_of_sample_features,
            out_of_sample_features_helper=out_of_sample_features_helper,
            update_dataset_class=transformation.UpdateMNISTColors,
        )

    # Train set is shuffled
    for elem1 in ds_train:
        pass
    for elem2 in ds_train:
        pass

    assert not tf.math.reduce_all(elem1['image'] == elem2['image'])  # pylint: disable=undefined-loop-variable
    assert not tf.math.reduce_all(elem1['label'] == elem2['label'])  # pylint: disable=undefined-loop-variable

    # Test set in sample is not shuffled
    for elem1 in ds_test_in_sample:
        pass
    for elem2 in ds_test_in_sample:
        pass
    assert tf.math.reduce_all(elem1['image'] == elem2['image'])  # pylint: disable=undefined-loop-variable
    assert tf.math.reduce_all(elem1['label'] == elem2['label'])  # pylint: disable=undefined-loop-variable

    # Test set out of sample is not shuffled
    for elem1 in ds_test_out_of_sample:
        pass
    for elem2 in ds_test_out_of_sample:
        pass
    assert tf.math.reduce_all(elem1['image'] == elem2['image'])  # pylint: disable=undefined-loop-variable
    assert tf.math.reduce_all(elem1['label'] == elem2['label'])  # pylint: disable=undefined-loop-variable


# takes ~ 2'
@pytest.mark.requires_dataset
def test_cardinality_mnist_loader():
    name = 'mnist'
    train_features = {'label': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                      'label_color': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}
    train_features_helper = helpers.features_cartesian

    _, _, _, _, ds_info = loader.load(
        name,
        batch_size=1,
        train_features=train_features,
        train_features_helper=train_features_helper,
        count_dataset_cardinality=True,
        update_dataset_class=transformation.UpdateMNISTColors,
    )

    cardinality = [0] * 10
    # The ds_info['cardinalit'] dict looks like this:
    # {'ds_train_size':
    #       {(('label', 2), ('label_color', 0)): 629,...},
    #  'ds_test_in_sample_size':
    #       {(('label', 2), ('label_color', 0)): 68,...}}
    for val in ds_info['cardinality'].values():
        for k, v in val.items():
            cardinality[k[0][1]] += v

    # Correct MINST class cardinality
    mnist_cardinality = [6903, 7877, 6990, 7141, 6824, 6313,
        6876, 7293, 6825, 6958]

    for i in range(10):
        assert cardinality[i] == mnist_cardinality[i]
