import numpy as np
from aiau.strategy.abstract_strategy import AbstractStrategy
from aiau.strategy.bias_estimation import estimate_bias_squared
from aiau.strategy.selection_utils import select_indices_robust
from aiau.oracles.oracle import Oracle

class EMSEBiasStrategy(AbstractStrategy):
    """
    Implementation of the EMSE-Bias strategy for active learning (uses bias estimation instead of EMSE).
    """

    def __init__(self, estimation_approach: str, oracle_type: str, oracle: Oracle) -> None:
        super(EMSEBiasStrategy, self).__init__()
        self.name = "EMSEBias"
        self.estimation_approach = estimation_approach
        self.oracle_type = oracle_type
        self.oracle = oracle

    def select_next_indices(self, data_manager, model, num_suggestions, requery=True, batch_strategy="top-k"):
        """
        Selects the next indices to label using bias estimation.
        """
        model_prediction = model.predict(data_manager.full_X)
        variance = np.var(model_prediction, axis=0) # shape (num_observations,)

        if batch_strategy == "top-k":
            bias = estimate_bias_squared(data_manager, model, self.estimation_approach, self.oracle_type, self.oracle)
            pemse = variance + bias
            return select_indices_robust(data_manager, pemse, num_suggestions, requery, batch_strategy)

        if batch_strategy=="eigen-decomposition":
            if self.estimation_approach == "cheat":
                bias = estimate_bias_squared(data_manager, model, self.estimation_approach, self.oracle_type, self.oracle)
                cobias = np.outer(bias, bias)
            elif self.estimation_approach == "direct":
                bias = estimate_bias_squared(data_manager, model, self.estimation_approach, self.oracle_type, self.oracle)
                cobias = np.outer(bias, bias)
            elif self.estimation_approach == "quadratic":
                bias, cobias = estimate_bias_squared(data_manager, model, self.estimation_approach, self.oracle_type, self.oracle, batch_strategy)

            covariance = np.cov(model_prediction.T)
            scores = np.diag(covariance) + np.diag(cobias)
            
            return select_indices_robust(data_manager, scores, num_suggestions, requery, batch_strategy)
        
        else:
            raise ValueError(f"Batch strategy {batch_strategy} currently not supported for EMSEBiasStrategy.")