import pandas as pd

import os
import numpy as np


import matplotlib.pyplot as plt
from scipy.stats import norm
import math
import numpy as  np

import seaborn as sns

sns.set()



def j_b_prime_lower_bound (eps, k, delta, barc, jb):
    if k==0:
        return jb
    jb_prime = math.exp(-k* eps) * jb - (1-math.exp(-k* eps))*delta * barc /(math.exp(eps)-1)

    if jb_prime>0:

        return jb_prime
    else:
        return 0


def get_dp_result(folder_prefix):

    dirs = os.listdir(folder_prefix)
    saved_model_names=[]
    for file in dirs:
        if file!='.DS_Store' and 'csv' not in file:
            saved_model_names.append(file)

    dict_eps={}
    dict_mean ={}
    dict_var ={}
    for saved_model_name in saved_model_names:
        filename = folder_prefix+ saved_model_name +'/all_exp.csv'
        print(saved_model_name)
        df = pd.read_csv(filename)

        epss = [df.loc[i, 'eps'] for i in range(df.shape[0])]
        mean= [df.loc[i, 'adv_loss_mean'] for i in range(df.shape[0])]
        var=[df.loc[i, 'adv_loss_var'] for i in range(df.shape[0])]

        dict_mean[saved_model_name] = mean
        dict_var[saved_model_name] = var
        dict_eps[saved_model_name] = epss

    return dict_eps,dict_mean, dict_var ,saved_model_names



if __name__ == '__main__':
    insDP = True
    isFlip = False
    if insDP:
        ks = [0, 10]
        delta = 1e-5
        barc = 0.5
    else:
        ks = [0, 4]
        delta = 0.0029
        barc = 0.5

    if insDP:
        if isFlip:
            folder_name = 'folder/path'
        else:
            folder_name = 'folder/path'
    else:
        if isFlip:
            folder_name = 'folder/path'
        else:
            folder_name = 'folder/path'
    if not os.path.exists(folder_name):
        os.mkdir(folder_name)

    for k_idx in range(1, len(ks)):
        if not os.path.exists(folder_name + 'k' + str(ks[k_idx])):
            os.mkdir(folder_name + 'k' + str(ks[k_idx]))

        all_eps_dicts = []
        all_advloss_mean_dicts = []
        all_advloss_var_dicts = []
        all_names = []

        if insDP:
            if isFlip:
                folder_prefixs = [
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                ]
            else:
                folder_prefixs = [
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                ]
        else:
            if isFlip:

                folder_prefixs = [
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                ]

            else:
                folder_prefixs = [
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                    'path/to/models/',
                ]

        if isFlip:
            label_name = 'LF'
        else:
            label_name = 'BKD'

        for pre in folder_prefixs: # different noises for different eps (x-axis)
            dict_eps,dict_mean, dict_var ,saved_model_names  = get_dp_result(pre )
            all_eps_dicts.append(dict_eps)
            all_advloss_mean_dicts.append(dict_mean)
            all_advloss_var_dicts.append(dict_var)
            all_names.append(saved_model_names) # name is for adv x

        if insDP:
            end_epoch =10
        else:
            end_epoch=5

        for used_epoch in range(1,end_epoch+1,1):
            fig = plt.figure()
            used_eps = [  all_eps_dicts[i]['adv0'][used_epoch-1]  for i in range(len(folder_prefixs))]
            used_jbs = [  all_advloss_mean_dicts[i]['adv0'][used_epoch-1]  for i in range(len(folder_prefixs))]
            lowerbounds= [ j_b_prime_lower_bound(used_eps[idx],ks[k_idx],delta,barc, used_jbs[idx] ) for idx in range(len(used_eps))]
            plt.plot(used_eps, [ x for x in lowerbounds], color='black', label=f"lower bound",ls='--')

            if insDP== False:

                mean = [  all_advloss_mean_dicts[i]['adv'+str(ks[k_idx])+ '_s1'][used_epoch-1]  for i in range(len(folder_prefixs))]
                plt.plot(used_eps, mean, label=label_name+' $\gamma=1$',marker='o',markersize=3)
                var = [  all_advloss_var_dicts[i]['adv'+str(ks[k_idx])][used_epoch-1]  for i in range(len(folder_prefixs))]
                lower = [x - y for x, y in zip(mean, var)]
                upper = [x + y for x, y in zip(mean, var)]
                plt.fill_between(used_eps, lower, upper, color=plt.gca().lines[-1].get_color(),   alpha=0.2)

            # the empirical adv loss
            mean = [  all_advloss_mean_dicts[i]['adv'+str(ks[k_idx])][used_epoch-1]  for i in range(len(folder_prefixs))]
            var = [  all_advloss_var_dicts[i]['adv'+str(ks[k_idx])][used_epoch-1]  for i in range(len(folder_prefixs))]
            lower = [x - y for x, y in zip(mean, var)]
            upper = [x + y for x, y in zip(mean, var)]
            if insDP== False:
                plt.plot(used_eps, mean, label=label_name+' $\gamma=50$',marker='o',markersize=3)
            else:
                plt.plot(used_eps, mean, label=label_name,marker='o',markersize=3)
            plt.fill_between(used_eps, lower, upper, color=plt.gca().lines[-1].get_color(),   alpha=0.2)

            if insDP== False:
                mean = [  all_advloss_mean_dicts[i]['adv'+str(ks[k_idx])+'_s100'][used_epoch-1]  for i in range(len(folder_prefixs))]
                plt.plot(used_eps, mean, label=label_name+' $\gamma=100$',marker='o',markersize=3)
                var = [  all_advloss_var_dicts[i]['adv'+str(ks[k_idx])][used_epoch-1]  for i in range(len(folder_prefixs))]
                lower = [x - y for x, y in zip(mean, var)]
                upper = [x + y for x, y in zip(mean, var)]
                plt.fill_between(used_eps, lower, upper, color=plt.gca().lines[-1].get_color(),   alpha=0.2)


            plt.xlabel("$\epsilon$", fontsize=22)

            plt.legend(prop={'size'   : 18})


            if insDP:
                if isFlip:
                    outfile = folder_name+ "k"+str(ks[k_idx])+'/m_insdp_flip_epoch'+str(used_epoch)
                    title= "(d) MNIST Label Flipping ($k$="+str(ks[k_idx])+ ')'
                else:
                    outfile = folder_name+ "k"+str(ks[k_idx])+'/m_insdp_bkd_epoch'+str(used_epoch)
                    title= "(c) MNIST Backdoor ($k$="+str(ks[k_idx])+ ')'
            else:
                if isFlip:
                    outfile = folder_name+ "k"+str(ks[k_idx])+'/m_userdp_flip_epoch'+str(used_epoch)
                    title= "(c) MNIST Label Flipping ($k$="+str(ks[k_idx])+ ')'
                else:
                    plt.ylabel('$J(D\')$', fontsize=22)
                    outfile = folder_name+ "k"+str(ks[k_idx])+'/m_userdp_bkd_epoch'+str(used_epoch)
                    title= "(a) MNIST Backdoor ($k$="+str(ks[k_idx])+ ')'
            plt.title(title, fontsize=22)
            plt.tight_layout()
            plt.savefig(outfile + ".png", dpi=300)
            plt.tight_layout()
            plt.savefig(outfile + ".pdf")
            plt.close()
