import pandas as pd

import os
import numpy as np

import matplotlib.pyplot as plt
from scipy.stats import norm
import math
import numpy as np
from scipy.stats.morestats import kstat

import seaborn as sns

sns.set()


def j_b_prime_upper_bound(eps, k, delta, barc, jb):
    return math.exp(k * eps) * jb + (1 - math.exp(k * eps)) * delta * barc / (
        1 - math.exp(eps))


def j_b_prime_lower_bound(eps, k, delta, barc, jb):
    if k == 0:
        return jb
    jb_prime = math.exp(-k * eps) * jb - (
        1 - math.exp(-k * eps)) * delta * barc / (math.exp(eps) - 1)

    if jb_prime > 0:

        return jb_prime
    else:
        return 0


def get_dp_result(folder_prefix, saved_model_name):

    filename = folder_prefix + saved_model_name + '/all_exp.csv'
    print(filename)
    df = pd.read_csv(filename)
    epochs = [df.loc[i, 'epoch'] for i in range(df.shape[0])]
    epss = [df.loc[i, 'eps'] for i in range(df.shape[0])]
    mean = [df.loc[i, 'adv_loss_mean'] for i in range(df.shape[0])]
    var = [df.loc[i, 'adv_loss_var'] for i in range(df.shape[0])]

    return epochs, epss, mean, var


if __name__ == '__main__':
    insDP = True
    isFlip = False
    isuser_alpha = True

    if insDP:
        ks = [0, 2, 4, 6, 8, 10]
        noise = [10]
        noise_idx = 0
        end_epoch = 10
    else:
        if isFlip:
            ks = [0, 2, 4, 6, 8, 10]
        else:
            ks = [0, 1, 2, 4, 6, 8, 10]
        noise = [2.5]
        noise_idx = 0
        end_epoch = 5

    if insDP:
        if isFlip:
            folder_name = 'folder/path/'
        else:
            folder_name = 'folder/path/'
    else:
        if isFlip:
            folder_name = 'folder/path/'
        else:
            folder_name = 'folder/path/'
    if not os.path.exists(folder_name):
        os.mkdir(folder_name)

    if not os.path.exists(folder_name + 'noise' + str(noise[noise_idx])):
        os.mkdir(folder_name + 'noise' + str(noise[noise_idx]))

    if insDP:
        delta = 1.0e-05
    else:
        delta = 0.0029
    barc = 0.5
    all_dicts = []
    all_names = []

    if insDP:
        if isFlip:
            folder_prefixs = [
                'path/to/models/',
            ]
        else:
            folder_prefixs = [
                'path/to/models/',
            ]
    else:
        if isFlip:
            if isuser_alpha:
                folder_prefixs = [
                    'path/to/models/',
                    'path/to/models/',
                ]
            else:
                folder_prefixs = [
                    'path/to/models/',
                    'path/to/models/',
                ]
        else:
            if isuser_alpha:
                folder_prefixs = [
                    'path/to/models/',
                    'path/to/models/',
                ]
            else:
                folder_prefixs = [
                    'path/to/models/',
                    'path/to/models/',
                ]

    if insDP:
        if isFlip:
            labels_name = ['LF']
        else:
            labels_name = ['BKD']
    else:
        if isFlip:
            if isuser_alpha:
                labels_name = [
                    'LF $\\alpha$=10%', 'LF $\\alpha$=40%', 'LF $\\alpha$=60%',
                    'LF $\\alpha$=100%'
                ]
            else:
                labels_name = [
                    'LF $\gamma=1$', 'LF $\gamma=50$', 'LF $\gamma=100$'
                ]
        else:
            if isuser_alpha:
                labels_name = [
                    'BKD $\\alpha$=10%', 'BKD $\\alpha$=40%',
                    'BKD $\\alpha$=60%', 'BKD $\\alpha$=100%'
                ]
            else:
                labels_name = [
                    'BKD $\gamma=1$',
                    'BKD $\gamma=50$',
                    'BKD $\gamma=100$',
                    'DBA $\gamma=1$',
                    'DBA $\gamma=50$',
                    'DBA $\gamma=100$',
                ]

    epochs, epss_0, mean_0, var_0 = get_dp_result(folder_prefixs[0], 'adv0')

    all_mean_diff_lines = dict()
    all_var_diff_lines = dict()
    for i in range(len(folder_prefixs)):
        all_mean_ks = dict()
        all_var_ks = dict()
        prefix = folder_prefixs[i]
        for k in ks:
            model_name = 'adv' + str(k)
            epochs, epss, mean, var = get_dp_result(prefix, model_name)
            all_mean_ks[k] = mean
            all_var_ks[k] = var

        all_mean_diff_lines[i] = all_mean_ks
        all_var_diff_lines[i] = all_var_ks

    for ep in range(1, end_epoch + 1, 1):
        fig = plt.figure()
        epoch_idx = ep - 1
        lowerbounds = [
            j_b_prime_lower_bound(epss_0[epoch_idx], k, delta, barc,
                                  mean_0[epoch_idx])
            for k in [i for i in range(0,
                                       max(ks) + 1, 1)]
        ]
        plt.plot([i for i in range(0,
                                   max(ks) + 1, 1)], [x for x in lowerbounds],
                 color='black',
                 label=f"lower bound",
                 ls='--')
        for i in range(len(folder_prefixs)):
            prefix = folder_prefixs[i]
            jbprimes_mean = []
            jbprimes_var = []
            for k in ks:
                jbprimes_mean.append(
                    all_mean_diff_lines[i][k][epochs[epoch_idx] - 1])
                if k == 0:
                    jbprimes_var.append(0)
                else:
                    jbprimes_var.append(
                        all_var_diff_lines[i][k][epochs[epoch_idx] - 1])
            plt.plot(ks,
                     jbprimes_mean,
                     label=labels_name[i],
                     marker='o',
                     markersize=3)
            lower = [x - y for x, y in zip(jbprimes_mean, jbprimes_var)]
            upper = [x + y for x, y in zip(jbprimes_mean, jbprimes_var)]
            plt.fill_between(ks,
                             lower,
                             upper,
                             color=plt.gca().lines[-1].get_color(),
                             alpha=0.2)

        plt.xlabel("$k$", fontsize=22)

        plt.legend(prop={'size': 16})

        if insDP:
            if isFlip:
                outfile = folder_name + 'noise' + str(
                    noise[noise_idx]) + "/m_flip_insdp_k_eps" + str(
                        epss[epoch_idx])
                title = "(b) MNIST Label Flipping ($\epsilon$=" + str(
                    epss[epoch_idx]) + ')'
            else:
                plt.ylabel('$J(D\')$', fontsize=22)
                outfile = folder_name + 'noise' + str(
                    noise[noise_idx]) + "/m_bkd_insdp_k_eps" + str(
                        epss[epoch_idx])
                title = "(a) MNIST Backdoor ($\epsilon$=" + str(
                    epss[epoch_idx]) + ')'
        else:
            if isFlip:
                outfile = folder_name + 'noise' + str(
                    noise[noise_idx]) + "/m_flip_userdp_k_eps" + str(
                        epss[epoch_idx])
                if isuser_alpha:
                    outfile = outfile + '_alpha'
                    title = "(c) MNIST Label Flipping ($\epsilon$=" + str(
                        epss[epoch_idx]) + ')'
                else:
                    outfile = outfile + '_gamma'
                    title = "(g) MNIST Label Flipping ($\epsilon$=" + str(
                        epss[epoch_idx]) + ')'
            else:
                plt.ylabel('$J(D\')$', fontsize=22)
                outfile = folder_name + 'noise' + str(
                    noise[noise_idx]) + "/m_bkd_userdp_k_eps" + str(
                        epss[epoch_idx])
                if isuser_alpha:
                    outfile = outfile + '_alpha'
                    title = "(a) MNIST Backdoor ($\epsilon$=" + str(
                        epss[epoch_idx]) + ')'
                else:
                    outfile = outfile + '_gamma'
                    title = "(e) MNIST Backdoor ($\epsilon$=" + str(
                        epss[epoch_idx]) + ')'

        print("outfile", outfile)
        plt.title(title, fontsize=22)

        plt.tight_layout()
        plt.savefig(outfile + ".png", dpi=300)
        plt.tight_layout()
        plt.savefig(outfile + ".pdf")
        plt.close()
