from core_accuracy import *
from my_utils import load_robust_resnet, rel_score
import pickle 
from matplotlib.patches import Patch
# plan load multiple l2 and linf adversarially trained models
# evaluate core accuracy + spur acc
# plot core vs spur, with marker shape/color determing norm + size determining epsilon

linf_eps = [0.5, 1.0, 2.0, 4.0, 8.0]
l2_eps = [0, 0.25, 0.5, 1, 3, 5]
def plot_results(metric='rcs'):
    # arches = ['resnet50', 'resnet18']
    arches = ['wide_resnet50_2']#, 'resnet18']
    colors = ['coral', 'deepskyblue']
    line_styles = ['-','--']
    if metric == 'rcs':
        results_path = './l2_vs_linf/results.pkl'
        dset, core_or_fg, c_or_f = 'Salient ImageNet-1M', 'Core', 'C'
    else:
        results_path = './l2_vs_linf/rival10_processed_results2.pkl'
        dset, core_or_fg, c_or_f = 'RIVAL10', 'Foreground', 'F'
    with open(results_path, 'rb') as f:
        results = pickle.load(f)

    if metric == 'rcs':
        results['robust_resnet18_l2_eps0'] = dict({'rca':0.4169, 'core': 0.78, 'spur':0.48})
        results['robust_resnet50_l2_eps0'] = dict({'rca':0.4639, 'core': 0.84, 'spur':0.57})
    # l2_to_linf_eps_ratio = 1/255
    f, axs = plt.subplots(1,2, figsize=(7,3.8))
    f_bar, axs_bar = plt.subplots(1,2, figsize=(7,3))
    f_line, ax_line = plt.subplots(1,1, figsize=(6,5.5))
    # for norm, epsilon, marker in zip([['linf', 'l2'], [linf_eps, l2_eps], ['o','*']]):
    for ax, ax_bar, arch, ls in zip(axs, axs_bar, arches, line_styles):
        prefix = 'robust_'# if 'wide' not in arch else ''
        handles = []
        ax.set_xlim([0,1])
        ax.set_ylim([0,1])
        xs = np.linspace(0,1,100)
        ax.plot(xs, xs, '-')
        ax.set_xlabel('Spurious Accuracy')
        ax.set_ylabel('Core Accuracy')
        l2_rcas, linf_rcas = [], []
        line_xs = np.arange(len(l2_eps))

        print(results.keys())

        for i,eps in enumerate(l2_eps):
            results_dict = results['{}{}_l2_eps{}'.format(prefix, arch, eps)]
            core_acc, spur_acc = [results_dict[x]/100 for x in ['core', 'spur']]
            handle = ax.scatter(spur_acc, core_acc, marker='s', s=20*eps, color=colors[0], 
                                edgecolors='black',label='$L_2$ AT')
            ax_bar.bar(i, results_dict['rca'], color=colors[0])
            l2_rcas.append(results_dict['rca'])

        for i, eps in enumerate(linf_eps):
            results_dict = results['{}{}_linf_eps{}'.format(prefix, arch, eps)]
            core_acc, spur_acc = [results_dict[x]/100 for x in ['core', 'spur']]
            handle = ax.scatter(spur_acc, core_acc, marker='^', s=20*eps, color=colors[1], 
                                edgecolors='black', label='$L_\infty$ AT')
            ax_bar.bar(1+len(l2_eps)+i, results_dict['rca'], color=colors[1])
            linf_rcas.append(results_dict['rca'])
        ax_line.plot(line_xs, [l2_rcas[0]]+linf_rcas, marker='*', ls=ls, color='deepskyblue', label='$\ell_\infty$ AT WideResNet50s')#.format(arch[-2:]))
        handles.append(handle)
        ax_line.plot(line_xs, l2_rcas, marker='o', ls=ls, color='coral', label='$\ell_2$ AT WidResNet50s')#.format(arch[-2:]))
        handles.append(handle)

        ax.set_title('Robust ResNet{}s'.format(arch[-2:]))
        ax_bar.set_title('Robust ResNet{}s'.format(arch[-2:]))
        ax.legend(handles=handles, loc='lower right')

        ax_bar.set_xticks([i for i in range(len(l2_eps))]+[1+len(l2_eps)+i for i in range(len(linf_eps))])
        ax_bar.set_xticklabels(['$\epsilon={}$'.format(eps) for eps in (l2_eps+linf_eps)], rotation='vertical')
        ax_bar.legend(handles=[Patch(facecolor=c, label=l) for (c,l) in zip(colors, ['$L_2$ AT', '$L_\infty$ AT'])],
                    loc='lower right')
        ax_bar.set_ylabel('$RCS$')
        
        
        ax_line2 = ax_line.twiny()
        _ = [a.set_xticks(line_xs) for a in [ax_line, ax_line2]]
        _ = [a.set_xlim(-0.25, -0.75+len(line_xs)) for a in [ax_line, ax_line2]]
        ax_line.set_xticklabels(['No AT' if eps ==0 else f'$\ell_2, \epsilon={eps}$' for eps in l2_eps], rotation=45, fontsize=9)
        ax_line2.set_xticklabels(['No AT' if eps ==0 else f'$\ell_\infty, \epsilon={eps}/255$' for eps in ([0]+linf_eps)], rotation=45, fontsize=9)
        ax_line.set_xlabel('Adversarial Training Attack Budget', fontsize=12)
        ax_line.set_ylabel('Relative {} Sensitivity ($R{}S$)'.format(core_or_fg, c_or_f), fontsize=12)
        ax_line.set_title(dset, fontsize=13)
        ax_line.legend()
    # f.tight_layout()
    # f.savefig('l2_vs_linf/scatter.jpg', dpi=200)

    # f_bar.tight_layout()
    # f_bar.savefig('l2_vs_linf/bar.jpg', dpi=200)

    f_line.tight_layout()
    f_line.savefig(f'l2_vs_linf/{metric}_line_wide.jpg', dpi=300, bbox_inches='tight', pad_inches=0.03)

    return ax_line

# plot_results()
def obtain_results():
    linf_eps = []#[0.5, 1.0, 2.0, 4.0, 8.0]
    l2_eps = [0,3]#[0, 0.25, 0.5, 1, 3, 5]
    arches = ['resnext50_32x4d', 'mobilenet', 'densenet', 'shufflenet', 'vgg16_bn']#['wide_resnet50_2']#, 'wide_resnet101_2']#, 'resnet18', 'resnet50']
    # arches.reverse()

    arches = ['wide_resnet50_2']
    # arches = ['mobilenet', 'shufflenet']
    # arches = ['densenet']
    # arches = ['resnet50']
    # linf_eps = [8.0]

    for arch in arches:
        for (norm, eps_list) in zip(['l2', 'linf'], [l2_eps, linf_eps]):
            # if norm == 'l2':
            #     continue
            if os.path.exists('./l2_vs_linf/results3.pkl'.format(norm)):
                with open('./l2_vs_linf/results3.pkl'.format(norm), 'rb') as f:
                    d = pickle.load(f)
            else:
                d = dict()
            
            for eps in eps_list:
                model_name = '{}_{}_eps{}'.format(arch, norm, eps)
                model_name = 'robust_'+model_name if 'resnet' in model_name else model_name 
                if model_name not in d:
                    model = load_robust_resnet(model_name)
                    core_acc, spur_acc, core_acc_by_class, spur_acc_by_class = core_spur_accuracy(model, noise_sigma=0.25, apply_norm=True)
                    rca = rel_score(core_acc / 100, spur_acc / 100)
                    d[model_name] = dict({
                        'core': core_acc, 'spur': spur_acc, 
                        'core_by_class': core_acc_by_class, 'spur_by_class': spur_acc_by_class, 
                        'rca':rca})
                    with open('./l2_vs_linf/results3.pkl'.format(norm), 'wb') as f:
                        pickle.dump(d, f)
                else:
                    core_acc, spur_acc = [d[model_name][x] for x in ['core', 'spur']]
                    rca = rel_score(core_acc / 100, spur_acc / 100)
                print('Model: {:<50}, Core Acc: {:.2f}, Spur Acc: {:.2f}, RCA: {:.2f}'.format(
                    model_name, core_acc, spur_acc, rca
                ))
                

def reprocess_rival10_results():
    ''' here we just process the raw rival10 results into a new dict matching
        the rcs results, so that our plotting functionality works symmetrically'''

    linf_eps = []#[0.5, 1.0, 2.0, 4.0, 8.0]
    l2_eps = [0,3]#[0, 0.25, 0.5, 1, 3, 5]
    # arches = ['resnet18', 'resnet50']
    arches = ['wide_resnet50_2']
    arches = ['mobilenet', 'densenet', 'shufflenet', 'resnext50_32x4d', 'vgg16_bn']

    processed = dict()
    for arch in arches:
        for norm, epsilons in zip(['l2', 'linf'], [l2_eps, linf_eps]):
            for eps in epsilons:
                mkey = f'robust_{arch}_{norm}_eps{eps}'
                # with open(f'../local_dcr/kiarash_noise_no_norm_in_eval_norm_in_training/{mkey}_linf.pkl', 'rb') as f:
                # with open(f'../local_dcr/kiarash_noise/{mkey}_linf.pkl', 'rb') as f:
                with open(f'../local_dcr/kiarash_noise_no_norm_train_nor_eval/{mkey}_linf.pkl', 'rb') as f:
                    raw = pickle.load(f)

                core, spur = [np.average(raw[0.25][f'noisy_{x}g_accs']) for x in ['b', 'f']]
                rfs = rel_score(core, spur)
                processed[mkey] = dict({'core': core, 'spur': spur, 'rca': rfs})
                print('Model : {:<20}, RFS: {:>3f}, Core: {:.2f}, Spur: {:.2f}'.format(mkey, rfs, core, spur))
    # print(processed)
    with open('./l2_vs_linf/rival10_processed_results2.pkl', 'wb') as f:
        pickle.dump(processed, f)

reprocess_rival10_results()
# obtain_results()
# plot_results('rcs')
# with open('./l2_vs_linf/results.pkl', 'rb') as f:
#     results = pickle.load(f)
# for k in results:
#     core_acc, spur_acc, rcs = [results[k][x] for x in ['core', 'spur', 'rca']]
#     print('Model: {:<50}, Core Acc: {:.2f}, Spur Acc: {:.2f}, RCS: {:.2f}'.format(
#                     k, core_acc, spur_acc, rcs
#                 ))
# plot_results('rfs')
