import pytest

from data_loading.multidimensional_data_generation import generate_multidimensional_data


@pytest.mark.parametrize(
    "num_samples, archetypes, dimensions",
    [
        (100, 10, 3),
        (50, 5, 2),
        (200, 20, 4),
    ],
)
def test_generate_multidimensional_data(num_samples, archetypes, dimensions):

    data_one, _ = generate_multidimensional_data(
        num_samples, archetypes, dimensions, seed=42
    )
    data_two, _ = generate_multidimensional_data(
        num_samples, archetypes, dimensions, seed=42
    )
    data_three, _ = generate_multidimensional_data(
        num_samples, archetypes, dimensions, seed=43
    )

    assert (data_one == data_two).all(), "Data with the same seed should be equal"
    assert (
        data_one != data_three
    ).any(), "Data with different seeds should not be equal"
