import os
import argparse

from tqdm import tqdm

import numpy as np

import torch
import torch.nn as nn

import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, CIFAR100, SVHN, ImageFolder

from utils import pgd, logits_accuracy, setup_seed
from models import ResNet18, PreActResNet18, WideResNet_28_10

if torch.cuda.is_available():
    cudnn.benchmark = True
    print(torch.cuda.get_device_name(torch.cuda.current_device()))

parser = argparse.ArgumentParser()
parser.add_argument("--root", type=str, default="./")
parser.add_argument("--folder", type=str, default="")
parser.add_argument("-m", "--model", type=str, default="ResNet18", choices=["PreActResNet18", "ResNet18", "WideResNet-28-10"])
parser.add_argument("-w", "--weight", type=str, default="")
parser.add_argument("-d", "--dataset", type=str, default="CIFAR10", choices=["CIFAR10", "CIFAR100", "SVHN", "Tiny-ImageNet"])
parser.add_argument("-a", "--attacker", choices=["CLEAN", "FGSM", "IFGSM", "PGD", "AA"], default="CLEAN", type=str)
parser.add_argument("--eps", default=8, type=int)
parser.add_argument("--num_class", default=10, type=int)
parser.add_argument("--cuda", action="store_true", default=False)
parser.add_argument("--seed", default=0, type=int)
args = parser.parse_args()

batch_size = 128
args.eps = args.eps/255

device = torch.device("cuda" if args.cuda and torch.cuda.is_available() else "cpu")
setup_seed(args.seed)

if args.model == "ResNet18":
    net = ResNet18(args.num_class)
elif args.model == "PreActResNet18":
    net = PreActResNet18(args.num_class, None)
elif args.model == "WideResNet-28-10":
    net = WideResNet_28_10(args.num_class)
else:
    net = ResNet18(args.num_class)

net = net.to(device)
net.load_state_dict(torch.load(os.path.join(args.root, f"{args.folder}/{args.dataset}_{args.model}/{args.weight}")))
net = net.eval()

transform = transforms.Compose([
    transforms.ToTensor(),
])

data_path = os.path.join(args.root, "data")
if args.dataset == "CIFAR10":
    dataset = CIFAR10(root=data_path, train=False, download=False, transform=transform)
elif args.dataset == "CIFAR100":
    dataset = CIFAR100(root=data_path, train=False, download=False, transform=transform)
elif args.dataset == "SVHN":
    dataset = SVHN(root=data_path, split="test", download=False, transform=transform)
elif args.dataset == "Tiny-ImageNet":
    dataset = ImageFolder(root="/content/tiny-imagenet-200/val", transform=transform)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)


# if args.normalize:
#     normalize = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
# else:
#     normalize = None

from Attacker import AutoAttack

adversary = AutoAttack(net, norm="Linf", eps=8/255, version="standard", verbose=False)
adversary.attacks_to_run = ['apgd-ce', 'apgd-t']
# adversary.apgd.n_restarts = 2
# adversary.fab.n_restarts = 2

def PGD_20_10(model, images, labels):
    return pgd(model, images, labels, steps=20, step_size=1/255, epsilon=args.eps, restarts=10)
def PGD10(model, images, labels):
    return pgd(model, images, labels, steps=10, step_size=2/255, epsilon=args.eps, restarts=1)
def FGSM(model, images, labels):
    return pgd(model, images, labels, steps=1, step_size=args.eps, epsilon=args.eps, restarts=1)

Acc = 0
with torch.no_grad():
    with tqdm(enumerate(dataloader), total=dataloader.__len__(), desc="Eval") as t:
        for i, (images, labels) in t:
            images, labels = images.to(device), labels.to(device)
            if args.attacker == "AA":
                images = adversary.run_standard_evaluation(images, labels)
            elif args.attacker == "PGD":
                images = images + PGD_20_10(net, images, labels)
            elif args.attacker == "IFGSM":
                images = images + PGD10(net, images, labels)
            elif args.attacker == "FGSM":
                images = images + FGSM(net, images, labels)
            else:
                pass
            Acc += logits_accuracy(net(images).detach(), labels)
            t.set_postfix(Acc=f"{Acc/(i+1):4.2%}")
print(f"Robust: {Acc/(i+1):4.2%}")
