import numpy as np

from aiau.strategy.abstract_strategy import AbstractStrategy
from aiau.strategy.error_estimation import estimate_pemse
from aiau.strategy.bias_estimation import estimate_bias_squared
from aiau.strategy.selection_utils import select_indices_robust
from aiau.oracles.oracle import Oracle

class PemsePemseDiffStrategy(AbstractStrategy):
    """
    Implementation of the DiffMaxEMSE strategy for active learning.
    
    """

    def __init__(self, estimation_approach: str, oracle_type: str, oracle: Oracle) -> None:
        super(PemsePemseDiffStrategy, self).__init__()
        self.previous_model_emse = None
        self.name = "PemsePemseDiffStrategy"
        self.estimation_approach = estimation_approach
        self.oracle_type = oracle_type
        self.oracle = oracle
        self.iteration_counter = 0

    def select_next_indices(self, data_manager, model, num_suggestions, requery=True, batch_strategy="top-k"):
        """
        Selects the next indices to label.

        Args:
            data_manager (DataManager): Reference to the data_manager which will load the observation if necessary.
            model (Model): Reference to the model which will be used to predict the target values.
            num_suggestions (int): The number of indices to suggest for querying.
            requery (bool): Whether to allow requerying of the same observation.

        Returns:
            list: The indices of the observations to label.
        """
        model_prediction = model.predict(data_manager.full_X)
        variance = np.var(model_prediction, axis=0) # shape (num_observations,)

        if batch_strategy=="top-k":
            bias = estimate_bias_squared(data_manager, model, self.estimation_approach, self.oracle_type, self.oracle) # shape (num_observations,)
            estimated_mse = variance + bias # shape (num_observations,)

            if self.iteration_counter > 1: # 3rd iteration
                output = self.previous_model_emse - estimated_mse
            else:
                output = estimated_mse
                self.iteration_counter += 1

            # Update the previous model EMSE with the current one for the next iteration
            self.previous_model_emse = estimated_mse

            return select_indices_robust(data_manager, output, num_suggestions, requery, batch_strategy)
        
        if batch_strategy=="eigen-decomposition":
            if self.estimation_approach == "cheat":
                bias = estimate_bias_squared(data_manager, model, self.estimation_approach, self.oracle_type, self.oracle)
                # outer product of bias
                cobias = np.outer(bias, bias) # shape (num_observations, num_observations)
            elif self.estimation_approach == "direct":
                bias = estimate_bias_squared(data_manager, model, self.estimation_approach, self.oracle_type, self.oracle)
                cobias = np.outer(bias, bias) # shape (num_observations, num_observations)
            elif self.estimation_approach == "quadratic":
                bias, cobias = estimate_bias_squared(data_manager, model, self.estimation_approach, self.oracle_type, self.oracle, batch_strategy)

            covariance = np.cov(model_prediction.T)
            comse = covariance + cobias

            if self.iteration_counter > 1:
                scores = self.previous_model_emse - comse # shape (num_observations, num_observations)
            else:
                scores = comse
                self.iteration_counter += 1

            self.previous_model_emse = comse

            return select_indices_robust(data_manager, scores, num_suggestions, requery, batch_strategy)
            