import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
from PIL import Image
import os
from tqdm.auto import tqdm
import numpy as np
from torchvision.utils import make_grid
from matplotlib import pyplot as plt
from matplotlib import image as mpimg
import sys
import torch
import torchvision
from torchvision import datasets, models, transforms
from os import path
from matplotlib import colors
import sys
from pathlib import Path
import time
import gc
import torchvision.transforms.functional as F
from utils import *
from attacks import *

image_id_list,label_tar_list,label_true=load_ground_truth('dataset/images.csv')  
device  = torch.device("cuda:0")
model = models.resnet50(pretrained=True).eval()
for param in model.parameters():
    param.requires_grad=False
model.to(device)

success_attack =0
total_time = 0.0
batch_size = 32
num_images_run = 1000
num_batches = np.int(np.ceil(num_images_run/batch_size))
criterion = nn.CrossEntropyLoss().to(device)  

# ablation study in the paper
stepsize_list = [0.05,0.2]
conf_list = [1,5]
lambda_list = [0.5,2.0]

# default parameters
stepsize = 0.1
confidence=0
lam= 1

# if run multiple paramters with ablation study, add more for loops here
for stepsize in stepsize_list:
    output_folder = "results_resnet/results_PieGD_step"+str(stepsize)+"_noconf_lam"+str(lam)+"/"        
    Path(output_folder).mkdir(parents=True, exist_ok=True)

    tic = time.time()
    for k in range(num_batches):
        if os.path.exists(os.path.join(output_folder,image_id_list[k*batch_size])+'.png'):
            continue
        batch_size_cur=min(batch_size,len(image_id_list)-k*batch_size)
        # size batch_size*channel*height*width
        X_ori = torch.zeros(batch_size_cur,3,299,299).to(device)   

        for i in range(batch_size_cur):  
            X_ori[i]=trn(Image.open(os.path.join('dataset/images',image_id_list[k*batch_size+i])+'.png'))  

        label_tar=torch.tensor(label_true[k*batch_size:k*batch_size+batch_size_cur]).to(device)
        attack = PieAPP_GD_attack_batch(model, criterion, stepsize, X_ori, label_tar, output_folder, image_id_list[k]+".png", conf=confidence,lam=lam)
        for i in range(batch_size_cur):    
            x_np=transforms.ToPILImage()(attack[i].detach().cpu())
            x_np.save(os.path.join(output_folder,image_id_list[k*batch_size+i])+'.png') 

    toc = time.time()
    print(f"Attack in {toc - tic:0.4f} seconds")