import numpy as np
from aiau.strategy.abstract_strategy import AbstractStrategy

class BatchBALDStrategy(AbstractStrategy):
    """
    Implementation of the BatchBALD strategy for batch acquisition in regression settings.
    Selects a batch of points that together maximize the joint mutual information between predictions and model parameters.
    This is a greedy approximation for regression, using predictive entropy and diversity.
    """

    def __init__(self) -> None:
        super(BatchBALDStrategy, self).__init__()
        self.name = "BatchBALD"

    def select_next_indices(self, data_manager, model, num_suggestions, requery=True, batch_strategy="top-k"):
        """
        Selects a batch of indices to label using a greedy BatchBALD approximation.

        Args:
            data_manager (DataManager): Reference to the data_manager which will load the observation if necessary.
            model (Model): Reference to the model which will be used to predict the target values.
            num_suggestions (int): The number of indices to suggest for querying.
            requery (bool): Whether to allow requerying of the same observation.

        Returns:
            list: The indices of the observations to label.
        """
        ensemble_predictions = model.predict(data_manager.full_X)  # shape: (n_models, n_points)
        var_pred = np.var(ensemble_predictions, axis=0)  # shape: (n_points,)
        epsilon = 1e-12
        entropy = 0.5 * np.log(2 * np.pi * np.e * var_pred + epsilon)  # shape: (n_points,)

        # Greedy selection: at each step, pick the point that increases joint entropy the most
        available_indices = np.arange(len(data_manager.full_X))
        if not requery:
            available_indices = np.setdiff1d(available_indices, data_manager.labelled_indices)

        selected = []
        pred_means = np.mean(ensemble_predictions, axis=0)
        for _ in range(num_suggestions):
            if len(selected) == 0:
                # First, pick the point with highest entropy
                idx = available_indices[np.argmax(entropy[available_indices])]
            else:
                selected_means = pred_means[selected]
                # Diversity: minimum absolute difference to already selected means
                diversity = np.array([
                    np.min(np.abs(pred_means[i] - selected_means)) if len(selected) > 0 else 0
                    for i in available_indices
                ])
                # Combine entropy and diversity (can tune weighting if desired)
                score = entropy[available_indices] + diversity
                idx = available_indices[np.argmax(score)]
            selected.append(idx)
            available_indices = available_indices[available_indices != idx]
            if len(available_indices) == 0:
                break
        return np.array(selected)
