import numpy as np
import pytest

from llm_inference.utils import (
    create_batches_by_seq_len,
)


def compare_batches(result, expected):
    """Compare batches in an order-invariant manner."""
    if len(result) != len(expected):
        return False

    result_set = set(frozenset(batch) for batch in result)
    expected_set = set(frozenset(batch) for batch in expected)

    return result_set == expected_set


@pytest.mark.parametrize(
    "batch_sizes, sequence_length_thresholds, input_lengths, expected",
    [
        # Test cases remain the same as in the previous version
        (
            [3, 2, 1],
            [10, 20, 30],
            [5, 15, 25, 8, 18, 28],
            [[25], [28], [15, 18], [5, 8]],
        ),
        ([4, 2], [20, 30], [5, 10, 15, 18], [[5, 10, 15, 18]]),
        ([4, 2, 1], [10, 20, 30], [25, 28, 22, 26], [[25], [28], [22], [26]]),
        ([2, 1], [10, 20], [15], [[15]]),
        (
            [4, 2, 1],
            [10, 20, 30],
            [5, 8, 12, 15, 18, 22, 25, 28],
            [[5], [18, 15], [12, 8], [22], [25], [28]],
        ),
        ([3, 2], [10, 20], [10, 10, 20, 20, 20], [[20, 20], [20, 10], [10]]),
        ([3, 2], [10, 20], [5, 6, 7, 8, 9], [[9, 8, 7], [6, 5]]),
        ([4, 2], [10, 20], [5, 8, 12, 15, 18], [[18, 15], [12, 8], [5]]),
    ],
)
def test_create_batches_by_seq_len(
    batch_sizes, sequence_length_thresholds, input_lengths, expected
):
    input_lengths = np.array(input_lengths)
    indices = create_batches_by_seq_len(
        batch_sizes, sequence_length_thresholds, input_lengths
    )
    result = [input_lengths[i] for i in indices]
    assert len(result) == len(
        expected
    ), f"Expected {len(expected)} batches, but got {len(result)}"
    assert compare_batches(
        result, expected
    ), f"Expected batches {expected}, but got {result}"


@pytest.mark.parametrize(
    "batch_sizes, sequence_length_thresholds, input_lengths",
    [
        # 15 > the maximum sequence length threshold
        ([2], [10], [5, 8, 12, 15]),
        # all items are > 30
        ([3, 2, 1], [10, 20, 30], [35, 40, 45, 50]),
    ],
)
def test_create_batches_by_seq_len_error_cases(
    batch_sizes, sequence_length_thresholds, input_lengths
):
    input_lengths = np.array(input_lengths)
    with pytest.raises(ValueError):
        create_batches_by_seq_len(
            batch_sizes, sequence_length_thresholds, input_lengths
        )
