import numpy as np

from eval_pipeline.explainers.abstract_inclusive_explainer import InclusiveExplainer

class RandomInclusiveExplainer(InclusiveExplainer):
    def __init__(self, random_factual = False, use_real_inputs = False):
        self.random_factual = random_factual
        self.use_real_inputs = use_real_inputs


    def __str__(self):
        return f'Random(RealInput:{str(self.use_real_inputs)},RandomFactual:{str(self.random_factual)})'

    def fit(self, pairs, singles, classifier, dev_dataset=None):
        self.singles = singles    

    def _get_probabilities(self, L, N):
        if self.use_real_inputs:
            return np.stack(self.singles.sample(n=L, replace=True)['prediction'].to_numpy())
        else:
            return self._get_random_probability_vectors(L, N)

    @staticmethod
    def _get_random_probability_vectors(L,N):
        p = np.random.uniform(size=(L, N))
        p = p / np.repeat(np.expand_dims(np.linalg.norm(p,axis=1, ord=1), -1), N, axis=-1)
        return p

    def estimate_icace(self, pairs):
        # get the number of classes to predict for
        N = pairs['review_majority_base'].iloc[0].shape[0]
        L = len(pairs)
        
        # get counterfactual predictions
        counterfactual_predictions = self._get_probabilities(L, N)

        # get factual predictions
        if not self.random_factual:
            factual_predictions = np.stack(pairs['prediction_base'].to_numpy())
        else:
            factual_predictions = self._get_probabilities(L, N)

        # return estimate
        estimates = counterfactual_predictions - factual_predictions
        return list(estimates)