import numpy as np
import logging
from .eigen_decomposition_batching import eigen_decomposition_batching

logger = logging.getLogger(__name__)

def select_top_k_indices(scores, num_suggestions, requery=True, labelled_indices=None):
    """
    Robust top-k selection with proper handling of requery constraints.
    
    Args:
        scores (np.ndarray): The scores for each observation.
        num_suggestions (int): The number of indices to suggest for querying.
        requery (bool): Whether to allow requerying of the same observation.
        labelled_indices (set or None): Set of already labelled indices.
    
    Returns:
        np.ndarray: The indices of the observations to label.
    
    Raises:
        ValueError: If insufficient unlabeled indices are available.
    """
    if scores.ndim != 1:
        raise ValueError(f"Scores must be 1-dimensional, got shape {scores.shape}")
    
    if num_suggestions <= 0:
        raise ValueError(f"num_suggestions must be positive, got {num_suggestions}")
    
    # Get sorted indices (highest scores first)
    sorted_indices = np.argsort(scores)[::-1]
    
    if requery or labelled_indices is None:
        # Simple case: can requery or no labelled indices constraint
        return sorted_indices[:num_suggestions]
    
    # Filter out already labelled indices
    labelled_set = set(labelled_indices) if not isinstance(labelled_indices, set) else labelled_indices
    unlabelled_indices = np.array([idx for idx in sorted_indices if idx not in labelled_set])
    
    if len(unlabelled_indices) < num_suggestions:
        logger.warning(
            f"Only {len(unlabelled_indices)} unlabelled indices available, "
            f"but {num_suggestions} suggestions requested. Returning all available."
        )
        return unlabelled_indices
    
    return unlabelled_indices[:num_suggestions]


def select_indices_robust(data_manager, scores, num_suggestions, requery=True, batch_strategy="top-k"):
    """
    Robust index selection with multiple batch strategies and proper error handling.
    
    Args:
        data_manager: DataManager instance with labelled_indices attribute.
        scores (np.ndarray): The scores for each observation (1D for top-k, 2D for eigen).
        num_suggestions (int): The number of indices to suggest for querying.
        requery (bool): Whether to allow requerying of the same observation.
        batch_strategy (str): The strategy for selecting indices ("top-k" or "eigen-decomposition").
    
    Returns:
        np.ndarray: The indices of the observations to label.
    
    Raises:
        ValueError: For invalid inputs or insufficient candidates.
    """
    if num_suggestions <= 0:
        raise ValueError(f"num_suggestions must be positive, got {num_suggestions}")
    
    if batch_strategy == "top-k":
        return select_top_k_indices(
            scores=scores,
            num_suggestions=num_suggestions,
            requery=requery,
            labelled_indices=data_manager.labelled_indices if not requery else None
        )
    
    elif batch_strategy == "eigen-decomposition":
        if scores.ndim != 2:
            raise ValueError(f"Eigen-decomposition requires 2D scores matrix, got shape {scores.shape}")
        
        # Use existing eigen decomposition logic
        return eigen_decomposition_batching(scores, num_suggestions, requery, data_manager)
    
    else:
        raise ValueError(f"Unsupported batch_strategy: {batch_strategy}")


# Legacy function for backward compatibility
def selection(data_manager, scores, num_suggestions, requery, batch_strategy):
    """
    Legacy selection function - use select_indices_robust instead.
    """
    logger.warning("selection() is deprecated, use select_indices_robust() instead")
    return select_indices_robust(data_manager, scores, num_suggestions, requery, batch_strategy)