import numpy as np
import pandas as pd
import ast
from scipy import stats
from scipy.spatial.distance import cosine
import warnings

warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in double_scalars")


def icace_per_row_k_th(row, k_th_match):
    scores = ast.literal_eval(row[f'text_match_{k_th_match}_scores'])
    explanation = np.array(list(scores.values())) - np.array(list(ast.literal_eval(row['text_base_scores']).values()))
    icace = np.array(list(ast.literal_eval(row['text_counterfactual_scores']).values())) - np.array(
        list(ast.literal_eval(row['text_base_scores']).values()))
    return np.linalg.norm(explanation - icace, ord=2)


def get_icace_error(row, k, threshold=1):
    under_the_threshold = 0
    for c in range(k):
        if (row[f'text_match_{c}_scores'] is np.nan) or (
                row[f'text_match_{c}_scores'] == "{'1': nan, '2': nan, '3': nan, '4': nan, '5': nan}"):
            return None
    flag = False
    if 'similarity_match_1' in row:
        flag = True
    # print(row['Unnamed: 0'])
    sum_predictions = {i: 0 for i in ast.literal_eval(row[f'text_match_0_scores']).keys()}
    for j in range(k):
        if flag:
            if row[f'similarity_match_{j + 1}'] < threshold:
                continue
        under_the_threshold += 1
        if row[f'text_match_{j}_scores'] is np.nan:
            return None
        score_j = ast.literal_eval(row[f'text_match_{j}_scores'])
        for t in score_j.keys():
            sum_predictions[t] += score_j[t]
    if under_the_threshold == 0:
        return 'no examples above threshold'
    average = {s: sum_predictions[s] / under_the_threshold for s in sum_predictions.keys()}
    if row['text_base_scores'] is np.nan:
        return None
    explanation = np.array(list(average.values())) - np.array(list(ast.literal_eval(row['text_base_scores']).values()))
    if row['text_counterfactual_scores'] is np.nan:
        return None
    icace = np.array(list(ast.literal_eval(row['text_counterfactual_scores']).values())) - np.array(
        list(ast.literal_eval(row['text_base_scores']).values()))
    # return stats.spearmanr(explanation, icace)[0]

    return {'l2': np.linalg.norm(explanation - icace, ord=2),
            'cosine': cosine(explanation, icace),
            'norm-diff': np.abs(np.linalg.norm(explanation, ord=2) - np.linalg.norm(icace, ord=2)),
            'under_the_threshold': under_the_threshold}


ablations_explainers = [
    'Causal_S-Transformer',
    # 'Causal_Bert',
    'Causal_Roberta', 'no_tcf',
    'no_cfc',
    # 'zero_cfc_new',
    'no_pax',
    'no_nax',
    'no_counterfactuals',
    'no_samples',
    'rule-based filtering',
    'v2_Causal S-Transformer'
]

baselines_explainers = [
    # 'CF-Gen',
    # 'gpt-3.5-turbo_few_shot',
    # 'gpt-3.5-turbo_zero_shot',
    # 'Approx',
    # 'Random_Matching',
    # 'roberta',
    # 'sentence_transformer',
    # 'ST_FT',
    # 'Causal_S-Transformer_GT'
    # 'Matching By Propensity'
    # 'sentence_transformer_causal',
    # 'sentiment_roberta',
    # 'sentence_transformer_causal_with_gt'
]
models = ablations_explainers

k_th = False
df_results = pd.DataFrame()
sizes = [
    'Llama-2-7b-chat-hf',
    'Llama-2-13b-chat-hf'
]
k = [1]
p_filtering = [0]
metrics = ['l2', 'cosine', 'norm-diff']
df_results_per_metric = {metric: pd.DataFrame() for metric in metrics}
dir_path = '/home/XXXXXX/MatchingBasedCausalExplanation/pairs/all/ablations'
sets = [
    f'base',
    f'GT',
    f'Miss_CF'
]
df_total_per_metric = {metric: pd.DataFrame() for metric in metrics}
for s in sets:
    for size in sizes:
        for model in models:
            for p in p_filtering:
                # df = pd.read_csv(
                #     f'/home/XXXXXX/MatchingBasedCausalExplanation/pairs/{description}/{size}/{model}.csv')
                df = pd.read_csv(
                    f'{dir_path}/{size}/{s}/{model}.csv')
                # similarities = df['similarity_match_1']
                # i want to find the threshold that will filter out 10% of the examples
                # threshold = np.quantile(similarities, p)
                threshold = 0
                df.dropna(subset=['text_match_0_scores'], inplace=True)
                df.dropna(subset=['text_base_scores'], inplace=True)
                df.dropna(subset=['text_counterfactual_scores'], inplace=True)

                for i in k:

                    if k_th:
                        df['icace_error'] = df.apply(lambda row: icace_per_row_k_th(row, i), axis=1)
                    else:
                        df['icace_error'] = df.apply(lambda row: get_icace_error(row, i, threshold=threshold), axis=1)
                        df = df.dropna(subset=['icace_error'])
                        # extract the number of examples above the threshold
                        df_under = df[df['icace_error'] == 'no examples above threshold']
                        under_the_threshold = len(df_under) / len(df)

                        df = df[df['icace_error'] != 'no examples above threshold']
                        for metric in metrics:
                            df[f'icace_error_{metric}'] = df['icace_error'].apply(lambda x: x[metric])
                        df = df.dropna(subset=['icace_error'])

                        for metric in metrics:
                            # df_total_per_metric[metric].loc[f'{model}-k={i}', f'{s}-{size}'] = round(
                            #     df[f'icace_error_{metric}'].mean(), 4)
                            df_total_per_metric[metric].loc[
                                f'{model}-k={i}', f'{size}-{s}'] = round(
                                df[f'icace_error_{metric}'].mean(), 4)
                            # df_total_per_metric[metric].loc[
                            #     f'{model}-threshold-{threshold}', f'{size}-under_threshold'] = round(
                            #     under_the_threshold, 2)

                    # df_results_per_metric[metric].loc[f'{model}-k={i}', f'{size}'] = round(
                    #     df[f'icace_error_{metric}'].mean(),
                    #     4)

    # for metric in metrics:
    #     df_results_per_metric[metric].to_csv(
    #         f'/home/XXXXXX/MatchingBasedCausalExplanation/pairs/{description}/{cat}/llamas_{s}_icace_{metric}_k=1.csv')

for metric in metrics:
    df_total_per_metric[metric].to_csv(
        f'{dir_path}/ablations_lastt_llamas_{metric}_k=1_{max(k)}.csv')

print('done')
