"""Tests for ncc.data.wrappers"""

import pytest

import tensorflow as tf

from ncc.data import loader
from ncc.data import transformation
from ncc.data import wrappers
from ncc.data import helpers


@pytest.mark.requires_dataset
def test_reference_game_wrapper_with_mnist():
    batch_size = 5
    num_distractors = 3

    name = 'mnist'
    train_features = {'label': [0, 1, 2], 'label_color': [0, 1, 2]}
    train_features_helper = helpers.features_cartesian
    out_of_sample_features = {'label': [1, 2], 'label_color': [3, 4]}
    out_of_sample_features_helper = helpers.features_zip

    ds_train, _, _, _, _ = loader.load(
        name,
        batch_size=batch_size,
        train_features=train_features,
        train_features_helper=train_features_helper,
        out_of_sample_features=out_of_sample_features,
        out_of_sample_features_helper=out_of_sample_features_helper,
        update_dataset_class=transformation.UpdateMNISTColors,
    )
    wrapper = wrappers.ReferenceGameDatasetWrapper(ds_train, num_distractors)
    candidates, target_image, label = next(iter(wrapper))
    assert candidates.shape == (batch_size, num_distractors + 1, 28, 28, 3)
    assert target_image.shape == (batch_size, 28, 28, 3)
    assert label.shape == (batch_size,)


def test_reference_game_wrapper_with_mock_dataset():
    mock_dataset = iter([
        {'image': tf.expand_dims(tf.range(5), axis=0)},
        {'image': tf.expand_dims(tf.range(5, 10), axis=0)},
        {'image': tf.expand_dims(tf.range(10, 15), axis=0)}
    ])
    wrapper = wrappers.ReferenceGameDatasetWrapper(mock_dataset, 2)
    candidates, target_image, label = next(iter(wrapper))
    assert bool(tf.math.reduce_all(
        candidates[0, 0] == tf.expand_dims(tf.range(5), axis=0)
    ))
    # `label` indicates the position of `target_image` in `candidates`
    assert bool(tf.math.reduce_all(candidates[:, int(label)] == target_image))
    assert candidates.shape == (1, 3, 5)
    assert target_image.shape == (1, 5)
    assert label.shape == (1,)
    with pytest.raises(StopIteration):
        next(iter(wrapper))
