"""Integration tests for TFMPE class."""

import jax
import jax.numpy as jnp
import pytest
from flax import nnx

from tfmpe.estimators.tfmpe import TFMPE, NormalDistribution
from .conftest import create_mock_tokens


class TestTFMPESampling:
    """Test TFMPE sampling functionality."""

    @pytest.fixture
    def identity_vf(self):
        """Identity vector field for testing.

        f(context, params, t) = 0 (no change to state)
        """
        def vf(context, params, t):
            return jnp.zeros_like(params.data)
        return vf

    @pytest.fixture
    def tfmpe_identity(self, identity_vf, solver):
        """TFMPE with identity vector field."""
        rngs = nnx.Rngs(params=jax.random.PRNGKey(0))
        return TFMPE(
            vf_network=identity_vf,
            base_dist=NormalDistribution(rngs=rngs),
            solver=solver,
            ode_kwargs={'rtol': 1e-5, 'atol': 1e-5},
        )

    def test_sample_posterior_single_sample(
        self, tfmpe_identity
    ):
        """Test sample_posterior returns correct shape.

        Given:
        - TFMPE instance
        - Context Token with (n_tokens=2, batch_size=1)
        - Params Token template with same structure
        - Single sample (one PRNG key)

        When:
        - Call sample_posterior()

        Then:
        - Output shape matches params shape
        - All values are finite
        """
        context = create_mock_tokens(jnp.zeros((1, 2, 1)))
        params = create_mock_tokens(jnp.zeros((1, 2, 1)))

        samples = tfmpe_identity.sample_posterior(
            context=context,
            params=params,
        )

        # Check output is Tokens
        assert isinstance(samples, type(context))

        # Check shape matches input (single sample)
        assert samples.data.shape == params.data.shape

        # Check all values are finite
        assert jnp.all(jnp.isfinite(samples.data))

    def test_sample_posterior_batch_params(
        self,
        identity_vf,
        solver
    ):
        """Test sample_posterior returns correct shape with batched parameters

        When:
        - Call sample_posterior()

        Then:
        - Output shape matches params shape
        - All values are finite
        - Vector function receives correct unbatched shapes
        """
        # Create spy wrapper to capture function call shapes
        call_logs = []

        def spy_vf(context, params, t):
            call_logs.append({
                'context_shape': context.data.shape,
                'params_shape': params.data.shape
            })
            return identity_vf(context, params, t)

        # Create TFMPE with spied vector field
        rngs = nnx.Rngs(params=jax.random.PRNGKey(0))
        tfmpe = TFMPE(
            vf_network=spy_vf,
            base_dist=NormalDistribution(rngs=rngs),
            solver=solver,
            ode_kwargs={'rtol': 1e-5, 'atol': 1e-5},
        )

        n_batch = 10
        context = create_mock_tokens(jnp.zeros((1, 2, 1)))
        params = create_mock_tokens(jnp.zeros((n_batch, 2, 1)))

        samples = tfmpe.sample_posterior(
            context=context,
            params=params,
        )

        # Check output is Tokens
        assert isinstance(samples, type(context))

        # Check shape matches input
        assert samples.data.shape == params.data.shape

        # Check all values are finite
        assert jnp.all(jnp.isfinite(samples.data))
        # Check all values are different
        assert jnp.all(samples.data[0:1] != samples.data[1:])

        # Verify spy captured correct shapes
        assert len(call_logs) > 0, "VF was never called"
        assert call_logs[0]['context_shape'] == (2, 1)
        assert call_logs[0]['params_shape'] == (2, 1)

    def test_sample_posterior_preserves_token_metadata(
        self, tfmpe_identity
    ):
        """Test that sample_posterior preserves Token metadata.

        Given:
        - TFMPE instance
        - Context Token with specific metadata
        - Params Token template

        When:
        - Call sample_posterior()

        Then:
        - Labels match params labels
        - Self-attention mask matches params
        - Slices metadata is preserved
        """
        context = create_mock_tokens(jnp.zeros((2, 1)))
        params = create_mock_tokens(jnp.zeros((2, 1)))

        samples = tfmpe_identity.sample_posterior(
            context=context,
            params=params,
        )

        # Check labels preserved
        assert jnp.array_equal(
            samples.labels, params.labels
        )

        # Check self-attention mask preserved
        assert jnp.array_equal(
            samples.self_attention_mask,
            params.self_attention_mask
        )

        # Check slices preserved
        assert samples.slices == params.slices

    def test_sample_posterior_with_identity_flow(
        self, tfmpe_identity
    ):
        """Test sample_posterior with identity vector field.

        With identity VF (f=0), samples should match base
        distribution (no transformation).

        Given:
        - TFMPE with f(x,c,t)=0
        - Params Token template

        When:
        - Call sample_posterior() twice with reseeded RNG

        Then:
        - Samples are deterministic (same RNG seed gives same
          result)
        """
        context = create_mock_tokens(jnp.zeros((1, 1)))
        params = create_mock_tokens(jnp.zeros((1, 1)))

        # Sample twice with same RNG seed
        nnx.reseed(tfmpe_identity, params=42)
        samples1 = tfmpe_identity.sample_posterior(
            context=context,
            params=params,
        )
        params2 = create_mock_tokens(jnp.zeros((1, 1)))
        nnx.reseed(tfmpe_identity, params=42)
        samples2 = tfmpe_identity.sample_posterior(
            context=context,
            params=params2,
        )

        # Should be identical (deterministic)
        assert jnp.allclose(samples1.data, samples2.data)

    def test_determinism_across_calls(
        self, tfmpe_identity
    ):
        """Test that same RNG seed gives deterministic results.

        Given:
        - TFMPE instance
        - Same RNG seed for resetting
        - Params Token templates

        When:
        - Call sample_posterior() with same RNG seed

        Then:
        - Samples are identical
        """
        context = create_mock_tokens(jnp.zeros((1, 1)))
        params = create_mock_tokens(jnp.zeros((1, 1)))

        nnx.reseed(tfmpe_identity, params=42)
        samples1 = tfmpe_identity.sample_posterior(
            context=context,
            params=params,
        )
        params2 = create_mock_tokens(jnp.zeros((1, 1)))
        nnx.reseed(tfmpe_identity, params=42)
        samples2 = tfmpe_identity.sample_posterior(
            context=context,
            params=params2,
        )

        assert jnp.allclose(samples1.data, samples2.data)

    def test_different_rng_seeds_give_different_samples(
        self, tfmpe_identity
    ):
        """Test that different RNG seeds give different samples.

        Given:
        - TFMPE instance
        - Different RNG seeds for resetting
        - Params Token templates

        When:
        - Call sample_posterior() with different seeds

        Then:
        - Samples differ
        """
        context = create_mock_tokens(jnp.zeros((1, 1)))
        params1 = create_mock_tokens(jnp.zeros((1, 1)))
        params2 = create_mock_tokens(jnp.zeros((1, 1)))

        nnx.reseed(tfmpe_identity, params=42)
        samples1 = tfmpe_identity.sample_posterior(
            context=context,
            params=params1,
        )
        nnx.reseed(tfmpe_identity, params=43)
        samples2 = tfmpe_identity.sample_posterior(
            context=context,
            params=params2,
        )

        # Should not be identical
        assert not jnp.allclose(samples1.data, samples2.data)

    @pytest.mark.parametrize(
        "n_tokens,batch_size",
        [
            (1, 1),
            (2, 1),
            (1, 3),
            (2, 2),
            (3, 5),
        ],
    )
    def test_sampling_various_token_shapes(
        self, tfmpe_identity, n_tokens, batch_size
    ):
        """Test sample_posterior with various token shapes.

        Given:
        - TFMPE instance
        - Different (n_tokens, batch_size) combinations
        - Params Token template

        When:
        - Call sample_posterior()

        Then:
        - Output shape matches params shape
        - All values finite
        """
        context = create_mock_tokens(
            jnp.zeros((n_tokens, batch_size))
        )
        params = create_mock_tokens(
            jnp.zeros((n_tokens, batch_size))
        )

        samples = tfmpe_identity.sample_posterior(
            context=context,
            params=params,
        )

        assert samples.data.shape == (n_tokens, batch_size)
        assert jnp.all(jnp.isfinite(samples.data))

class TestTFMPELogProb:
    """Test TFMPE log probability computation."""

    @pytest.fixture
    def identity_vf(self):
        """Identity vector field."""
        def vf(context, params, t):
            return jnp.zeros_like(params.data)
        return vf

    @pytest.fixture
    def tfmpe_identity(self, identity_vf, solver):
        """TFMPE with identity vector field."""
        rngs = nnx.Rngs(params=jax.random.PRNGKey(0))
        return TFMPE(
            vf_network=identity_vf,
            base_dist=NormalDistribution(rngs=rngs),
            solver=solver,
            ode_kwargs={'rtol': 1e-5, 'atol': 1e-5},
        )

    def test_log_prob_returns_scalar(
        self, tfmpe_identity
    ):
        """Test that log_prob_posterior_samples returns scalar.

        Given:
        - TFMPE instance
        - Single posterior sample Token (n_tokens=1,
          batch_size=1)

        When:
        - Call log_prob_posterior_samples()

        Then:
        - Output is a scalar (shape ())
        - Value is finite
        """
        context = create_mock_tokens(jnp.zeros((1, 1)))
        theta_data = jnp.zeros((1, 1))
        theta = create_mock_tokens(theta_data)

        log_prob = tfmpe_identity.log_prob_posterior_samples(
            theta=theta,
            context=context,
        )

        # Check output is scalar
        assert log_prob.shape == ()

        # Check value is finite
        assert jnp.isfinite(log_prob)

    def test_log_prob_consistency_with_samples(
        self, tfmpe_identity
    ):
        """Test log_prob is finite for samples from posterior.

        Given:
        - TFMPE instance
        - Sample from posterior using sample_posterior()
        - Params Token template

        When:
        - Compute log_prob for that sample

        Then:
        - Log prob is finite
        - Log prob is reasonable value
        """
        context = create_mock_tokens(jnp.zeros((1, 1)))
        params = create_mock_tokens(jnp.zeros((1, 1)))

        # Generate sample with fixed seed
        nnx.reseed(tfmpe_identity, params=42)
        sample = tfmpe_identity.sample_posterior(
            context=context,
            params=params,
        )

        # Compute log prob
        log_prob = tfmpe_identity.log_prob_posterior_samples(
            theta=sample,
            context=context,
        )

        # Check finiteness
        assert jnp.isfinite(log_prob)

        # For identity VF, log_prob should be negative
        # (since sample from standard normal)
        assert log_prob < 0

    @pytest.mark.parametrize(
        "n_tokens,batch_size",
        [
            (1, 1),
            (2, 1),
            (1, 3),
            (2, 2),
            (3, 5),
        ],
    )
    def test_log_prob_various_token_shapes(
        self, tfmpe_identity, n_tokens, batch_size
    ):
        """Test log_prob_posterior_samples with various shapes.

        Given:
        - TFMPE instance
        - Different (n_tokens, batch_size) combinations

        When:
        - Call log_prob_posterior_samples()

        Then:
        - Output is always a scalar
        - Value is finite
        """
        context = create_mock_tokens(
            jnp.zeros((n_tokens, batch_size))
        )
        theta_data = jax.random.normal(
            jax.random.PRNGKey(42), (n_tokens, batch_size)
        )
        theta = create_mock_tokens(theta_data)

        log_prob = tfmpe_identity.log_prob_posterior_samples(
            theta=theta,
            context=context,
        )

        # Always returns scalar
        assert log_prob.shape == ()
        assert jnp.isfinite(log_prob)


class TestTFMPEInitialization:
    """Test TFMPE initialization and configuration."""

    def test_tfmpe_initialization(self, solver):
        """Test TFMPE can be initialized.

        Given:
        - Vector field function
        - Base distribution module
        - ODE solver

        When:
        - Create TFMPE instance

        Then:
        - Instance created successfully
        - Attributes set correctly
        """
        def vf(context, params, time):
            return jnp.zeros_like(params.data)

        rngs = nnx.Rngs(params=jax.random.PRNGKey(0))
        tfmpe = TFMPE(
            vf_network=vf,
            base_dist=NormalDistribution(rngs=rngs),
            solver=solver,
            ode_kwargs={'rtol': 1e-5, 'atol': 1e-5},
        )

        assert tfmpe.vf_network is not None
        assert tfmpe.base_dist is not None
        assert tfmpe.solver is not None

    def test_tfmpe_with_custom_ode_kwargs(self, solver):
        """Test TFMPE with custom ODE kwargs.

        Given:
        - Custom rtol and atol values

        When:
        - Create TFMPE with custom kwargs

        Then:
        - ODE kwargs stored correctly
        """
        def vf(context, params, time):
            return jnp.zeros_like(params.data)

        rngs = nnx.Rngs(params=jax.random.PRNGKey(0))
        custom_kwargs = {'rtol': 1e-3, 'atol': 1e-4}
        tfmpe = TFMPE(
            vf_network=vf,
            base_dist=NormalDistribution(rngs=rngs),
            solver=solver,
            ode_kwargs=custom_kwargs,
        )

        assert tfmpe.ode_kwargs == custom_kwargs


class TestPosteriorSamplingBenchmark:
    """Benchmark posterior sampling performance with realistic
    scales."""

    @pytest.fixture
    def tfmpe_with_transformer(self, solver):
        """TFMPE instance with Transformer vector field.

        Uses a lightweight transformer config suitable for
        benchmarking across varying token/batch sizes.
        """
        from tfmpe.preprocessing.tokens import Tokens
        from tfmpe.nn.transformer import (
            Transformer, TransformerConfig
        )

        # Create minimal template tokens
        params_dict = {'x': jnp.ones((1, 1, 1)) * 0.5}
        params_tokens = Tokens.from_pytree(
            params_dict, sample_ndims=1
        )

        # Create lightweight transformer
        config = TransformerConfig(
            latent_dim=16,
            n_encoder=1,
            n_decoder=1,
            n_heads=1,
            n_ff=2,
        )
        rngs = nnx.Rngs(
            params=jax.random.PRNGKey(0),
            dropout=jax.random.PRNGKey(1),
        )
        transformer = Transformer(
            config=config, tokens=params_tokens, rngs=rngs
        )
        base_dist = NormalDistribution(rngs=rngs)
        tfmpe = TFMPE(
            vf_network=transformer,
            base_dist=base_dist,
            solver=solver,
        )
        tfmpe.eval()

        return tfmpe

    @pytest.mark.slow
    @pytest.mark.parametrize("n_tokens", [1, 10, 20, 50])
    @pytest.mark.parametrize("sample_size", [1, 100, 1000])
    def test_sample_posterior_benchmark(
        self, tfmpe_with_transformer, n_tokens, sample_size,
        benchmark
    ):
        """Benchmark sample_posterior across token/batch sizes.

        Measures wall-clock time for sampling from posterior with
        varying token lengths and batch sizes to understand scaling
        dynamics.

        Parameters
        ----------
        n_tokens : int
            Number of tokens in the sequence
        batch_size : int
            Number of samples in the batch
        benchmark : pytest_benchmark.fixture
            Benchmark fixture for timing
        """
        # Create tokens of specified shape
        context = create_mock_tokens(
            jnp.zeros((sample_size, n_tokens, 1))
        )
        params = create_mock_tokens(
            jnp.zeros((sample_size, n_tokens, 1))
        )

        # Benchmark the sampling operation
        def sample():
            return tfmpe_with_transformer.sample_posterior(
                context=context,
                params=params,
            )

        benchmark(sample)
