from copy import deepcopy
from typing import Dict
import pathlib

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from pylab import rcParams
import seaborn as sns
import pandas as pd


# Global plotting settings
rcParams['figure.figsize'] = 12, 4
sns.set_style("whitegrid")
sns.set_context("notebook", font_scale=1.8)
# Important for paper compatibility
matplotlib.rcParams['font.family'] = "serif"
matplotlib.rcParams['ps.useafm'] = True
matplotlib.rcParams['pdf.use14corefonts'] = True
matplotlib.rcParams['text.usetex'] = True
legend_bool = "brief"


def performance_plot_n_analysis(df_input: pd.DataFrame, logdir: pathlib.Path, method: Dict):
    rcParams['figure.figsize'] = 3, 8
    sns.set_style("whitegrid")
    sns.set_context("notebook", font_scale=2.0)
    # Important for paper compatibility
    matplotlib.rcParams['font.family'] = "serif"
    matplotlib.rcParams['ps.useafm'] = True
    matplotlib.rcParams['pdf.use14corefonts'] = True
    matplotlib.rcParams['text.usetex'] = True
    legend_bool = "brief"
    df = deepcopy(df_input)
    df = df.loc[df["result name"] == "average_stability0"]
    # Select analysis method(s) here:
    df["analysis method"][df["analysis method"] == "single_task_blur_stability_gradient_sensitivity_analysis"] = \
        "task_perturbed_gradient"
    df["analysis method"][df["analysis method"] == "single_task_contrast_stability_gradient_sensitivity_analysis"] = \
        "task_perturbed_gradient"
    df["analysis method"][df["analysis method"] == "single_task_shift_stability_gradient_sensitivity_analysis"] = \
        "task_perturbed_gradient"

    df["analysis method"][df["analysis method"] == "single_task_loss_blur_sensitivity_analysis"] = \
        "task_loss_perturbation_sensitivity_analysis"
    df["analysis method"][df["analysis method"] == "loss_input_sensitivity_analysis"] = \
        "task_loss_perturbation_sensitivity_analysis"

    df["analysis method"][df["analysis method"] == "single_task_loss_contrast_sensitivity_analysis"] = \
        "task_loss_perturbation_sensitivity_analysis"
    df["analysis method"][df["analysis method"] == "single_task_loss_shift_sensitivity_analysis"] = \
        "task_loss_perturbation_sensitivity_analysis"

    df = df.loc[(df["analysis method"] == "task_perturbed_gradient") |
                (df["analysis method"] == "single_task_gradient_sensitivity_analysis") |
                (df["analysis method"] == "task_loss_perturbation_sensitivity_analysis")]
    #
    df["analysis method"] = [str(value).replace("_", " ").split("task")[-1].split("sensitivity")[0]
                      for value in df["analysis method"].values]

    print(f"Analysis methods: {df['analysis method']}")

    df["task uid"] = [str(value).strip("[]").replace(", ", "-")
                           for value in df["task uid"].values]
    df["task dataset"] = [str(value).strip("[]").replace(", ", "-")
                               for value in df["task dataset"].values]
    df = df.loc[df["task dataset"] == "'CIFAR10'"]
    df_base = df.loc[df["exp set name"].str.startswith('Base')]
    base_cifar10_blur_stab = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base CIFAR')]['result']
                                     & df_base["analysis method"] == "task perturbed gradient")
    base_cifar10_gradient = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base CIFAR')]['result']
                                     & df_base["analysis method"] == "gradient")
    base_cifar10_input_sensitivity = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base CIFAR')]['result']
                                     & df_base["analysis method"] == "loss perturbation")

    df = df.loc[((df["exp set name"].str.endswith('0')) & (df["task uid"] == '9')) | ((df["exp set name"].str.endswith('6')) & (df["task uid"] == '5')) |
                ((df["exp set name"].str.endswith('2')) & (df["task uid"] == '1'))]

    print(f"FINAL DF: {df}")
    df_plot = pd.DataFrame(columns={'experiment name', 'N', 'seed', 'bias reset', r'Analysis results [dB]',
                                    'Analysis method'})
    df_plot['experiment name'] = df['exp set name']
    df_plot['seed'] = df['exp seed']
    df_plot[r'Analysis results [dB]'] = df['result']
    df_plot['Analysis method'] = df['analysis method']
    df_plot['bias reset'] = df_plot['experiment name'].str.contains('BiasReset')

    for i in df_plot.index:
        if df_plot['experiment name'].loc[i][-1] == '2':
            df_plot['N'].loc[i] = 'N=2'
        elif df_plot['experiment name'].loc[i][-1] == '6':
            df_plot['N'].loc[i] = 'N=6'
        elif df_plot['experiment name'].loc[i][-1] == '0':
            df_plot['N'].loc[i] = 'N=10'

        df_plot[df_plot["Analysis method"] == "task perturbed gradient"][r'Analysis results [dB]'].loc[i] = \
            20 * np.log10(df_plot[r'Analysis results [dB]'].loc[i]) # / base_cifar10_blur_stab)
        df_plot[df_plot["Analysis method"] == "gradient"][r'Analysis results [dB]'].loc[i] = \
            20 * np.log10(df_plot[r'Analysis results [dB]'].loc[i]) # / base_cifar10_gradient)
        df_plot[df_plot["Analysis method"] == "loss perturbation"][r'Analysis results [dB]'].loc[i] = \
            20 * np.log10(df_plot[r'Analysis results [dB]'].loc[i]) # / base_cifar10_input_sensitivity)
    print(df_plot)
    pal = ["#005858", "#daa520"]
    # pal = ["#1e488f", "#789b73"]  # cobalt, grey green
    # g = sns.barplot(data=df_plot, x="N", y="Analysis results", hue='bias reset', order=['2', '6', '10'])
    # times = df_plot.N.unique()
    g = sns.catplot(x="N", y=r'Analysis results [dB]', col="Analysis method", sharey=False,
                    data=df_plot, hue='bias reset', order=['N=2', 'N=6', 'N=10'],
                    kind="bar", ci="sd", aspect=1.6, height=3.6, palette=pal)
    for ax in g.axes.flat:
        ax.set_title(str(ax.get_title()).split("=")[-1])
    g._legend.remove()
    axes = g.axes.flatten()
    axes[0].set(xlabel="")
    axes[1].set(xlabel="")
    axes[2].set(xlabel="")
    g.savefig(logdir/'discretizationanalysisplot.pdf')


def analysis_dataset_and_transf(df_input: pd.DataFrame, logdir: pathlib.Path, method: Dict):
    rcParams['figure.figsize'] = 3,8
    sns.set_style("whitegrid")
    sns.set_context("notebook", font_scale=2.0)
    # Important for paper compatibility
    matplotlib.rcParams['font.family'] = "serif"
    matplotlib.rcParams['ps.useafm'] = True
    matplotlib.rcParams['pdf.use14corefonts'] = True
    matplotlib.rcParams['text.usetex'] = True
    legend_bool = "brief"

    df = deepcopy(df_input)
    df = df.loc[df["result name"] == "average_stability0"]
    df["task uid"] = [str(value).strip("[]").replace(", ", "-")
                      for value in df["task uid"].values]
    df["task dataset"] = [str(value).strip("[]").replace(", ", "-").replace("'", "")
                          for value in df["task dataset"].values]

    df = df.loc[df["task uid"] == "1"]

    codes = {'single_task_blur_stability_gradient_sensitivity_analysis': "task_perturbed_gradient",
             'single_task_contrast_stability_gradient_sensitivity_analysis': "task_perturbed_gradient",
             'single_task_shift_stability_gradient_sensitivity_analysis': "task_perturbed_gradient",
             'single_task_loss_blur_sensitivity_analysis': "task_loss_perturbation_sensitivity_analysis",
             'single_task_loss_contrast_sensitivity_analysis': "task_loss_perturbation_sensitivity_analysis",
             'single_task_loss_shift_sensitivity_analysis': "task_loss_perturbation_sensitivity_analysis"}

    df["analysis method"][df["analysis method"] == "single_task_blur_stability_gradient_sensitivity_analysis"] = \
        "task_perturbed_gradient"
    df["analysis method"][df["analysis method"] == "single_task_contrast_stability_gradient_sensitivity_analysis"] = \
        "task_perturbed_gradient"
    df["analysis method"][df["analysis method"] == "single_task_shift_stability_gradient_sensitivity_analysis"] = \
        "task_perturbed_gradient"

    df["analysis method"][df["analysis method"] == "single_task_loss_blur_sensitivity_analysis"] = \
        "task_loss_perturbation_sensitivity_analysis"
    df["analysis method"][df["analysis method"] == "single_task_loss_contrast_sensitivity_analysis"] = \
        "task_loss_perturbation_sensitivity_analysis"
    df["analysis method"][df["analysis method"] == "single_task_loss_shift_sensitivity_analysis"] = \
        "task_loss_perturbation_sensitivity_analysis"

    df = df.loc[
        (df["analysis method"] == "task_perturbed_gradient") |
        (df["analysis method"] == "single_task_gradient_sensitivity_analysis") |
        (df["analysis method"] == "task_loss_perturbation_sensitivity_analysis")
        ]
    df["analysis method"] = [str(value).replace("_", " ").split("task")[-1].split("sensitivity")[0].strip()
                             for value in df["analysis method"].values]
    df_base = df.loc[df["exp set name"].str.startswith('Base')]
    print(f"ANALYSIS METHODS: {df['analysis method']}")
    # TODO: USE CORRECT VALUES - nned blur for blur, need shift for shift, need contrast for contrast
    # base_svhn = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base SVHN')]['result'])
    # base_fashionmnist = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base Fashion')]['result'])
    # base_cifar10 = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base CIFAR')]['result'])
    # base_cifar10_blur_stab = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base CIFAR')]['result']
    #                                  & df_base["analysis method"] == "blur-stability-gradient")
    # base_cifar10_gradient = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base CIFAR')]['result']
    #                                 & df_base["analysis method"] == "gradient")
    # base_cifar10_input_sensitivity = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base CIFAR')]['result']
    #                                          & df_base["analysis method"] == "loss-input")

    df_plot = pd.DataFrame(columns={'experiment name', 'dataset', 'transform', 'seed', 'bias reset', 'Analysis result',
                                    "Analysis method"})
    df_plot['experiment name'] = df['exp set name']
    df_plot['dataset'] = df['task dataset']
    df_plot['transform'] = [str(value).strip().split(" ")[-2]
                      for value in df["exp set name"].values]
    # print(f"Transforms: {df_plot['transform']}")
    df_plot['seed'] = df['exp seed']
    df_plot['Analysis result'] = df['result']
    df_plot['Analysis method'] = df['analysis method']

    # calculate in dB
    for i in df_plot.index:
        if df_plot['experiment name'].loc[i][0] == '0':
            df_plot['bias reset'].loc[i] = True
        else:
            df_plot['bias reset'].loc[i] = False
        if 'Blur' in df_plot['experiment name'].loc[i]:
            df_plot['transform'].loc[i] = 'Blur'
        elif 'Shift' in df_plot['experiment name'].loc[i]:
            df_plot['transform'].loc[i] = 'Shift'
        elif 'Contrast' in df_plot['experiment name'].loc[i]:
            df_plot['transform'].loc[i] = 'Contrast'
    print(f"Transforms: {df_plot['transform']}")
    print(df_plot)
    pal = ["#005858", "#daa520"]
    g = sns.catplot(x="transform", y=r'Analysis result', col="dataset",
                    row="Analysis method", sharey=False, order=["Blur", "Contrast", "Shift"],
                    data=df_plot, hue='bias reset', hue_order=[False, True],
                    col_order=['CIFAR10', 'SVHN', 'FashionMNIST'],
                    row_order=["gradient", "perturbed gradient", "loss perturbation"],
                    kind="bar", ci="sd", aspect=1.3, palette=pal, height=3.6) #, palette=pal)
    for ax in g.axes.flat:
        ax.set_title(str(ax.get_title()).split("=")[-1])
    g._legend.remove()
    axes = g.axes.flatten()
    axes[0].set_title(r"CIFAR10")
    axes[0].set(xlabel="", ylabel="Gradient [dB]")
    axes[1].set_title(r"SVHN")
    axes[2].set_title("FashionMNIST")
    axes[3].set(xlabel="", ylabel="Gradient perturbation [dB]")
    # axes[2].set_title("")
    axes[6].set(xlabel="", ylabel="Loss perturbation [dB]")
    axes[3].set_title("")
    axes[4].set_title("")
    axes[5].set_title("")
    axes[6].set_title("")
    axes[7].set_title("")
    axes[8].set_title("")
    axes[7].set(xlabel="")
    axes[8].set(xlabel="")
    g.savefig(logdir / 'datasets_and_transformations_analysis.pdf')

def analysis_fisher_n2(df_input: pd.DataFrame, logdir: pathlib.Path, method: Dict):
    rcParams['figure.figsize'] = 3, 8
    sns.set_style("whitegrid")
    sns.set_context("notebook", font_scale=2.0)
    # Important for paper compatibility
    matplotlib.rcParams['font.family'] = "serif"
    matplotlib.rcParams['ps.useafm'] = True
    matplotlib.rcParams['pdf.use14corefonts'] = True
    matplotlib.rcParams['text.usetex'] = True
    legend_bool = "brief"

    df = deepcopy(df_input)
    df = df.loc[df["result name"] == "average_stability0"]
    df["task uid"] = [str(value).strip("[]").replace(", ", "-")
                      for value in df["task uid"].values]
    df["task dataset"] = [str(value).strip("[]").replace(", ", "-").replace("'", "")
                          for value in df["task dataset"].values]

    df = df.loc[df["task uid"] == "1"]

    df["analysis method"][df["analysis method"] == "single_task_fisher_diag_sensitivity_analysis"] = \
        "task_fisher_gradient"

    # TODO: combine properly
    df = df.loc[
        df["analysis method"] == "task_fisher_gradient"
        ]
    df["analysis method"] = [str(value).replace("_", " ").split("task")[-1].split("sensitivity")[0].strip()
                             for value in df["analysis method"].values]
    df_base = df.loc[df["exp set name"].str.startswith('Base')]
    print(f"ANALYSIS METHODS: {df['analysis method']}")

    df_plot = pd.DataFrame(
        columns={'experiment name', 'dataset', 'transform', 'seed', 'bias reset', 'Analysis result',
                 "Analysis method"})
    df_plot['experiment name'] = df['exp set name']
    df_plot['dataset'] = df['task dataset']
    df_plot['transform'] = [str(value).strip().split(" ")[-2]
                            for value in df["exp set name"].values]
    df_plot['seed'] = df['exp seed']
    df_plot['Analysis result'] = df['result']
    df_plot['Analysis method'] = df['analysis method']

    # calculate in dB
    for i in df_plot.index:
        if df_plot['experiment name'].loc[i][0] == '0':
            df_plot['bias reset'].loc[i] = True
        else:
            df_plot['bias reset'].loc[i] = False
        if 'Blur' in df_plot['experiment name'].loc[i]:
            df_plot['transform'].loc[i] = 'Blur'
        elif 'Shift' in df_plot['experiment name'].loc[i]:
            df_plot['transform'].loc[i] = 'Shift'
        elif 'Contrast' in df_plot['experiment name'].loc[i]:
            df_plot['transform'].loc[i] = 'Contrast'
    print(f"Transforms: {df_plot['transform']}")
    print(df_plot)
    pal = ["#005858", "#daa520"]

    g = sns.catplot(x="transform", y=r'Analysis result', col="dataset",
                    row="Analysis method", sharey=False, order=["Blur", "Contrast", "Shift"],
                    data=df_plot, hue='bias reset', hue_order=[False, True],
                    col_order=['CIFAR10', 'SVHN', 'FashionMNIST'],
                    row_order=["fisher gradient"],
                    kind="bar", ci="sd", aspect=1.3, palette=pal, height=3.6)  # , palette=pal)
    for ax in g.axes.flat:
        ax.set_title(str(ax.get_title()).split("=")[-1])
    g._legend.remove()
    axes = g.axes.flatten()
    axes[0].set_title(r"CIFAR10")
    axes[0].set(xlabel="", ylabel="Fisher gradient [dB]")
    axes[1].set_title(r"SVHN")
    axes[2].set_title("FashionMNIST")
    g.savefig(logdir / 'datasets_and_transformations_fisher_analysis.pdf')


def analysis_fisher_LR(df_input: pd.DataFrame, logdir: pathlib.Path, method: Dict):
    rcParams['figure.figsize'] = 3, 8
    sns.set_style("whitegrid")
    sns.set_context("notebook", font_scale=2.0)
    # Important for paper compatibility
    matplotlib.rcParams['font.family'] = "serif"
    matplotlib.rcParams['ps.useafm'] = True
    matplotlib.rcParams['pdf.use14corefonts'] = True
    matplotlib.rcParams['text.usetex'] = True
    legend_bool = "brief"

    df = deepcopy(df_input)
    df = df.loc[df["result name"] == "average_stability0"]
    df["task uid"] = [str(value).strip("[]").replace(", ", "-")
                      for value in df["task uid"].values]
    df["task dataset"] = [str(value).strip("[]").replace(", ", "-").replace("'", "")
                          for value in df["task dataset"].values]

    df = df.loc[df["task uid"] == "1"]

    df["analysis method"][df["analysis method"] == "single_task_fisher_diag_sensitivity_analysis"] = \
        "task_fisher_gradient"
    df = df.loc[
        df["analysis method"] == "task_fisher_gradient"
        ]
    df["analysis method"] = [str(value).replace("_", " ").split("task")[-1].split("sensitivity")[0].strip()
                             for value in df["analysis method"].values]

    df_base = df.loc[df["exp set name"].str.startswith('Base')]
    print(f"ANALYSIS METHODS: {df['analysis method']}")

    df_plot = pd.DataFrame(
        columns={'experiment name', 'dataset', 'transform', 'seed', 'bias reset', 'Analysis result',
                 "Analysis method"})
    df_plot['experiment name'] = df['exp set name']
    df_plot["experiment name"] = [str(value).replace("_", " ").replace("Long", "").replace("long", "").strip()
                             for value in df_plot["experiment name"].values]
    df_plot['dataset'] = df['task dataset']
    df_plot['transform'] = [str(value).strip().split(" ")[-2]
                            for value in df["exp set name"].values]
    df_plot['seed'] = df['exp seed']
    df_plot['Analysis result'] = df['result']
    df_plot['Analysis method'] = df['analysis method']
    for i in df_plot.index:
        if df_plot['experiment name'].loc[i][0] == '0':
            df_plot['bias reset'].loc[i] = True
        else:
            df_plot['bias reset'].loc[i] = False
        if 'Blur' in df_plot['experiment name'].loc[i]:
            df_plot['transform'].loc[i] = 'Blur'
        elif 'Shift' in df_plot['experiment name'].loc[i]:
            df_plot['transform'].loc[i] = 'Shift'
        elif 'Contrast' in df_plot['experiment name'].loc[i]:
            df_plot['transform'].loc[i] = 'Contrast'
    print(df_plot)
    pal = ["#005858", "#daa520"]
    g = sns.catplot(x="dataset", y=r'Analysis result', col="experiment name",
                    row="Analysis method", sharey=True,
                    data=df_plot, hue='bias reset', hue_order=[False, True],
                    col_order=['Decrease LR Blur N=2', 'Blur N=2', 'Increase LR Blur N=2'],
                    row_order=["fisher gradient"],
                    kind="bar", ci="sd", aspect=1.3, palette=pal, height=3.6)
    for ax in g.axes.flat:
        ax.set_title(str(ax.get_title()).split("=")[-1])
    g._legend.remove()
    axes = g.axes.flatten()
    axes[0].set_title(r"Decrease LR")
    axes[0].set(xlabel="", ylabel="Fisher gradient [dB]")
    axes[1].set(xlabel="")
    axes[2].set(xlabel="")
    axes[1].set_title(r"Blur N=2")
    axes[2].set_title("Increase")
    g.savefig(logdir / 'datasets_and_transformations_fisher_analysis.pdf')


def analysis_fisher_bias_reset(df_input: pd.DataFrame, logdir: pathlib.Path, method: Dict):
    rcParams['figure.figsize'] = 3, 8
    sns.set_style("whitegrid")
    sns.set_context("notebook", font_scale=2.0)
    # Important for paper compatibility
    matplotlib.rcParams['font.family'] = "serif"
    matplotlib.rcParams['ps.useafm'] = True
    matplotlib.rcParams['pdf.use14corefonts'] = True
    matplotlib.rcParams['text.usetex'] = True
    legend_bool = "brief"

    df = deepcopy(df_input)
    df = df.loc[df["result name"] == "average_stability0"]
    df["task uid"] = [str(value).strip("[]").replace(", ", "-")
                      for value in df["task uid"].values]
    df["task dataset"] = [str(value).strip("[]").replace(", ", "-").replace("'", "")
                          for value in df["task dataset"].values]

    df = df.loc[df["task uid"] == "1"]
    df = df.loc[df["task dataset"] == "CIFAR10"]

    df["analysis method"][df["analysis method"] == "single_task_fisher_diag_sensitivity_analysis"] = \
        "task_fisher_gradient"
    df = df.loc[
        df["analysis method"] == "task_fisher_gradient"
        ]
    df["analysis method"] = [str(value).replace("_", " ").split("task")[-1].split("sensitivity")[0].strip()
                             for value in df["analysis method"].values]

    df_base = df.loc[df["exp set name"].str.startswith('Base')]
    print(f"ANALYSIS METHODS: {df['analysis method']}")

    df_plot = pd.DataFrame(
        columns={'experiment name', 'dataset', 'transform', 'seed', 'bias reset', 'Analysis result',
                 "Analysis method"})
    df_plot['experiment name'] = df['exp set name']

    df_plot['dataset'] = df['task dataset']
    df_plot['transform'] = [str(value).strip().split(" ")[-2]
                            for value in df["exp set name"].values]

    df_plot["experiment name"][df_plot["experiment name"] == "Long Blur N=2"] = \
        "FF Bias Reset"
    df_plot["experiment name"] = [str(value).replace("_", " ").replace("Long", "").replace("long", "").strip()
                                  for value in df_plot["experiment name"].values]

    df_plot["experiment name"][df_plot["experiment name"] == "0 02 No BR BR Blur N=2"] = \
        "FT Bias Reset"
    df_plot["experiment name"][df_plot["experiment name"] == "0 02 BiasReset Blur N=2"] = \
        "TT Bias Reset"
    df_plot["experiment name"][df_plot["experiment name"] == "0 02 BR No BR Blur N=2"] = \
        "TF Bias Reset"

    df_plot = df_plot.loc[df_plot["experiment name"] != "Blur N=2"]
    df_plot = df_plot[df_plot["experiment name"] != "Increase LR Blur N=2"]
    df_plot = df_plot[df_plot["experiment name"] != "Decrease LR Blur N=2"]

    df_plot['seed'] = df['exp seed']
    df_plot['Analysis result'] = df['result']
    df_plot['Analysis method'] = df['analysis method']
    for i in df_plot.index:
        if df_plot['experiment name'].loc[i][0] == '0':
            df_plot['bias reset'].loc[i] = True
        else:
            df_plot['bias reset'].loc[i] = False
        if 'Blur' in df_plot['experiment name'].loc[i]:
            df_plot['transform'].loc[i] = 'Blur'
        elif 'Shift' in df_plot['experiment name'].loc[i]:
            df_plot['transform'].loc[i] = 'Shift'
        elif 'Contrast' in df_plot['experiment name'].loc[i]:
            df_plot['transform'].loc[i] = 'Contrast'
    print(df_plot)
    pal = ["#005858", "#daa520"]
    g = sns.catplot(x="dataset", y=r'Analysis result', col="experiment name",
                    row="Analysis method", sharey=True,  #order=["CIFAR10"], #, "Contrast", "Shift"],
                    data=df_plot, hue='bias reset', hue_order=[False, True],
                    col_order=['FF Bias Reset', 'FT Bias Reset', 'TF Bias Reset', 'TT Bias Reset'],
                    kind="bar", ci="sd", aspect=1.3, palette=pal, height=3.6)
    for ax in g.axes.flat:
        ax.set_title(str(ax.get_title()).split("=")[-1])
    g._legend.remove()
    axes = g.axes.flatten()
    axes[0].set_title(r"FF Bias Reset")
    axes[1].set_title(r"FT Bias Reset")
    axes[2].set_title(r"TF Bias Reset")
    axes[3].set_title(r"TT Bias Reset")
    axes[0].set(xlabel="", ylabel="Fisher gradient [dB]")
    axes[1].set(xlabel="")
    axes[2].set(xlabel="")
    axes[3].set(xlabel="")
    g.savefig(logdir / 'datasets_and_transformations_fisher_analysis.pdf')


def performance_plot_pretrained_analysis(df_input: pd.DataFrame, logdir: pathlib.Path, method: Dict):
    rcParams['figure.figsize'] = 3, 8
    sns.set_style("whitegrid")
    sns.set_context("notebook", font_scale=2.0)
    # Important for paper compatibility
    matplotlib.rcParams['font.family'] = "serif"
    matplotlib.rcParams['ps.useafm'] = True
    matplotlib.rcParams['pdf.use14corefonts'] = True
    matplotlib.rcParams['text.usetex'] = True
    legend_bool = "brief"

    df = deepcopy(df_input)
    df = df.loc[df["result name"] == "average_stability0"]
    df = df.loc[
                (df["analysis method"] == "single_task_gradient_sensitivity_analysis") |
                (df["analysis method"] == "single_taskinput_gradient_sensitivity_analysis")]
    df["analysis method"] = [str(value).replace("_", " ").split("task")[-1].split("sensitivity")[0].strip()
                             for value in df["analysis method"].values]
    print(f"ANALYSIS METHODS: {df['analysis method']}")
    df_base = df.loc[df["exp set name"].str.startswith('Base')]
    base_svhn_grad = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base SVHN')
                                    & df_base["analysis method"] == "gradient"]['result'])
    base_fashionmnist_grad = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base Fashion')
                                            & df_base["analysis method"] == "gradient"]['result'])
    base_cifar10_grad = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base CIFAR')
                                       & df_base["analysis method"] == "gradient"]['result'])
    #
    # base_svhn_input_param = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base SVHN')
    #                                 & df_base["analysis method"] == "input and parameter"]['result'])
    # base_fashionmnist_input_param = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base Fashion')
    #                                         & df_base["analysis method"] == "input and parameter"]['result'])
    # base_cifar10_input_param = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base CIFAR')
    #                                    & df_base["analysis method"] == "input and parameter"]['result'])
    #
    base_svhn_input_grad = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base SVHN')
                                               & df_base["analysis method"] == "input gradient"]['result'])
    base_fashionmnist_input_grad = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base Fashion')
                                                       & df_base["analysis method"] == "input gradient"]['result'])
    base_cifar10_input_grad = np.mean(df_base.loc[df_base["exp set name"].str.startswith('Base CIFAR')
                                                  & df_base["analysis method"] == "input gradient"]['result'])

    df["task uid"] = [str(value).strip("[]").replace(", ", "-")
                           for value in df["task uid"].values]
    df = df.loc[df["task uid"] == "1"]

    df_plot = pd.DataFrame(columns={'experiment name', 'pretraining dataset', 'target dataset',
                                    'start CIFAR10', 'FashionMNIST',
                                    'seed', 'bias reset', r'Analysis results $\Delta$ [dB]', 'Analysis method'})
    df_plot['experiment name'] = df['exp set name']
    df_plot['seed'] = df['exp seed']
    df_plot[r'Analysis results $\Delta$ [dB]'] = df['result']
    df_plot['Analysis method'] = df['analysis method']
    df_plot['bias reset'] = df_plot['experiment name'].str.startswith('BR')
    df_plot['FashionMNIST'] = df_plot['experiment name'].str.contains('Fashion')
    df_plot['start CIFAR10'] = ~df_plot['experiment name'].str.endswith('10')
    for i in df_plot.index:
        if df_plot['experiment name'].loc[i][-1] == '0':
            df_plot['target dataset'].loc[i] = 'CIFAR10'
        else:
            df_plot['pretraining dataset'].loc[i] = 'CIFAR10'
        if pd.isna(df_plot['target dataset'].loc[i]) and df_plot['FashionMNIST'].loc[i]:
            df_plot['target dataset'].loc[i] = 'FashionMNIST'
        elif pd.isna(df_plot['target dataset'].loc[i]):
            df_plot['target dataset'].loc[i] = 'SVHN'
        if pd.isna(df_plot['pretraining dataset'].loc[i]) and df_plot['FashionMNIST'].loc[i]:
            df_plot['pretraining dataset'].loc[i] = 'FashionMNIST'
        elif pd.isna(df_plot['pretraining dataset'].loc[i]):
            df_plot['pretraining dataset'].loc[i] = 'SVHN'
    for i in df_plot[df_plot["Analysis method"] == "gradient"].index:
        if df_plot['target dataset'].loc[i] == 'CIFAR10':
            [str(value).strip("[]").replace(", ", "-")
             for value in df["task uid"].values]
            df = df.loc[df["task uid"] == "1"]
            df_plot[df_plot["Analysis method"] == "gradient"][r'Analysis results $\Delta$ [dB]'].loc[i] = \
                20 * np.log10(1e-6 + df_plot[df_plot["Analysis method"] == "gradient"][r'Analysis results $\Delta$ [dB]'].loc[i]) # / base_cifar10_grad)
        elif df_plot['target dataset'].loc[i] == 'SVHN':
            df_plot[df_plot["Analysis method"] == "gradient"][r'Analysis results $\Delta$ [dB]'].loc[i] = \
                20 * np.log10(1e-6 +df_plot[df_plot["Analysis method"] == "gradient"][r'Analysis results $\Delta$ [dB]'].loc[i]) # / base_svhn_grad)
        elif df_plot['target dataset'].loc[i] == 'FashionMNIST':
            df_plot[df_plot["Analysis method"] == "gradient"][r'Analysis results $\Delta$ [dB]'].loc[i] = \
                20 * np.log10(1e-6 +df_plot[df_plot["Analysis method"] == "gradient"][r'Analysis results $\Delta$ [dB]'].loc[i]) # / base_fashionmnist_grad)

    for i in df_plot[df_plot["Analysis method"] == "input gradient"].index:
        if df_plot['target dataset'].loc[i] == 'CIFAR10':
            df_plot[df_plot["Analysis method"] == "input gradient"][r'Analysis results $\Delta$ [dB]'].loc[i] = \
                20 * np.log10(1e-6 +df_plot[df_plot["Analysis method"] == "input gradient"][r'Analysis results $\Delta$ [dB]'].loc[i]) # / base_cifar10_input_grad)
        elif df_plot['target dataset'].loc[i] == 'SVHN':
            df_plot[df_plot["Analysis method"] == "input gradient"][r'Analysis results $\Delta$ [dB]'].loc[i] = \
                20 * np.log10(1e-6 +df_plot[df_plot["Analysis method"] == "input gradient"][r'Analysis results $\Delta$ [dB]'].loc[i]) # / base_svhn_input_grad)
        elif df_plot['target dataset'].loc[i] == 'FashionMNIST':
            df_plot[df_plot["Analysis method"] == "input gradient"][r'Analysis results $\Delta$ [dB]'].loc[i] = \
                20 * np.log10(1e-6 +df_plot[df_plot["Analysis method"] == "input gradient"][r'Analysis results $\Delta$ [dB]'].loc[i]) #/ base_fashionmnist_input_grad)

    print(df_plot)
    # print(df_plot[df_plot["Analysis method"] == "input gradient"])
    pal = ["#005858", "#daa520"]
    # pal = ["#1e488f", "#789b73"]
    g = sns.catplot(x="start CIFAR10", y=r'Analysis results $\Delta$ [dB]',
                    row="Analysis method", col='FashionMNIST', sharey=False,
                    data=df_plot, hue='bias reset', order=[False, True], col_order=[False, True],
                    row_order=["gradient", "input gradient"],
                    kind="bar", ci="sd", aspect=1.3, palette=pal, height=3.6)
    for ax in g.axes.flat:
        ax.set_title(str(ax.get_title()).split("=")[-1])
    axes = g.axes.flatten()
    axes[0].set_title(r"SVHN and CIFAR10")
    axes[0].set(ylabel="Gradient [dB]")
    axes[1].set_title(r"FashionMNIST and CIFAR10")
    axes[2].set_title("")
    axes[2].set(ylabel="Input gradient [dB]")
    axes[3].set_title("")
    g.savefig(logdir/'pretraining_analysis_plot.pdf')


def step_effect_plot(df_input: pd.DataFrame, logdir: pathlib.Path, method: Dict):
    """
    # TODO: description
    :param df_input:
    :param title:
    :param logdir:
    :param savename:
    :return:
    """
    # Select last task of each experiment

    df = deepcopy(df_input)
    df_plot = df.loc[df["result name"] == method["result name"]]
    df_plot = df_plot.loc[df_plot["analysis method"] == method["analysis method"]]
    df_plot["task uid"] = [str(value).strip("[]").replace(", ", "-")
                                        for value in df_plot["task uid"].values]
    df_plot["task dataset"] = [str(value).strip("[]").replace(", ", "-")
                           for value in df_plot["task dataset"].values]

    # SELECTING DATASET
    df_plot = df_plot.loc[df_plot["task dataset"] == str("'CIFAR10'-'CIFAR10'")]

    df_plot["exp set name"] = [str(value).replace("_", "-")
                               for value in df_plot["exp set name"].values]
    # Select correct experiments
    df_plot = df_plot.loc[df_plot["task uid"] == str(int(str(df_plot["exp set name"]).split("=")[-1])-1)]

    print(f"checking: {df_plot['task uid'].str.startswith('0')}")
    df_plot = df_plot[df_plot["task uid"].str.startswith('0')]
    df_plot = df_plot.sort_values(by=['task uid'], ascending=True)

    print(f"df_plot: {df_plot}")

    # Initialize figure and ax
    fig, ax = plt.subplots()
    # Set the scale of the x-and y-axes
    g = sns.barplot(x="task uid",
                      y="result",
                      hue="exp set name", data=df_plot,
                      palette="muted"
                      )
    plt.legend(title=method["legend"])
    plt.title(method["title"])
    plt.xlabel(method["xlabel"])
    plt.ylabel(method["ylabel"])
    plt.savefig(logdir / str(method["savename"] + ".pdf"),
                bbox_inches='tight', pad_inches=0)
    plt.close()


def step_effect_plot_old(df_input: pd.DataFrame, logdir: pathlib.Path, method: Dict):
    """
    # TODO: description
    :param df_input:
    :param title:
    :param logdir:
    :param savename:
    :return:
    """
    df = deepcopy(df_input)
    df_plot = df.loc[df["result name"] == method["result name"]]
    df_plot = df_plot.loc[df_plot["analysis method"] == method["analysis method"]]
    df_plot["task uid"] = [str(value).strip("[]").replace(", ", "-")
                                        for value in df_plot["task uid"].values]
    df_plot["task dataset"] = [str(value).strip("[]").replace(", ", "-")
                           for value in df_plot["task dataset"].values]

    # SELECTING DATASET
    df_plot = df_plot.loc[df_plot["task dataset"] == str("'CIFAR10'-'CIFAR10'")]

    df_plot["exp set name"] = [str(value).replace("_", "-")
                               for value in df_plot["exp set name"].values]

    print(f"checking: {df_plot['task uid'].str.startswith('0')}")
    df_plot = df_plot[df_plot["task uid"].str.startswith('0')]
    df_plot = df_plot.sort_values(by=['task uid'], ascending=True)

    print(f"df_plot: {df_plot}")

    # Initialize figure and ax
    fig, ax = plt.subplots()
    # Set the scale of the x-and y-axes
    g = sns.barplot(x="task uid",
                      y="result",
                      hue="exp set name", data=df_plot,
                      palette="muted"
                      )
    plt.legend(title=method["legend"])
    plt.title(method["title"])
    plt.xlabel(method["xlabel"])
    plt.ylabel(method["ylabel"])
    plt.savefig(logdir / str(method["savename"] + ".pdf"),
                bbox_inches='tight', pad_inches=0)
    plt.close()


def step_size_plot(df_input: pd.DataFrame, logdir: pathlib.Path, method: Dict):
    """
    # TODO: description
    :param df_input:
    :param title:
    :param logdir:
    :param savename:
    :return:
    """
    df = deepcopy(df_input)
    df_plot = df.loc[df["result name"] == method["result name"]]
    df_plot = df_plot.loc[df_plot["analysis method"] == method["analysis method"]]
    df_plot["task uid"] = [str(value[1]-value[0]) for value in df_plot["task uid"].values if value[1]-value[0] > 0]
    df_plot["task dataset"] = [str(value).strip("[]").replace(", ", "-")
                           for value in df_plot["task dataset"].values]

    df_plot = df_plot.sort_values(by=['task uid'], ascending=True)

    print(f"df_plot: {df_plot}")

    # Initialize figure and ax
    fig, ax = plt.subplots()
    # Set the scale of the x-and y-axes
    g = sns.swarmplot(x="task uid",
                      y="result",
                      hue="task dataset", data=df_plot,
                      palette="muted"
                      )
    plt.legend(title=method["legend"])
    plt.title(method["title"])
    plt.xlabel(method["xlabel"])
    plt.ylabel(method["ylabel"])
    plt.savefig(logdir / str(method["savename"] + ".pdf"),
                bbox_inches='tight', pad_inches=0)
    plt.close()