import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
from PIL import Image
import os
from os import path
from tqdm.auto import tqdm
import csv
import numpy as np
from differential_color_functions import rgb2lab_diff, ciede2000_diff
from perc_al import PerC_AL
import time
import differential_color_functions
from matplotlib import pyplot as plt
import sys
from pathlib import Path
from utils import *

# simple Module to normalize an image
class Normalize(nn.Module):
    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        self.mean = torch.Tensor(mean)
        self.std = torch.Tensor(std)
    def forward(self, x):
        return (x - self.mean.type_as(x)[None,:,None,None]) / self.std.type_as(x)[None,:,None,None]

# fix the random seed of pytorch and make cudnn deterministic for reproducing the same results
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True

#load image list
image_id_list, label_tar_list, label_true=load_ground_truth('dataset/images.csv')
# specify the device    
device  = torch.device("cuda")
# load the pre-trained model
# model = models.inception_v3(pretrained=True).eval()
model = models.resnet50(pretrained=True).eval()

for param in model.parameters():
    param.requires_grad=False
model.to(device)


#set the mode (untargeted or targeted)
untargeted=1
batch_size=32
num_images = 1000
num_batches = np.int(np.ceil(num_images/batch_size))
print(num_batches)
color_differences=[]
output_folder = "./results_resnet/images_adv_PerC_AL/"
Path(output_folder).mkdir(parents=True, exist_ok=True)

tic = time.perf_counter()
for k in range(0,num_batches):
    batch_size_cur=min(batch_size,len(image_id_list)-k*batch_size)
    #load a batch of input images with the size of batch_size*channel*height*width
    X_ori = torch.zeros(batch_size_cur,3,299,299).to(device)   
    
    for i in range(batch_size_cur):  
        X_ori[i]=trn(Image.open(os.path.join('dataset/images',image_id_list[k*batch_size+i])+'.png'))  
    X_ori_LAB=rgb2lab_diff(X_ori,device)
    if untargeted:
        labels = torch.tensor(label_true[k*batch_size:k*batch_size+batch_size_cur]).to(device)
    
    else:
        labels=torch.tensor(label_tar_list[k*batch_size:k*batch_size+batch_size_cur]).to(device)

    approach = PerC_AL(device=device,max_iterations=300,alpha_l_init=1,alpha_c_init=0.5,confidence=20)
    X_adv = approach.adversary(model, X_ori, labels=labels, targeted=False)
   
    #save the modified images
    for j in range(batch_size_cur):
#         torch.save(X_adv[j].detach(), os.path.join(output_folder,image_id_list[k*batch_size+j])+'.pt')
        x_np=transforms.ToPILImage()(X_adv[j].detach().cpu())
        x_np.save(os.path.join(output_folder,image_id_list[k*batch_size+j])+'.png') 
toc = time.perf_counter()
print(f"PerC in {toc - tic:0.4f} seconds")