from explainers.matching_based_explainer import MatchingBasedExplainer
import pandas as pd
import numpy as np

from utils.constants import CEBAB_CONCEPTS


class CF_Gen(MatchingBasedExplainer):

    def get_matches_for_concept_stance(self):
        raise NotImplementedError

    def get_matches_for_concept_cebab(self, pairs, concept, base_direction, target_direction, top_k, batch_size=128):
        candidate_pairs = pairs[
            (pairs['intervention_type'] == concept) & (
                    pairs['intervention_aspect_base'] == base_direction) & (
                    pairs['intervention_aspect_counterfactual'] == target_direction)]
        candidates_generations = self.set_to_match.copy()
        df_matches_pairs = pd.DataFrame()
        df_matches_pairs['original_id'] = candidate_pairs['original_id_base'].values
        df_matches_pairs['text_base'] = candidate_pairs['description_base'].values
        df_matches_pairs['text_counterfactual'] = candidate_pairs['description_counterfactual'].values
        candidate_pairs = candidate_pairs.copy().reset_index()
        for idx in candidate_pairs.index:
            row = candidate_pairs.loc[idx]
            original_id = row['original_id_base']
            generations_slice = candidates_generations[candidates_generations['original_id'] == str(original_id)]
            generations_slice = generations_slice[generations_slice['edit_id'] == str(row['edit_id_counterfactual'])]
            if len(generations_slice) == 0:
                continue
            if self.all_generations_slice:
                for i in range(len(generations_slice)):
                    example = generations_slice.iloc[i]
                    df_matches_pairs.loc[idx, f'original_id_match_{i}'] = example['original_id']
                    df_matches_pairs.loc[idx, f'text_match_{i}'] = example['prediction']
                    # df_matches_pairs.loc[idx, f'generation_id_match_{i}'] = example['generation_index']
            else:
                for i in range(top_k):
                    example = generations_slice.sample(1).iloc[0]
                    df_matches_pairs.loc[idx, f'original_id_match_{i}'] = example['original_id']
                    df_matches_pairs.loc[idx, f'text_match_{i}'] = example['prediction']
                    # df_matches_pairs.loc[idx, f'generation_id_match_{i}'] = example['generation_index']
        df_matches_pairs['concept_intervention'] = [concept] * len(df_matches_pairs)
        df_matches_pairs['base_direction'] = [base_direction] * len(df_matches_pairs)
        df_matches_pairs['target_direction'] = [target_direction] * len(df_matches_pairs)
        return df_matches_pairs

    def __init__(self, all_generations_slice, description=None, top_k=1):
        self.all_generations_slice = all_generations_slice
        set_to_match = pd.read_csv(
            f'/home/XXXXXX/MatchingBasedCausalExplanation/sets/generative_explainers/nitagens/test_generations.csv')
        self.source = pd.read_csv('/home/XXXXXX/MatchingBasedCausalExplanation/sets/sources/test.csv')
        set_to_match['original_id'] = set_to_match['id'].apply(lambda x: str(int(x.split('_')[0]) + 2027))
        set_to_match['edit_id'] = set_to_match['id'].apply(lambda x: str(int(x.split('_')[1])))
        super().__init__(set_to_match=set_to_match, description=description, top_k=top_k)

    def get_explainer_description(self):
        if self.description is not None:
            return self.description
        return f'CF-Gen'

    def fit(self):
        raise NotImplemented

    def set_representation_model(self, model):
        print(f'no model in {self.get_explainer_description}, app-{self.approach}')

    def create_df_factual_counterfactual(self, cebab_counterfactual):
        if 'text' in self.source.columns:
            text_col = 'text'
            version = 0
        else:
            text_col = 'description'
            version = 1
        source_counterfactual_map = {text: {concept: [] for concept in CEBAB_CONCEPTS} for text in
                                     self.source[text_col]}

        for idx in self.source.index:
            example = self.source.loc[idx]
            intervention = example['edit_type']
            if intervention == 'None':
                continue
            text = example[text_col]
            counterfactuals = cebab_counterfactual[cebab_counterfactual['id'] == example['id']]
            source_counterfactual_map[text][intervention] = list(counterfactuals['prediction'])
            if version == 0:
                source_counterfactual_map[text][f'{intervention} label'] = example[intervention]
            else:
                source_counterfactual_map[text][f'{intervention} label'] = example[f'{intervention}_aspect_majority']

        df = pd.DataFrame.from_dict(source_counterfactual_map, orient='index')
        df['original_text'] = df.index
        # df.to_csv(path, index=False)
        return df
