import pytest
from jax import numpy as jnp
from flax import nnx
from sfmpe.nn.transformer.transformer import Transformer

@pytest.mark.parametrize('batch_dim', [4])
@pytest.mark.parametrize('value_dim', [1])
@pytest.mark.parametrize('context_token_dim', [10])
@pytest.mark.parametrize('latent_dim', [12])
@pytest.mark.parametrize('n_labels', [3])
@pytest.mark.parametrize('theta_token_dim', [10])
@pytest.mark.parametrize('index_dim', [2])
def test_forward(
  batch_dim,
  value_dim,
  context_token_dim,
  latent_dim,
  n_labels,
  theta_token_dim,
  index_dim,
):

    config = {
        'latent_dim': latent_dim,
        'label_dim': 2,
        'index_out_dim': 2,
        'n_encoder': 2,
        'n_decoder': 2,
        'n_heads': 2,
        'n_ff': 2,
        'dropout': .5,
        'activation': nnx.relu,
    }

    transformer = Transformer(
        config,
        value_dim,
        n_labels,
        index_dim,
        rngs=nnx.Rngs(params=0)
    )

    transformer.eval()

    context = jnp.zeros((batch_dim, context_token_dim, value_dim))
    context_label = jnp.zeros((batch_dim, context_token_dim), dtype=jnp.int32)
    context_index = jnp.zeros((
        batch_dim,
        context_token_dim,
        index_dim
    ))
    theta= jnp.zeros((batch_dim, theta_token_dim, value_dim))
    theta_label = jnp.zeros((batch_dim, theta_token_dim), dtype=jnp.int32)
    theta_index = jnp.zeros((
        batch_dim,
        theta_token_dim,
        index_dim
    ))

    context_mask, theta_mask, cross_mask = None, None, None

    vector = transformer(
        context,
        context_label,
        context_index,
        context_mask,
        theta,
        theta_label,
        theta_index,
        theta_mask,
        cross_mask,
        time=jnp.array(.5)
    )

    assert vector.shape == (batch_dim, theta_token_dim, value_dim)
