import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
from PIL import Image
import os
from os import path
from tqdm.auto import tqdm
import numpy as np
import time
from matplotlib import pyplot as plt
import sys
from pathlib import Path
from utils import *
from attacks import *

image_id_list, label_tar_list, label_true=load_ground_truth('dataset/images.csv')   
device  = torch.device("cuda:0")
# load the pre-trained model
model = models.resnet50(pretrained=True).eval()
for param in model.parameters():
    param.requires_grad=False
model.to(device)


untargeted=1
batch_size=32
num_images = 1000
num_batches = np.int(np.ceil(num_images/batch_size))

l2_norm=[]
l_inf_norm=[]

tic = time.perf_counter()
criterion = nn.CrossEntropyLoss().to(device)  
step_size = 0.1
output_folder = "./results_resnet/PGD/"
Path(output_folder).mkdir(parents=True, exist_ok=True)
for k in range(num_batches):
    batch_size_cur=min(batch_size,len(image_id_list)-k*batch_size)
    #load a batch of input images
    X_ori = torch.zeros(batch_size_cur,3,299,299).to(device)   

    for i in range(batch_size_cur):  
        X_ori[i]=trn(Image.open(os.path.join('dataset/images',image_id_list[k*batch_size+i])+'.png'))  
    labels=torch.tensor(label_true[k*batch_size:k*batch_size+batch_size_cur]).to(device)

    X_adv = PGD_attack(model, criterion, step_size, X_ori, labels)

    #save the attack images
    for j in range(batch_size_cur):
        x_np=transforms.ToPILImage()(X_adv[j].detach().cpu())
        x_np.save(os.path.join(output_folder,image_id_list[k*batch_size+j])+'.png') 
toc = time.perf_counter()
print(f"PGD in {toc - tic:0.4f} seconds")