import os
from abc import ABC, abstractmethod
import time

import pandas as pd

from explainers.explainer import Explainer
from utils.constants import RESULTS_PATH, CONFOUNDERS_MAPS
from utils.results_utils import make_dir


class MatchingBasedExplainer(Explainer, ABC):
    def __init__(self, set_to_match, description, top_k, representation_model=None,
                 representation_model_per_concept=None, text_column='text'):
        super().__init__()
        self.text_column = text_column
        self.set_to_match = set_to_match
        self.description = description
        self.representation_model = representation_model
        self.representation_model_per_concept = representation_model_per_concept
        self.top_k = top_k

    @abstractmethod
    def get_matches_for_concept_cebab(self, pairs, concept, base_direction, target_direction, k):
        raise NotImplemented()

    @abstractmethod
    def get_matches_for_concept_stance(self):
        pass

    def set_representation_model(self, model):
        raise NotImplemented()

    def save_matches(self, all_pairs, concept, base_direction, target_direction):
        p = os.path.join(RESULTS_PATH, 'matches_icace')
        make_dir(p)
        p = os.path.join(p, f'{base_direction}->{target_direction}')
        make_dir(p)
        p = os.path.join(p, f'{self.get_explainer_description()}.csv')
        df = pd.DataFrame()
        df['factual'] = list(all_pairs['description_base'].values)
        matching_columns = [c for c in all_pairs.columns if 'description_match' in c]
        for c in matching_columns:
            df[c] = list(all_pairs[c].values)
        df['counterfactual'] = list(all_pairs['description_counterfactual'].values)
        df['icace'] = list(all_pairs['icace_error'].values)
        df.to_csv(p, index=False)

    def build_matches_pairs(self, pairs, concepts, k, path_dir, save_outputs=True, directions=None, setup='cebab',
                            return_df=False,
                            models_to_save_predictions=[]):
        final_df = pd.DataFrame()
        p_l = len(pairs)
        for concept in concepts:
            if (directions is None) and (setup == 'stance'):
                c_direction = list(CONFOUNDERS_MAPS[concept])
            else:
                c_direction = directions
            base_to_target = [f'{b}->{t}' for b in c_direction for t in
                              c_direction if b != t] + ['average']
            for base_direction in c_direction:
                for target_direction in c_direction:
                    if base_direction == target_direction:
                        continue
                    if setup == 'cebab':
                        pairs_of_matches = self.get_matches_for_concept_cebab(pairs, concept, base_direction,
                                                                              target_direction,
                                                                              k)
                    elif setup == 'stance':
                        intervention_aspect_base_col = f'{concept}_text'
                        intervention_aspect_counterfactual_col = 'edit_goal'

                        pairs_prime = pairs[(pairs[intervention_aspect_base_col] == base_direction) & (
                                pairs[f'{intervention_aspect_counterfactual_col}'] == target_direction)]

                        pairs_of_matches = self.get_matches_for_concept_stance(pairs_prime, concept, base_direction,
                                                                               target_direction,
                                                                               k)
                    final_df = pd.concat([final_df, pairs_of_matches])
                    if len(final_df) == p_l:
                        break
        print(len(final_df))
        if save_outputs:
            final_df = final_df.dropna(subset=['text'])
            # for i in range(k):
            #     final_df = final_df.dropna(subset=[f'text_match_{i}'])
            if len(models_to_save_predictions) > 0:
                temp = final_df.copy()
                for model_key in models_to_save_predictions.keys():
                    model = models_to_save_predictions[model_key]
                    for i in range(k):
                        temp[f'prediction_match_{i}'] = model.get_predictions(
                            list(temp[f'text_match_{i}'].values), batch_size=1024)
                    temp[f'prediction_base'] = model.get_predictions(list(temp[f'text_base'].values), batch_size=1024)
                    temp[f'prediction_counterfactual'] = model.get_predictions(
                        list(temp[f'text_counterfactual'].values), batch_size=1024)
                    temp.to_csv(os.path.join(path_dir, f'{self.get_explainer_description()}_{model_key}.csv'),
                                index=False)
            else:
                final_df.to_csv(os.path.join(path_dir, f'{self.get_explainer_description()}.csv'), index=True)
        if return_df:
            return final_df
