import pytest
import numpy as np
from aiau.data.index_initialisation_utils import (
    lower_left_corner_indices_2d,
    random_indices
)


@pytest.fixture
def grid_data():
    # 3x3 grid with increasing x and y
    return np.array([
        [0.0, 2.0], [1.0, 2.0], [2.0, 2.0],
        [0.0, 1.0], [1.0, 1.0], [2.0, 1.0],
        [0.0, 0.0], [1.0, 0.0], [2.0, 0.0]
    ])


def test_lower_left_corner_selection_subset(grid_data):
    selected = lower_left_corner_indices_2d(grid_data, num_samples=3, lower_bound=0.0, upper_bound=1.5)
    assert len(selected) == 3
    for idx in selected:
        x, y = grid_data[idx]
        assert 0.0 <= x <= 1.5
        assert y <= 1.5


def test_lower_left_corner_selection_all_points(grid_data):
    selected = lower_left_corner_indices_2d(grid_data, num_samples=5, lower_bound=0.0, upper_bound=2.0)
    assert len(selected) == 5
    # all points should be within full grid range
    for idx in selected:
        x, y = grid_data[idx]
        assert 0.0 <= x <= 2.0
        assert y <= 2.0


def test_lower_left_corner_too_many_requested(grid_data):
    with pytest.raises(ValueError):
        # Only 6 points fall into lower-left corner when upper_bound is 1.5
        lower_left_corner_indices_2d(grid_data, num_samples=10, upper_bound=1.5)


def test_random_indices_returns_correct_number(grid_data):
    selected = random_indices(grid_data, num_samples=4)
    assert len(selected) == 4
    assert all(0 <= idx < len(grid_data) for idx in selected)


def test_random_indices_no_duplicates_when_replace_false(grid_data):
    selected = random_indices(grid_data, num_samples=9)
    assert len(selected) == len(set(selected))  # all unique
    assert sorted(selected) == sorted(np.arange(9))  # must be permutation of all indices


def test_random_indices_too_many_requested(grid_data):
    with pytest.raises(ValueError):
        # Requesting more samples than available without replacement
        random_indices(grid_data, num_samples=20)