from eval_pipeline.explainers import InclusiveExplainer
from eval_pipeline.utils.metric_utils import _calculate_ite, _aggregate_metrics

import numpy as np

class ATEExplainer(InclusiveExplainer):
    def __init__(self):
        self.ate = None

    def __str__(self):
        return 'ATEExplainer'

    def fit(self, pairs, singles, classifier, dev_dataset=None):
        # compute the ATE
        pairs = _calculate_ite(pairs)
        self.ate = _aggregate_metrics(pairs, ['intervention_type', 'intervention_aspect_base', 'intervention_aspect_counterfactual'], ['ITE'])
        pass

    def estimate_icace(self, pairs):
        # return the ATE
        # TODO: deal with directions where we have no estimates
        predictions = pairs.join(self.ate, on=['intervention_type', 'intervention_aspect_base', 'intervention_aspect_counterfactual'])        
        
        n_classes = len(pairs['review_majority_base'].iloc[0])
        predictions['ITE'] = predictions['ITE'].apply(lambda x: np.zeros(n_classes) if str(x) == 'nan' else x)
        
        predictions = predictions['ITE'].to_list()
        return predictions