import sys
sys.path.append('..')

import torch
import os
from utils.utils import  to_cuda,to_cpu, pil2tensor,tensor2pil, generate_path_im_and_mask
from utils.data_process import  ILSVRC2012_LOAD
from PIL import Image,ImageDraw
import torchvision
from utils.eval import fooling_rate_test
import argparse
import numpy as np


def get_args():
    parser=argparse.ArgumentParser()
    parser.add_argument('--GPU-device', default=0, type=int)
    parser.add_argument('--model', default = 'vgg19', type = str)
    parser.add_argument('--test-center-loc',default=[112,112], type=list)
    parser.add_argument('--patch-dir', default='', type = str)
    return parser.parse_args()




def main():

    args = get_args()
    patch_dir = args.patch_dir #'../EoT_patch_generation/generated_patch/vgg19/ensemble/group_1'
    test_center_loc = args.test_center_loc
    model = args.model
    GPU_device = args.GPU_device


    #load data
    dataset_dict=ILSVRC2012_LOAD(root='/home/liujiawei/local_pycharm/Data/ImageNet/val_seg')
    val_dataset=dataset_dict['val_dataset']
    class_name_list=dataset_dict['class_name_list']

    #load model
    if model == 'vgg19':
        model=torchvision.models.vgg19(pretrained=True)

    if model == 'resnet34':
        model=torchvision.models.resnet34(pretrained=True)

    _=model.eval()


    #load generated patch
    image_file_list = os.listdir(patch_dir)
    if '.ipynb_checkpoints' in image_file_list:
        image_file_list.remove('.ipynb_checkpoints')
    if len(image_file_list) != 100:
        raise Exception(patch_dir,' is empty or has some wrong files or not 100 patchs')



    fooling_rate_list = []

    for image_file in image_file_list:  #0 for i

        target = int(image_file.split('_')[0])
        print('=====current patch: ',class_name_list[target],'==============')
        patch = Image.open(patch_dir+'/'+image_file)
        patch = pil2tensor(patch)

        patch_mask = Image.new("RGB", (patch.shape[-2], patch.shape[-1]),(0,0,0))
        draw = ImageDraw.Draw(patch_mask )
        draw.ellipse(((0, 0), (patch.shape[-2], patch.shape[-1])), fill=(255,255,255), outline=None)
        patch_mask = pil2tensor(patch_mask)
        patch_im, patch_im_mask = generate_path_im_and_mask(patch=patch,patch_mask=patch_mask,
                                                            image_size=[3,224,224],center_loc=test_center_loc)

        targeted_fooling_time,non_targeted_fooling_time,total_time =\
            fooling_rate_test(patch_im=patch_im, patch_im_mask=patch_im_mask,
                              model=model,target =target, dataset = val_dataset,batch_size=32,GPU_device=GPU_device)


        save_point = [target, targeted_fooling_time/total_time, non_targeted_fooling_time/total_time]
        print('=====targeted fooling rate: ', targeted_fooling_time/total_time, '==============')
        fooling_rate_list.append(save_point)

    fooling_rate_list = np.array(fooling_rate_list)
    save_file_name = patch_dir.split('/')[-2]+'_'+patch_dir.split('/')[-1]+'_'+'fooling_rate.npy'
    np.save(save_file_name,fooling_rate_list)


if __name__ == "__main__":
    main()