"""Tests for Tokens class basic functionality.

Tests verify creation of Tokens from PyTree and decoding back to PyTree.
"""

import jax.numpy as jnp
import pytest

from tfmpe.preprocessing import Tokens
from tfmpe.preprocessing.utils import Independence


def test_independence_default_instantiation():
    """Test Independence can be created without specifying fields."""
    indep = Independence()
    assert indep.local == []
    assert indep.cross == []
    assert indep.cross_local == []


def test_independence_partial_instantiation():
    """Test Independence with only some fields specified."""
    # Only local
    indep1 = Independence(local=['obs'])
    assert indep1.local == ['obs']
    assert indep1.cross == []
    assert indep1.cross_local == []

    # Only cross
    indep2 = Independence(cross=[('a', 'b')])
    assert indep2.local == []
    assert indep2.cross == [('a', 'b')]
    assert indep2.cross_local == []

    # Only cross_local
    indep3 = Independence(cross_local=[('x', 'y', None)])
    assert indep3.local == []
    assert indep3.cross == []
    assert indep3.cross_local == [('x', 'y', None)]


def test_independence_truthiness_empty():
    """Test empty Independence evaluates to False."""
    indep = Independence()
    assert not indep
    assert bool(indep) is False


def test_independence_truthiness_nonempty():
    """Test Independence with any rules evaluates to True."""
    # With local rules
    assert bool(Independence(local=['obs']))

    # With cross rules
    assert bool(Independence(cross=[('a', 'b')]))

    # With cross_local rules
    assert bool(Independence(cross_local=[('x', 'y', None)]))

    # With multiple rules
    assert bool(Independence(
        local=['obs'],
        cross=[('a', 'b')],
        cross_local=[('x', 'y', None)]
    ))


@pytest.fixture
def three_level_pytree():
    """3-level hierarchical structure."""
    return {
        'global_mu': jnp.array([[1.0]]),  # (1, 1)
        'group_mu': jnp.array([[2.0], [3.0]]),  # (2, 1)
        'local_theta': jnp.array(
            [[4.0], [5.0], [6.0], [7.0]]
        ),  # (4, 1)
        'obs': jnp.array([[0.1], [0.2], [0.3], [0.4]])  # (4, 1)
    }


@pytest.fixture
def three_level_independence():
    """Independence spec for 3-level structure."""
    return Independence(
        local=['local_theta', 'obs'],
        cross=[
            ('global_mu', 'obs'),
            ('obs', 'global_mu')
        ],
        cross_local=[
            ('group_mu', 'local_theta', None),
            ('local_theta', 'obs', None)
        ]
    )


def test_from_pytree_data_shape(simple_pytree, simple_independence):
    """Test that data array has correct shape."""
    tokens = Tokens.from_pytree(
        simple_pytree,
        condition=['obs'],
        independence=simple_independence,
        sample_ndims=0,
        batch_ndims={'mu': 1, 'theta': 1, 'obs': 1}
    )

    # Total tokens: 1 (mu) + 3 (theta) + 3 (obs) = 7
    # Max batch size: 1
    assert tokens.data.shape == (7, 1)


def test_from_pytree_labels_shape(simple_pytree, simple_independence):
    """Test that labels have correct shape and values."""
    tokens = Tokens.from_pytree(
        simple_pytree,
        condition=['obs'],
        independence=simple_independence,
        sample_ndims=0,
        batch_ndims={'mu': 1, 'theta': 1, 'obs': 1}
    )

    # Labels should have shape (7,) for 7 total tokens
    assert tokens.labels.shape == (7,)

    # Check that each block has consistent labels
    mu_label = tokens.labels[0]
    theta_labels = tokens.labels[1:4]
    obs_labels = tokens.labels[4:7]

    # All tokens from same key should have same label
    assert jnp.all(theta_labels == theta_labels[0])
    assert jnp.all(obs_labels == obs_labels[0])

    # Different keys should have different labels
    assert mu_label != theta_labels[0]
    assert mu_label != obs_labels[0]
    assert theta_labels[0] != obs_labels[0]


def test_from_pytree_labels_with_sample_dims():
    """Test labels with sample dimensions."""
    pytree = {
        'a': jnp.array([[[1.0], [2.0]], [[3.0], [4.0]]]),  # (2, 2, 1)
        'b': jnp.array([[[5.0]], [[6.0]]])  # (2, 1, 1)
    }

    tokens = Tokens.from_pytree(
        pytree,
        condition=[],
        independence=Independence(),
        sample_ndims=1,
        batch_ndims={'a': 1, 'b': 1}
    )

    # Labels should have shape (2, 3) for 2 samples, 3 total tokens
    assert tokens.labels.shape == (2, 3)

    # All samples should have same label structure
    assert jnp.array_equal(tokens.labels[0], tokens.labels[1])

    # Check label values
    a_labels = tokens.labels[0, 0:2]
    b_labels = tokens.labels[0, 2:3]

    assert jnp.all(a_labels == a_labels[0])
    assert jnp.all(b_labels == b_labels[0])
    assert a_labels[0] != b_labels[0]


def test_from_pytree_mask_shapes(simple_pytree, simple_independence):
    """Test that masks have correct shapes."""
    tokens = Tokens.from_pytree(
        simple_pytree,
        condition=['obs'],
        independence=simple_independence,
        sample_ndims=0,
        batch_ndims={'mu': 1, 'theta': 1, 'obs': 1}
    )

    # Self-attention mask: (7, 7)
    assert tokens.self_attention_mask.shape == (7, 7)

    # Padding mask: None for basic case
    assert tokens.padding_mask is None


def test_decode_round_trip(simple_pytree, simple_independence):
    """Test that decoder recovers original PyTree."""
    tokens, decoder = Tokens.from_pytree(
        simple_pytree,
        condition=['obs'],
        independence=simple_independence,
        sample_ndims=0,
        batch_ndims={'mu': 1, 'theta': 1, 'obs': 1},
        return_decoder=True
    )

    reconstructed = decoder(tokens)

    # Check keys match
    assert set(reconstructed.keys()) == set(simple_pytree.keys())

    # Check shapes match
    for key in simple_pytree:
        assert reconstructed[key].shape == simple_pytree[key].shape

    # Check values match
    for key in simple_pytree:
        assert jnp.allclose(reconstructed[key], simple_pytree[key])


def test_decode_after_modification(simple_pytree, simple_independence):
    """Test decoder after modifying different keys."""
    tokens, decoder = Tokens.from_pytree(
        simple_pytree,
        condition=['obs'],
        independence=simple_independence,
        sample_ndims=0,
        batch_ndims={'mu': 1, 'theta': 1, 'obs': 1},
        return_decoder=True
    )

    # Modify different keys with different coefficients
    modified_data = tokens.data.copy()
    # mu: offset 0, size 1 -> multiply by 10
    modified_data = modified_data.at[0:1, :].set(
        modified_data[0:1, :] * 10.0
    )
    # theta: offset 1, size 3 -> multiply by 2
    modified_data = modified_data.at[1:4, :].set(
        modified_data[1:4, :] * 2.0
    )
    # obs: offset 4, size 3 -> multiply by 0.5
    modified_data = modified_data.at[4:7, :].set(
        modified_data[4:7, :] * 0.5
    )

    # Create modified tokens with new data
    modified_tokens = Tokens(
        data=modified_data,
        labels=tokens.labels,
        position=tokens.position,
        condition=tokens.condition,
        self_attention_mask=tokens.self_attention_mask,
        padding_mask=tokens.padding_mask,
        functional_inputs=tokens.functional_inputs
    )

    reconstructed = decoder(modified_tokens)

    # Check that values have correct coefficients applied
    assert jnp.allclose(reconstructed['mu'], simple_pytree['mu'] * 10.0)
    assert jnp.allclose(reconstructed['theta'], simple_pytree['theta'] * 2.0)
    assert jnp.allclose(reconstructed['obs'], simple_pytree['obs'] * 0.5)




def test_from_pytree_with_functional_inputs(
    simple_pytree,
    simple_independence
):
    """Test from_pytree with functional inputs."""
    # Create functional inputs matching pytree structure
    functional_inputs = {
        'mu': jnp.array([[0.0]]),
        'theta': jnp.array([[1.0], [1.0], [1.0]]),
        'obs': jnp.array([[2.0], [2.1], [2.2]])
    }

    tokens = Tokens.from_pytree(
        simple_pytree,
        condition=['obs'],
        independence=simple_independence,
        functional_inputs=functional_inputs,
        sample_ndims=0,
        batch_ndims={'mu': 1, 'theta': 1, 'obs': 1}
    )

    # Check functional_inputs is not None
    assert tokens.functional_inputs is not None

    # Check shape matches data
    assert tokens.functional_inputs.shape == tokens.data.shape


def test_from_pytree_functional_inputs_with_sample_dims():
    """Test functional inputs with sample dimensions."""
    pytree = {
        'a': jnp.array([[[1.0], [2.0]], [[3.0], [4.0]]]),  # (2, 2, 1)
        'b': jnp.array([[[5.0]], [[6.0]]])  # (2, 1, 1)
    }

    functional_inputs = {
        'a': jnp.array([[[0.0], [0.1]], [[0.0], [0.1]]]),  # (2, 2, 1)
        'b': jnp.array([[[1.0]], [[1.0]]])  # (2, 1, 1)
    }

    tokens = Tokens.from_pytree(
        pytree,
        condition=[],
        independence=Independence(),
        functional_inputs=functional_inputs,
        sample_ndims=1,
        batch_ndims={'a': 1, 'b': 1}
    )

    # Check functional_inputs shape: (2, 3, 1)
    assert tokens.functional_inputs is not None
    assert tokens.functional_inputs.shape == (2, 3, 1)
    assert tokens.functional_inputs.shape == tokens.data.shape


