from explainers.matching_based_explainer import MatchingBasedExplainer
import pandas as pd
import numpy as np

from utils.constants import COLS_MAP_STANCE, COLS_MAP_CEBAB


class Generative(MatchingBasedExplainer):

    def __init__(self, version, approach=2, description=None, all_generations_slice=True, top_k=1, setup_name='cebab',
                 level=''):
        if setup_name == 'cebab':
            set_to_match = pd.read_csv(
                f'sets/generative_explainers/gpt-3.5-turbo_{version}/test_generations.csv')
        elif setup_name == 'stance':
            if level != '':
                level = f'_{level}'
            set_to_match = pd.read_csv(
                f'/home/XXXXXX/MatchingBasedCausalExplanation/sets/stance_detection/{version}_preds_included{level}.csv')
            set_to_match = set_to_match[set_to_match['split'] == 'test']
        self.setup_name = setup_name
        self.version = version
        self.all_generations_slice = all_generations_slice
        # for this explainer the set to match is a set of generations.\
        self.approach = approach
        super().__init__(set_to_match=set_to_match, description=description, top_k=top_k)
        if setup_name == 'cebab':
            self.cols_map = COLS_MAP_CEBAB
            self.set_to_match = self.set_to_match.dropna(subset=['generation'])
        elif setup_name == 'stance':
            self.cols_map = COLS_MAP_STANCE

    def get_explainer_description(self):
        if self.description:
            return f'{self.description}'
        return f'gpt-3.5-turbo_{self.version}'

    def set_representation_model(self, model):
        print(f'no model in {self.get_explainer_description}, app-{self.approach}')

    def get_matches_for_concept_cebab(self, pairs, concept, base_direction, target_direction, top_k, batch_size=128):
        candidate_pairs = pairs[(pairs[self.cols_map['base_direction']] == base_direction) & (
                pairs[self.cols_map['target_direction']] == target_direction)]
        candidate_pairs = candidate_pairs[
            (pairs['intervention_type'] == concept)]
        candidates_generations = self.set_to_match[(self.set_to_match[f'intervention_aspect'] == concept) & (
                self.set_to_match[f'target_direction'] == target_direction)]

        df_matches_pairs = pd.DataFrame()
        df_matches_pairs['original_id'] = candidate_pairs[self.cols_map['base_original_id']].values
        df_matches_pairs['text_base'] = candidate_pairs[self.cols_map['base_text']].values
        df_matches_pairs['text_counterfactual'] = candidate_pairs[self.cols_map['cf_text']].values
        if self.setup_name == 'stance':
            df_matches_pairs['instruction_base'] = candidate_pairs['instruction'].values
            df_matches_pairs['original_instruction_base'] = candidate_pairs['original_instruction'].values
            df_matches_pairs['instruction_counterfactual'] = candidate_pairs['edit_instruction'].values
            df_matches_pairs['instruction_to_label_probs_base'] = candidate_pairs['instruction_to_label_probs'].values
            df_matches_pairs['instruction_to_label_probs_counterfactual'] = candidate_pairs[
                'edit_instruction_to_label_probs'].values
            df_matches_pairs['original_instruction_to_label_probs_base'] = candidate_pairs[
                'original_instruction_to_label_probs'].values
        candidate_pairs = candidate_pairs.copy().reset_index()
        for idx in candidate_pairs.index:
            row = candidate_pairs.loc[idx]
            original_id = row['original_id_base']

            generations_slice = candidates_generations[candidates_generations['original_id'] == original_id]
            if len(generations_slice) == 0:
                continue
            if self.all_generations_slice:
                for i in range(len(generations_slice)):
                    example = generations_slice.iloc[i]
                    df_matches_pairs.loc[idx, f'text_match_{i}'] = example[self.cols_map['generation']]
                    df_matches_pairs.loc[idx, f'original_id_match_{i}'] = example['original_id']
                    df_matches_pairs.loc[idx, f'generation_id_match_{i}'] = example['generation_index']

            else:
                for i in range(top_k):
                    example = generations_slice.sample(1).iloc[0]
                    df_matches_pairs.loc[idx, f'original_id_match_{i}'] = example['original_id']
                    df_matches_pairs.loc[idx, f'text_match_{i}'] = example[self.cols_map['generation']]
                    df_matches_pairs.loc[idx, f'generation_id_match_{i}'] = example['generation_index']

        df_matches_pairs['concept_intervention'] = [concept] * len(df_matches_pairs)
        df_matches_pairs['base_direction'] = [base_direction] * len(df_matches_pairs)
        df_matches_pairs['target_direction'] = [target_direction] * len(df_matches_pairs)
        return df_matches_pairs

    def get_matches_for_concept_stance(self, pairs, concept, base_direction, target_direction, top_k, batch_size=128):
        candidate_pairs = pairs[(pairs[f'{concept}_text'] == base_direction) & (
                pairs['edit_goal'] == target_direction) & (pairs['edit_type'] == concept)]

        candidates_generations = self.set_to_match[
            (self.set_to_match[f'edit_type'] == concept) & (self.set_to_match[f'edit_goal'] == target_direction)]
        cols_to_keep = ['id', 'text', 'label', 'edit_label', 'edit_text', 'edit_instruction', 'instruction',
                        'original_instruction', 'edit_id',
                        'edit_type', 'edit_goal']
        for col in candidate_pairs.columns:
            if ('preds' in col) or ('probs' in col):
                cols_to_keep.append(col)

        df_matches_pairs = pd.DataFrame()
        for col in cols_to_keep:
            df_matches_pairs[col] = candidate_pairs[col].values

        candidate_pairs = candidate_pairs.copy().reset_index()
        for idx in candidate_pairs.index:
            row = candidate_pairs.loc[idx]
            original_id = row['original_id']
            generations_slice = candidates_generations[candidates_generations['original_id'] == original_id]
            # make sure the generations are from the same domain
            generations_slice = generations_slice[generations_slice['domain'] == row['domain_text']]
            if len(generations_slice) == 0:
                continue
            if self.all_generations_slice:
                for i in range(len(generations_slice)):
                    example = generations_slice.iloc[i]
                    for c in cols_to_keep:
                        if 'edit' in c:
                            c_changed = c.split('edit_')[1]
                            c_changed = c_changed + '_match_' + str(i)
                            df_matches_pairs.loc[idx, c_changed] = example[c]
                        else:
                            continue
            else:
                for i in range(top_k):
                    example = generations_slice.sample(1).iloc[0]
                    for c in cols_to_keep:
                        if 'edit' in c:
                            c_changed = c.split('edit_')[1]
                            c_changed = c_changed + '_match_' + str(i)
                            df_matches_pairs.loc[idx, c_changed] = example[c]
                        else:
                            continue

        df_matches_pairs['concept_intervention'] = [concept] * len(df_matches_pairs)
        df_matches_pairs['base_direction'] = [base_direction] * len(df_matches_pairs)
        df_matches_pairs['target_direction'] = [target_direction] * len(df_matches_pairs)
        return df_matches_pairs
