import os 
import argparse
import json
import torch
import numpy as np
import random
import pprint 
from utils.loader import load_model, robustbench_weight_download, attack_loader,  data_loader
from utils.compute_accuracy import compute_class_acc, benchmark
from utils.extract_shap import compute_shap
# from robustbench.eval import benchmark
from torchvision import transforms

PROJECT = os.getcwd() # Project directory
''' Robustbench Model
CIFAR10, Linf, eps=8/255
[WRN28-10] Wang2020Improving(ICLR20;mart), Wu2020Adversarial_extra(NIPS20;awp), Wang2023Better_WRN-28-10(ICML23), Xu2023Exploring_WRN-28-10(ICLR23),
           Pang2022Robustness_WRN28_10(ICML22), Rade2021Helper_ddpm(ICLR22), Sridhar2021Robust(ACC22), Zhang2020Geometry(ICLR21), 
           Gowal2021Improving_28_10_ddpm_100m(NIPS21), Sehwag2020Hydra(NIPS20), Carmon2019Unlabeled(NIPS19),  
           
[WRN34-10] Zhang2019Theoretically(ICLR19;trades), Rade2021Helper_extra(ICLR22;hat), Chen2024Data_WRN_34_10(Pattern Recognition 24), 
           Sehwag2021Proxy(ICLR22), Addepalli2021Towards_WRN34(ECCV22), Addepalli2022Efficient_WRN_34_10(NIPS22)
           Cui2020Learnable_34_10(ICCV21), Zhang2020Attacks(ICML20), Huang2020Self(NIPS20), Wu2020Adversarial(NIPS20), Zhang2019You(NIPS19)
           
CIFAR100
[WRN28-10] Cui2023Decoupled_WRN-28-10(arXiv23), Wang2023Better_WRN-28-10(ICML23), Rebuffi2021Fixing_28_10_cutmix_ddpm(arXiv21), Pang2022Robustness_WRN28_10(ICML22), 
           [x]Hendrycks2019Using(ICML19)
[WRN34-10] Cui2023Decoupled_WRN-34-10_autoaug(arXiv23), Addepalli2022Efficient_WRN_34_10(NIPS22), Cui2023Decoupled_WRN-34-10(arXiv23)
           Cui2020Learnable_34_10_LBGAT9_eps_8_255(ICCV21), Sehwag2021Proxy(ICLR22), Jia2022LAS-AT_34_10(arXiv22), Chen2021LTD_WRN34_10(arXiv21), Addepalli2021Towards_WRN34(ECCV22), Cui2020Learnable_34_10_LBGAT6(ICCV21)

ImageNet, Linf, eps=4/255
[ResNet50] Salman2020Do_R50, Engstrom2019Robustness, Wong2020Fast
[Trans]    Liu2023Comprehensive_Swin-L
           Liu2023Comprehensive_Swin-B
           Mo2022When_Swin-B
           Mo2022When_ViT-B
'''
def parse_arguments():
    parser = argparse.ArgumentParser(description='Arguments for Robustness')
    parser.add_argument('--network', default='resnet50', type=str, help='Cifar10/100: [wrn28-10, wrn34-10], ImageNet: [resnet50, trans]')
    parser.add_argument('--method', default='Salman2020Do_R50', type=str, help='baseline, mixup, cutmix, robustbench method names ...')
    parser.add_argument('--threat_model', default='Linf', type=str, help='None, Linf, L2 ...(robustbench weight)')
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--steps', default=10, type=int, help='adv. steps')
    parser.add_argument('--seed', default=0, type=int, help='random seed')

    parser.add_argument('--dataset', type=str, default='imagenet', help='cifar10, cifar100, imagenet')
    parser.add_argument('--dataset_split', type=str, default='val', help='imagenet = [train, val]')
    parser.add_argument('--mode', type=str, default='robust_acc,attack_acc,shap', help='robust_acc, attack_acc, shap')
    args = parser.parse_args()
    return args

def initial_setting(args):
    # mode setting
    args.mode = args.mode.replace(' ', '').split(',')
    # Result save directory
    args.result_dir = f'{PROJECT}/results'
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Dataset path, num of classes, eps
    if args.dataset == 'imagenet':
        args.data_root = f'/local_datasets/ILSVRC2012/'
        if os.path.exists(args.data_root) == False:
            args.data_root = f'/data2/local_dataets/ILSVRC2012/'
        args.num_classes = 1000
        args.eps = 4/255
        args.n_ex = 5000
    elif args.dataset in ['cifar10', 'cifar100']:
        args.data_root = f'/local_datasets/'
        if os.path.exists(args.data_root) == False:
            args.data_root = f'/data2/local_datasets/'
        args.eps = 8/255
        args.n_ex = 10000 # number of cifar validation set
        if args.dataset == 'cifar10':
            args.num_classes = 10
        elif args.dataset == 'cifar100':
            args.num_classes = 100

    # Model weight path
    args.pth_dir = f'{PROJECT}/pretrained_weights/{args.network}'
    if args.method in ['baseline', 'mixup', 'cutmix']:
        args.weight_path = f'{args.pth_dir}/{args.dataset}/Aug/{args.method}_model_best.pth.tar'
    else:
        if not os.path.exists(os.path.join(args.pth_dir, args.dataset, args.threat_model, f'{args.method}.pt')):
            robustbench_weight_download(args.method, args.network, args.dataset, args.threat_model, args.pth_dir)
        args.weight_path = f'{args.pth_dir}/{args.dataset}/{args.threat_model}/{args.method}.pt'
    assert os.path.exists(args.weight_path), f'Not Found {args.weight_path}'
    return args

def main(args):
    # Initial setting: dataset path, num of classes etc.
    args = initial_setting(args)
    print('Arguments: \n', pprint.pformat(vars(args)))

    # Load model
    model = load_model(model_name=args.method, model_dir=args.pth_dir, 
                       dataset=args.dataset, threat_model=args.threat_model,
                       num_classes=args.num_classes)
    model = model.cuda()
    model.eval()
    device = torch.device(args.device)

    if args.dataset == 'imagenet':
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor()
            ])
    elif args.dataset in ['cifar10', 'cifar100']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            ])
    
    if 'robust_acc' in args.mode:
        # Robustbench benchmark: clean acc, robust acc(=auto attack)
        robust_acc_dir = f'{args.result_dir}/class_acc/{args.network}/{args.dataset}/{args.method}_{args.dataset_split}'
        robust_acc_file = f'{robust_acc_dir}/robustbench_acc.json'
        if os.path.exists(robust_acc_dir) == False:
            os.makedirs(robust_acc_dir, exist_ok=True)
        
        if not os.path.exists(robust_acc_file):
            clean_acc, robust_acc = benchmark(model, model_name=args.method, n_examples=args.n_ex, dataset=args.dataset,
                                threat_model=args.threat_model, eps=args.eps, device=device, to_disk=True, 
                                data_dir=args.data_root, preprocessing=transform)
            robust_acc_result = {'clean': clean_acc, 'robust': robust_acc}
            print(f'Clean acc: {clean_acc}, Robust acc: {robust_acc}')
            with open(robust_acc_file, 'w') as f:
                json.dump(robust_acc_result, f, indent=4)
    
    if 'attack_acc' in args.mode:
        if args.dataset == 'imagenet':
            args.data_root = f'{args.data_root}/{args.dataset_split}'
        for attack in ['no', 'fgsm', 'pgd', 'cw']:
            args.attack = attack
            print(f'Attack: {args.attack}')
            if args.attack == 'no':
                ### Prepare Datasets
                _, loader = data_loader(args.dataset, args.data_root, batch_size=args.batch_size, split=args.dataset_split, mode=args.method, network=args.network, att=args.attack)
                # Shap loader
                _, shap_loader = data_loader(args.dataset, args.data_root, batch_size=1, split=args.dataset_split, mode=args.method, network=args.network, att=args.attack)
    
                att_loader = None
                
            elif args.attack != 'no':
                ### Prepare Datasets
                _, loader = data_loader(args.dataset, args.data_root, batch_size=args.batch_size, split=args.dataset_split, mode=args.method, network=args.network, att=args.attack)
                
                att_loader = attack_loader(args, model)

            # 1. Model Accuracy Evaluation - attack (no, fgsm, pgd, cw, bim)
            args.acc_dir = f'{args.result_dir}/class_acc/{args.network}/{args.dataset}/{args.method}_{args.dataset_split}'
            if os.path.exists(args.acc_dir) == False:
                os.makedirs(args.acc_dir, exist_ok=True)
            args.acc_file = f'{args.acc_dir}/attack_{args.attack}.json'
            if not os.path.exists(args.acc_file):
                acc = compute_class_acc(model=model, attack = args.attack, save_path=args.acc_file, nat_loader=loader, att_loader=att_loader, eps=args.eps, num_classes=args.num_classes)
                print(f'[Done] 1. Model Accuracy Evaluation - Save {args.acc_file}')
            else:
                print(f'[Done] 1. Model Accuracy Evaluation - Load {args.acc_file}')

            if 'shap' in args.mode:
                # 2. Extract Shap - before attack / after attack
                if attack == 'no':
                    args.shap_save_dir = f'{args.result_dir}/shap/{args.network}/{args.dataset}/{args.method}_{args.dataset_split}'
                    if os.path.exists(args.shap_save_dir) == False:
                        os.makedirs(args.shap_save_dir, exist_ok=True)
                    args.shap_save_file = f'{args.shap_save_dir}/attack_{args.attack}.pkl'
                    if not os.path.exists(args.shap_save_file):
                        compute_shap(model=model, network=args.network, attack=args.attack, save_path=args.shap_save_file, nat_loader=shap_loader, att_loader=att_loader, eps=args.eps, num_classes=args.num_classes, batch_size=1)
                        print(f'[Done] 2. Extract Shap - Save {args.shap_save_file}')
                    else:
                        print(f'[Done] 2. Extract Shap - Load {args.shap_save_file}')
    
    print('Done')

if __name__ == '__main__': 
    args = parse_arguments()
    
    main(args)