from in9_eval import *
import os
import pickle
from tqdm import tqdm
import argparse
from torchvision import models
'''
We hypothesize that robust models have greater background sensitivity.
We will now leverage IN-9 to test this hypothesis.

We will focus on Original, BG-T, Mixed-Same, and Mixed-Rand.

Other result: for the IN-9 challenge (adversarial backgrounds),
Standard resnet50 was vulnerable for 67.33% (2727/4050) foregrounds. 

robust_resnet50_l2_eps1 was vulnerable for 82.07% (3324/4050) foregrounds.
robust_resnet50_l2_eps3 was vulnerable for 87.73% (3553/4050) foregrounds.
robust_resnet50_l2_eps5 was vulnerable for 91.38% (3701/4050) foregrounds.

robust_resnet50_linf_eps8.0 was vulnerable for 88.35% (3578/4050) foregrounds.
robust_resnet50_linf_eps4.0 was vulnerable for 82.42% (3338/4050) foregrounds.
robust_resnet50_linf_eps2.0 was vulnerable for 79.83% (3233/4050) foregrounds.
'''
def load_cached_results(results_path='./results/model_eval.pkl'):
    if os.path.exists(results_path):
        with open(results_path, 'rb') as f:
            results = pickle.load(f)
    else:
        results = dict()
    return results

def cache_results(results_path, results):
    with open(results_path, 'wb') as f:
        pickle.dump(results, f)

def get_arch(mtype):
    add_custom_forward = True
    if 'wide' in mtype:
        arch = models.wide_resnet50_2() if '50' in mtype else models.wide_resnet101_2()
    elif 'mobilenet' in mtype:
        arch = models.mobilenet_v2()
    elif 'shufflenet' in mtype:
        arch = models.shufflenet_v2_x1_0()
    elif 'vgg' in mtype:
        arch = models.vgg16_bn()
    elif 'densenet' in mtype:
        arch = models.densenet161()
    elif 'resnext50' in mtype:
        arch=models.resnext50_32x4d()
    elif 'resnet' in mtype:
        arch = mtype.split('_')[0]
        add_custom_forward = False
    return arch, add_custom_forward

def eval_robust_models(mkeys, eval_dsets=['mixed_same', 'mixed_rand']):
    # uniform args
    args = argparse.ArgumentParser().parse_args()
    args.data_path ='/REDACTED/data/bg_challenge'
    args.in9 = False
    # args.arch = 'resnet50'
    root_path = '/REDACTED/dcr_models/pretrained-robust/'

    # load cached results
    results_path = './results/model_eval.pkl'
    results = load_cached_results(results_path)
    print(results.keys())
    
    for mkey in tqdm(mkeys):
        args.arch, args.add_custom_forward = get_arch(mkey)#mkey.split('_')[0]
        args.checkpoint = root_path + mkey
        if mkey not in results:
            results[mkey] = dict()
        print()
        for dset in eval_dsets:
            # if dset not in results[mkey]:
            args.eval_dataset = dset
            acc = main(args)
            results[mkey][dset] = acc 
            cache_results(results_path, results)
            print('Model: {:<20}.......Dset: {:<15}.......Acc: {:.3f}'.format(
                mkey, dset, results[mkey][dset]
            ))
       
if __name__ == "__main__":
    pass
    ### MAIN TEXT ANALYSIS: ResNet18s and ResNet50s
    # l2_models = [['resnet{}_l2_eps{}.ckpt'.format(arch, eps) for eps in [0, 0.25, 0.5, 1,3,5]] for arch in [18, 50]]
    # linf_models = [['resnet{}_linf_eps{}.ckpt'.format(arch, eps) for eps in [0.5, 1.0, 2.0, 4.0, 8.0]] for arch in [18, 50]]
    # eval_robust_models(l2_models[0]+linf_models[0]+l2_models[1]+linf_models[1])

    ### SUPPLEMENTARY: WideResNet50s
    # l2_models = ['wide_resnet50_2_l2_eps{}.ckpt'.format(eps) for eps in [0, 0.25, 0.5, 1,3,5]]
    # linf_models = ['wide_resnet50_2_linf_eps{}.ckpt'.format(eps) for eps in [0.5, 1.0, 2.0, 4.0, 8.0]]
    # eval_robust_models(l2_models+linf_models)

    ### SUPPLEMENTARY: Additional Backbones
    # backbones = ['shufflenet', 'mobilenet', 'vgg16_bn', 'resnext50_32x4d', 'densenet']
    # standard = [f'{x}_l2_eps0.ckpt' for x in backbones]
    # robust = [f'{x}_l2_eps3.ckpt' for x in backbones]
    # eval_robust_models(standard+robust)