from collections import OrderedDict
import os
import argparse
import pickle
import torch
import torch.nn as nn
import torch.utils.data as data
import autoattack
import torchvision.datasets as datasets
import torchvision.transforms as transforms 
from torch.nn import Conv2d, Linear
from models.layers import SubnetConv, SubnetLinear
from models.resnet_cifar import resnet18
from models.resnet_cifar_silu import resnet18_silu
from utils.utils import measure_model_sparsity
from models.dense_resnet_cifar import resnet18_dense
from models.dense_resnet_silu_cifar import resnet18_silu_dense
from models.wrn_cifar import wrn_28_4
from models.dense_wrn_cifar import wrn_28_4_dense
from models.vgg_cifar import vgg16_bn
from models.vgg_cifar_dense import vgg16_bn_dense


parser = argparse.ArgumentParser(description='PyTorch CIFAR PGD Attack Evaluation')
parser.add_argument('--batch-size', type=int, default=500, metavar='N',
                    help='input batch size for testing (default: 200)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--epsilon', default=0.031,
                    help='perturbation')
parser.add_argument('--model-path',
                    default='./checkpoints/model_cifar_wrn.pt',
                    help='model for white-box attack evaluation')
parser.add_argument("--dataset", type=str, choices=["CIFAR10", "SVHN", "CIFAR100"], default="CIFAR10")
parser.add_argument("--model", type=str, choices=["resnet18", "resnet18_silu", "resnet18_dense", "resnet18_silu_dense", "wrn_28_4", "wrn_28_4_dense", "vgg16_bn", "vgg16_bn_dense"], default="resnet18")
parser.add_argument("--prune_reg", type=str, choices=["channel", "weight"], default="weight")
parser.add_argument('--log-path', type=str, default='./log_file.txt')
parser.add_argument('--version', type=str, default='standard')
parser.add_argument('--gpu', type=str, default='cuda:0')
parser.add_argument('--task_mode', type=str, default='pretrain')
parser.add_argument(
        "--k",
        type=float,
        default=1.0,
        help="Fraction of weight variables kept in subnet",
    )


def set_prune_rate_model(model, device):
    for block_n, v in model.named_modules():
        if hasattr(v, "set_prune_rate"):
            v.set_prune_rate(args.k, args.k, 0.1, device)


# parse arguments
args = parser.parse_args()
# settings
device = torch.device(f"cuda:{args.gpu}")

#attack settings 
norm = 'Linf' 
version = 'standard'
epsilon = 0.031
batch_size = 500
n_ex = 10000

# apply transform
transform_list = [transforms.ToTensor()]
transform_chain = transforms.Compose(transform_list)

#data loading
if args.dataset == 'CIFAR10':
    data_dir = "/.../shared_data/CIFAR10"
    item = datasets.CIFAR10(root=data_dir, train=False, transform=transform_chain, download=False)
    mean=[0.4914, 0.4822, 0.4465] 
    std=[0.2023, 0.1994, 0.2010]

#data loading
elif args.dataset == 'SVHN':
    data_dir = "/.../shared_datasets/SVHN"
    item = datasets.SVHN(root=data_dir, split='test', transform=transform_chain, download=False)
    mean=[0.43090966, 0.4302428, 0.44634357]
    std=[0.19759192, 0.20029082, 0.19811132]

#create testset
test_loader = data.DataLoader(item, batch_size=batch_size, shuffle=False, num_workers=0)
l = [x for (x, y) in test_loader]
x_test = torch.cat(l, 0)
l = [y for (x, y) in test_loader]
y_test = torch.cat(l, 0)

#import model
model_checkpoint = args.model_path
ckpt = torch.load(model_checkpoint, map_location=device) 

# define model and load checkpoint
if args.model == "resnet18":
    model = resnet18(SubnetConv, SubnetLinear, init_type='kaiming_normal', mean=mean, std=std, prune_reg=args.prune_reg, task_mode=args.task_mode, normalize=False)
    model.load_state_dict(ckpt['state_dict'], strict=True) 
    if "finetune" in args.task_mode:
        set_prune_rate_model(model, device)
elif args.model == "resnet18_silu":
    model = resnet18_silu(SubnetConv, SubnetLinear, init_type='kaiming_normal', mean=mean, std=std, prune_reg=args.prune_reg, task_mode=args.task_mode, normalize=False)
    model.load_state_dict(ckpt['state_dict'], strict=True)
elif args.model == "resnet18_dense":
    model = resnet18_dense(conv_layer=Conv2d, linear_layer=Linear, init_type="kaiming_normal", num_classes=10)
    model.load_state_dict(ckpt['state_dict'], strict=False)
elif args.model == "resnet18_silu_dense":
    model = resnet18_silu_dense(conv_layer=Conv2d, linear_layer=Linear, init_type="kaiming_normal", num_classes=10)
    model.load_state_dict(ckpt['state_dict'], strict=False)
elif args.model == "vgg16_bn":
    model = vgg16_bn(SubnetConv, SubnetLinear, init_type='kaiming_normal', mean=mean, std=std, prune_reg=args.prune_reg, task_mode=args.task_mode, normalize=False)
    model.load_state_dict(ckpt['state_dict'], strict=True)
elif args.model == "vgg16_bn_dense":
    model = vgg16_bn_dense(conv_layer=Conv2d, linear_layer=Linear, init_type="kaiming_normal", num_classes=10)
    model.load_state_dict(ckpt['state_dict'], strict=False)
# define model and load checkpoint
elif args.model == "wrn_28_4":
    model = wrn_28_4(SubnetConv, SubnetLinear, init_type='kaiming_normal', mean=mean, std=std, prune_reg=args.prune_reg, task_mode=args.task_mode, normalize=False)
    model.load_state_dict(ckpt['state_dict'], strict=True) 
    if "finetune" in args.task_mode:
        set_prune_rate_model(model, device)
elif args.model == "wrn_28_4_dense":
    model = wrn_28_4_dense(conv_layer=Conv2d, linear_layer=Linear, init_type="kaiming_normal", num_classes=10, mean=mean, std=std)
    model.load_state_dict(ckpt['state_dict'], strict=True) 

# load ckpt
model.to(device)
model.eval()

print(measure_model_sparsity(model))
# print model path to recognize model in output
print(args.model_path)
adversary = autoattack.AutoAttack(model, norm=norm, eps=epsilon, version=version, device=device)
adversary.run_standard_evaluation(x_test[:n_ex].to(device), y_test[:n_ex].to(device), bs=batch_size) 
