import numpy as np

from explainers.explainer import Explainer
from explainers.matching_based_explainer import MatchingBasedExplainer
from utils.constants import COLS_MAP_CEBAB, COLS_MAP_STANCE
from utils.metric_utils import cosine_similarity_matrix, cosine_similarity_matrix
import pandas as pd


class ZeroExplainer(Explainer):
    def __init__(self):
        super().__init__()

    def get_explainer_description(self):
        return 'Zero Explainer'

    def icace_error(self, model, pairs, concept, base_direction, target_direction, save_outputs=False):
        def zero_explainer(row):
            prediction_base = row['prediction_base']
            prediction_counterfactual = row['prediction_counterfactual']
            explanation = np.zeros(len(prediction_base)) - prediction_base
            icace = np.array(prediction_counterfactual) - np.array(prediction_base)
            return np.linalg.norm(explanation - icace, ord=2)

        pairs['icace_error'] = pairs.apply(lambda row: zero_explainer(row), axis=1)
        return pairs


class RandomExplainer(Explainer):
    def __init__(self):
        super().__init__()

    def get_explainer_description(self):
        return 'Random Probability Explainer'

    def icace_error(self, model, pairs, concept, base_direction, target_direction, save_outputs=False):
        def _get_random_probability_vectors(L, N):
            p = np.random.uniform(size=(L, N))
            p = p / np.repeat(np.expand_dims(np.linalg.norm(p, axis=1, ord=1), -1), N, axis=-1)
            return p

        def random_explainer(row):
            prediction_base = row['prediction_base']
            prediction_counterfactual = row['prediction_counterfactual']
            icace = np.array(prediction_counterfactual) - np.array(prediction_base)
            explanation = _get_random_probability_vectors(1, len(icace)) - row['prediction_base']
            return np.linalg.norm(explanation - icace, ord=2)

        pairs['icace_error'] = pairs.apply(lambda row: random_explainer(row), axis=1)
        return pairs


#
# class AveragePredictions(Explainer):
#     def set_representation_model(self, model):
#         pass
#
#     def __init__(self, set_to_match, description=None):
#         super().__init__(set_to_match=set_to_match, description=description, top_k=None)
#         self.set_to_match = self.set_to_match.dropna(subset=['text'])
#
#     def icace_error(self, model, pairs, concept, base_direction, target_direction, save_outputs=False):
#         pairs = pairs.copy()
#         pairs = pairs.reset_index()
#
#         def icace_error_approach_1(row):
#             prediction_base = row['prediction_base']
#             prediction_counterfactual = row['prediction_counterfactual']
#
#             s = np.array([1 / len(prediction_base)] * len(prediction_base))
#
#             explanation = s - np.array(prediction_base)
#             icace = np.array(prediction_counterfactual) - np.array(prediction_base)
#
#             return np.linalg.norm(explanation - icace, ord=2)
#
#         pairs['icace_error'] = pairs.apply(lambda row: icace_error_approach_1(row), axis=1)
#
#         # drop None values of icace error
#         pairs = pairs.dropna(subset=['icace_error'])
#         if save_outputs:
#             self.save_matches(pairs, concept, base_direction, target_direction)
#         return pairs
#
#     def get_explainer_description(self):
#         if self.description is not None:
#             return f'Matching-{self.description}'
#
#         return f'Matching - Average Prediction'


class RandomMatching(MatchingBasedExplainer):
    def __init__(self, set_to_match, description, top_k=1, setup_name='cebab'):
        self.top_k = top_k
        self.setup_name = setup_name
        if self.setup_name == 'cebab':
            self.cols_map = COLS_MAP_CEBAB
        elif self.setup_name == 'stance':
            self.cols_map = COLS_MAP_STANCE
        else:
            raise ValueError(f'Unknown setup name: {self.setup_name}')
        super().__init__(set_to_match=set_to_match, description=description, top_k=top_k)

    def icace_error(self, model, pairs, concept, base_direction, target_direction, save_outputs=False):

        pairs = pairs.copy()
        pairs = pairs.reset_index()
        candidates = self.set_to_match[self.set_to_match[f'{concept}_label'] ==
                                       target_direction]
        candidates = candidates.reset_index()

        def compute_icace(row):
            if self.top_k > len(candidates):
                matches = candidates
            else:
                matches = candidates.sample(n=self.top_k)
            sum_predictions = np.array([0] * len(row[f'prediction_base']), dtype=float)
            for i in range(len(matches)):
                pairs.loc[row['index'], f'description_match_{i}'] = matches.iloc[i]['text']
                # pairs.loc[row['index'], f'prediction_match_{i}'] = model.get_predictions(matches.iloc[i]['text'])
                sum_predictions = sum_predictions + np.array(model.get_predictions(matches.iloc[i]['text'])).squeeze()
            average = sum_predictions / len(matches)

            explanation = average - np.array(row[f'prediction_base'])
            icace = np.array(row[f'prediction_counterfactual']) - np.array(row[f'prediction_base'])
            return np.linalg.norm(explanation - icace, ord=2)

        pairs['icace_error'] = pairs.apply(lambda row: compute_icace(row), axis=1)
        if save_outputs:
            self.save_matches(pairs, concept, base_direction, target_direction)
        return pairs

    def get_explainer_description(self):
        if self.description is not None:
            return self.description
        return 'Random Matching'

    def set_representation_model(self, model):
        print(f'no model in {self.get_explainer_description}')

    def get_matches_for_concept_cebab(self, pairs, concept, base_direction, target_direction, k):
        candidate_pairs = pairs[(pairs[self.cols_map['base_direction']] == base_direction) & (
                pairs[self.cols_map['target_direction']] == target_direction)]
        candidate_pairs = candidate_pairs[
            (pairs['intervention_type'] == concept)]

        candidates_matches = self.set_to_match[self.set_to_match[f'{concept}_label'] ==
                                               target_direction]

        candidates_matches = candidates_matches.reset_index()
        df_matches_pairs = pd.DataFrame()
        df_matches_pairs['original_id'] = candidate_pairs[self.cols_map['base_original_id']].values
        df_matches_pairs['text_base'] = candidate_pairs[self.cols_map['base_text']].values
        df_matches_pairs['text_counterfactual'] = candidate_pairs[self.cols_map['cf_text']].values

        for i in range(k):
            if len(candidate_pairs) > len(candidates_matches):
                matches = candidates_matches.sample(n=len(candidates_matches))
                for j in range(len(candidate_pairs) - len(candidates_matches)):
                    matches = pd.concat([matches, candidates_matches.sample(n=1)])
            else:
                matches = candidates_matches.sample(n=len(candidate_pairs))
            df_matches_pairs[f'text_match_{i}'] = matches['text'].values
            df_matches_pairs[f'id_match_{i}'] = matches['Unnamed: 0'].values

        df_matches_pairs['concept_intervention'] = [concept] * len(df_matches_pairs)
        df_matches_pairs['base_direction'] = [base_direction] * len(df_matches_pairs)
        df_matches_pairs['target_direction'] = [target_direction] * len(df_matches_pairs)

        return df_matches_pairs

    def get_matches_for_concept_stance(self, pairs, concept, base_direction, target_direction, top_k, batch_size=128):
        text_col = 'text'
        candidate_pairs = pairs[(pairs[f'{concept}_text'] == base_direction) & (
                pairs['edit_goal'] == target_direction) & (pairs['edit_type'] == concept)]

        domains = candidate_pairs['domain_text'].unique()
        if len(domains) == 0:
            return pd.DataFrame()
        cols_to_keep = ['id', 'text', 'label', 'edit_label', 'edit_text', 'edit_instruction', 'instruction',
                        'original_instruction', 'edit_id',
                        'edit_type', 'edit_goal']
        for col in candidate_pairs.columns:
            if ('preds' in col) or ('probs' in col):
                cols_to_keep.append(col)

        per_domains_matches = []
        for domain in domains:
            pairs_per_domain = candidate_pairs[candidate_pairs['domain_text'] == domain]
            if len(candidate_pairs) == 0:
                return pd.DataFrame()
            candidate_matches_per_domain = self.set_to_match[(self.set_to_match[f'{concept}_text'] ==
                                                              target_direction) & (
                                                                     self.set_to_match[f'domain_text'] == domain)]

            candidate_matches_per_domain = candidate_matches_per_domain.reset_index(drop=True)
            pairs_per_domain = pairs_per_domain.reset_index(drop=True)
            pairs_per_domain = pairs_per_domain.copy()
            df_matches_pairs_per_domain = pd.DataFrame()
            for col in cols_to_keep:
                df_matches_pairs_per_domain[col] = pairs_per_domain[col].values

            if 'prediction_base' in pairs_per_domain.columns:
                df_matches_pairs_per_domain['prediction_base'] = pairs_per_domain['prediction_base'].values
                df_matches_pairs_per_domain['prediction_counterfactual'] = pairs_per_domain[
                    'prediction_counterfactual'].values

            for i in range(top_k):
                if len(pairs_per_domain) > len(candidate_matches_per_domain):
                    matches = candidate_matches_per_domain.sample(n=len(candidate_matches_per_domain))
                    for j in range(len(pairs_per_domain) - len(candidate_matches_per_domain)):
                        matches = pd.concat([matches, candidate_matches_per_domain.sample(n=1)])
                else:
                    matches = candidate_matches_per_domain.sample(n=len(pairs_per_domain))
                for col in cols_to_keep:
                    df_matches_pairs_per_domain[f'{col}_match_{i}'] = matches[col].values
            per_domains_matches.append(df_matches_pairs_per_domain)

            # concat all the matches per domain
        df_matches_pairs = pd.concat(per_domains_matches, axis=0)
        # df_matches_pairs['concept_base'] = [concept] * len(df_matches_pairs)
        df_matches_pairs['treatment'] = [concept] * len(df_matches_pairs)
        df_matches_pairs['base_direction'] = [base_direction] * len(df_matches_pairs)
        df_matches_pairs['target_direction'] = [target_direction] * len(df_matches_pairs)

        return df_matches_pairs
# class Confounder_Explainer(MatchingBasedExplainer):
#     def __init__(self, set_to_match=None, description=None, representation_model=None, top_k=1,
#                  representation_model_per_concept=None):
#         super().__init__(set_to_match=set_to_match, description=description, representation_model=representation_model,
#                          representation_model_per_concept=representation_model_per_concept, top_k=top_k)
#
#     def get_matches_for_concept(self, pairs, concept, base_direction, target_direction, top_k, batch_size=128):
#         pairs = pairs[pairs['intervention_type'].notnull()]
#         candidate_pairs = pairs[
#             (pairs['intervention_type'] == concept) & (
#                     pairs['intervention_aspect_base'] == base_direction) & (
#                     pairs['intervention_aspect_counterfactual'] == target_direction)]
#
#         df_matches_pairs = pd.DataFrame()
#         # df_matches_pairs['original_id'] = candidate_pairs['original_id_base'].values
#         for i in range(len(candidate_pairs)):
#             pair = candidate_pairs.iloc[i]
#             original_id = pair['original_id_base']
#             gt_candidates_base = pairs[
#                 (pairs['description_base'] == pair['description_base']) & (
#                         pairs['intervention_type'] != concept)]
#             gt_candidates_base = gt_candidates_base[
#                 [col for col in gt_candidates_base.columns if 'counterfactual' in col]]
#             gt_candidates_base = gt_candidates_base.rename(columns=lambda x: x.replace('_counterfactual', ''))
#
#             gt_candidates_target = pairs[
#                 (pairs['description_counterfactual'] == pair['description_base']) & (
#                         pairs['intervention_type'] != concept)]
#             gt_candidates_target = gt_candidates_target[
#                 [col for col in gt_candidates_target.columns if 'base' in col]]
#             gt_candidates_target = gt_candidates_target.rename(columns=lambda x: x.replace('_base', ''))
#
#             gt_candidates = pd.concat([gt_candidates_base, gt_candidates_target])
#             gt_candidates = gt_candidates.rename(columns={'description': 'text'})
#             if len(gt_candidates) == 0:
#                 continue
#
#             df_matches_pairs.loc[i, 'text_base'] = pair['description_base']
#             df_matches_pairs.loc[i, 'text_counterfactual'] = pair['description_counterfactual']
#             df_matches_pairs.loc[i, f'original_id'] = pair['original_id_base']
#             k = len(gt_candidates)
#             for j in range(k):
#                 match = gt_candidates.iloc[j]
#                 df_matches_pairs.loc[i, f'text_match_{j}'] = match['text']
#                 df_matches_pairs.loc[i, f'similarity_match_{j}'] = [1]
#                 df_matches_pairs.loc[i, f'original_id_match_{j}'] = match['original_id']
#
#         df_matches_pairs['concept_intervention'] = [concept] * len(df_matches_pairs)
#         df_matches_pairs['base_direction'] = [base_direction] * len(df_matches_pairs)
#         df_matches_pairs['target_direction'] = [target_direction] * len(df_matches_pairs)
#         return df_matches_pairs
