import numpy as np
from collections import defaultdict
from math import ceil

from transformers import AutoTokenizer
import torch


from eval_pipeline.explainers.abstract_exclusive_explainer import ExclusiveExplainer
from eval_pipeline.customized_models.bert import BertForNonlinearSequenceClassification
from eval_pipeline.explainers.exclusive_explainers.explainer_utils import dataset_aspects_to_onehot 


class ApproxCounterfactual(ExclusiveExplainer):

    def __init__(self, absa_model_path, device = 'cpu', batch_size = 64, num_classes = 2):
        self.device = device
        self.batch_size = batch_size

        self.aspects = ['food', 'service', 'noise', 'ambiance']

        self.absa_model = BertForNonlinearSequenceClassification.from_pretrained(absa_model_path).to(self.device)
        if 'CEBaB/' in absa_model_path:
            self.tokenizer = AutoTokenizer.from_pretrained(absa_model_path.split('/')[1].split('.')[0])
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(absa_model_path)

        self.encoded_dataset = None

        self.num_classes = num_classes

    def __str__(self):
        return 'ApproxCounterfactual'

    def fit(self, dataset, classifier_predictions, classifier, dev_dataset=None):
        # get the concept encodings for the dataset
        dataset['encoded'] = list(dataset_aspects_to_onehot(dataset))
        dataset['prediction'] = list(classifier_predictions)

        self.encoded_dataset = dataset

    # NOTE: code duplication with s_learner
    @staticmethod
    def _get_representations_after_interventions(pairs):
        """
        Simulate interventions in the explainable representation space.
        """
        pairs_after_intervention = pairs.copy()
        for aspect in ['food', 'service', 'ambiance', 'noise']:
            pairs_after_intervention[f'{aspect}_aspect_majority_base'] = ((pairs_after_intervention['intervention_type'] == aspect) *
                                                                          pairs_after_intervention['intervention_aspect_counterfactual']) + (
                                                                                     (pairs_after_intervention['intervention_type'] != aspect) *
                                                                                     pairs_after_intervention[f'{aspect}_aspect_majority_base'])

        return dataset_aspects_to_onehot(pairs_after_intervention.rename(columns=lambda col: col.replace('_base', '')))

    # NOTE: code duplication with s_learner
    def _get_pairs_with_predicted_aspect_labels(self, pairs):
        """
        Use the ABSA model to predict the labels of the aspects that are not intervened upon.
        """
        self.absa_model.to(self.device)
        self.absa_model.eval()

        # create absa inputs
        text = pairs['description_base'].to_list()

        n_batches = ceil(len(text)/self.batch_size)
        predictions = defaultdict(list)

        # for every aspect
        for aspect in self.aspects:
            # create ABSA inputs
            absa_input = self.tokenizer(text, [aspect]*len(text), return_tensors='pt', padding=True, truncation=True).to(self.device)

            # for every batch
            for i in range(n_batches):

                # pass to model
                absa_input_batch = {k: v[i*self.batch_size: (i+1)*self.batch_size] for k,v in absa_input.items()}
                absa_output_batch = self.absa_model(**absa_input_batch)

                predictions[aspect].append(absa_output_batch.logits.detach().cpu())

        # stack batches and get argmax
        predictions = {k: np.argmax(torch.concat(v).numpy(), axis=1) for k,v in predictions.items()}

        # ABSA model uses different encodings
        absa_to_lime_encodings = {
            0:'Negative',
            1:'Positive',
            2:'unknown'
        }

        encoder = np.vectorize(absa_to_lime_encodings.get)
        predictions = {k: encoder(v) for k,v in predictions.items()}

        # overwrite current labels
        for aspect in self.aspects:
            pairs[f'{aspect}_aspect_majority_base'] = predictions[aspect]

        return pairs 

    def estimate_icace(self, pairs):
        # use predicted aspect labels for the base examples, instead of the ground truth labels
        # NOTE: performance seems to be worse.
        # pairs_predicted = self._get_pairs_with_predicted_aspect_labels(pairs.copy())
        pairs_predicted = pairs.copy()
        # apply the intervention in concept-space
        pairs_after_intervention = self._get_representations_after_interventions(pairs_predicted)
        # NOTE: this only works in inclusive pipelines, move this explainer to inclusive explainers.
        factual_predictions = pairs_predicted['prediction_base']

        # hack to compare array values with pandas
        encoded = self.encoded_dataset['encoded'].astype(str)
        pairs_after_intervention = [str(arr) for arr in list(pairs_after_intervention)]
        
        # for every test time example, sample a potential counterfactual
        predictions = []
        for example, factual_prediction in zip(pairs_after_intervention, factual_predictions):
            if sum(encoded == example):
                counterfactual_id = list(encoded[encoded == example].sample(n=1).index)[0]

                # NOTE: this should be a vector?
                # counterfactual_prediction = self.encoded_dataset.iloc[counterfactual_id]['prediction']
                counterfactual_prediction = self.encoded_dataset.loc[counterfactual_id]['prediction']

                # TODO: get counterfactual prediction
                predictions.append(counterfactual_prediction - factual_prediction)
            
            else:
                predictions.append([0.0] * self.num_classes)

        return predictions