import numpy as np
from collections import defaultdict
from math import ceil

# from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer, util
import torch


from eval_pipeline.explainers.abstract_exclusive_explainer import ExclusiveExplainer
# from eval_pipeline.customized_models.bert import BertForNonlinearSequenceClassification
from eval_pipeline.explainers.exclusive_explainers.explainer_utils import dataset_aspects_to_onehot 


class SEmbeddingCounterfactual(ExclusiveExplainer):

    def __init__(self, device = 'cpu', batch_size = 64, num_classes = 2):
        self.device = device
        self.batch_size = batch_size

        self.aspects = ['food', 'service', 'noise', 'ambiance']

        self.encoded_dataset = None

        self.num_classes = num_classes

        self.model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2').to(self.device)

    def __str__(self):
        return 'SEmbeddingCounterfactual'

    def fit(self, dataset, classifier_predictions, classifier, dev_dataset=None):
        # get the concept encodings for the dataset
        dataset['encoded'] = list(dataset_aspects_to_onehot(dataset))
        dataset['prediction'] = list(classifier_predictions)

        self.encoded_dataset = dataset

        # sentence embed
        embeddings = self.model.encode(dataset['description'])
        self.counterfactual_id_to_embedding = {ident: embedding for (ident, embedding) in zip(dataset.index, embeddings)}

    # NOTE: code duplication with s_learner
    @staticmethod
    def _get_representations_after_interventions(pairs):
        """
        Simulate interventions in the explainable representation space.
        """
        pairs_after_intervention = pairs.copy()
        for aspect in ['food', 'service', 'ambiance', 'noise']:
            pairs_after_intervention[f'{aspect}_aspect_majority_base'] = ((pairs_after_intervention['intervention_type'] == aspect) *
                                                                          pairs_after_intervention['intervention_aspect_counterfactual']) + (
                                                                                     (pairs_after_intervention['intervention_type'] != aspect) *
                                                                                     pairs_after_intervention[f'{aspect}_aspect_majority_base'])

        return dataset_aspects_to_onehot(pairs_after_intervention.rename(columns=lambda col: col.replace('_base', '')))

    def estimate_icace(self, pairs):
        pairs = pairs.copy()
        
        # apply the intervention in concept-space
        pairs_after_intervention = self._get_representations_after_interventions(pairs)

        # NOTE: this only works in inclusive pipelines, move this explainer to inclusive explainers.
        factual_predictions = pairs['prediction_base']
        factual_embeddings = self.model.encode(pairs['description_base'].reset_index(drop=True))

        # hack to compare array values with pandas
        encoded = self.encoded_dataset['encoded'].astype(str)
        pairs_after_intervention = [str(arr) for arr in list(pairs_after_intervention)]
        
        # for every test time example, sample the closest potential counterfactual
        predictions = []
        for example, factual_prediction, factual_embedding in zip(pairs_after_intervention, factual_predictions, factual_embeddings):
            if sum(encoded == example):
                # get all matching counterfactuals
                counterfactual_ids = list(encoded[encoded == example].index)
                counterfactual_embedings = np.stack([self.counterfactual_id_to_embedding[ident] for ident in counterfactual_ids])

                # find the highest dot score
                dot_scores = util.dot_score(factual_embedding, counterfactual_embedings).numpy().squeeze()
                counterfactual_id = counterfactual_ids[np.argmax(dot_scores)]

                counterfactual_prediction = self.encoded_dataset.iloc[counterfactual_id]['prediction']

                predictions.append(counterfactual_prediction - factual_prediction)
            
            else:
                predictions.append([0.0] * self.num_classes)

        return predictions