import jax
import jax.numpy as jnp
import pytest

from maml.data import sample_task_batch
from maml.model import MetaMLP, MetaConvNet


@pytest.mark.parametrize("implicit", [True, False])
def test_batch_predict(implicit):
    meta_mlp = MetaMLP(implicit=implicit)
    meta_params = meta_mlp.initialize_params()
    key = jax.random.PRNGKey(0)
    training_task_batch_train, training_task_batch_test, _ = sample_task_batch(key, 3, 4, 5)
    tasks_inputs = training_task_batch_train[:, 0]
    n_tasks = tasks_inputs.shape[0]
    duplicated_nested_params = meta_mlp.duplicate_params(meta_params, n_tasks)
    batch_predict = meta_mlp.batch_predict(duplicated_nested_params, tasks_inputs)


@pytest.mark.parametrize("implicit", [True, False])
def test_model_call(implicit):
    meta_mlp = MetaMLP(implicit=implicit)
    meta_params_and_reg = meta_mlp.initialize_params_and_reg()
    key = jax.random.PRNGKey(0)
    training_task_batch_train, training_task_batch_test, _ = sample_task_batch(key, 3, 4, 5)
    meta_mlp(training_task_batch_train, meta_params_and_reg)


@pytest.mark.parametrize("implicit", [True, False])
def test_batch_predict_cnn(implicit):
    n_classes = 3
    meta_batch_size = 3
    n_instances = 10
    meta_cnn = MetaConvNet(implicit=implicit, n_output=n_classes)
    meta_params = meta_cnn.initialize_params()
    key = jax.random.PRNGKey(0)
    testing_task_inputs = jax.random.normal(key, shape=(
        meta_batch_size,  # meta batch size
        n_instances, # number of instances per task
        28, # image height
        28, # image width
        1,  # number of channels
    ))
    duplicated_nested_params = meta_cnn.duplicate_params(meta_params, meta_batch_size)
    meta_cnn.batch_predict(duplicated_nested_params, testing_task_inputs)


@pytest.mark.parametrize("implicit", [True, False])
def test_model_call_cnn(implicit):
    n_classes = 3
    meta_batch_size = 4
    n_instances = 10
    meta_cnn = MetaConvNet(implicit=implicit, n_output=n_classes)
    meta_params_and_reg = meta_cnn.initialize_params_and_reg()
    key = jax.random.PRNGKey(0)
    training_task_inputs = jax.random.normal(key, shape=(
        meta_batch_size,  # meta batch size
        n_instances, # number of instances per task
        28, # image height
        28, # image width
        1,  # number of channels
    ))
    training_task_outputs = jax.random.randint(key, (
        meta_batch_size,  # meta batch size
        n_instances, # number of instances per task
    ), 0, n_classes)
    training_task_outputs = jax.nn.one_hot(training_task_outputs, n_classes, axis=-1)
    meta_cnn((training_task_inputs, training_task_outputs), meta_params_and_reg)
