import pytest
import numpy as np
from aiau.strategy.selection_utils import select_top_k_indices, select_indices_robust


class MockDataManager:
    """Mock DataManager for testing selection utilities."""
    
    def __init__(self, labelled_indices):
        self.labelled_indices = set(labelled_indices)


class TestSelectionUtils:
    """Test cases for the robust selection utility functions."""
    
    def test_select_top_k_indices_with_requery(self):
        """Test normal case with requery=True."""
        scores = np.array([0.1, 0.8, 0.3, 0.9, 0.2])
        indices = select_top_k_indices(scores, num_suggestions=2, requery=True)
        expected = np.array([3, 1])  # Highest scores first
        np.testing.assert_array_equal(indices, expected)
    
    def test_select_top_k_indices_without_requery(self):
        """Test normal case with requery=False and some labelled indices."""
        scores = np.array([0.1, 0.8, 0.3, 0.9, 0.2])
        labelled_indices = {1, 3}  # Exclude indices 1 and 3
        indices = select_top_k_indices(
            scores, 
            num_suggestions=2, 
            requery=False, 
            labelled_indices=labelled_indices
        )
        expected = np.array([2, 4])  # Next highest scores after excluding 1 and 3
        np.testing.assert_array_equal(indices, expected)
    
    def test_select_top_k_indices_insufficient_unlabelled(self):
        """Test edge case - requesting more suggestions than available unlabelled."""
        scores = np.array([0.1, 0.8, 0.3, 0.9, 0.2])
        labelled_indices = {0, 1, 2, 3}  # Only index 4 is unlabelled
        indices = select_top_k_indices(
            scores, 
            num_suggestions=3, 
            requery=False, 
            labelled_indices=labelled_indices
        )
        expected = np.array([4])  # Only one unlabelled index available
        np.testing.assert_array_equal(indices, expected)
    
    def test_select_top_k_indices_all_labelled(self):
        """Test edge case - all indices are already labelled."""
        scores = np.array([0.1, 0.8, 0.3, 0.9, 0.2])
        labelled_indices = {0, 1, 2, 3, 4}  # All indices are labelled
        indices = select_top_k_indices(
            scores, 
            num_suggestions=2, 
            requery=False, 
            labelled_indices=labelled_indices
        )
        expected = np.array([])  # No unlabelled indices available
        np.testing.assert_array_equal(indices, expected)
    
    def test_select_top_k_indices_invalid_inputs(self):
        """Test validation of input parameters."""
        scores = np.array([0.1, 0.8, 0.3])
        
        # Test invalid number of suggestions
        with pytest.raises(ValueError, match="num_suggestions must be positive"):
            select_top_k_indices(scores, num_suggestions=0)
        
        with pytest.raises(ValueError, match="num_suggestions must be positive"):
            select_top_k_indices(scores, num_suggestions=-1)
        
        # Test invalid scores shape
        scores_2d = np.array([[0.1, 0.8], [0.3, 0.9]])
        with pytest.raises(ValueError, match="Scores must be 1-dimensional"):
            select_top_k_indices(scores_2d, num_suggestions=1)
    
    def test_select_indices_robust_top_k_integration(self):
        """Test integration with data manager using top-k strategy."""
        dm = MockDataManager(labelled_indices=[1, 3])
        scores = np.array([0.1, 0.8, 0.3, 0.9, 0.2])
        indices = select_indices_robust(
            dm, 
            scores, 
            num_suggestions=2, 
            requery=False, 
            batch_strategy="top-k"
        )
        expected = np.array([2, 4])
        np.testing.assert_array_equal(indices, expected)
    
    def test_select_indices_robust_invalid_batch_strategy(self):
        """Test validation of batch strategy parameter."""
        dm = MockDataManager(labelled_indices=[])
        scores = np.array([0.1, 0.8, 0.3])
        
        with pytest.raises(ValueError, match="Unsupported batch_strategy"):
            select_indices_robust(
                dm, 
                scores, 
                num_suggestions=1, 
                requery=True, 
                batch_strategy="invalid_strategy"
            )
    
    def test_select_indices_robust_eigen_decomposition_validation(self):
        """Test validation for eigen-decomposition strategy."""
        dm = MockDataManager(labelled_indices=[])
        scores_1d = np.array([0.1, 0.8, 0.3])  # Should be 2D for eigen
        
        with pytest.raises(ValueError, match="Eigen-decomposition requires 2D scores matrix"):
            select_indices_robust(
                dm, 
                scores_1d, 
                num_suggestions=1, 
                requery=True, 
                batch_strategy="eigen-decomposition"
            )
    
    def test_edge_case_single_point(self):
        """Test behavior with a single data point."""
        scores = np.array([0.5])
        
        # Should work with requery
        indices = select_top_k_indices(scores, num_suggestions=1, requery=True)
        expected = np.array([0])
        np.testing.assert_array_equal(indices, expected)
        
        # Should work when unlabelled
        indices = select_top_k_indices(scores, num_suggestions=1, requery=False, labelled_indices=set())
        expected = np.array([0])
        np.testing.assert_array_equal(indices, expected)
        
        # Should return empty when labelled
        indices = select_top_k_indices(scores, num_suggestions=1, requery=False, labelled_indices={0})
        expected = np.array([])
        np.testing.assert_array_equal(indices, expected)
    
    def test_large_num_suggestions(self):
        """Test behavior when requesting more suggestions than total points."""
        scores = np.array([0.1, 0.8, 0.3])
        
        # With requery, should return all points (but only up to num_suggestions)
        indices = select_top_k_indices(scores, num_suggestions=5, requery=True)
        expected = np.array([1, 2, 0])  # All points in descending score order
        np.testing.assert_array_equal(indices, expected)
    
    def test_score_ordering_consistency(self):
        """Test that indices are consistently ordered by descending scores."""
        # Test with various score patterns
        test_cases = [
            np.array([0.1, 0.2, 0.3, 0.4, 0.5]),  # Ascending
            np.array([0.5, 0.4, 0.3, 0.2, 0.1]),  # Descending  
            np.array([0.3, 0.1, 0.5, 0.2, 0.4]),  # Random
            np.array([0.2, 0.2, 0.3, 0.2, 0.1]),  # With duplicates
        ]
        
        for scores in test_cases:
            indices = select_top_k_indices(scores, num_suggestions=3, requery=True)
            
            # Verify returned indices are in descending score order
            selected_scores = scores[indices]
            assert np.all(selected_scores[:-1] >= selected_scores[1:]), \
                f"Scores not in descending order: {selected_scores} for input {scores}"
