import argparse
import torch
import numpy as np

from src.evaluation import test_clean, test_adv, test_transfer_adv, test_simba
from src.utils_dataset import load_dataset, load_svhn
from src.utils_general import seed_everything, get_model
from src.attacks import pgd_rand
from src.model import ResNet10


parser = argparse.ArgumentParser()
parser.add_argument("--method")
parser.add_argument("--dataset")
parser.add_argument("--pgd_step", type=int)
parser.add_argument("--path")
parser.add_argument("--transfer_path")
args = parser.parse_args()

seed_everything(1)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = ResNet10()
model.load_state_dict(torch.load(args.path, map_location=device))
model.to(device)


if args.method == "standard":
	if args.dataset == "svhn":
		train_loader, test_loader = load_svhn(128)
	else:
		train_loader, test_loader = load_dataset(128)
	test_log = test_clean(test_loader, model, device)
	print("Accuracy: {0:.2f}%".format(test_log[0]))
	
elif args.method == "pgd":
	if args.dataset == "svhn":
		train_loader, test_loader = load_svhn(128)
	else:
		train_loader, test_loader = load_dataset(128)
	attack_param = {"ord":np.inf, "epsilon": 8./255., "alpha":2./255., "num_iter": args.pgd_step, "restart": 1}
	adv_log = test_adv(test_loader, model, pgd_rand, attack_param, device)
	print("Accuracy: {0:.2f}%".format(adv_log[0]))
	
elif args.method == "transfer":
	if args.dataset == "svhn":
		train_loader, test_loader = load_svhn(128)
	else:
		train_loader, test_loader = load_dataset(128)
	transfer_model = ResNet10()
	transfer_model.load_state_dict(torch.load(args.transfer_path, map_location=device))
	transfer_model.to(device)
	attack_param = {"ord":np.inf, "epsilon": 8./255., "alpha":2./255., "num_iter": args.pgd_step, "restart": 1}
	adv_log = test_transfer_adv(test_loader, transfer_model, model, pgd_rand, attack_param, device)
	print("Accuracy: {0:.2f}%".format(adv_log[0]))
	
elif args.method == "simba":
	if args.dataset == "svhn":
		train_loader, test_loader = load_svhn(1)
	else:
		train_loader, test_loader = load_dataset(1)
	adv_log = test_simba(test_loader, model, device)
	print("Accuracy: {0:.2f}%".format(adv_log[0]))