import pytest

from offline_rl.data.index_utils import cumulative_index_to_multi_index


# yapf: disable
@pytest.mark.parametrize("inputs,expected", [
    ((0, [1,]), (0, 0)),
    ((1, [1, 3]), (1, 0)),
    ((2, [1, 3]), (1, 1)),
    ((3, [1, 3, 6]), (2, 0)),
    ((5, [1, 3, 6]), (2, 2)),
    ((5, [4, 5, 6]), (2, 0)),
    ((3, [4, 5, 6]), (0, 3)),
])
# yapf: enable
def test_cumulative_index_to_multi_index(inputs, expected):
    assert cumulative_index_to_multi_index(*inputs) == expected


def test_cumulative_index_to_multi_index_edge_cases():
    try:
        cumulative_index_to_multi_index(0, [])
    except AssertionError:
        assert True

    try:
        cumulative_index_to_multi_index(10, [0])
    except AssertionError:
        assert True
