import numpy as np
import pandas as pd
import ast
from explainers.matching_based_explainer import MatchingBasedExplainer
from utils.constants import COLS_MAP_CEBAB, COLS_MAP_STANCE, CONFOUNDERS_MAPS


class Propensity(MatchingBasedExplainer):

    def __init__(self, set_to_match, description, representation_model=None, top_k=1, k_th=None, threshold=0.5,
                 assign=False, adding_prompt=False, setup_name='cebab'):
        super().__init__(set_to_match=set_to_match, description=description, representation_model=representation_model,
                         top_k=top_k)
        self.top_k = top_k
        self.k_th = k_th
        self.threshold = threshold
        self.set_to_match = self.set_to_match.dropna(subset=['text'])
        self.adding_prompt = adding_prompt
        self.assign = assign
        self.setup_name = setup_name
        if self.setup_name == 'cebab':
            self.cols_map = COLS_MAP_CEBAB
        elif self.setup_name == 'stance':
            self.cols_map = COLS_MAP_STANCE
        else:
            raise ValueError(f'Unknown setup name: {self.setup_name}')
        # load aspects models

    def get_matches_for_concept_cebab(self, pairs, concept, base_direction, target_direction, top_k, batch_size=128):
        candidate_pairs = pairs[(pairs[self.cols_map['base_direction']] == base_direction) & (
                pairs[self.cols_map['target_direction']] == target_direction)]
        candidate_pairs = candidate_pairs[
            (pairs['intervention_type'] == concept)]

        if len(candidate_pairs) == 0:
            return pd.DataFrame()
        candidate_matches = self.set_to_match[self.set_to_match[f'{concept}_label'] ==
                                              target_direction]

        candidate_matches = candidate_matches.reset_index(drop=True)
        candidate_pairs = candidate_pairs.reset_index(drop=True)
        candidate_pairs_copy = candidate_pairs.copy()
        if self.setup_name == 'cebab':
            probs_matching_set = candidate_matches[f'{concept}_probs'].values
            probs_base = candidate_pairs_copy[f'{concept}_probs_base'].values
        elif self.setup_name == 'stance':
            # make it array
            candidate_matches[f'text_to_speaker_probs'] = candidate_matches[f'text_to_speaker_probs'].apply(
                lambda x: ast.literal_eval(x))
            probs_matching_set = candidate_matches[f'text_to_speaker_probs'].values
            candidate_pairs_copy[f'text_to_speaker_probs'] = candidate_pairs_copy[f'text_to_speaker_probs'].apply(
                lambda x: ast.literal_eval(x))
            probs_base = candidate_pairs_copy[f'text_to_speaker_probs'].values
        else:
            raise ValueError(f'Unknown setup name: {self.setup_name}')
        if self.setup_name == 'cebab':
            if target_direction == 'Positive':
                idx = 2
            elif target_direction == 'Negative':
                idx = 1
            elif target_direction == 'unknown':
                idx = 0
            else:
                raise ValueError(f'Unknown direction {target_direction}')
        elif self.setup_name == 'stance':
            if target_direction == 'poet':
                idx = 0
            elif target_direction == '5-year-old':
                idx = 1
            elif target_direction == 'professor':
                idx = 2
            elif target_direction == 'farmer':
                idx = 3

        propensity_matches = [p[idx] for p in probs_matching_set]
        propensity_base = [p[idx] for p in probs_base]
        distances = [[np.abs(p1 - p2) for p2 in propensity_matches] for p1 in propensity_base]
        distances = np.array(distances)
        matches_indexes = np.argsort(distances, axis=1)[:, :top_k]
        df_matches_pairs = pd.DataFrame()
        df_matches_pairs['original_id'] = candidate_pairs[self.cols_map['base_original_id']].values
        df_matches_pairs['text_base'] = candidate_pairs[self.cols_map['base_text']].values
        df_matches_pairs['text_counterfactual'] = candidate_pairs[self.cols_map['cf_text']].values
        if self.setup_name == 'stance':
            df_matches_pairs['instruction_base'] = candidate_pairs['instruction'].values
            df_matches_pairs['instruction_counterfactual'] = candidate_pairs['edit_instruction'].values
            df_matches_pairs['original_instruction_base'] = candidate_pairs['original_instruction'].values
            df_matches_pairs['instruction_to_label_probs_base'] = candidate_pairs['instruction_to_label_probs'].values
            df_matches_pairs['instruction_to_label_probs_counterfactual'] = candidate_pairs[
                'edit_instruction_to_label_probs'].values
            df_matches_pairs['original_instruction_to_label_probs_base'] = candidate_pairs[
                'original_instruction_to_label_probs'].values

        for i in range(top_k):
            match = candidate_matches.iloc[matches_indexes[:, i]]
            if self.setup_name == 'cebab':
                df_matches_pairs[f'original_id_match_{i}'] = match['Unnamed: 0'].values
            elif self.setup_name == 'stance':
                df_matches_pairs[f'original_id_match_{i}'] = match['id'].values
                df_matches_pairs[f'instruction_match_{i}'] = match['instruction'].values
                df_matches_pairs[f'original_instruction_match_{i}'] = match['original_instruction'].values
                df_matches_pairs[f'instruction_to_label_probs_match_{i}'] = match[
                    'instruction_to_label_probs'].values
                df_matches_pairs[f'original_instruction_to_label_probs_match_{i}'] = match[
                    'original_instruction_to_label_probs'].values
            else:
                raise ValueError(f'Unknown setup name: {self.setup_name}')
            df_matches_pairs[f'text_match_{i}'] = match['text'].values
            # TODO add original id to matches
        df_matches_pairs['concept_intervention'] = [concept] * len(df_matches_pairs)
        df_matches_pairs['base_direction'] = [base_direction] * len(df_matches_pairs)
        df_matches_pairs['target_direction'] = [target_direction] * len(df_matches_pairs)
        return df_matches_pairs

    def get_matches_for_concept_stance(self, pairs, concept, base_direction, target_direction, top_k, batch_size=128):
        text_col = 'text'
        candidate_pairs = pairs[(pairs[f'{concept}_text'] == base_direction) & (
                pairs['edit_goal'] == target_direction) & (pairs['edit_type'] == concept)]

        cols_to_keep = ['id', 'text', 'label', 'edit_label', 'edit_text', 'edit_instruction', 'instruction',
                        'original_instruction', 'edit_id',
                        'edit_type', 'edit_goal']
        for col in candidate_pairs.columns:
            if ('preds' in col) or ('probs' in col):
                cols_to_keep.append(col)

        domains = candidate_pairs['domain'].unique()
        if len(domains) == 0:
            return pd.DataFrame()
        candidate_matches = self.set_to_match[(self.set_to_match[f'{concept}_text'] ==
                                               target_direction)]
        candidate_matches[f'{text_col}_to_{concept}_probs'] = candidate_matches[f'text_to_{concept}_probs'].apply(
            lambda x: ast.literal_eval(x))
        candidate_pairs[f'{text_col}_to_{concept}_probs'] = candidate_pairs[f'text_to_{concept}_probs'].apply(
            lambda x: ast.literal_eval(x))

        per_domains_matches = []
        for domain in domains:
            pairs_per_domain = candidate_pairs[candidate_pairs['domain'] == domain]
            if len(pairs_per_domain) == 0:
                return pd.DataFrame()
            candidate_matches_per_domain = candidate_matches[(candidate_matches[f'{concept}_text'] ==
                                                              target_direction) & (
                                                                     candidate_matches['domain'] == domain)]

            candidate_matches_per_domain = candidate_matches_per_domain.reset_index(drop=True)
            probs_matching_set = candidate_matches_per_domain[f'text_to_{concept}_probs'].values

            pairs_per_domain = pairs_per_domain.reset_index(drop=True)
            pairs_per_domain = pairs_per_domain.copy()
            probs_base = pairs_per_domain[f'{text_col}_to_{concept}_probs'].values

            propensity_matches = [p[CONFOUNDERS_MAPS[concept][target_direction]] for p in probs_matching_set]
            propensity_base = [p[CONFOUNDERS_MAPS[concept][target_direction]] for p in probs_base]
            distances = [[np.abs(p1 - p2) for p2 in propensity_matches] for p1 in propensity_base]
            distances = np.array(distances)
            matches_indexes = np.argsort(distances, axis=1)[:, :top_k]
            df_matches_pairs_per_domain = pd.DataFrame()
            for col in cols_to_keep:
                df_matches_pairs_per_domain[col] = pairs_per_domain[col].values

            for i in range(top_k):
                match = candidate_matches.iloc[matches_indexes[:, i]]
                for col in cols_to_keep:
                    df_matches_pairs_per_domain[f'{col}_match_{i}'] = match[col].values

            per_domains_matches.append(df_matches_pairs_per_domain)

        # concat all the matches per domain
        df_matches_pairs = pd.concat(per_domains_matches, axis=0)
        df_matches_pairs['treatment'] = [concept] * len(df_matches_pairs)
        # df_matches_pairs['concept_counterfactual'] = [concept] * len(df_matches_pairs)
        df_matches_pairs['base_direction'] = [base_direction] * len(df_matches_pairs)
        df_matches_pairs['target_direction'] = [target_direction] * len(df_matches_pairs)

        return df_matches_pairs
