import numpy as np
import pandas as pd

import matplotlib
matplotlib.use('Agg')

import matplotlib.pyplot as plt

plt.style.use(['seaborn-paper', './paper.mplstyle'])

def save_figs(fn,types=('.pdf', '.png')):
    fig = plt.gcf()
    fig.tight_layout()
    for t in types:
        fig.savefig(fn+t, bbox_inches='tight')

plt.savefig = save_figs

# cw == Calrini_wagner
# pgd == pgd
# au == Autoattack
# d == Deepfool
# n == New

def preprocess(full_str):
    df = pd.read_csv(full_str, delimiter='\t')
    df = df.rename(columns=lambda x: x.strip())
    df = df.applymap(lambda x: np.nan if x == ' nan ' else float(x))
    
    corrections = ['cw_', 'cw_f_', 'n_', 'n_f_', 'd_', 'au_', 'pgd_', 'pgd_f_', 'co_']
    base_keys = [key + 'd' for key in corrections]
    
        
    df[base_keys] = df[base_keys].applymap(lambda x: np.nan if x < 1e-5 else float(x))
    df[base_keys]= df[base_keys].applymap(lambda x: np.nan if x > 50 else float(x))            
    
    for key in corrections:
        #df[np.isnan(df[key + 'd'])][key + 't'] = np.nan
        df.loc[np.isnan(df[key + 'd']), key + 't'] = np.nan
    
    df = df[~np.isnan(df['co_d'])]
    df = df[df['co_d'] > 1e-5]            
    
    return df


def postprocess(df, select_col='au_d'):
    cols = ['n_d', 'pgd_d', 'cw_d', 'au_d', 'd_d'] #['cw_d', 'n_d', 'pgd_d', 'au_d', 'd_d']
    cols_set = ['Ours', 'PGD', 'Carlini', 'Auto', 'Deep'] #['Carlini', 'New', 'PGD', 'Auto', 'Deep']
    vals_set = ''
    mean_string = ''
    median_string = ''
    col_string = 'Ours \\t PGD \\t C-W \\t Auto \\t DeepF'
    
    for i, col in enumerate(cols):
        subset = list(i for i in cols if i != col)
        val = np.mean(df[col] < (df[subset].min(axis=1) + 1e-5))
        percentage = 100*(df[col] - (df[subset].min(axis=1))) / (df[subset].min(axis=1))
        imp_save = np.mean(percentage[df[col] < df[subset].min(axis=1)])
        #print(col, val)
        vals_set +=  cols_set[i] + ': ' + str(val) + ' || ' + str(imp_save) + ' !! '
        mean_string += str(np.nanmean(percentage)) + ' \\t '
        median_string += str(np.nanmedian(percentage)) + ' \\t '        
        if col == select_col:
            save_1 = np.nanmean(percentage)
            save_2 = np.nanmedian(percentage)
            prop = np.mean(df[col] <= df[subset].min(axis=1))
            save_3 = np.nanmedian(df[col])
            save_4 = 100*np.mean(~np.isnan(df[col]))
            save_5 = np.nanmedian(df[col.split('_')[0] + '_t'])
        '''if col == 'au_d':
            #print(percentage[df[col] < df[subset].min(axis=1)], df.loc[[df[col] < df[subset].min(axis=1)], col])
            condition = df[col] < df[subset].min(axis=1)
            df['mini'] = df[subset].min(axis=1)
            print('AAAAAAAAAA ', df.loc[condition, col])
            print('###',  df.loc[condition, 'au_d'])
            print('!!!!', df.loc[condition, 'mini'])
            print('bbb', df.loc[condition, ['cw_d', 'n_d', 'pgd_d', 'au_d', 'd_d']])
            
            #df[col].loc(df[col] < df[subset].min(axis=1)))
            '''
            
    #print(vals_set)
    #print(col_string)
    #print(mean_string)
    #print(median_string)
    return save_1, save_2, prop, save_3, save_4, save_5

base_str = 'cifar10-1.0-samples-750_v2_'

vals_in = [1, 2, 3, 4, 383, 510, 638, 765, 893, 1020]
vals_2 = [-1, -2, -3, -4]


#f, ax = plt.subplots(1,2)


def processing(vals, color='b', plotting=False):
    means = []
    meds = []
    props = []
    avgs = []
    attackeds = []
    timesset = []
    for val in vals:
        print('='*20)
        #print(val)
        full_str = base_str + str(val) + '.0-0.0784313725490196'

        '''df = pd.read_csv(full_str, delimiter='\t')
        df = df.rename(columns=lambda x: x.strip())
        df = df.applymap(lambda x: np.nan if x == ' nan ' else float(x))
        df = df.applymap(lambda x: np.nan if float(x) < 1e-5 else float(x))   
        df = df.applymap(lambda x: np.nan if x > 100 else float(x))   '''
        
        df = preprocess(full_str)                  

        #print(df[['cw_d', 'n_d', 'pgd_d', 'au_d', 'd_d', 'co_d']].mean())

        print('Successful proportion')
        print('CW: {}, New: {}, PGD: {}, Auto: {}, DeepFool: {}, Cohen: {}'.format(np.mean(df['cw_d'] > 0), np.mean(df['n_d'] > 0), np.mean(df['pgd_d'] > 0), np.mean(df['au_d'] > 0), np.mean(df['d_d'] > 0), np.mean(df['co_d'] > 0)))
        
        print('Outperformance')
        s_1, s_2, p, s_3, s_4, s_5 = postprocess(df)
        means.append(s_1)
        meds.append(s_2)
        props.append(p)
        avgs.append(s_3)
        attackeds.append(s_4)
        timesset.append(s_5)
    

    if color == 'b':


        vals = np.asarray(vals) / 255

        if plotting:
            plt.clf()             
            plt.plot(vals, means, 'b')
            plt.xlabel('AutoAttack Radius')
            plt.ylabel('Average Percentage Difference')  
            plt.xlim(0, np.max(vals))
            plt.tight_layout()
            plt.savefig('AutoAttack-rad')#.pdf')
    else:
        vals = -1*np.asarray(vals)
        if plotting:
            plt.clf()
            plt.plot(vals, means, 'r')
            plt.xlabel('Multiples of Cohen Radius')  
            plt.xlim(1,4)
            plt.tight_layout()
            plt.savefig('AutoAttack-Cohen-rad')#.pdf')
        
    return vals, means, avgs, attackeds, timesset
    
plt.clf() 

vals, means, avgs, attackeds, timesset = processing(vals_in, plotting=False)
vals_in = np.asarray(vals_in)
fig, axs = plt.subplots(1,2)
#axs[0].plot(vals_in / 255, means, 'b')
axs[0].plot(vals_in / 255, attackeds)#, 'b')
axs[0].set_ylabel('Sucessfull Attacked')#, color='b')  
axs[0].set_xlabel('Average Radius')
axs[0].set_xlim(0, np.max(vals_in / 255))
#ax2 = axs[0].twinx()
#ax2.plot(vals_in / 255, avgs, 'r-.')
#ax2.set_ylabel('Average Certified Radius')  



vals, means, avgs, attackeds, timesset = processing(vals_2, color='b', plotting=False)
vals_2 = np.asarray(vals_2)
#axs[1].plot(-1*vals_2, means, 'b')
axs[1].plot(-1*vals_2, attackeds)#, 'b')
#axs[1].set_xlabel('Multiples of Cohen Radius')  
axs[1].set_xlim(1,4)
axs[1].set_xlabel('Autoattack Radius = Cohen$\\times n$')

#ax3 = axs[1].twinx()
#ax3.plot(-1*vals_2, avgs, 'r-.')

#ax3.set_ylabel('Average Certified Radius', color='r')
#ax3.set_ylabel('Average Certified Radius', color='r')  


plt.savefig('AutoAttack-Consolidated')


plt.clf() 

vals, means, avgs, attackeds, timesset = processing(vals_in, plotting=False)
vals_in = np.asarray(vals_in)
fig, axs = plt.subplots(2,2)
#axs[0].plot(vals_in / 255, means, 'b')
axs[0,0].plot(vals_in / 255, 100*np.asarray(attackeds))#, 'b')
axs[0,0].set_ylabel('Attacked Proportion')
axs[1,0].set_xlabel('$\epsilon$')
axs[0,0].set_xlim(0, np.max(vals_in / 255))
#ax2 = axs[0].twinx()
axs[1,0].plot(vals_in / 255, avgs)#, 'r-.')
axs[1,0].set_ylabel('Average Attack Radius')#, color='b')  
#axs[1,0].set_xlabel('AutoAttack Radius')
axs[1,0].set_xlim(0, np.max(vals_in / 255))
axs[1,0].set_ylim(0, 4)

#ax2.set_ylabel('Average Certified Radius')  


vals, means, avgs, attackeds, timesset = processing(vals_2, color='b', plotting=False)

vals_2 = np.asarray(vals_2)
#axs[1].plot(-1*vals_2, means, 'b')
axs[0,1].plot(-1*vals_2, 100*np.asarray(attackeds))#, 'r')#, 'b')
#axs[1].set_xlabel('Multiples of Cohen Radius')  
axs[0,1].set_xlim(1,4)
axs[1,1].set_xlabel('$n$ ($\epsilon$ = Cohen $\\times \\textrm{  } n$)')

#ax3 = axs[1].twinx()
axs[1,1].plot(-1*vals_2, avgs)#, 'r-.')

#axs[1,1].set_ylabel('Average Certified Radius')#, color='r')
#axs[0,1].set_ylabel('Average Certified Radius')#, color='r')  
axs[1,1].set_xlim(1,4)
axs[1,1].set_ylim(0,4)

plt.savefig('AutoAttack-Consolidated-Grid')




processing(vals_in, plotting=True)
processing(vals_2, color='r', plotting=True)
#plt.show()

base_str = 'cifar10-1.0-samples-750_v2_-2.0-'

pgd_vals = [0.058823529411764705, 0.13725490196078433, 0.11764705882352941, 0.09803921568627451, 0.0784313725490196, 0.0392156862745098]

pgd_vals = np.sort(np.asarray(pgd_vals))[::-1]

def processing_pgd(vals):
    means = []
    meds = []
    props = []
    val_means = []
    for val in vals:
        #print('='*20)
        #print(val)
        full_str = base_str + str(val)

        '''df = pd.read_csv(full_str, delimiter='\t')
        df = df.rename(columns=lambda x: x.strip())
        df = df.applymap(lambda x: np.nan if x == ' nan ' else float(x))
        df = df.applymap(lambda x: np.nan if x < 1e-5 else float(x)) 
        df = df.applymap(lambda x: np.nan if x > 100 else float(x))   '''

        df = preprocess(full_str)                            
        
        s_1, s_2, p, s_3, _, _ = postprocess(df, select_col='pgd_d')
        means.append(s_1)
        meds.append(s_2)
        props.append(p)
        val_means.append(s_3)
    plt.clf()

    fig, ax1 = plt.subplots()

    color = 'tab:red'
    ax1.set_xlabel(r'$\epsilon \times 255$')
    ax1.set_ylabel('Median Perturbation Radius', color=color)
    ax1.plot(vals*255, val_means, color=color)
    ax1.tick_params(axis='y', labelcolor=color)

    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

    color = 'tab:blue'
    ax2.set_ylabel('Proportion Best', color=color)  # we already handled the x-label with ax1
    ax2.plot(vals*255, props, color=color)
    ax2.tick_params(axis='y', labelcolor=color)

    fig.tight_layout()  # otherwise the right y-label is slightly clipped   
    plt.savefig('pgd')#.pdf')
    #print(vals)
    #print(means)
    #print(props)
    #print(val_means)
    return df

df_out = processing_pgd(pgd_vals)

def at_radius(df, radius, col):
    return (df[col] <= radius).mean()

def processing_general(datasets, vals):
    outputs = {}
    radii = np.linspace(0, 5, 400)   
    cols = ['n_d', 'pgd_d', 'cw_d', 'au_d', 'd_d', 'co_d'] #['cw_d', 'n_d', 'pgd_d', 'au_d', 'd_d']
    bonus_set = ['pgd', 'cw', 'n']
    
    keys = ['Cohen', 'Ours', 'Ours (F)', 'PGD', 'PGD (F)', 'C-W', 'C-W (F)', 'Auto', 'DeepF']
    keys_prefixes = ['co', 'n', 'n_f', 'pgd', 'pgd_f', 'cw', 'cw_f', 'au', 'd']
    keys_reference = ['k:', 'b', 'b-.', 'r', 'r-.', 'g', 'g-.', 'm', 'c']    
    
    label_letters = ['a)', 'b)', 'c)', 'd)', 'e)', 'f)']
    label_letters = iter(label_letters) 
    
    dataset_names = {}
    dataset_names['mnist'] = 'M' #'MNIST'
    dataset_names['cifar10'] = 'C' #'CIFAR-$10$'
    dataset_names['tinyimagenet'] = 'T-I' 
    
    technique_names = {}      
    technique_names['n_d'] = 'Ours'
    technique_names['pgd_d'] = 'PGD'    
    technique_names['d_d'] = 'D.Fool'    
    technique_names['cw_d'] = 'C-W'    
    technique_names['au_d'] = 'Auto'            
    
    plt.clf()
    fig, axs = plt.subplots(2, 3)
    
    i = 0
    
    for dataset in datasets:    
        for val in vals:
            success, best, med, perc_coh, times = {}, {}, {}, {}, {}
            #print('='*20)
            #print(val)
            full_str = dataset + '-' + str(val)

            '''df = pd.read_csv(full_str, delimiter='\t')
            df = df.rename(columns=lambda x: x.strip())
            df = df.applymap(lambda x: np.nan if x == ' nan ' else float(x))
            df = df.applymap(lambda x: np.nan if x < 1e-5 else float(x))
            df = df.applymap(lambda x: np.nan if x > 100 else float(x))'''
            
            df = preprocess(full_str)                        

            #print(np.sum(~np.isnan(df['au_d'])), 'A1')

            
            df = df[~np.isnan(df['co_d'])]
            df = df[df['co_d'] > 1e-5]            

            #print(np.sum(~np.isnan(df['au_d'])), np.mean(~np.isnan(df['au_d'])), 'A2')            
            
            for col in cols:
                success[col] = 100*np.mean(~np.isnan(df[col]))
                #print(success)
                #if col == 'au_d':
                #    print(success[col], 100*np.mean(~np.isnan(df[col])), 'B')
                
                subset = list(i for i in cols[:-1] if i != col)
                percentage = 100*(df[col] - (df[subset].min(axis=1))) / (df[subset].min(axis=1))
                                           
                df = df.replace(np.nan, 1000)     
                best[col] = 100*np.mean(df[col] < (df[subset].min(axis=1) + 1e-5))
                df = df.replace(1000, np.nan)
                med[col] = np.nanmedian(df[col])
                perc_coh[col] = 100*np.nanmedian((df[col] - df['co_d']) / (df['co_d'] + 1e-5))

                times[col] = np.nanmedian(df[col.split('_')[0] + '_t'])
                #print(col, 'au_t', np.sum(np.isnan(df['au_t'])), np.sum(np.isnan(df['au_d'])), df.shape)
                
                #print('alepha', col.split('_')[0] + '_t', np.median(df[col.split('_')[0] + '_t']), np.nanmedian(df[col.split('_')[0] + '_t']), times[col], times)
                
                outputs[col] = np.asarray([at_radius(df, rad, col) for rad in radii])
                if col.split('_')[0] in bonus_set:
                    outputs[col.split('_')[0] + '_f_d'] = np.asarray([at_radius(df, rad, col.split('_')[0] + '_f_d') for rad in radii])            
                    col_mod = col.split('_')[0] + '_f_d'

                    med[col_mod] = np.nanmedian(df[col_mod]) / med[col]
                    times[col_mod] = np.nanmedian(df[col.split('_')[0] + '_f_t']) / times[col]
                    
                    outputs[col_mod] = np.asarray([at_radius(df, rad, col) for rad in radii])
                else:
                    col_mod = col.split('_')[0] + '_f_d'

                    med[col_mod] = 1
                    times[col_mod] = 1
                    
         
            base_string = dataset_names[dataset] + '-$' + str(val) + '$ & \t '

            for col in cols[:-1]:
                #if col == 'au_d':
                #    print(str(success[col]), success[col], 'C')
                
                if col != 'n_d':
                    base_string = ' & \t '
                
                mod_string = base_string + technique_names[col] + ' & \t $' + f"{success[col]:.0f}" + '\%$ & \t $' + f"{best[col]:.0f}" + '\%$ & \t $' + f"{med[col]:.2f}" + '$ & \t $' + f"{perc_coh[col]:.0f}" + '\%$ & \t $' + f"{times[col]:.2f}"
                col_mod = col.split('_')[0] + '_f_d'
                mod_string += '$ & \t $' + f"{med[col_mod]:.2f}" + '$ & \t $' + f"{times[col_mod]:.2f}" + '$'
                print(mod_string + '\\\\')

            print('\\cmidrule(r){1-2} \\cmidrule(r){3-7} \\cmidrule(r){8-9}')
                
            for indx, key in enumerate(keys):
                axs[np.unravel_index(i, (2,3))].plot(radii, outputs[keys_prefixes[indx] + '_d'], keys_reference[indx], label=keys[indx])
            axs[np.unravel_index(i, (2,3))].set_title(next(label_letters) + ' ' + dataset_names[dataset] + ', $\sigma = $ ' + str(val))
            if i == 0 or i == 3:
                axs[np.unravel_index(i, (2,3))].set_ylabel('Attack Proportion')
            else:
                axs[np.unravel_index(i, (2,3))].get_yaxis().set_ticks([])
            if i >= 3:
                axs[np.unravel_index(i, (2,3))].set_xlabel('Radius')
            else:
                axs[np.unravel_index(i, (2,3))].get_xaxis().set_ticks([])
                
            i += 1

            
    plt.savefig('consolidated')

processing_general(['mnist', 'cifar10', 'tinyimagenet'], [0.5, 1.0])
        
        
full_str = 'tinyimagenet' + '-' + str(1.0)

df = pd.read_csv(full_str, delimiter='\t')
df = df.rename(columns=lambda x: x.strip())
df = df.applymap(lambda x: np.nan if x == ' nan ' else float(x))
df = df.applymap(lambda x: np.nan if x < 1e-5 else float(x))

df = preprocess(full_str)

#print(np.sum(~np.isnan(df['au_d'])), 'A1')


#df = df[~np.isnan(df['co_d'])]
#df = df[df['co_d'] > 1e-5]            


def sigma_variance_test(dataset='cifar10', val=1.0):
    full_str = 'sigma_variance_test_' + dataset + '-' + str(val)
    df = pd.read_csv(full_str, delimiter='\t')
    df = df.rename(columns=lambda x: x.strip())
        
    scalings = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5]
    
    outputs = []
    
    for scaling in scalings:
        outputs.append(df['n_d_' + str(scaling)].mean())
    
    plt.clf()
    plt.plot(np.asarray(scalings), outputs)
    plt.xlabel('$\hat{\sigma}$')
    plt.ylabel('Average Perturbation Radius')
    plt.savefig('scaling_performance')
    
sigma_variance_test()    
