import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
from PIL import Image
import os
from tqdm.auto import tqdm
import numpy as np
from torchvision.utils import make_grid
from matplotlib import pyplot as plt
from matplotlib import image as mpimg
import sys
import torch
import torchvision
from torchvision import datasets, models, transforms
from os import path
from matplotlib import colors
import sys
from pathlib import Path
import time
import gc
import torchvision.transforms.functional as F
from perceptual_advex.perceptual_attacks import *
sys.path.append('../')
from utils import *

image_id_list,label_tar_list,label_true=load_ground_truth('../dataset/images.csv')
# specify the device    
device  = torch.device("cuda")
# load the pre-trained model
model = models.resnet50(pretrained=True).eval()

for param in model.parameters():
    param.requires_grad=False
model = model.to(device)


# attack = PerceptualPGDAttack(
#     model,
#     num_iterations=20,
#     # The LPIPS distance bound on the adversarial examples.
#     bound=1.0,
#     # The model to use for calculate LPIPS; here we use AlexNet.
#     # You can also use 'self' to perform a self-bounded attack.
#     lpips_model='alexnet',
# )

attack = LagrangePerceptualAttack(model, lpips_model='alexnet')

total_time = 0.0
batch_size = 32
num_images_run = 1000
num_batches = np.int(np.ceil(num_images_run/batch_size))

output_folder = "./results_resnet/LPA/"
Path(output_folder).mkdir(parents=True, exist_ok=True)
success = 0
tic = time.perf_counter()
for k in range(num_batches):
    batch_size_cur=min(batch_size,len(image_id_list)-k*batch_size)

    X_ori = torch.zeros(batch_size_cur,3,299,299).to(device)   

    for i in range(batch_size_cur):  
        X_ori[i]=trn(Image.open(os.path.join('../dataset/images',image_id_list[k*batch_size+i])+'.png')) 
        
    label=torch.tensor(label_true[k*batch_size:k*batch_size+batch_size_cur]).to(device)
    
    adv_inputs = attack(X_ori, label)
    
    for i in range(batch_size_cur):    
        x_np=transforms.ToPILImage()(adv_inputs[i].detach().cpu())
        x_np.save(os.path.join(output_folder,image_id_list[k*batch_size+i])+'.png') 

toc = time.perf_counter()
print(f"Attack in {toc - tic:0.4f} seconds")