
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

from resnet18_32x32 import ResNet18_32x32

import numpy as np
import random
import os
import math

from autoattack import AutoAttack




def main():

    seed = 100
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'


    # load model and images
    model = ResNet18_32x32().to(device)
    model.load_state_dict(torch.load('./weights/resnet18_9554.pth'))
    model.to(device)
    model.eval()

    nature_samples = Variable(torch.from_numpy(np.load('./AEs/clean_inputs.npy')).to(device), requires_grad=True)
    labels_samples = Variable(torch.LongTensor(np.load('./AEs/clean_labels.npy')).to(device), requires_grad=False)
    labels_samples = labels_samples.max(1)[1]


    pred_labels = model(nature_samples)
    pred_labels = torch.max(pred_labels, 1)[1]

    acc = (pred_labels == labels_samples).sum() / len(labels_samples)
    print(f"Model Accuracy is {acc.item()*100:.2f}")

    epsilon = 8/255.
    adversary = AutoAttack(model, norm='Linf', eps=epsilon, version='standard', device=device)

    # adv_samples, adv_labels = adversary.run_standard_evaluation(nature_samples, labels_samples, bs=64, return_labels=True)
    adversary.attacks_to_run = ['fab-t'] # apgd-ce, apgd-t, fab-t, square
    adv_dict = adversary.run_standard_evaluation_individual(nature_samples, labels_samples, bs=128, return_labels=True)
    for k,v in adv_dict.items():
        # print(k)
        adv_samples, adv_labels = v

        np.save(f'./AEs/raw/{k.upper()}_AdvSamples.npy', adv_samples.detach().cpu().numpy())
        np.save(f'./AEs/raw/{k.upper()}_AdvLabels.npy', adv_labels.detach().cpu().numpy())

        print(f"For {k.upper()}, Attack success rate is {(labels_samples != adv_labels).sum() / len(adv_labels) * 100:.2f}")

if __name__ == "__main__":

    main()