"""Shared fixtures for estimator tests."""

import diffrax
import jax.numpy as jnp
import pytest

from tfmpe.preprocessing.tokens import Tokens
from tfmpe.preprocessing.utils import SliceInfo, Labeller


def create_mock_tokens(data: jnp.ndarray) -> Tokens:
    """Create a mock Tokens object for testing.

    Parameters
    ----------
    data : Array
        Data array for the tokens, shape (n_tokens, batch) or
        (sample_size, n_tokens, batch)

    Returns
    -------
    Tokens
        Mock Tokens object with minimal setup
    """
    if data.ndim == 2:
        # Shape is (n_tokens, batch)
        n_tokens = data.shape[0]
        batch_size = data.shape[1]
    else:
        # Shape is (sample_size, n_tokens, batch)
        n_tokens = data.shape[1]
        batch_size = data.shape[2]
    return Tokens(
        data=data,
        labels=jnp.zeros(n_tokens, dtype=jnp.int32),
        self_attention_mask=jnp.ones((n_tokens, n_tokens)),
        padding_mask=None,
        functional_inputs=None,
        slices={'x': SliceInfo(
            offset=0,
            event_shape=(),
            batch_shape=(batch_size,)
        )},
        labeller=Labeller(label_map={'x': 0}),
        independence=None,
    )


@pytest.fixture
def doubling_vf():
    """Vector field for doubling flow: f(context, params, t)
    -> Array.

    For testing, ignores context and applies log(2) · params.data.
    """
    def vf(context: Tokens, params: Tokens, t):
        return jnp.log(2.0) * params.data
    return vf


@pytest.fixture
def solver():
    """Diffrax ODE solver instance."""
    return diffrax.Dopri5()
