import scipy
import numpy as np
from typing import AnyStr, List
from slingpy import AbstractDataSource
from slingpy.models.abstract_base_model import AbstractBaseModel
from genedisco.active_learning_methods.acquisition_functions.base_acquisition_function import \
    BaseBatchAcquisitionFunction



#import IPython



class MeanOptimismAcquisition(BaseBatchAcquisitionFunction):
    def __call__(self,
                 dataset_x: AbstractDataSource,
                 select_size: int,
                 available_indices: List[AnyStr], 
                 last_selected_indices: List[AnyStr] = None, 
                 model: AbstractBaseModel = None,
                 beta_optimism: float = 0.1,
                 ) -> List:
        
        avail_dataset_x = dataset_x.subset(available_indices)
        model_pedictions = model.predict(avail_dataset_x, return_std_and_margin=True)


        if len(model_pedictions) != 3:
            raise TypeError("The provided model does not output uncertainty.")
        
        pred_mean, pred_uncertainties, _ = model_pedictions

        #IPython.embed()

        if len(pred_mean) < select_size:
            raise ValueError("The number of query samples exceeds"
                             "the size of the available data.")


        numerical_selected_indices = np.flip(
            np.argsort(pred_mean + beta_optimism*pred_uncertainties)
        )[:select_size]
        selected_indices = [available_indices[i] for i in numerical_selected_indices]


        return selected_indices