import numpy as np
import pandas as pd
from tqdm import tqdm

from explainers.matching_based_explainer import MatchingBasedExplainer
from utils.constants import CEBAB_CONCEPTS, COLS_MAP_CEBAB, COLS_MAP_STANCE, STANCE_CONFOUNDERS
from utils.metric_utils import cosine_similarity_matrix


class Approx(MatchingBasedExplainer):

    def __init__(self, set_to_match, description, representation_model=None, top_k=1,
                 matching_approach=1, setup_name='cebab'):
        self.setup_name = setup_name
        super().__init__(set_to_match=set_to_match, representation_model=representation_model, description=description,
                         top_k=top_k)
        self.set_to_match = self.set_to_match.dropna(subset=['text'])
        self.approach = 'sample_approx'
        self.matching_approach = matching_approach
        if self.setup_name == 'cebab':
            self.cols_map = COLS_MAP_CEBAB
        elif self.setup_name == 'stance':
            self.cols_map = COLS_MAP_STANCE
        else:
            raise ValueError(f'Unknown setup name: {self.setup_name}')
        if self.representation_model is not None:
            self.approach = 'exact_approx'
            self.set_to_match['embeddings'] = self.representation_model.get_embeddings(
                list(self.set_to_match['text'].values))

    def find_match_idx(self, candidates, example):
        indexes = candidates.index
        if len(indexes) == 1:
            return indexes.values[0]
        candidates = candidates.reset_index()

        candidates_embeddings = candidates['embeddings'].values
        # embeddings bases
        bases_embeddings = self.representation_model.get_embeddings([example[self.cols_map['base_text']]])

        dist_mat = cosine_similarity_matrix(bases_embeddings, candidates_embeddings)

        match_temp_idx = list(np.argmax(dist_mat, axis=1))
        match_original_idx = indexes[match_temp_idx].values[0]
        return match_original_idx

    def set_representation_model(self, model):
        self.representation_model = model
        if self.approach == 'exact_approx':
            self.set_to_match['embeddings'] = self.representation_model.get_embeddings(
                list(self.set_to_match['text'].values))

    def get_matches_for_concept_cebab(self, pairs, concept, base_direction, target_direction, top_k, batch_size=128):
        candidate_pairs = pairs[(pairs['intervention_base'] == base_direction) & (
                pairs['intervention_counterfactual'] == target_direction) & (pairs['intervention_type'] == concept)]

        candidates_matches = self.set_to_match.copy()
        encode_str = {'tensor(0)': 'unknown', 'tensor(1)': 'Negative', 'tensor(2)': 'Positive'}
        for c in CEBAB_CONCEPTS:
            candidate_pairs[f'{c}_predictions_base'] = candidate_pairs[f'{c}_predictions_base'].map(encode_str)
            candidates_matches[f'{c}_predictions'] = candidates_matches[f'{c}_predictions'].map(encode_str)

        df_matches_pairs = pd.DataFrame()

        candidate_pairs = candidate_pairs.copy().reset_index()

        for idx in candidate_pairs.index:
            row = candidate_pairs.loc[idx]
            candidates_matches_per_idx = candidates_matches[
                candidates_matches[f'{concept}_label'] == target_direction].copy()
            for c in CEBAB_CONCEPTS:
                if c != concept:
                    candidates_matches_per_idx = candidates_matches_per_idx[
                        candidates_matches_per_idx[f'{c}_predictions'] == row[f'{c}_predictions_base']]

            if len(candidates_matches_per_idx) == 0:
                continue
            if len(candidates_matches_per_idx) < top_k:
                matches = candidates_matches_per_idx.sample(n=len(candidates_matches_per_idx))
                for i in range(top_k - len(candidates_matches_per_idx)):
                    matches = pd.concat([matches, candidates_matches_per_idx.sample(n=1)])
            else:
                matches = candidates_matches_per_idx.sample(n=top_k)
            for i in range(top_k):
                df_matches_pairs.loc[idx, 'original_id'] = row[self.cols_map['base_original_id']]
                df_matches_pairs.loc[idx, 'text_base'] = row[self.cols_map['base_text']]
                df_matches_pairs.loc[idx, 'text_counterfactual'] = row[self.cols_map['cf_text']]
                df_matches_pairs.loc[idx, f'original_id_match_{i}'] = matches.iloc[i]['Unnamed: 0']
                df_matches_pairs.loc[idx, f'text_match_{i}'] = matches.iloc[i]['text']

        df_matches_pairs['concept_intervention'] = [concept] * len(df_matches_pairs)
        df_matches_pairs['base_direction'] = [base_direction] * len(df_matches_pairs)
        df_matches_pairs['target_direction'] = [target_direction] * len(df_matches_pairs)
        return df_matches_pairs

    def get_matches_for_concept_stance(self, pairs, concept, base_direction, target_direction, top_k, batch_size=128):
        candidate_pairs = pairs[(pairs[f'{concept}_text'] == base_direction) & (
                pairs['edit_goal'] == target_direction) & (pairs['edit_type'] == concept)]

        if len(candidate_pairs) == 0:
            return pd.DataFrame()

        candidates_matches = self.set_to_match.copy().reset_index(drop=True)
        confounders = [c for c in STANCE_CONFOUNDERS if c != concept]

        candidate_pairs = candidate_pairs.copy().reset_index()
        cols_to_keep = ['id', 'text', 'label', 'edit_label', 'edit_text', 'edit_instruction', 'instruction',
                        'original_instruction', 'edit_id',
                        'edit_type', 'edit_goal']
        for col in candidate_pairs.columns:
            if ('preds' in col) or ('probs' in col):
                cols_to_keep.append(col)

        df_matches_pairs = pd.DataFrame()
        for col in cols_to_keep:
            df_matches_pairs[col] = candidate_pairs[col].values

        for idx in tqdm(candidate_pairs.index):
            row = candidate_pairs.loc[idx]
            candidates_matches_per_idx = candidates_matches[
                candidates_matches[f'{concept}_text'] == target_direction].copy()
            for c in confounders:
                if c != concept:
                    candidates_matches_per_idx = candidates_matches_per_idx[
                        candidates_matches_per_idx[f'text_to_{c}_preds'] == row[f'text_to_{c}_preds']]

            if len(candidates_matches_per_idx) == 0:
                continue
            if len(candidates_matches_per_idx) < top_k:
                matches = candidates_matches_per_idx.sample(n=len(candidates_matches_per_idx))
                for i in range(top_k - len(candidates_matches_per_idx)):
                    matches = pd.concat([matches, candidates_matches_per_idx.sample(n=1)])
            else:
                matches = candidates_matches_per_idx.sample(n=top_k)
            for i in range(top_k):
                for c in cols_to_keep:
                    if 'edit' in c:
                        continue
                    df_matches_pairs.loc[idx, c] = row[c]
                    df_matches_pairs.loc[idx, f'{c}_match_{i}'] = matches.iloc[i][c]

        df_matches_pairs['concept_intervention'] = [concept] * len(df_matches_pairs)
        df_matches_pairs['base_direction'] = [base_direction] * len(df_matches_pairs)
        df_matches_pairs['target_direction'] = [target_direction] * len(df_matches_pairs)
        return df_matches_pairs
