import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
import math
import numpy as np
import seaborn as sns

sns.set()


def j_b_prime_lower_bound(eps, k, delta, barc, jb):
    if k == 0:
        return jb
    jb_prime = math.exp(-k * eps) * jb - (
        1 - math.exp(-k * eps)) * delta * barc / (math.exp(eps) - 1)

    if jb_prime > 0:

        return jb_prime
    else:
        return 0


def get_dp_result(folder_prefix):

    dirs = os.listdir(folder_prefix)
    saved_model_names = []
    for file in dirs:
        if file != '.DS_Store' and 'csv' not in file:
            saved_model_names.append(file)

    dict_eps = {}
    dict_mean = {}
    dict_var = {}
    for saved_model_name in saved_model_names:
        filename = folder_prefix + saved_model_name + '/all_exp.csv'
        print(saved_model_name)
        df = pd.read_csv(filename)

        epss = [df.loc[i, 'eps'] for i in range(df.shape[0])]
        mean = [df.loc[i, 'adv_loss_mean'] for i in range(df.shape[0])]
        var = [df.loc[i, 'adv_loss_var'] for i in range(df.shape[0])]

        dict_mean[saved_model_name] = mean
        dict_var[saved_model_name] = var
        dict_eps[saved_model_name] = epss

    return dict_eps, dict_mean, dict_var, saved_model_names


if __name__ == '__main__':
    insDP = True
    isFlip = True
    if insDP:
        ks = [0, 4]
        delta = 1e-5
        barc = 2
        end_epoch = 3
    else:
        ks = [0, 4]
        delta = 0.0029
        barc = 0.2
        end_epoch = 4
    if insDP:
        if isFlip:
            folder_name = 'folder/path'
        else:
            folder_name = 'folder/path'
    else:
        if isFlip:
            folder_name = 'folder/path'
        else:
            folder_name = 'folder/path'
    if not os.path.exists(folder_name):
        os.mkdir(folder_name)

    for k_idx in range(1, len(ks)):
        if not os.path.exists(folder_name + 'k' + str(ks[k_idx])):
            os.mkdir(folder_name + 'k' + str(ks[k_idx]))

        all_eps_dicts = []
        all_advloss_mean_dicts = []
        all_advloss_var_dicts = []
        all_names = []

        if insDP:
            if isFlip:
                folder_prefixs = [
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                ]
            else:
                folder_prefixs = [
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                ]
        else:
            if isFlip:
                folder_prefixs = [
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                ]
            else:
                folder_prefixs = [
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                ]

        if isFlip:
            label_name = 'LF'
        else:
            label_name = 'BKD'

        for pre in folder_prefixs:  # different noises for different eps (x-axis)
            dict_eps, dict_mean, dict_var, saved_model_names = get_dp_result(
                pre)
            all_eps_dicts.append(dict_eps)
            all_advloss_mean_dicts.append(dict_mean)
            all_advloss_var_dicts.append(dict_var)
            all_names.append(saved_model_names)  # name is for adv x

        for used_epoch in range(1, end_epoch + 1, 1):
            fig = plt.figure()
            used_eps = [
                all_eps_dicts[i]['adv0'][used_epoch - 1]
                for i in range(len(folder_prefixs))
            ]
            used_jbs = [
                all_advloss_mean_dicts[i]['adv0'][used_epoch - 1]
                for i in range(len(folder_prefixs))
            ]
            lowerbounds = [
                j_b_prime_lower_bound(used_eps[idx], ks[k_idx], delta, barc,
                                      used_jbs[idx])
                for idx in range(len(used_eps))
            ]
            plt.plot(used_eps, [x for x in lowerbounds],
                     color='black',
                     label=f"lower bound",
                     ls='--')

            if insDP == False:

                mean = [
                    all_advloss_mean_dicts[i]['adv' + str(ks[k_idx]) +
                                              '_s1'][used_epoch - 1]
                    for i in range(len(folder_prefixs))
                ]
                plt.plot(used_eps,
                         mean,
                         label=label_name + ' $\gamma=1$',
                         marker='o',
                         markersize=3)
                var = [
                    all_advloss_var_dicts[i]['adv' +
                                             str(ks[k_idx])][used_epoch - 1]
                    for i in range(len(folder_prefixs))
                ]
                lower = [x - y for x, y in zip(mean, var)]
                upper = [x + y for x, y in zip(mean, var)]
                plt.fill_between(used_eps,
                                 lower,
                                 upper,
                                 color=plt.gca().lines[-1].get_color(),
                                 alpha=0.2)

            # the empirical adv loss
            mean = [
                all_advloss_mean_dicts[i]['adv' + str(ks[k_idx])][used_epoch -
                                                                  1]
                for i in range(len(folder_prefixs))
            ]

            if insDP == False:
                plt.plot(used_eps,
                         mean,
                         label=label_name + ' $\gamma=50$',
                         marker='o',
                         markersize=3)
            else:
                plt.plot(used_eps,
                         mean,
                         label=label_name,
                         marker='o',
                         markersize=3)
            var = [
                all_advloss_var_dicts[i]['adv' + str(ks[k_idx])][used_epoch -
                                                                 1]
                for i in range(len(folder_prefixs))
            ]
            lower = [x - y for x, y in zip(mean, var)]
            upper = [x + y for x, y in zip(mean, var)]
            plt.fill_between(used_eps,
                             lower,
                             upper,
                             color=plt.gca().lines[-1].get_color(),
                             alpha=0.2)

            if insDP == False:
                mean = [
                    all_advloss_mean_dicts[i]['adv' + str(ks[k_idx]) +
                                              '_s100'][used_epoch - 1]
                    for i in range(len(folder_prefixs))
                ]
                plt.plot(used_eps,
                         mean,
                         label=label_name + ' $\gamma=100$',
                         marker='o',
                         markersize=3)
                var = [
                    all_advloss_var_dicts[i]['adv' +
                                             str(ks[k_idx])][used_epoch - 1]
                    for i in range(len(folder_prefixs))
                ]
                lower = [x - y for x, y in zip(mean, var)]
                upper = [x + y for x, y in zip(mean, var)]
                plt.fill_between(used_eps,
                                 lower,
                                 upper,
                                 color=plt.gca().lines[-1].get_color(),
                                 alpha=0.2)

            plt.xlabel("$\epsilon$", fontsize=22)

            plt.legend(prop={'size': 18})

            if insDP:
                if isFlip:
                    outfile = folder_name + "k" + str(
                        ks[k_idx]) + '/c_insdp_flip_epoch' + str(used_epoch)
                    title = "(d) CIFAR-10 Label Flipping ($k$=" + str(
                        ks[k_idx]) + ')'
                else:
                    outfile = folder_name + "k" + str(
                        ks[k_idx]) + '/c_insdp_bkd_epoch' + str(used_epoch)
                    title = "(c) CIFAR-10 Backdoor ($k$=" + str(
                        ks[k_idx]) + ')'
            else:  # user level
                if isFlip:
                    outfile = folder_name + "k" + str(
                        ks[k_idx]) + '/c_userdp_flip_epoch' + str(used_epoch)
                    title = "(d) CIFAR-10 Label Flipping ($k$=" + str(
                        ks[k_idx]) + ')'
                else:
                    outfile = folder_name + "k" + str(
                        ks[k_idx]) + '/c_userdp_bkd_epoch' + str(used_epoch)
                    title = "(b) CIFAR-10 Backdoor ($k$=" + str(
                        ks[k_idx]) + ')'
            plt.title(title, fontsize=22)
            plt.tight_layout()
            plt.savefig(outfile + ".png", dpi=300)
            plt.tight_layout()
            plt.savefig(outfile + ".pdf")
            plt.close()
