import numpy as np

from aiau.strategy.abstract_strategy import AbstractStrategy
from aiau.oracles.oracle import Oracle
from aiau.strategy.bias_estimation import estimate_bias_squared
from aiau.strategy.selection_utils import select_indices_robust

class DiffBiasReductionStrategy(AbstractStrategy):
    """
    Implementation of the DiffBiasReduction strategy for active learning.
    
    """

    def __init__(self, estimation_approach: str, oracle_type: str, oracle: Oracle) -> None:
        super(DiffBiasReductionStrategy, self).__init__()
        self.previous_bias_squared = None
        self.name = "DiffBiasReduction"
        self.estimation_approach = estimation_approach
        self.oracle_type = oracle_type
        self.oracle = oracle

    def select_next_indices(self, data_manager, model, num_suggestions, requery=True, batch_strategy="top-k"):
        """
        Selects the next indices to label.

        Args:
            data_manager (DataManager): Reference to the data_manager which will load the observation if necessary.
            model (Model): Reference to the model which will be used to predict the target values.
            num_suggestions (int): The number of indices to suggest for querying.
            requery (bool): Whether to allow requerying of the same observation.

        Returns:
            list: The indices of the observations to label.
        """
        if batch_strategy == "top-k":
            bias_squared = estimate_bias_squared(data_manager, model, self.estimation_approach, self.oracle_type, self.oracle)

            if self.previous_bias_squared is None:
                self.previous_bias_squared = bias_squared
                output = bias_squared
            else:
                output = self.previous_bias_squared - bias_squared
                self.previous_bias_squared = bias_squared

            return select_indices_robust(data_manager, output, num_suggestions, requery, batch_strategy)
        
        if batch_strategy == "eigen-decomposition":            
            if self.estimation_approach == "cheat":
                bias = estimate_bias_squared(data_manager, model, self.estimation_approach, self.oracle_type, self.oracle)
                # outer product of bias
                cobias = np.outer(bias, bias) # shape (num_observations, num_observations)
            elif self.estimation_approach == "direct":
                bias = estimate_bias_squared(data_manager, model, self.estimation_approach, self.oracle_type, self.oracle)
                cobias = np.outer(bias, bias) # shape (num_observations, num_observations)
            elif self.estimation_approach == "quadratic":
                bias, cobias = estimate_bias_squared(data_manager, model, self.estimation_approach, self.oracle_type, self.oracle, batch_strategy)

            if self.previous_bias_squared is None:
                self.previous_bias_squared = cobias
                scores = cobias
            else: 
                scores = self.previous_bias_squared - cobias
                self.previous_bias_squared = cobias

            return select_indices_robust(data_manager, scores, num_suggestions, requery, batch_strategy)