import ast

import numpy as np
import pandas as pd
import torch
from utils.constants import DEVICE, COLS_MAP_CEBAB, COLS_MAP_STANCE
from explainers.matching_based_explainer import MatchingBasedExplainer
from utils.metric_utils import cosine_similarity_matrix, cace


class Matching(MatchingBasedExplainer):
    def __init__(self, set_to_match, description, representation_model=None, top_k=1, k_th=None, threshold=0,
                 assign=False, representation_model_per_concept=None, adding_prompt=False,
                 cfc_include=False, tcf_include=False, model_to_device=False, setup_name='cebab',
                 text_col='text'):
        super().__init__(set_to_match=set_to_match, description=description, representation_model=representation_model,
                         representation_model_per_concept=representation_model_per_concept, top_k=top_k)
        self.top_k = top_k
        self.cfc_include = cfc_include
        self.tcf_include = tcf_include
        self.k_th = k_th
        self.threshold = threshold
        # self.set_to_match = self.set_to_match.dropna(subset=[self.text_col])
        self.adding_prompt = adding_prompt
        self.assign = assign
        self.setup_name = setup_name
        self.text_col = text_col
        if self.setup_name == 'cebab':
            self.cols_map = COLS_MAP_CEBAB
        elif self.setup_name == 'stance':
            self.cols_map = COLS_MAP_STANCE
        else:
            raise ValueError(f'Unknown setup name: {self.setup_name}')
        if self.assign and representation_model is not None:
            self.set_to_match['embeddings'] = self.representation_model.get_embeddings(
                list(self.set_to_match[self.text_col].values))
        if self.assign and representation_model_per_concept is not None:
            self.matching_set_embeddings_per_concept = {}
            for concept in self.representation_model_per_concept.keys():
                matching_embeddings = self.representation_model_per_concept[
                    concept].get_embeddings(list(self.set_to_match[self.text_col].values))
                # self.set_to_match[f'{concept}_embeddings'] = self.representation_model_per_concept[
                #     concept].get_embeddings(list(self.set_to_match[self.text_col].values))
                # send the embeddings to gpu
                matching_embeddings = [np.array(embedding, dtype=np.float32) for embedding in matching_embeddings]
                matching_embeddings = torch.from_numpy(np.array(matching_embeddings)).to(DEVICE)
                self.matching_set_embeddings_per_concept[concept] = matching_embeddings
                matching_embeddings = [np.array(embedding.cpu(), dtype=np.float32) for embedding in matching_embeddings]
                self.set_to_match[f'{concept}_embeddings'] = matching_embeddings

    def icace_error(self, model, pairs, concept, base_direction, target_direction,
                    save_outputs=False):
        if self.setup_name == 'cebab':
            candidates_pairs = self.get_matches_for_concept_cabab(pairs, concept, base_direction, target_direction,
                                                                  top_k=self.top_k)
        elif self.setup_name == 'stance':
            candidates_pairs = self.get_matches_for_concept_stance(pairs, concept, base_direction, target_direction,
                                                                   top_k=self.top_k)

        if self.top_k is not None:
            k_to_match = self.top_k
            if len(candidates_pairs) <= self.top_k:
                k_to_match = len(candidates_pairs) - 1
            for j in range(k_to_match):
                # TODO remove self.top_candidate
                if self.setup_name == 'cebab':
                    candidates_pairs[f'prediction_match_{j}'] = model.get_predictions(
                        list(candidates_pairs[f'text_match_{j}'].values))
                elif self.setup_name == 'stance':
                    candidates_pairs[f'prediction_match_{j}'] = candidates_pairs[f'{model}_probs_match_{j}'].apply(
                        lambda x: ast.literal_eval(x))

        if self.top_k is not None:
            k_to_match = self.top_k
            if len(candidates_pairs) <= self.top_k:
                k_to_match = len(candidates_pairs) - 1

        def icace_error_approach_1(row):
            under_the_threshold = 0
            prediction_base = row['prediction_base']
            prediction_counterfactual = row['prediction_counterfactual']
            s = np.array([0] * len(prediction_base))
            for i in range(k_to_match):
                if row[f'similarity_match_{i}'] > self.threshold:
                    s = s + np.array(row[f'prediction_match_{i}'])
                    under_the_threshold += 1
            if under_the_threshold == 0:
                return None
            else:
                explanation = s / under_the_threshold - np.array(prediction_base)
                icace = np.array(prediction_counterfactual) - np.array(prediction_base)

                return np.linalg.norm(explanation - icace, ord=2)

        candidates_pairs['icace_error'] = candidates_pairs.apply(lambda row: icace_error_approach_1(row), axis=1)

        # drop None values of icace error
        candidates_pairs = candidates_pairs.dropna(subset=['icace_error'])
        if save_outputs:
            self.save_matches(candidates_pairs, concept, base_direction, target_direction)
        return candidates_pairs

    def set_representation_model(self, model):
        self.representation_model = model
        self.set_to_match['embeddings'] = self.representation_model.get_embeddings(
            list(self.set_to_match[self.text_col].values))
        print(f'\nlm is changed and embeddings are updated for the explainer {self.get_explainer_description()}')

    def set_representation_model_per_concept(self, model, concepts):
        for concept in concepts:
            self.representation_model_per_concept[concept] = model
            self.set_to_match[f'{concept}_embeddings'] = self.representation_model_per_concept[concept].get_embeddings(
                list(self.set_to_match[self.text_col].values))
        print(
            f'\nlm is changed and embeddings are updated for the explainer {self.get_explainer_description()}, for concepts {concepts}')

    def get_matches_for_concept_cebab(self, pairs, concept, base_direction, target_direction, top_k, batch_size=128):
        candidate_pairs = pairs[(pairs['intervention_base'] == base_direction) & (
                pairs['intervention_counterfactual'] == target_direction) & (pairs['intervention_type'] == concept)]

        if len(candidate_pairs) == 0:
            return pd.DataFrame()
        if self.tcf_include or self.cfc_include:
            return self.get_matches_for_concept_with_gt(pairs, concept, base_direction, target_direction, top_k,
                                                        batch_size)
        candidate_matches = self.set_to_match[self.set_to_match[f'{concept}_label'] ==
                                              target_direction]

        candidate_matches = candidate_matches.reset_index(drop=True)
        candidate_pairs = candidate_pairs.reset_index(drop=True)
        candidate_pairs = candidate_pairs.copy()

        # if 'causal_embeddings_base' in candidate_pairs:
        #     base_embeddings = candidate_pairs['causal_embeddings_base'].values
        if self.representation_model_per_concept:
            base_embeddings = self.representation_model_per_concept[concept].get_embeddings(
                list(candidate_pairs['description_base'].values))
        elif self.representation_model:
            # embedding for pairs base
            base_embeddings = self.representation_model.get_embeddings(
                list(candidate_pairs['description_base'].values))

        # embeddings for matches
        if self.representation_model_per_concept:
            if self.assign:
                if self.matching_set_embeddings_per_concept != {}:
                    matches_candidates_embeddings = self.matching_set_embeddings_per_concept[concept][
                        candidate_matches.index.values]
                else:
                    matches_candidates_embeddings = self.set_to_match[f'{concept}_embeddings'].values
            else:
                matches_candidates_embeddings = self.representation_model_per_concept[concept].get_embeddings(
                    list(candidate_matches[self.text_col].values))
        elif self.representation_model:
            if self.assign:
                matches_candidates_embeddings = candidate_matches[f'embeddings'].values
            else:
                matches_candidates_embeddings = self.representation_model.get_embeddings(
                    list(candidate_matches[self.text_col].values))
        else:
            raise ValueError('representation model is not set')

        dist_mat = cosine_similarity_matrix(base_embeddings, matches_candidates_embeddings, already_on_device=False)
        k = top_k
        if len(matches_candidates_embeddings) <= k:
            k = len(matches_candidates_embeddings) - 1
        if k == 1:
            matches_indexes = np.argmax(dist_mat, axis=1)
            matches_real_indexes = candidate_matches.iloc[matches_indexes].index
            similarities_values = np.max(dist_mat, axis=1)
        else:
            matches_indexes = np.argsort(-dist_mat, axis=1)[:, :k]
            similarities_values = np.take_along_axis(dist_mat, matches_indexes, axis=1)
            # take the indexes of the matches
            matches_real_indexes = []
            for i in range(len(matches_indexes)):
                matches_real_indexes.append(list(candidate_matches.iloc[matches_indexes[i]].index))

        df_matches_pairs = pd.DataFrame()
        df_matches_pairs['original_id'] = candidate_pairs[self.cols_map['base_original_id']].values
        df_matches_pairs['text_base'] = candidate_pairs[self.cols_map['base_text']].values
        df_matches_pairs['text_counterfactual'] = candidate_pairs[self.cols_map['cf_text']].values
        if 'prediction_base' in candidate_pairs.columns:
            df_matches_pairs['prediction_base'] = candidate_pairs['prediction_base'].values
            df_matches_pairs['prediction_counterfactual'] = candidate_pairs['prediction_counterfactual'].values

        for i in range(k):
            if len(matches_indexes.shape) == 1:
                match = candidate_matches.loc[matches_real_indexes[i]]
                df_matches_pairs[f'original_id_match_{i}'] = match['Unnamed: 0']
                df_matches_pairs[f'text_match_{i}'] = match[self.text_col]
                df_matches_pairs[f'similarity_match_{i}'] = similarities_values[i]

            else:
                match = candidate_matches.loc[[lst[i] for lst in matches_indexes]]
                df_matches_pairs[f'original_id_match_{i}'] = match['Unnamed: 0'].values
                df_matches_pairs[f'text_match_{i}'] = match[self.text_col].values
                df_matches_pairs[f'similarity_match_{i}'] = similarities_values[:, i]

        df_matches_pairs['concept_intervention'] = [concept] * len(df_matches_pairs)
        df_matches_pairs['base_direction'] = [base_direction] * len(df_matches_pairs)
        df_matches_pairs['target_direction'] = [target_direction] * len(df_matches_pairs)
        return df_matches_pairs

    def get_matches_for_concept_with_gt(self, pairs, concept, base_direction, target_direction, top_k, batch_size=128):
        k = top_k
        # drop pairs with intervention type None
        pairs = pairs[pairs['intervention_type'].notnull()]
        candidate_pairs = pairs[
            (pairs['intervention_type'] == concept) & (
                    pairs['intervention_aspect_base'] == base_direction) & (
                    pairs['intervention_aspect_counterfactual'] == target_direction)]

        if self.representation_model_per_concept:
            candidate_pairs[f'{concept}_embeddings_base'] = self.representation_model_per_concept[
                concept].get_embeddings(
                list(candidate_pairs['description_base'].values))

        elif self.representation_model:

            candidate_pairs[f'{concept}_embeddings_base'] = self.representation_model.get_embeddings(
                list(candidate_pairs['description_base'].values))

        df_matches_pairs = pd.DataFrame()
        for i in range(len(candidate_pairs)):
            pair = candidate_pairs.iloc[i]
            original_id = pair['original_id_base']
            if self.cfc_include and not self.tcf_include:
                gt_candidates_base = pairs[
                    (pairs['description_base'] == pair['description_base']) & (
                            pairs['intervention_type'] != concept)]
                gt_candidates_base = gt_candidates_base[
                    [col for col in gt_candidates_base.columns if 'counterfactual' in col]]
                gt_candidates_base = gt_candidates_base.rename(columns=lambda x: x.replace('_counterfactual', ''))

                gt_candidates_target = pairs[
                    (pairs['description_counterfactual'] == pair['description_base']) & (
                            pairs['intervention_type'] != concept)]
                gt_candidates_target = gt_candidates_target[
                    [col for col in gt_candidates_target.columns if 'base' in col]]
                gt_candidates_target = gt_candidates_target.rename(columns=lambda x: x.replace('_base', ''))

                gt_candidates = pd.concat([gt_candidates_base, gt_candidates_target])
                gt_candidates = gt_candidates.rename(columns={'description': self.text_col})
                if len(gt_candidates) > 0:
                    gt_candidates['cace'] = gt_candidates.apply(lambda x: cace(row_tcf=pair, row_cfc=x), axis=1)
                    gt_candidates = gt_candidates[gt_candidates['cace'] > 0.01]
                    if len(gt_candidates) > 0:
                        gt_candidates[f'{concept}_embeddings'] = self.representation_model_per_concept[
                            concept].get_embeddings(
                            list(gt_candidates[self.text_col].values))
                    else:
                        continue
                else:
                    continue

            elif self.tcf_include:
                gt_candidates = pd.DataFrame(pair).T
                gt_candidates = gt_candidates[
                    [col for col in gt_candidates.columns if 'counterfactual' in col]]
                gt_candidates = gt_candidates.rename(columns=lambda x: x.replace('_counterfactual', ''))
                gt_candidates = gt_candidates.rename(columns={'description': self.text_col})
                gt_candidates[f'{concept}_embeddings'] = self.representation_model_per_concept[
                    concept].get_embeddings(
                    list(gt_candidates[self.text_col].values))

            else:
                raise ValueError('only_cf and tcf_include cannot be both false')
            no_gt_candidates = self.set_to_match[self.set_to_match[f'{concept}_label'] == target_direction]
            candidates = pd.concat([gt_candidates, no_gt_candidates])
            if self.representation_model_per_concept:
                matches_candidates_embeddings = candidates[f'{concept}_embeddings'].values

            elif self.representation_model:
                matches_candidates_embeddings = candidates[f'embeddings'].values
            base_embeddings = pair[f'{concept}_embeddings_base']
            dist_mat = cosine_similarity_matrix(base_embeddings, matches_candidates_embeddings)

            matches_indexes = np.argsort(-dist_mat, axis=0)[:k]

            similarities_values = np.take_along_axis(dist_mat, matches_indexes, axis=0)
            df_matches_pairs.loc[i, 'text_base'] = pair['description_base']
            df_matches_pairs.loc[i, 'text_counterfactual'] = pair['description_counterfactual']
            df_matches_pairs.loc[i, f'original_id'] = pair['original_id_base']
            for j in range(k):
                match = candidates.iloc[matches_indexes[j]]
                df_matches_pairs.loc[i, f'text_match_{j}'] = match[self.text_col]
                df_matches_pairs.loc[i, f'similarity_match_{j}'] = similarities_values[j]
                df_matches_pairs.loc[i, f'original_id_match_{j}'] = match['original_id']

        df_matches_pairs['concept_intervention'] = [concept] * len(df_matches_pairs)
        df_matches_pairs['base_direction'] = [base_direction] * len(df_matches_pairs)
        df_matches_pairs['target_direction'] = [target_direction] * len(df_matches_pairs)
        return df_matches_pairs

    def get_matches_for_concept_stance(self, pairs, concept, base_direction, target_direction, top_k, batch_size=128):
        if self.tcf_include:
            return self.get_matches_for_concept_with_gt_stance(pairs, concept, base_direction, target_direction, top_k,
                                                               batch_size)
        text_col = self.text_col
        candidate_pairs = pairs[
            (pairs[f'{concept}_text'] == base_direction) & (pairs['edit_goal'] == target_direction) & (
                    pairs['edit_type'] == concept)]

        domains = candidate_pairs['domain'].unique()
        if len(domains) == 0:
            return pd.DataFrame()

        cols_to_keep = ['id', 'text', 'label', 'edit_label', 'edit_text', 'edit_instruction', 'instruction',
                        'original_instruction', 'edit_id',
                        'edit_type', 'edit_goal']
        for col in candidate_pairs.columns:
            if ('preds' in col) or ('probs' in col):
                cols_to_keep.append(col)

        per_domains_matches = []
        for domain in domains:
            pairs_per_domain = candidate_pairs[candidate_pairs['domain'] == domain]
            if len(candidate_pairs) == 0:
                return pd.DataFrame()
            candidate_matches_per_domain = self.set_to_match[(self.set_to_match[f'{concept}_text'] ==
                                                              target_direction) & (
                                                                     self.set_to_match['domain'] == domain)]

            candidate_matches_per_domain = candidate_matches_per_domain.reset_index(drop=True)
            pairs_per_domain = pairs_per_domain.reset_index(drop=True)
            pairs_per_domain = pairs_per_domain.copy()

            if self.representation_model_per_concept:
                base_embeddings = self.representation_model_per_concept[concept].get_embeddings(
                    list(pairs_per_domain[text_col].values))
            elif self.representation_model:
                # embedding for pairs base
                base_embeddings = self.representation_model.get_embeddings(
                    list(pairs_per_domain[text_col].values))

            # embeddings for matches
            if self.representation_model_per_concept:
                if self.assign:
                    if self.matching_set_embeddings_per_concept != {}:
                        matches_candidates_embeddings = self.matching_set_embeddings_per_concept[concept][
                            candidate_matches_per_domain.index.values]
                    else:
                        matches_candidates_embeddings = self.set_to_match[f'{concept}_embeddings'].values
                else:
                    matches_candidates_embeddings = self.representation_model_per_concept[concept].get_embeddings(
                        list(candidate_matches_per_domain[text_col].values))
            elif self.representation_model:
                if self.assign:
                    matches_candidates_embeddings = candidate_matches_per_domain[f'embeddings'].values
                else:
                    matches_candidates_embeddings = self.representation_model.get_embeddings(
                        list(candidate_matches_per_domain[text_col].values))
            else:
                raise ValueError('representation model is not set')

            dist_mat = cosine_similarity_matrix(base_embeddings, matches_candidates_embeddings, already_on_device=False)
            k = top_k
            if len(matches_candidates_embeddings) <= k:
                k = len(matches_candidates_embeddings) - 1
            if k == 1:
                matches_indexes = np.argmax(dist_mat, axis=1)
                matches_real_indexes = candidate_matches_per_domain.iloc[matches_indexes].index
                similarities_values = np.max(dist_mat, axis=1)
            else:
                matches_indexes = np.argsort(-dist_mat, axis=1)[:, :k]
                similarities_values = np.take_along_axis(dist_mat, matches_indexes, axis=1)
                # take the indexes of the matches
                matches_real_indexes = []
                for i in range(len(matches_indexes)):
                    matches_real_indexes.append(list(candidate_matches_per_domain.iloc[matches_indexes[i]].index))

                # matches_real_indexes = candidate_matches.iloc[matches_indexes].index

            df_matches_pairs_per_domain = pd.DataFrame()
            for c in cols_to_keep:
                df_matches_pairs_per_domain[c] = pairs_per_domain[c].values

            if 'prediction_base' in pairs_per_domain.columns:
                df_matches_pairs_per_domain['prediction_base'] = pairs_per_domain['prediction_base'].values
                df_matches_pairs_per_domain['prediction_counterfactual'] = pairs_per_domain[
                    'prediction_counterfactual'].values

            # if len(matches_indexes.shape) == 1:
            #     match_indices = matches_real_indexes
            # else:
            #     match_indices = np.take(matches_indexes, i, axis=1)
            #
            # for i in range(k):
            #     match = candidate_matches_per_domain.iloc[match_indices[:, i]]
            #     for c in cols_to_keep:
            #         if 'edit' in c:
            #             continue
            #         df_matches_pairs_per_domain[f'{c}_match_{i}'] = match[c].values
            for i in range(k):
                if k == 1:
                    match = candidate_matches_per_domain.loc[matches_real_indexes]
                    similarity = similarities_values
                else:
                    match = candidate_matches_per_domain.loc[[lst[i] for lst in matches_indexes]]
                    similarity = similarities_values[:, i]
                if len(match) == 0:
                    continue
                elif len(match.shape) == 1:
                    match = match.to_frame().T
                match = match.rename(lambda x: f'{x}_match_{i}', axis=1)
                specific_columns_to_keep = [f'{c}_match_{i}' for c in cols_to_keep if 'edit' not in c]
                df_matches_pairs_per_domain = pd.concat(
                    [df_matches_pairs_per_domain.reset_index(drop=True),
                     match[specific_columns_to_keep].reset_index(drop=True)], axis=1)
                df_matches_pairs_per_domain[f'similarity_match_{i}'] = similarity

            # df_matches_pairs_per_domain[f'similarity_match_{i}'] = similarities_values[:, i]
            per_domains_matches.append(df_matches_pairs_per_domain)

            # concat all the matches per domain
        df_matches_pairs = pd.concat(per_domains_matches, axis=0)
        df_matches_pairs['treatment'] = [concept] * len(df_matches_pairs)
        # df_matches_pairs['concept_base'] = [concept] * len(df_matches_pairs)
        # df_matches_pairs['concept_counterfactual'] = [concept] * len(df_matches_pairs)
        df_matches_pairs['base_direction'] = [base_direction] * len(df_matches_pairs)
        df_matches_pairs['target_direction'] = [target_direction] * len(df_matches_pairs)

        return df_matches_pairs

    def get_matches_for_concept_with_gt_stance(self, pairs, concept, base_direction, target_direction, top_k,
                                               batch_size=128):
        k = top_k
        text_col = self.text_col
        candidate_pairs = pairs[
            (pairs[f'{concept}_text'] == base_direction) & (pairs['edit_goal'] == target_direction) & (
                    pairs['edit_type'] == concept)]

        if self.representation_model_per_concept:
            candidate_pairs[f'{concept}_embeddings_base'] = self.representation_model_per_concept[
                concept].get_embeddings(
                list(candidate_pairs[text_col].values))
            candidate_pairs[f'edit_{concept}_embeddings'] = self.representation_model_per_concept[
                concept].get_embeddings(
                list(candidate_pairs[f'edit_{text_col}'].values))

        elif self.representation_model:

            candidate_pairs[f'{concept}_embeddings_base'] = self.representation_model.get_embeddings(
                list(candidate_pairs[text_col].values))
            candidate_pairs[f'edit_{concept}_embeddings'] = self.representation_model.get_embeddings(
                list(candidate_pairs[f'edit_{text_col}'].values))

        df_matches_pairs = pd.DataFrame()
        cols_to_keep = ['id', 'text', 'label', 'edit_label', 'edit_text', 'edit_instruction', 'instruction',
                        'original_instruction', 'edit_id',
                        'edit_type', 'edit_goal']
        for col in candidate_pairs.columns:
            if ('preds' in col) or ('probs' in col):
                cols_to_keep.append(col)

        for i in range(len(candidate_pairs)):
            pair = candidate_pairs.iloc[i]
            if self.tcf_include:
                gt_candidates = pd.DataFrame(pair).T
                gt_candidates = gt_candidates[
                    [col for col in gt_candidates.columns if 'edit' in col]]
                gt_candidates = gt_candidates.rename(columns=lambda x: x.replace('edit_', ''))
                # gt_candidates = gt_candidates.rename(columns={'description': self.text_col})

            else:
                raise ValueError('only_cf and tcf_include cannot be both false')
            no_gt_candidates = self.set_to_match[(self.set_to_match[f'{concept}_text'] == target_direction) &
                                                 (self.set_to_match['domain_text'] == pair['domain_text'])]
            # if self.representation_model_per_concept:
            #     gt_candidates[f'{concept}_embeddings'] = self.representation_model_per_concept[
            #         concept].get_embeddings(
            #         list(gt_candidates[self.text_col].values))
            # else:
            #     gt_candidates[f'embeddings'] = self.representation_model.get_embeddings(
            #         list(gt_candidates[self.text_col].values))

            candidates = pd.concat([gt_candidates, no_gt_candidates])
            if self.representation_model_per_concept:
                matches_candidates_embeddings = candidates[f'{concept}_embeddings'].values
            else:
                matches_candidates_embeddings = candidates[f'embeddings'].values

            # if self.representation_model_per_concept:
            #     # matches_candidates_embeddings = candidates[f'{concept}_embeddings'].values
            #     matches_candidates_embeddings = self.representation_model_per_concept[
            #         concept].get_embeddings(
            #         list(candidates[self.text_col].values))
            # elif self.representation_model:
            #     matches_candidates_embeddings = self.representation_model.get_embeddings(
            #         list(candidates[self.text_col].values))
            base_embeddings = pair[f'{concept}_embeddings_base']
            dist_mat = cosine_similarity_matrix(base_embeddings, matches_candidates_embeddings, already_on_device=False)

            matches_indexes = np.argsort(-dist_mat, axis=0)[:k]

            similarities_values = np.take_along_axis(dist_mat, matches_indexes, axis=0)
            single_df_matches_pairs = pd.DataFrame()

            for c in cols_to_keep:
                single_df_matches_pairs[c] = [pair[c]]
            for j in range(k):
                match = candidates.iloc[matches_indexes[j]]
                for c in cols_to_keep:
                    if 'edit' in c:
                        continue
                    single_df_matches_pairs[f'{c}_match_{j}'] = match[c]
            df_matches_pairs = pd.concat([df_matches_pairs, single_df_matches_pairs], axis=0)
        df_matches_pairs['treatment'] = [concept] * len(df_matches_pairs)
        # df_matches_pairs['concept_base'] = [concept] * len(df_matches_pairs)
        # df_matches_pairs['concept_counterfactual'] = [concept] * len(df_matches_pairs)
        df_matches_pairs['base_direction'] = [base_direction] * len(df_matches_pairs)
        df_matches_pairs['target_direction'] = [target_direction] * len(df_matches_pairs)

        return df_matches_pairs
