from eval_pipeline.explainers import InclusiveExplainer
from eval_pipeline.utils.metric_utils import _calculate_ite, _aggregate_metrics

import numpy as np

class CONATEExplainer(InclusiveExplainer):
    def __init__(self, confounders):
        self.ate = None
        self.confounders = confounders
        self.confounders_base = [f'{confounder}_aspect_majority_base' for confounder in self.confounders]


    def __str__(self):
        return 'CONATEExplainer' + str(self.confounders)

    def fit(self, pairs, singles, classifier, dev_dataset=None):
        # compute the ATE
        pairs = _calculate_ite(pairs)

        self.ate = _aggregate_metrics(pairs, ['intervention_type', 'intervention_aspect_base', 'intervention_aspect_counterfactual'] + self.confounders_base, ['ITE'])
        pass

    def estimate_icace(self, pairs):
        # return the ATE
        # TODO: deal with directions where we have no estimates
        predictions = pairs.join(self.ate, on=['intervention_type', 'intervention_aspect_base', 'intervention_aspect_counterfactual'] + self.confounders_base)        
        
        n_classes = len(pairs['review_majority_base'].iloc[0])
        predictions['ITE'] = predictions['ITE'].apply(lambda x: np.zeros(n_classes) if str(x) == 'nan' else x)
        
        predictions = predictions['ITE'].to_list()
        return predictions