from collections.abc import Sequence
from itertools import combinations

from flax import nnx
from jax import Array, numpy as jnp
import pytest

from offline.modules.mlp import MLP, MLPEnsemble


def check_unique(sequence: Sequence[Array]):
    return all(jnp.allclose(x, y) for x, y in combinations(sequence, 2))


def test_mlp():
    mlp = MLP(in_features=5, out_features=4, rngs=nnx.Rngs(0), dropout=0.5)
    inputs = jnp.ones((3, 5))
    outputs = [mlp(inputs) for _ in range(10)]
    assert all(x.shape == (3, 4) for x in outputs)
    assert not check_unique(outputs)


@pytest.mark.parametrize("out_axis", range(-3, 3))
def test_mlp_ensemble(out_axis: int):
    mlp = MLPEnsemble(
        dropout=0.5,
        ensemble_size=6,
        in_features=5,
        out_axis=out_axis,
        out_features=4,
        rngs=nnx.Rngs(0),
    )
    inputs = jnp.ones((3, 5))
    outputs = [mlp(inputs) for _ in range(10)]
    if out_axis == -1:
        shape = (3, 4, 6)
    elif out_axis == -2:
        shape = (3, 6, 4)
    else:
        shape = (3, 4)[:out_axis] + (6,) + (3, 4)[out_axis:]
    assert all(x.shape == shape for x in outputs)
    assert not check_unique(outputs)
