#%%

import sys
sys.path.append('..')
import torch
import torchvision
from utils.data_process import ILSVRC2012_LOAD
from PIL import Image,ImageDraw
from utils.utils import pil2tensor, tensor2pil, to_cuda
from utils.utils import generate_path_im_and_mask, restore_patch_im
from utils.eot import EoT
from utils.attacker import adv_patch,EoD_patch
import random
import argparse
import torchvision.transforms as transforms
from utils.data_process import proxy_dataset_select, model_select
from utils.s_map import cal_patch_saliency,I_s_map
from utils.data_process import class_name_list #variable
import os
from utils.eval import fooling_rate_test
import numpy as np
import matplotlib.pyplot as plt




def EoD_patch_n_Saliency(patch_im, patch_im_mask, img, model, I_s, I_d, loss_fn_sal,
               loss_fn_dis,lr_s,lr_d,target):

    saliency_list = []

    img_mask = torch.abs(patch_im_mask - 1)
    #########################distraction of EoD#################################
    for i in range(I_d):
        d_im, d_im_mask, loss_list = adv_patch(img = patch_im, #the patch_im is the image now # untarget_
                                               patch_im = img,  #img is the patch you want to update
                                               patch_im_mask = img_mask,
                                               model=model,
                                               iteration=1,
                                               lr=lr_d,
                                               loss_seg=100,
                                               loss_fn=loss_fn_dis)
        img = d_im
        saliency = cal_patch_saliency(patch_im = patch_im , patch_im_mask = patch_im_mask,
                                      model = model,target= target,background=img)
        saliency_list.append(saliency)
    ##############################################################################


    ###########################saliency increase################################
    # print('saliency loop')
    for i in range(I_s):
        tem_patch_im, tem_patch_im_mask, loss_list = adv_patch(img = img,
                                                               patch_im = patch_im,
                                                               patch_im_mask=patch_im_mask,
                                                               model=model,
                                                               iteration=1,
                                                               lr=lr_s,
                                                               loss_fn=loss_fn_sal,
                                                               loss_seg=100)
        saliency = cal_patch_saliency(patch_im = tem_patch_im ,patch_im_mask = tem_patch_im_mask,
                                      model = model, target= target,background=img)
        saliency_list.append(saliency)
    ############################################################################

    return tem_patch_im, tem_patch_im_mask,saliency_list
#%%
#args = get_args()
target_random_seed = 0
model = 'vgg19'
diameter = 50
GPU_device = 2
proxy_data = 'uniform'

attack_iters = 80
lr_d = 1/255
I_d = 2
lr_s = 1/255
I_s = 10
restarts = 1
#%%
#load data
dataset_dict=ILSVRC2012_LOAD(root='/home/liujiawei/local_pycharm/Data/ImageNet/val_seg')
val_dataset=dataset_dict['val_dataset']
class_name_list=dataset_dict['class_name_list']

#select proxy data
dataset = proxy_dataset_select(proxy_data)
#select model
model = model_select(model)
_=model.eval()

#select
target=323


#%%
attack_iters = 400
target = 909







print('current target: ',class_name_list[target])
#initialize a patch of diameter args.diameter
patch = Image.new("RGB", (diameter, diameter),(0,0,0))
draw = ImageDraw.Draw(patch)
draw.ellipse(((0, 0), (diameter, diameter)), fill=(255,255,255), outline=None)
patch = pil2tensor(patch)
patch_mask = patch.clone()
patch = patch.uniform_(0.0, 1.0)*patch_mask

#transfer the model and patch to GPU
model = to_cuda(model, GPU_device)
patch = to_cuda(patch,GPU_device)
patch_mask = to_cuda(patch_mask, GPU_device)

#initilise eot
eot = EoT()

#initialise the objective functions for image backgrounds and patch
def loss_fn_sal(output):
    loss = -torch.log(torch.nn.functional.softmax(output, dim=1)[0][target])
    return loss
def loss_fn_dis(output):
    loss = torch.log(torch.nn.functional.softmax(output, dim=1)[0][target])
    return loss

saliency_list_EoD = []#a list built for recording the patch saliency

#patch generation stage
for i in range(attack_iters):
    if i%10 == 0:
        print(i)
    img = eot.img_select(dataset) #sample a image from proxy distribution
    img = to_cuda(img.unsqueeze(0), GPU_device)
    center_loc = eot.locate(patch_size=[100,100],img_size=[224,224])# sample a random location
    patch_im, patch_im_mask = generate_path_im_and_mask(patch = patch, patch_mask= patch_mask,
                                  image_size=img.shape[-3:],center_loc=center_loc)#generate mask of attached patch

    patch_im,patch_im_mask,sub_list = EoD_patch_n_Saliency(img = img, patch_im= patch_im,patch_im_mask= patch_im_mask,
                                            model = model, I_d = I_d ,I_s = I_s, loss_fn_dis=loss_fn_dis, loss_fn_sal = loss_fn_sal,
                                            lr_s = lr_s, lr_d = lr_d, target=target)# i-th iteration of patch generation

    patch = restore_patch_im(patch_im=patch_im, loc=center_loc, patch_side=diameter)#turn the patch_im into the patch

    saliency_list_EoD.extend(sub_list)

#%%
np.save(np.array(saliency_list_EoD),'EoD.npy')







#%%EoT+Uniform

#%%

#attack

print('current target: ',class_name_list[target])

#initialize a patch of diameter args.diameter
patch = Image.new("RGB", (diameter, diameter),(0,0,0))
draw = ImageDraw.Draw(patch)
draw.ellipse(((0, 0), (diameter, diameter)), fill=(255,255,255), outline=None)
patch = pil2tensor(patch)
patch_mask = patch.clone()
patch = patch.uniform_(0.0, 1.0)*patch_mask

#transfer the model and patch to GPU
model = to_cuda(model, GPU_device)
patch = to_cuda(patch,GPU_device)
patch_mask = to_cuda(patch_mask, GPU_device)

#initilise eot
eot = EoT()

#initialise the objective functions for image backgrounds and patch
def loss_fn_sal(output):
    loss = -torch.log(torch.nn.functional.softmax(output, dim=1)[0][target])
    return loss
def loss_fn_dis(output):
    loss = torch.log(torch.nn.functional.softmax(output, dim=1)[0][target])
    return loss

saliency_list_uniform = []#a list built for recording the patch saliency

#patch generation stage
for i in range(attack_iters):
    if i%10 == 0:
        print(i)
    img = eot.img_select(dataset) #sample a image from proxy distribution
    img = to_cuda(img.unsqueeze(0), GPU_device)
    center_loc = eot.locate(patch_size=[100,100],img_size=[224,224])# sample a random location
    patch_im, patch_im_mask = generate_path_im_and_mask(patch = patch, patch_mask= patch_mask,
                                  image_size=img.shape[-3:],center_loc=center_loc)#generate mask of attached patch

    patch_im,patch_im_mask,sub_list = EoD_patch_n_Saliency(img = img, patch_im= patch_im,patch_im_mask= patch_im_mask,
                                            model = model, I_d = I_d ,I_s = I_s, loss_fn_dis=loss_fn_dis, loss_fn_sal = loss_fn_sal,
                                            lr_s = lr_s, lr_d = 0.0, target=target)# i-th iteration of patch generation

    patch = restore_patch_im(patch_im=patch_im, loc=center_loc, patch_side=diameter)#turn the patch_im into the patch

    saliency_list_uniform.extend(sub_list)

np.save(np.array(saliency_list_uniform),'uniform.npy')







#%% ImageNet
print('current target: ',class_name_list[target])
#initialize a patch of diameter args.diameter
patch = Image.new("RGB", (diameter, diameter),(0,0,0))
draw = ImageDraw.Draw(patch)
draw.ellipse(((0, 0), (diameter, diameter)), fill=(255,255,255), outline=None)
patch = pil2tensor(patch)
patch_mask = patch.clone()
patch = patch.uniform_(0.0, 1.0)*patch_mask

#transfer the model and patch to GPU
model = to_cuda(model, GPU_device)
patch = to_cuda(patch,GPU_device)
patch_mask = to_cuda(patch_mask, GPU_device)

#initilise eot
eot = EoT()

#initialise the objective functions for image backgrounds and patch
def loss_fn_sal(output):
    loss = -torch.log(torch.nn.functional.softmax(output, dim=1)[0][target])
    return loss
def loss_fn_dis(output):
    loss = torch.log(torch.nn.functional.softmax(output, dim=1)[0][target])
    return loss

saliency_list_ImageNet = []#a list built for recording the patch saliency

#patch generation stage
for i in range(attack_iters):
    if i%10 == 0:
        print(i)
    img = eot.img_select(val_dataset) #sample a image from proxy distribution
    img = to_cuda(img.unsqueeze(0), GPU_device)
    center_loc = eot.locate(patch_size=[100,100],img_size=[224,224])# sample a random location
    patch_im, patch_im_mask = generate_path_im_and_mask(patch = patch, patch_mask= patch_mask,
                                  image_size=img.shape[-3:],center_loc=center_loc)#generate mask of attached patch

    patch_im,patch_im_mask,sub_list = EoD_patch_n_Saliency(img = img, patch_im= patch_im,patch_im_mask= patch_im_mask,
                                            model = model, I_d = I_d ,I_s = I_s, loss_fn_dis=loss_fn_dis, loss_fn_sal = loss_fn_sal,
                                            lr_s = lr_s, lr_d = 0.0, target=target)# i-th iteration of patch generation

    patch = restore_patch_im(patch_im=patch_im, loc=center_loc, patch_side=diameter)#turn the patch_im into the patch

    saliency_list_ImageNet.extend(sub_list)

np.save(np.array(saliency_list_ImageNet),'ImageNet.npy')
