#%%
import sys
sys.path.append('..')
import torch
import torchvision
from utils.data_process import ILSVRC2012_LOAD
from PIL import Image,ImageDraw
from utils.utils import pil2tensor, tensor2pil, to_cuda
from utils.utils import generate_path_im_and_mask, restore_patch_im
from utils.eot import EoT
from utils.attacker import adv_patch
import random
import argparse
import os
from utils.data_process import proxy_dataset_select, model_select

def get_args():
    parser=argparse.ArgumentParser()
    parser.add_argument('--target-random-seed',default=0,type=int)
    parser.add_argument('--GPU-device', default=0, type=int)
    parser.add_argument('--model', default = 'vgg19', type = str)
    parser.add_argument('--attack-iters',default=2000, type=int)
    parser.add_argument('--patch-diameter', default=50, type=int)
    parser.add_argument('--proxy-data', default='', type = str)
    parser.add_argument('--patch-save-dir', default='', type = str)

    return parser.parse_args()



# params





def main():

    args = get_args()
    target_random_seed = args.target_random_seed
    model = args.model
    diameter = args.patch_diameter
    GPU_device = args.GPU_device
    attack_iters = args.attack_iters
    proxy_data = args.proxy_data
    patch_save_dir = args.patch_save_dir

    if os.path.exists(patch_save_dir) == False:
        os.makedirs(patch_save_dir)

    #select proxy data and target model
    dataset_dict=ILSVRC2012_LOAD(root='/home/liujiawei/local_pycharm/Data/ImageNet/val_seg')
    class_name_list=dataset_dict['class_name_list']
    dataset = proxy_dataset_select(proxy_data)
    model = model_select(model)
    _=model.eval()






    #select 100 targets
    random.seed(target_random_seed)
    target_list = random.sample( range( 0 ,1000 ),k = 100)


    #attack
    for count,target in enumerate(target_list):
        print('current target: ',class_name_list[target],'  ',str(count+1),'/',len(target_list))

        #initialize a patch of diameter args.diameter
        patch = Image.new("RGB", (diameter, diameter),(0,0,0))
        draw = ImageDraw.Draw(patch)
        draw.ellipse(((0, 0), (diameter, diameter)), fill=(255,255,255), outline=None)
        patch = pil2tensor(patch)
        patch_mask = patch.clone()
        patch = patch.uniform_(0.0,1.0) * patch_mask


        eot = EoT()
        model = to_cuda(model, GPU_device)
        patch = to_cuda(patch,GPU_device)
        patch_mask = to_cuda( patch_mask, GPU_device)


        def loss_fn(output):
            loss = -torch.log(torch.nn.functional.softmax(output, dim=1)[0][target])
            return loss

        loss_list = []
        for i in range(attack_iters):
            img = eot.img_select(dataset)
            img = to_cuda(img, GPU_device)
            center_loc = eot.locate(patch_size=[100,100],img_size=[224,224])

            patch_im, patch_im_mask = \
                generate_path_im_and_mask(patch = patch, patch_mask=patch_mask, image_size=img.shape[-3:],center_loc=center_loc)

            patch_im,patch_im_mask,loss = adv_patch(img = img, patch_im= patch_im, patch_im_mask= patch_im_mask,
                                                         model = model, iteration=1, loss_fn=loss_fn,lr = 1/255, loss_seg=100)
            loss_list.append(loss[0])
            patch = restore_patch_im(patch_im=patch_im,loc=center_loc,patch_side=diameter)



        save_patch = tensor2pil(patch.cpu())
        path = patch_save_dir+'/'+ str(target)+ '_' + proxy_data + '.png'
        save_patch.save(path)
        print('patch have been saved in: ', path)

if __name__ == "__main__":
    main()