import pandas as pd
import numpy as np
import os

from eval_pipeline.explainers import InclusiveExplainer

class GPT3Counterfactual(InclusiveExplainer):
    def __init__(self, GPT3_output_dir, k = 0, eval_split = 'test', num_classes = 2):
        self.num_classes = num_classes

        # find all the possible GPT3 output files that match the current setting
        GPT3_files = sorted(filter(lambda x: eval_split in x, os.listdir(GPT3_output_dir)))
        GPT3_k = [int(x.split('_')[1]) for x in GPT3_files]

        # find the largest GPT3 prompt applicable for this k
        applicable = np.array(GPT3_k) <= k
        if applicable.any():
            applicable_indexes = np.where(applicable == True)[0]
            last_applicable_index = applicable_indexes[-1]

            GPT3_file = os.path.join(GPT3_output_dir, f'{eval_split}_{GPT3_k[last_applicable_index]}_pairs.json')
            print(f'Loading {GPT3_file}')
            self.gpt3_counterfactuals = pd.read_json(GPT3_file)

        elif GPT3_files:
            GPT3_file = os.path.join(GPT3_output_dir, f'{eval_split}_{GPT3_k[0]}_pairs.json')
            print(f'Loading {GPT3_file}')
            self.gpt3_counterfactuals = pd.read_json(GPT3_file)
        else:
            self.gpt3_counterfactuals = None

        self.classifier = None

    def __str__(self):
        return 'GPT3Counterfactual'

    def fit(self, pairs, singles, classifier, dev_dataset=None):
        self.classifier = classifier

    def estimate_icace(self, pairs):
        if isinstance(self.gpt3_counterfactuals, pd.DataFrame):
            # check if the json corresponds to the pairs
            columns_to_compare = ['description_base','intervention_type','intervention_aspect_base','intervention_aspect_counterfactual']
            assert (self.gpt3_counterfactuals[columns_to_compare].reset_index(drop=True) == pairs[columns_to_compare].reset_index(drop=True)).all().all()

            # get the counterfactual predictions
            counterfactual_dataset = pd.DataFrame()
            counterfactual_dataset['description'] = self.gpt3_counterfactuals['description_counterfactual']
            # NOTE: just to keep the predict_proba function happy
            counterfactual_dataset['review_majority'] = 0

            counterfactual_probas, _ = self.classifier.predict_proba(counterfactual_dataset)

            # get estimates
            factual_probas = np.stack([a.squeeze() for a in pairs.prediction_base.to_numpy()])
            estimates = counterfactual_probas - factual_probas
            return list(estimates)

        else:
            # return dummy estimates
            return list(np.zeros((len(pairs), self.num_classes)))

