import matplotlib.pyplot as plt
import pickle
from matplotlib.patches import Patch
import matplotlib.cm as cmap
import matplotlib.lines as mlines
import numpy as np
from model_eval import load_cached_results, cache_results
import labellines
### We now visualize the results obtained in model_eval.py

def extract_val(results_dict, eval_dset):
    if eval_dset == 'BG Gap':
        val = results_dict['mixed_same'] - results_dict['mixed_rand']
    elif eval_dset == 'Relative BG Gap':
        val = (results_dict['mixed_same'] - results_dict['mixed_rand']) / results_dict['mixed_same']
    else:
        val = results_dict[eval_dset]
    return val

eval_dsets=['original', 'only_bg_t', 'mixed_same', 'mixed_rand']
linf_eps = [0.5, 1.0, 2.0, 4.0, 8.0]
l2_eps = [0.25, 0.5, 1, 3, 5]

### We'll start simple with bar plots for bg-only-t and bg-gap metrics.
def plot_results_bar(eval_dset):
    arches = ['resnet18', 'resnet50']
    colors = ['coral', 'deepskyblue']
    with open('./results/model_eval.pkl', 'rb') as f:
        results = pickle.load(f)
    f_bar, axs_bar = plt.subplots(1,2, figsize=(7,3))
    for ax_bar, arch in zip(axs_bar, arches):
        for i,eps in enumerate(l2_eps):
            results_dict = results['{}_l2_eps{}.ckpt'.format(arch, eps)]
            val = extract_val(results_dict, eval_dset)
            # val = results_dict['mixed_same'] - results_dict['mixed_rand'] if eval_dset == 'BG Gap' else results_dict[eval_dset]
            ax_bar.bar(i, val, color=colors[0])

        for i, eps in enumerate(linf_eps):
            results_dict = results['{}_linf_eps{}.ckpt'.format(arch, eps) if eps != 0 else '{}_l2_eps0.ckpt'.format(arch)]
            val = extract_val(results_dict, eval_dset)
            ax_bar.bar(1+len(l2_eps)+i, val, color=colors[1])

        ax_bar.set_title('Robust ResNet{}s'.format(arch[-2:]))
        ax_bar.set_xticks([i for i in range(len(l2_eps))]+[1+len(l2_eps)+i for i in range(len(linf_eps))])
        ax_bar.set_xticklabels(['$\epsilon={}$'.format(eps) for eps in (l2_eps+linf_eps)], rotation='vertical')
        ax_bar.legend(handles=[Patch(facecolor=c, label=l) for (c,l) in zip(colors, ['$L_2$ AT', '$L_\infty$ AT'])],
                    loc='lower right')
        ax_bar.set_ylabel(eval_dset+'{}'.format(' Accuracy' if eval_dset in eval_dsets else ''))

    f_bar.tight_layout()
    f_bar.savefig('plots/{}_l2_linf_AT.jpg'.format(eval_dset), dpi=200)

### Now we get fancier. We plot BG Gap by showing Mixed Same and Mixed Rand on same fig
def draw_vertical_segment(ax, x, y1, y2, c, add_label=False):
    if add_label:
        ax.plot([x,x], [y1, y2], '-', c=c, label='{:.1f}'.format(y1-y2))
    else:
        ax.plot([x,x], [y1, y2], '-', c=c)
    ax.plot(x, y1, '^', c=c)
    ax.plot(x, y2, 'o', c=c)
    return ax

def line_segment_plot(eval_dset1='mixed_same', eval_dset2='mixed_rand', ylabel='Accuracy on IN-9 Subset',
                      title='Background Gap', results_path='./results/model_eval.pkl', add_label=True):
    ''' for each model, we plot line segment w/ eval_dset1 acc on top and eval_dset2 on bottom '''
    with open(results_path, 'rb') as f:
        results = pickle.load(f)

    ext = '' if 'waterbirds' in results_path else '.ckpt'

    linf_eps = [0.5, 1.0, 2.0, 4.0, 8.0]
    l2_eps = [0, 0.25, 0.5, 1, 3, 5]

    # arches = ['resnet18', 'resnet50']
    # f, axs = plt.subplots(1,2, figsize=(8.5,4.25))

    arches = ['wide_resnet50_2']
    f, axs = plt.subplots(1,1, figsize=(4.25,4.25))
    axs =[axs]
    
    colors = [cmap.viridis(x/len(l2_eps)) for x in range(len(l2_eps))]
    for ax, arch in zip(axs, arches):
        labels = []
        for i, eps in enumerate(l2_eps):
            results_dict = results['{}_l2_eps{}{}'.format(arch, eps, ext)]
            y1, y2 = [100.* results_dict[x] for x in [eval_dset1, eval_dset2]]
            ax = draw_vertical_segment(ax, i+1, y1, y2, colors[i], add_label)
            labels.append('$\ell_2$ $\epsilon={}$'.format(eps) if eps>0 else 'No AT')

        for i, eps in enumerate(linf_eps):
            results_dict = results['{}_linf_eps{}{}'.format(arch, eps, ext)]
            y1, y2 = [100. * results_dict[x] for x in [eval_dset1, eval_dset2]]
            ax = draw_vertical_segment(ax, 1+len(l2_eps)+i, y1, y2, colors[1+i], add_label)
            labels.append('$\ell_\infty$ $\epsilon={}$'.format(eps))
        
        
        labellines.labelLines(ax.get_lines())

        mixed_same_handle = mlines.Line2D([], [], marker='^', c='black', linestyle='None', label=eval_dset1.replace('_', ' ').title())
        mixed_rand_handle = mlines.Line2D([], [], marker='o', c='black', linestyle='None', label=eval_dset2.replace('_', ' ').title())
        ax.legend(handles=[mixed_same_handle, mixed_rand_handle], loc='lower left')
        ax.set_xticks(np.arange(1,1+len(l2_eps)+len(linf_eps)))
        ax.set_xticklabels(labels, rotation='vertical')
        # ax.set_xlabel('{} Adversarial Training Norm and Budget'.format(arch.replace('resnet', 'ResNet')))
        ax.set_xlabel('WideResNet50 Adversarial Training Norm and Budget')
        ax.set_ylabel(ylabel)
        ax.set_title(title)
    f.tight_layout()
    f.savefig('./plots/{}_best_wide.jpg'.format(title.lower().replace(' ', '_')), dpi=300, bbox_inches='tight', pad_inches=0.03)


### We also want some scatter plots to compare Adv Robust Acc to other metrics (objectnet, inet-c)
def scatter_vs_robust_acc(results_path, extract_val, ylabel):
    l2_robust_accs = load_cached_results('../spurious/results/l2_robust_accs_inet.pkl')
    linf_robust_accs = load_cached_results('../spurious/results/linf_robust_accs_inet.pkl')
    saved_adv_accs = dict({'l2': l2_robust_accs, 'linf': linf_robust_accs})

    ood_saved_accs = load_cached_results(results_path)

    styles, colors = ['-^', '-v'], ['deepskyblue', 'coral']
    # arches = ['resnet18', 'resnet50']
    f, ax = plt.subplots(1,1)
    max_x, max_y = 0, 0
    for arch in ['resnet18', 'resnet50']:
        for adv_train_norm, epsilons, s, c in zip(['linf', 'l2'], [linf_eps, l2_eps], styles, colors):
            adv_robust_accs, ood_accs = [], []
            arch_prettyname = arch.replace('resnet', 'RN')
            # standard model
            val = 100.* saved_adv_accs['l2'][f'{arch}_l2_eps0'][0]
            no_AT_inet_acc =  val.item() if arch == 'resnet50' else val
            # try:
            no_AT_ood_acc = extract_val(ood_saved_accs[f'{arch}_l2_eps0'])
            # except:
            #     no_AT_ood_acc = ood_saved_accs[f'{arch}_l2_eps0'][0]
            #     print(arch, no_AT_ood_acc)
            #     no_AT_ood_acc = no_AT_ood_acc.item()


            for adv_train_eps in epsilons:
                norm2 = adv_train_norm if adv_train_eps > 0 else 'l2'
                mkey = f'{arch}_{norm2}_eps{adv_train_eps}'

                # adv_robust_accs.append(100.*np.average(list(saved_adv_accs[adv_train_norm][mkey].values())))
                # adv_robust_accs.append(100.*list(saved_adv_accs[adv_train_norm][mkey].values())[-2]))
                val = 100.*saved_adv_accs[adv_train_norm][mkey][0]
                if not isinstance(val, float):
                    print(adv_train_norm, mkey)
                    val = val.item()
                # adv_robust_accs.append(100.*saved_adv_accs[adv_train_norm][mkey][0].item())#.values())))
                adv_robust_accs.append(val)

                ood_accs.append(extract_val(ood_saved_accs[mkey]))
            # max_x = max(max_x, max(adv_robust_accs))
            # max_y = max(max_y, max(ood_accs))

            adv_robust_accs =  [no_AT_inet_acc] + adv_robust_accs
            ood_accs = [no_AT_ood_acc] + ood_accs
            ax.plot(adv_robust_accs, ood_accs, s, c=c, label='${}$ AT {}'.format(adv_train_norm.replace('l', '\ell_').replace('inf', '\infty'), arch_prettyname))
            ax.plot(no_AT_inet_acc, no_AT_ood_acc, s[-1], c='gray', label=f'No AT {arch_prettyname}')
       
        # ax.plot(adv_robust_accs[0], ood_accs[0], 's', c='gray', label='No AT')
            max_x = max(max_x, no_AT_inet_acc)
            max_y = max(max_y, no_AT_ood_acc)
        ax.set_xlim([0, 1.1*max_x])
        ax.set_ylim([0, 1.1*max_y])
        xs = np.linspace(0, no_AT_inet_acc, 100)
        ax.plot(xs, no_AT_ood_acc / no_AT_inet_acc * xs, '-.' if arch == 'resnet18' else '--')
        # ax.set_title(arch.title())
        # f.savefig('./plots/{}_vs_inet_acc_{}.jpg'.format(ylabel.lower().replace(' ', '_'), arch), dpi=300, bbox_inches='tight', pad_inches=0.03)

    ax.legend()
    # ax.set_xlabel('Accuracy under PGD Attack')
    ax.set_xlabel('Standard ImageNet Accuracy')
    ax.set_ylabel(ylabel+ ' Accuracy')
    f.savefig('./plots/{}_vs_inet_acc.jpg'.format(ylabel.lower().replace(' ', '_'), arch), dpi=300, bbox_inches='tight', pad_inches=0.03)

def both_ood_accs_vs_id_acc():
    '''
    Goal is to have two subplots, one for each backbone.
    Within one subplot, we plot OOD acc vs. in Distr Acc. The tricky thing is we are going to
    show *two* OOD accs (imagenet C and Objectnet). The point is Inet-C acc goes down linearly w/
    standard acc, while Objectnet (breaking spur corrs) goes down faster for robust models. 
    '''
    # plt.style.use('ggplot')
    # getting our OOD accs. Both need some of their own postprocessing
    inet_c_accs = load_cached_results('./results/imagenet_c3.pkl')
    extract_inet_c_acc = lambda d : np.average([100.* (1-v.item()) for v in d.values()])
    objectnet_accs = load_cached_results('./results/objectnet_no_norm.pkl')
    extract_objectnet_acc = lambda x: x.cpu().item()

    # this is pretty extra, all we want is the standard accs, which is available by accessing test_eps=0
    l2_robust_accs = load_cached_results('../spurious/results/l2_robust_accs_inet.pkl')
    linf_robust_accs = load_cached_results('../spurious/results/linf_robust_accs_inet.pkl')
    saved_adv_accs = dict({'l2': l2_robust_accs, 'linf': linf_robust_accs})

    # f, axs = plt.subplots(1,2, figsize=(9,3.75))
    # arches = ['resnet18', 'resnet50']

    f, axs = plt.subplots(1,1, figsize=(4.5,3.75))
    arches = ['wide_resnet50_2']
    axs =[axs]

    styles, colors = ['-*', '-o'], ['deepskyblue', 'coral']
    for ax, arch in zip(axs, arches):
        val = 100.* saved_adv_accs['l2'][f'{arch}_l2_eps0'][0]
        no_AT_id_acc = val.item() if arch == 'resnet50' else val
        # try:
        no_AT_ood_objectnet_acc = extract_objectnet_acc(objectnet_accs[f'{arch}_l2_eps0'])
        no_AT_ood_inet_c_acc = extract_inet_c_acc(inet_c_accs[f'{arch}_l2_eps0'])

        xs = np.linspace(0, no_AT_id_acc, 100)
        l2 = ax.plot(xs, no_AT_ood_inet_c_acc / no_AT_id_acc * xs, '-.', c='black', label='ImageNet-C')
        l1 = ax.plot(xs, no_AT_ood_objectnet_acc / no_AT_id_acc * xs, '--', c='gray', label='ObjectNet')            
        labellines.labelLines(ax.get_lines())
        
        handles = []
        for adv_train_norm, epsilons, s, c in zip(['linf', 'l2'], [linf_eps, l2_eps], styles, colors):
            id_accs, ood_objectnet_accs, ood_inet_c_accs = [], [], [] # in distribution (standard acc) vs out of distr
            arch_prettyname = 'Wide ResNet50' if 'wide' in arch else arch.replace('resnet', 'ResNet') 
            # standard model

            for adv_train_eps in epsilons:
                norm2 = adv_train_norm if adv_train_eps > 0 else 'l2'
                mkey = f'{arch}_{norm2}_eps{adv_train_eps}'

                # id_acc = 100.*saved_adv_accs[adv_train_norm][mkey][0]
                id_acc = 100.*saved_adv_accs['l2'][mkey][0]
                if not isinstance(id_acc, float):
                    id_acc = id_acc.item()
                id_accs.append(id_acc)
                ood_objectnet_accs.append(extract_objectnet_acc(objectnet_accs[mkey]))
                ood_inet_c_accs.append(extract_inet_c_acc(inet_c_accs[mkey]))
            
            id_accs = [no_AT_id_acc] + id_accs
            ood_objectnet_accs = [no_AT_ood_objectnet_acc] + ood_objectnet_accs
            ood_inet_c_accs = [no_AT_ood_inet_c_acc] + ood_inet_c_accs

            # if adv_train_norm == 'linf':
            #     l2_label = 

            handles.append(ax.plot(id_accs, ood_objectnet_accs, s, c=c, label='${}$ AT'.format(adv_train_norm.replace('l', '\ell_').replace('inf', '\infty'))))
            # l2_handle =  ax.plot(id_accs, ood_objectnet_accs, s, c=c, label='${}$ AT {}'.format(adv_train_norm.replace('l', '\ell_').replace('inf', '\infty'), arch_prettyname))
            handles.append(ax.plot(id_accs, ood_inet_c_accs, s, c=c, label='${}$ AT'.format(adv_train_norm.replace('l', '\ell_').replace('inf', '\infty'))))#, arch_prettyname))
            top_handle = ax.plot(no_AT_id_acc, no_AT_ood_objectnet_acc, 's', c='gray', label=(f'No AT' if adv_train_norm=='l2' else '_no_legend_'))
            # top_handle = ax.plot(no_AT_id_acc, no_AT_ood_objectnet_acc, 's', c='gray', label=(f'No AT {arch_prettyname}' if adv_train_norm=='l2' else '_no_legend_'))
            ax.plot(no_AT_id_acc, no_AT_ood_inet_c_acc, 's', c='gray', label='_no_legend_')#, label=f'No AT {arch_prettyname}')

        ax.legend(handles=[handles[0][0], handles[-1][0], top_handle[0]]) 
        # lines = ax.get_lines()
        # print([l.get_ls() for l in lines])
        # labellines.labelLines(lines[-2:])
        ax.set_xlabel('In Distribution Accuracy')
        ax.set_ylabel('Out of Distribution Accuracy')
        ax.set_title(arch_prettyname)

    # f.tight_layout(); f.savefig('./plots/both_ood_vs_id_acc.jpg', dpi=300, bbox_inches='tight', pad_inches=0.03)
    f.tight_layout(); f.savefig('./plots/both_ood_vs_id_acc_wide.jpg', dpi=300, bbox_inches='tight', pad_inches=0.03)

#### Now we plot our similar but related hypothesis about presence of spurious correlations --> better robustness
def in9_adv_robustness():
    results_path = '../spurious/results/in9_eval.pkl'
    results = load_cached_results(results_path)
    l2_epsilons = [0, 0.25, 0.5, 1, 3, 5]
    linf_epsilons = [0, 0.5, 1.0, 2.0, 4.0, 8.0]

    l2_test_eps = [0, 0.25, 0.5, 1, 3, 5]
    linf_test_eps = [0, 0.5/255, 1/255, 2/255, 4/255, 8/255]
    colors = [cmap.viridis(i/len(linf_epsilons)) for i in range(len(linf_epsilons))]
    for arch in ['resnet18', 'resnet50']:
        f, axs = plt.subplots(1,2, figsize=(9,3.75), sharey=True)
        # all_diffs = []
        for ax, norm, epsilons, test_eps in zip(axs, ['l2', 'linf'], [l2_epsilons, linf_epsilons], [l2_test_eps, linf_test_eps]):
            # test_eps = [x/255 for x in epsilons] if norm == 'linf' else epsilons[1:]
            labels = ['Test $\epsilon={}$'.format('{:.1f}'.format(255*x).rstrip('0').rstrip('.')+'/255' if norm == 'linf' else x) for x in test_eps]
            for i, test_eps in enumerate(test_eps):
                # for train_eps in epsilons:
                #     print(train_eps, results[f'{arch}_{norm}_eps{train_eps}']['mixed_same'][norm].keys())
                # norm2 = norm if test_eps > 0 else 'l2' # sorry for the hackiness
                diffs = [results[f"{arch}_{norm if train_eps > 0 else 'l2'}_eps{train_eps}"]['mixed_same'][norm][test_eps] - 
                         results[f"{arch}_{norm if train_eps > 0 else 'l2'}_eps{train_eps}"]['mixed_rand'][norm][test_eps]
                         for train_eps in epsilons]
                # diffs = [results[f'{arch}_{norm if test_eps > 0 else 'l2'}_eps{train_eps}']['mixed_same'][norm][test_eps] - 
                #          results[f'{arch}_{norm if test_eps > 0 else 'l2'}_eps{train_eps}']['mixed_rand'][norm][test_eps]
                #          for train_eps in epsilons]
                ax.plot(epsilons, diffs, '-o' if test_eps > 0 else '-*', c=colors[i], label=labels[i])
            ax.legend()
            ax.set_ylabel('$\Delta$ Adv. Robust Accuracy')
            ax.set_xlabel('${}$ Adversarial Training $\epsilon$'.format(norm.replace('l', '\ell_').replace('inf', '\infty')))
            ax.set_title('Mixed Same - Mixed Rand')
        f.tight_layout(); f.savefig('./plots/ms_mr_diff_adv_robustness_{}.jpg'.format(arch), dpi=300, bbox_inches='tight', pad_inches=0.03)


def simpler_in9_robustness_plot():
    results_path = '../spurious/results/in9_eval.pkl'
    results = load_cached_results(results_path)
    l2_epsilons = [0, 0.25, 0.5, 1, 3, 5]
    linf_epsilons = [0.5, 1.0, 2.0, 4.0, 8.0]

    colors = ['coral', 'deepskyblue']
    f, axs = plt.subplots(1,2, figsize=(9,4))
    bar_f, bar_axs = plt.subplots(1,2, figsize=(9,3.75))
    hatches = ['////', '\\\\\\']

    fourpanel_f, fourpanel_axs = plt.subplots(2,2, figsize=(10,8), sharey=True)
    plt.style.use('ggplot')

    scatter_line_f, scatter_line_ax = plt.subplots(1,1)
    clean_ms_s, clean_mr_s, adv_ms_s, adv_mr_s = [], [], [], []
    for ax, bar_ax, arch, fourp_row in zip(axs, bar_axs, ['resnet18', 'resnet50'], fourpanel_axs):
        i = 0
        for norm, epsilons, s, c, fourpanel_ax in zip(['l2', 'linf'], [l2_epsilons, linf_epsilons], ['-o', '-*'], colors, fourp_row):
            diffs = []
            # baseline diffs tracks 'BG Gap'-like metric (Diff in clean acc bw MS and MR) for the diff models
            no_at_ms_clean_acc, no_at_mr_clean_acc = [results[f'{arch}_l2_eps0'][f'mixed_{x}']['l2'][0] for x in ['same', 'rand']]
            baseline_diffs = [no_at_ms_clean_acc - no_at_mr_clean_acc] if norm == 'linf' else []
            
            clean_diffs, attacked_diffs = [no_at_ms_clean_acc-no_at_mr_clean_acc] if norm=='linf' else [], [no_at_ms_clean_acc-no_at_mr_clean_acc] if norm =='linf' else []
            for eps in epsilons:
                i += 1
                mkey = f'{arch}_{norm}_eps{eps}'
                
                # test_eps = epsilons[1:] if norm == 'l2' else [e/255 for e in epsilons]
                # ms_at, mr_at = [np.average([results[mkey][f'mixed_{x}'][norm][e] for e in test_eps]) for x in ['same',  'rand']]
                
                # ms_at, mr_at = [np.average(list(results[mkey]['mixed_{}'.format(x)][norm].values())) for x in ['same', 'rand']]
                e = eps if norm == 'l2' else eps/255
                ms_at, mr_at = [results[mkey]['mixed_{}'.format(x)][norm][e] for x in ['same', 'rand']]
                diffs.append(ms_at-mr_at)

                # baseline_eps = 0 if eps > 0  else (0.25 if norm == 'l2' else 0.5/255)
                clean_ms, clean_mr = [results[mkey][f'mixed_{x}'][norm][0] for x in ['same', 'rand']]
                baseline_diffs.append(clean_ms-clean_mr)

                bar_ax.bar(1.5*i-0.25, clean_ms-clean_mr, color=(c if eps != 0 else 'lightgray'), hatch=hatches[0], width=0.5)
                bar_ax.bar(1.5*i+0.25, ms_at-mr_at, color=(c if eps != 0 else 'lightgray'), hatch=hatches[1], width=0.5)

                # for scatter/line
                clean_diffs.append(clean_ms-clean_mr)
                attacked_diffs.append(ms_at-mr_at)
                scatter_line_ax.plot(clean_diffs[-1], attacked_diffs[-1], color=(c if eps != 0 else 'gray'), marker=(('v' if '18' in arch else '^') if eps !=0 else 's'))

                clean_ms_s.append(clean_ms)
                clean_mr_s.append(clean_mr)
                adv_ms_s.append(ms_at)
                adv_mr_s.append(mr_at)

                # diffs.append(ms-mr)
            if norm == 'linf':
                diffs = [old_diffs[0]] + diffs
            for curr_ax in [ax, fourpanel_ax]:
                curr_ax.plot(diffs, s, c=c, label='${}$ AT, Robust Acc. Gap'.format(norm.replace('l', '\ell_').replace('inf', '\infty')))
                curr_ax.plot(baseline_diffs, ls=':', marker=s[1], color='gray', label='No AT, Standard Acc. Gap')
            old_diffs = diffs

            ls = '-' if '18' in arch else '--'
            scatter_line_ax.plot(clean_diffs, attacked_diffs, c=c, ls=ls, label='${}$ AT {}'.format(norm.replace('l', '\ell_').replace('inf', '\infty'), arch.replace('resnet', 'ResNet')))

            fourpanel_ax.set_xticks(np.arange(len(diffs)))
            fourpanel_ax.set_xticklabels(['No AT']+['$\epsilon={}$'.format(eps) for eps in epsilons])
            fourpanel_ax.set_ylabel('Accuracy Gain due to Spurious Feature', fontsize=10)
            fourpanel_ax.set_xlabel('Train and Test Attack Budget')
            fourpanel_ax.legend()

        # for the bar plot
        bar_ax.set_xticks(1.5*np.arange(1, i+1))
        ticklabels = ['No AT']+['$\ell_2, \epsilon={}$'.format(eps) for eps in l2_eps] + ['$\ell_\infty, \epsilon={}$'.format(eps) for eps in linf_eps]
        bar_ax.set_xticklabels(ticklabels, rotation='vertical')
        if arch == 'resnet50':
            bar_ax.legend(handles=[Patch(facecolor=c, label=l) for (c,l) in zip(colors, ['$\ell_2$ AT', '$\ell_\infty$ AT'])] +
                                [Patch(facecolor='white', hatch=h, label=l) for (h,l) in zip(hatches, ['Clean', 'Attacked'])],
                        loc='upper left')
        bar_ax.set_xlabel('Train and Test Attack Budget')
        bar_ax.set_ylabel('Accuracy Gain due to Spurious Feature')
        bar_ax.set_title(arch.replace('resnet', 'ResNet'))


        ax.set_xticks(np.arange(len(diffs)))
        ax.set_xticklabels(l2_epsilons)
        ax.legend()
        ax.set_xlabel('$\ell_2$ Adv Train $\epsilon$')
        ax.set_ylabel('$\Delta$ Adv. Robust Accuracy')
        ax2 = ax.twiny()
        ax2.set_xticks(np.arange(len(diffs)))
        ax2.set_xlabel('$\ell_\infty$ Adv Train $\epsilon$')
        ax2.set_xticklabels([f'{x}/255' for x in ([0]+linf_epsilons)])
    # f.tight_layout();f.savefig('./plots/ms_mr_diff_avg2.jpg', dpi=300, bbox_inches='tight', pad_inches=0.03)
    # bar_f.tight_layout();bar_f.savefig('./plots/ms_mr_diff_avg_bar.jpg', dpi=300, bbox_inches='tight', pad_inches=0.03)

    scatter_line_ax.legend()
    xs = np.linspace(0,0.15,100)
    scatter_line_ax.plot(xs, xs, '-.')
    scatter_line_ax.set_xlabel('Clean Accuracy Gain due to Spur Ftr')
    scatter_line_ax.set_ylabel('Adv Robust Accuracy Gain due to Spur Ftr')
    # scatter_line_f.tight_layout(); scatter_line_f.savefig('./plots/ms_mr_diff_scatter_line.jpg', dpi=300, bbox_inches='tight', pad_inches=0.03)

    avg_ms, avg_mr, avg_adv_ms, avg_adv_mr = [100*np.average(x) for x in [clean_ms_s, clean_mr_s, adv_ms_s, adv_mr_s]]
    print('{:.2f}&{:.2f}'.format(avg_ms, avg_adv_ms))
    print('{:.2f}&{:.2f}'.format(avg_mr, avg_adv_mr))
    print('{:.2f}&{:.2f}'.format(avg_ms-avg_mr, avg_adv_ms-avg_adv_mr))

    fourpanel_f.tight_layout(); fourpanel_f.savefig('./plots/ms_mr_fourpanel.jpg', dpi=300)

def robustness_undefended_ms_mr():
    results_path = '../spurious/results/in9_eval.pkl'
    results = load_cached_results(results_path)
    l2_epsilons = [0, 0.25, 0.5, 1, 3, 5]
    ms_accs, mr_accs = [], []
    f, ax = plt.subplots(1,1)
    for eps in l2_epsilons:
        ms_acc, mr_acc = [results['resnet18_l2_eps1'][f'mixed_{x}']['l2'][eps] for x in ['same', 'rand']]
        ms_accs.append(ms_acc)
        mr_accs.append(mr_acc)
        # diffs
    ax.plot(l2_epsilons, ms_accs, '-*', label='Mixed-Same')
    ax.plot(l2_epsilons, mr_accs, '-*', label='Mixed-Rand')
    ax.legend()
    f.savefig('test.png')

def simple_bar(ax, results, extract_val, arch_keys, arch_names, ylabel):
    ax.set_ylabel(ylabel)
    i = 0
    xticks, xticklabels = [], []
    print(ylabel)
    for arch_key, arch_name in zip(arch_keys, arch_names):
        i += 1
        no_at_val, l2_val = summarize_val(results, extract_val, arch_key, in9=('Background' in ylabel))
        ax.bar(i-0.15, no_at_val, width=0.3, color='gray')
        ax.bar(i+0.15, l2_val, width=0.3, color='coral')
        xticks.append(i)
        xticklabels.append(arch_name)
    # handles = [mpatches.Patch(color=c, label=l) for c,l in zip(['gray', 'coral'], ['No AT', '$\ell_2$ AT ($\epsilon=3$)'])]
    # l = ax.legend(handles=handles)#, bbox_to_anchor=(0.5, -0.2), loc='upper center', ncol=3)
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabels)

def summarize_val(results, extract_val, mkey, in9=False):
    pre = 'robust_' if 'robust' in list(results.keys())[-1] else ''
    # print(results.keys())
    # pre = ''
    suf = '.ckpt' if in9 else ''
    no_at_val = extract_val(results[pre+mkey+'_l2_eps0'+suf])
    l2_val = extract_val(results[pre+mkey+'_l2_eps3'+suf])
    return no_at_val, l2_val

def extra_backbones_all_benchmarks():
    '''
    for supplementary, we will generate 'figure 1' type plots for five new backbones on all benchmarks
    total of 5 bar plots
    '''
    plt.style.use('ggplot')
    backbones = ['mobilenet', 'shufflenet', 'vgg16_bn', 'densenet', 'resnext50_32x4d']
    pretty_names = ['MobileNetv2', 'ShuffleNet', 'VGG16 (bn)', 'DenseNet161', 'ResNext50 (32x4d)']
    f, axs = plt.subplots(2,1,figsize=(12,6.4))
    # ax1 : rfs
    results = load_cached_results('../attempt2/l2_vs_linf/rival10_processed_results2.pkl')
    extract_val = lambda x : x['rca']
    simple_bar(axs[0], results, extract_val, backbones, pretty_names, '$RFS$ (RIVAL10)')

    #ax2 : rcs
    results = load_cached_results('../attempt2/l2_vs_linf/results.pkl')
    simple_bar(axs[1], results, extract_val, backbones, pretty_names, '$RCS$ (Salient ImageNet-1M)')
    # axs[1].set_xlabel('Lower $RFS$ ($RCS$) entails Lower Foreground (Core Feature) Sensitivity', fontsize=16)
    f.savefig('plots/extra_backbones_rfs_rcs.jpg', dpi=300, bbox_inches='tight', pad_inches=0.5)


    f, axs = plt.subplots(2,1,figsize=(12,6.4))
    #ax3 : bg gap
    results = load_cached_results('./results/model_eval.pkl')   
    extract_val = lambda x : (x['mixed_same']-x['mixed_rand'])
    simple_bar(axs[0], results, extract_val, backbones, pretty_names, 'Background Gap (IN-9)')

    #ax4: waterbirds
    results = load_cached_results('../spurious/results/waterbirds_eval_best_val_saved2.pkl')
    extract_val = lambda x : (x['majority']-x['minority']) 
    simple_bar(axs[1], results, extract_val, backbones, pretty_names, 'Waterbirds Gap')
    # axs[1].set_xlabel('Higher Gap entails Greater Background/Spurious Sensitivity', fontsize=16)
    f.savefig('plots/extra_backbones_gaps.jpg', dpi=300, bbox_inches='tight', pad_inches=0.5)

    f, axs = plt.subplots(1,1,figsize=(12,3.2))
    #ax5: objectnet
    objectnet_accs = load_cached_results('./results/objectnet_no_norm.pkl')
    l2_robust_accs = load_cached_results('../spurious/results/l2_robust_accs_inet.pkl')
    ood_id_ratio = dict({k:0.01*objectnet_accs[k].cpu().item()/l2_robust_accs[k][0] for k in objectnet_accs})
    extract_val = lambda x:x
    simple_bar(axs, ood_id_ratio, extract_val, backbones, pretty_names, 'ObjectNet : ImageNet Acc')
    # axs.set_xlabel('Lower Ratio entails Lower Natural Distributional Robustness', fontsize=16)
    f.savefig('plots/extra_backbones_ood.jpg', dpi=300, bbox_inches='tight', pad_inches=0.5)

    f, ax = plt.subplots(1,1, figsize=(0,0))
    ax.set_axis_off()
    handles = [Patch(color=c, label=l) for c,l in zip(['gray', 'coral'], ['No AT', '$\ell_2$ AT ($\epsilon=3$)'])]
    l = ax.legend(handles=handles, bbox_to_anchor=(0.5, -0.2), loc='upper center', ncol=2)
    f.savefig('./plots/extra_backbones_legend.jpg', dpi=300, extra_artists=[l], bbox_inches='tight', pad_inches=0.01)

if __name__ == '__main__':
    
    # import seaborn as sns
    # sns.set_theme()
    
    ### boring bar plots
    # plot_results_bar('only_bg_t')
    # plot_results_bar('BG Gap')
    # plot_results_bar('Relative BG Gap')
    
    ### cooler line segment plots -- for demonstrating gaps between two metrics
    # line_segment_plot()
    # line_segment_plot('original', 'only_bg_t', title='Removing Foregrounds')
    # line_segment_plot('majority', 'minority', 'Validation Accuracy', 'Waterbirds', '../spurious/results/waterbirds_eval_best_val_saved2.pkl')

    ### scatter plots b/w adv robustness and distr robustness
    # avg = lambda d : np.average([100.* (1-v.item()) for v in d.values()])
    # scatter_vs_robust_acc('./results/imagenet_c3.pkl', extract_val=avg, ylabel='ImageNet C')

    # scatter_vs_robust_acc('./results/objectnet_no_norm.pkl', extract_val = lambda x: x.cpu().item(), ylabel='ObjectNet')

    # both_ood_accs_vs_id_acc()

    ### reverse hypothesis line plots
    # in9_adv_robustness()
    # simpler_in9_robustness_plot()
    # robustness_undefended_ms_mr()

    ### extra backbones
    extra_backbones_all_benchmarks()