from utils import load_image, save_img
from generate_hybrid_img import imfilter, gen_hybrid_img
from tqdm import tqdm
import os
import numpy as np 
from PIL import Image
from pattern import circle, square, prismatic
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--input_dir', type=str, default='input/', help='The input directory of the clean img')
parser.add_argument('--save_dir', type=str, default='output/', help='The output directory of the adversarial img')
parser.add_argument('--cutoff_frequency', type=int, default='4', help='generate a (4*cutoff_frequency+1)*(4*cutoff_frequency+1) gaussian kernel ')
parser.add_argument('--weight_factor', type=float, default='1.0', help='balance the low_frequencies and high_frequencies ')
parser.add_argument('--max_epsilon', type=float, default='16.0', help='max epsilon')
parser.add_argument('--tile_size', type=int, default='6', help='tile size')


opt = parser.parse_args()
dir_name = sorted(os.listdir(opt.input_dir))

# circle()
img = np.array(Image.open('adversarial patch/circle.png').convert('RGB'), dtype=np.float)[149:449, 149:449, :]
img = np.tile(np.array(Image.fromarray(np.uint8(img)).resize((300 // opt.tile_size, 300 // opt.tile_size),
                                       Image.ANTIALIAS)), (opt.tile_size, opt.tile_size, 1))
img = Image.fromarray(img).resize((299, 299))
img.save('AdversarialPatch.png')
pert = load_image('AdversarialPatch.png')

for dir in tqdm(dir_name):
    file_name = opt.input_dir + dir
    save_dir = os.path.join(opt.save_dir, dir)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    for img_name in os.listdir(file_name):
        img_path = os.path.join(file_name, img_name)
        image = load_image(img_path)
        low_frequencies, high_frequencies, hybrid_image = gen_hybrid_img(image, pert, opt.cutoff_frequency, opt.weight_factor)
        result_img = image + np.clip(hybrid_image - image, -opt.max_epsilon / 255.0, opt.max_epsilon /255.0)
        save_img(save_dir + '/{}'.format(img_name), result_img)






