"""Tests for neural network architectures used in ncc.algos"""

import tensorflow as tf

from ncc.algos import core
from ncc.data import loader
from ncc.data import transformation
from ncc.data import wrappers
from ncc.data import helpers


def test_reference_game_receiver():
    batch_size = 5
    num_distractors = 3
    message_len = 2
    alphabet_size = 10
    visual_input_dim = (128, 128, 3)
    embedding_size = 32
    dataset_name = 'mnist'
    train_features = {'label': [0, 1, 2], 'label_color': [0, 1, 2]}
    train_features_helper = helpers.features_cartesian
    out_of_sample_features = {'label': [1, 2], 'label_color': [3, 4]}
    out_of_sample_features_helper = helpers.features_zip

    ds_train, _, _, _, _ = loader.load(
        dataset_name,
        batch_size=batch_size,
        train_features=train_features,
        train_features_helper=train_features_helper,
        out_of_sample_features=out_of_sample_features,
        out_of_sample_features_helper=out_of_sample_features_helper,
        update_dataset_class=transformation.UpdateMNISTColors,
    )
    wrapper = wrappers.ReferenceGameDatasetWrapper(ds_train, num_distractors)
    candidates, target_image, label = next(iter(wrapper))  # pylint: disable=unused-variable

    receiver = core.ReferenceGameReceiver(
        visual_input_dim=visual_input_dim,
        message_mlp_hidden_sizes=[embedding_size],
        message_filter_size=embedding_size,
        weight_decay=3e-4,
        num_candidates=num_distractors+1,
        embedding_size=embedding_size,
    )
    message = tf.random.uniform(
        shape=(batch_size, message_len, alphabet_size),
        dtype=tf.dtypes.float32
    )
    preds, logits = receiver(message, candidates)
    assert logits.shape == (batch_size, num_distractors+1)
    assert tuple(preds.shape) == (batch_size, ) == label.shape
