import os

import numpy as np
import pandas as pd
import wandb
from tqdm import tqdm
import json
from utils.constants import CEBAB_DIRECTIONS, RESULTS_PATH, DATE, CONFOUNDERS_MAPS
import ast


def make_dir(path):
    if not os.path.isdir(path):
        os.makedirs(path)


def calculate_cebab_score_stance_setup(model_to_explain_name, concepts, explainers, pairs,
                                       name, wandb_on, path_dir=None, return_log=False,
                                       save_outputs=False):
    print('\nThe pipeline has started now.')
    # for model in model_to_explain_name:
    explainers_descriptions = [e.get_explainer_description() for e in explainers]
    df_overall = pd.DataFrame()
    pairs[f'prediction_base'] = pairs[f'{model_to_explain_name}_probs'].apply(lambda x: ast.literal_eval(x))
    pairs[f'prediction_counterfactual'] = pairs[f'edit_{model_to_explain_name}_probs'].apply(
        lambda x: ast.literal_eval(x))
    for concept in concepts:
        # for concept in tqdm(concepts):
        c_direction = list(CONFOUNDERS_MAPS[concept])
        directions = [f'{b}->{t}' for b in c_direction for t in
                      c_direction if b != t] + ['average']

        df = pd.DataFrame(columns=directions, index=explainers_descriptions)

        valid_pairs = {e_d: 0 for e_d in explainers_descriptions}
        optional_pairs = {e_d: 0 for e_d in explainers_descriptions}
        for base_direction in c_direction:
            for target_direction in c_direction:
                if target_direction == base_direction:
                    continue
                # f = open(os.path.join(results_path, f'{concept}_{base_direction}_{target_direction}.txt'), 'a')
                intervention_aspect_base_col = f'{concept}_text'
                intervention_aspect_counterfactual_col = 'edit_goal'

                pairs_prime = pairs[(pairs[intervention_aspect_base_col] == base_direction) & (
                        pairs[f'{intervention_aspect_counterfactual_col}'] == target_direction)]

                for e in explainers:
                    optional_pairs[e.get_explainer_description()] += len(pairs_prime)
                    pairs_results = e.icace_error(pairs=pairs_prime, model=model_to_explain_name, concept=concept,
                                                  base_direction=base_direction, target_direction=target_direction,
                                                  save_outputs=save_outputs)
                    valid_pairs[e.get_explainer_description()] += len(pairs_results)
                    icace_error = pairs_results['icace_error'].mean()
                    if icace_error is None:
                        continue
                    icace_error = np.round(icace_error, decimals=3)
                    description = f'\n\t Explainer: {e.get_explainer_description()}\n\t' \
                                  f'Explained model: {model_to_explain_name}\n\t ' \
                                  f'Direction: {base_direction} -> {target_direction}\n' \
                                  f'########### ICACE-error: {icace_error} ###########\n\n'

                    # print(description)
                    df.loc[e.get_explainer_description(), f'{base_direction}->{target_direction}'] = round(
                        icace_error,
                        2)

        # f.write(description)
        df['average'] = df.mean(axis=1, skipna=True)
        df.style.highlight_max(color='lightgreen', axis=0)
        if path_dir is not None:
            results_path = get_results_path_per_concept(path_dir=path_dir, concept=concept, name=name)
            df = df.round(2)
            df.to_csv(f'{results_path}.csv')
        df_overall[concept] = df['average']

        df_overall['average'] = round(df_overall.mean(axis=1), 2)
        df_overall.style.highlight_max(color='lightgreen', axis=0)
        if path_dir is not None:
            p = get_results_path_per_concept(path_dir=path_dir, concept='overall', name=name)
            df_overall.to_csv(f'{p}.csv')

        wandb_dict = {
            f'{name}_{model_to_explain_name}_{e.get_explainer_description()}': df_overall.loc[
                e.get_explainer_description(), 'average']
            for e
            in explainers}

        if wandb_on:
            wandb.log(wandb_dict)
        else:
            print(wandb_dict)
    #     # wandb.log(
    #     #     {f'above_the_threshold_{e.get_explainer_description()}': 100 * valid_pairs[e.get_explainer_description()] /
    #     #                                                              optional_pairs[e.get_explainer_description()]
    #     #      for e in explainers})
    #     if return_log:
    #         return wandb_dict
    #
    # return df_overall


def calculate_cebab_score(model_to_explain, concepts, explainers, pairs,
                          name, wandb_on, path_dir=None, return_log=False, save_outputs=False):
    print('\nThe pipeline has started now.')
    explainers_descriptions = [e.get_explainer_description() for e in explainers]
    df_overall = pd.DataFrame(columns=concepts, index=explainers_descriptions)
    pairs[f'prediction_base'] = model_to_explain.get_predictions(list(pairs['description_base'].values))
    pairs[f'prediction_counterfactual'] = model_to_explain.get_predictions(
        list(pairs['description_counterfactual'].values))

    for concept in tqdm(concepts):
        print(f'\nTreatment: {concept}')
        directions = [f'{b}->{t}' for b in CEBAB_DIRECTIONS for t in CEBAB_DIRECTIONS if b != t] + ['average']
        df = pd.DataFrame(columns=directions, index=explainers_descriptions)

        valid_pairs = {e_d: 0 for e_d in explainers_descriptions}
        optional_pairs = {e_d: 0 for e_d in explainers_descriptions}
        for base_direction in CEBAB_DIRECTIONS:
            for target_direction in CEBAB_DIRECTIONS:
                if target_direction == base_direction:
                    continue
                # f = open(os.path.join(results_path, f'{concept}_{base_direction}_{target_direction}.txt'), 'a')
                intervention_type_col = 'intervention_type'
                intervention_aspect_base_col = 'intervention_aspect_base'
                intervention_aspect_counterfactual_col = 'intervention_aspect_counterfactual'

                pairs_prime = pairs[
                    (pairs[intervention_type_col] == concept) & (
                            pairs[intervention_aspect_base_col] == base_direction) & (
                            pairs[intervention_aspect_counterfactual_col] == target_direction)]

                for e in explainers:
                    optional_pairs[e.get_explainer_description()] += len(pairs_prime)
                    pairs_results = e.icace_error(pairs=pairs_prime, model=model_to_explain, concept=concept,
                                                  base_direction=base_direction, target_direction=target_direction,
                                                  save_outputs=save_outputs)
                    valid_pairs[e.get_explainer_description()] += len(pairs_results)
                    icace_error = pairs_results['icace_error'].mean()
                    if icace_error is None:
                        continue
                    icace_error = np.round(icace_error, decimals=2)
                    description = f'\n\t Explainer: {e.get_explainer_description()}\n\t' \
                                  f'Explained model: {model_to_explain.get_model_description()}\n\t ' \
                                  f'Direction: {base_direction} -> {target_direction}\n' \
                                  f'########### ICACE-error: {icace_error} ###########\n\n'
                    df.loc[e.get_explainer_description(), f'{base_direction}->{target_direction}'] = round(icace_error,
                                                                                                           2)

        # f.write(description)
        df['average'] = df.mean(axis=1, skipna=True)
        df.style.highlight_max(color='lightgreen', axis=0)
        if path_dir is not None:
            results_path = get_results_path_per_concept(path_dir=path_dir, concept=concept, name=name)
            df = df.round(2)
            df.to_csv(f'{results_path}.csv')
        df_overall[concept] = df['average']

    df_overall['average'] = round(df_overall.mean(axis=1), 2)
    df_overall.style.highlight_max(color='lightgreen', axis=0)
    if path_dir is not None:
        p = get_results_path_per_concept(path_dir=path_dir, concept='overall', name=name)
        df_overall.to_csv(f'{p}.csv')

    wandb_dict = {
        f'{name}_{e.get_explainer_description()}': df_overall.loc[e.get_explainer_description(), 'average']
        for e
        in explainers}

    if wandb_on:
        wandb.log(wandb_dict)
        # wandb.log(
        #     {f'above_the_threshold_{e.get_explainer_description()}': 100 * valid_pairs[e.get_explainer_description()] /
        #                                                              optional_pairs[e.get_explainer_description()]
        #      for e in explainers})
        if return_log:
            return wandb_dict

    return df_overall


def get_results_path_per_concept(path_dir, concept, name):
    p = os.path.join(path_dir, 'cebab_score')
    make_dir(p)
    p = os.path.join(p, name)
    make_dir(p)
    p = os.path.join(p, concept)
    return p


def get_results_path(config, save_config=True):
    results_dir = RESULTS_PATH
    make_dir(RESULTS_PATH)
    results_dir = os.path.join(results_dir, config.setup_name)
    make_dir(RESULTS_PATH)
    results_dir = os.path.join(results_dir, config.model_to_fine_tune)
    make_dir(RESULTS_PATH)
    results_dir = os.path.join(results_dir, str(config.seed))
    make_dir(RESULTS_PATH)
    results_dir = os.path.join(results_dir, config.treatment)
    make_dir(RESULTS_PATH)
    if config.setup_name == 'cebab':
        results_dir = os.path.join(results_dir, config.filter_level)
        make_dir(RESULTS_PATH)
    results_dir = os.path.join(results_dir, f'lr_{config.learning_rate}_epochs_{config.epochs}')
    # results_dir = os.path.join(results_dir, DATE)
    make_dir(results_dir)

    if save_config:
        with open(os.path.join(results_dir, 'config.txt'), "w") as text_file:
            json.dump(dict(config), text_file)

    return results_dir
