from eval_pipeline.explainers import InclusiveExplainer
from eval_pipeline.utils.metric_utils import _calculate_icace, _aggregate_metrics

import numpy as np

class CaCEExplainer(InclusiveExplainer):
    def __init__(self):
        self.cace = None

    def __str__(self):
        return 'CaCEExplainer'

    def fit(self, pairs, singles, classifier, dev_dataset=None):
        # compute the CaCE
        pairs = _calculate_icace(pairs)
        self.cace = _aggregate_metrics(pairs, ['intervention_type', 'intervention_aspect_base', 'intervention_aspect_counterfactual'], ['ICaCE'])
        

    def estimate_icace(self, pairs):
        # return the CaCE
        # TODO: deal with directions where we have no estimates
        predictions = pairs.join(self.cace, on=['intervention_type', 'intervention_aspect_base', 'intervention_aspect_counterfactual'])
                
        n_classes = len(pairs['review_majority_base'].iloc[0])
        predictions['ICaCE'] = predictions['ICaCE'].apply(lambda x: np.zeros(n_classes) if str(x) == 'nan' else x)
        
        predictions = predictions['ICaCE'].to_list()
        return predictions