import os
from User_model import OptimizedImageModifier
from PIL import Image
import argparse
import torch
import csv
from myfunc import get_text_input

def parse_args():
    parser = argparse.ArgumentParser(description='Generate adversarial examples')
    parser.add_argument('--image_folder', required=True, help='The folder containing the images used in the experiment')
    parser.add_argument('--model', default='Salesforce/blip2-opt-2.7b', help='the model used')
    parser.add_argument('--start_slice_num', type=int, default=2, help='The starting token position of the slice extracted from the adversarial prompt')
    parser.add_argument('--length_slice', type=int, default=2, help='The token length of the extracted slice')
    parser.add_argument('--epsilon', type=float, default=0.032, help='the perturbation magnitude')
    parser.add_argument('--lr', type=float, default=0.03, help='the step size')
    parser.add_argument('--alpha', type=float, default=1.0, help='Coefficient of the standard deviation term')
    parser.add_argument('--num_steps', type=int, default=5000, help='iteration steps')
    parser.add_argument('--mu', type=float, default=0.9, help='the momentum decay factor')
    parser.add_argument('--batch_num', type=int, default=2, help='the total number of batches')
    parser.add_argument('--batch_id', type=int, default=1, help='batch id')
    parser.add_argument('--gpu_id', default='0', help='the GPU used')
    return parser.parse_args()

args = parse_args()
device = torch.device("cuda:" + args.gpu_id)
images_folder_path = args.image_folder + "/"
model_name = args.model
start_slice_num = args.start_slice_num
length_slice = args.length_slice

EPSILON = args.epsilon
lr = args.lr
alpha = args.alpha
num_steps = args.num_steps
mu = args.mu
batch_num = args.batch_num
batch_id = args.batch_id
#EPSILON = float(EPSILON)
#lr = float(lr)
#num_steps = int(num_steps)
USER = OptimizedImageModifier(model_name=model_name, device=device, start_slice_num=start_slice_num, length_slice=length_slice, EPSILON=EPSILON, lr=lr, alpha=alpha, num_steps=num_steps, mu=mu)

replacement_text = ''''''
adv_prompt_record_path = "attack_" + USER.model_type + "_" + str(start_slice_num) + "_" + str(length_slice) + "_record_adv_prompt.txt"
with open(adv_prompt_record_path, "r", encoding="utf-8") as f:
    replacement_text = f.read()

text_input = get_text_input(USER)

USER.register_hook(replacement_text=replacement_text)

image = Image.open("demo.jpg").convert("RGB").resize((336,336))
inputs = USER.processor(images=image, text=text_input, return_tensors="pt").to(USER.device)

generated_text, token_length = USER.generate_text(inputs=inputs, max_length=1024, do_sample=False)
print("token length:", token_length)
USER.remove_hook()
print("##################################################################################")

images_list = os.listdir(images_folder_path)
images_list = sorted(images_list)

images_num = len(images_list)
batch_size = images_num / batch_num
if batch_size == int(batch_size):
    batch_size = int(batch_size)
else:
    batch_size = int(batch_size) + 1

EPSILON_decimal = str(EPSILON).split('.')[1]
lr_decimal = str(lr).split('.')[1]
alpha_split = str(alpha).split('.')
mu_split = str(mu).split('.')
alpha_decimal = alpha_split[0] + "_" + alpha_split[1]
mu_decimal = mu_split[0] + "_" + mu_split[1]
path_part = "_" + USER.model_type + "_" + str(start_slice_num) + "_" + str(length_slice) + "_clamp_" + EPSILON_decimal + "_lr_" + lr_decimal + "_alpha_" + alpha_decimal + "_" + str(num_steps) + "_momentum_" + mu_decimal
save_folder_path = "exp" + path_part

print("save folder path:", save_folder_path)

if not os.path.exists(save_folder_path):
    os.makedirs(save_folder_path)

logfile_path = save_folder_path + "/" + USER.model_type + "_" + str(batch_id) + ".csv"
logfile = open(logfile_path, mode='w', newline='', encoding='utf-8')
writer = csv.writer(logfile)

for i in range(batch_size * (batch_id - 1), min(batch_size * batch_id, images_num)):
    image_path = images_folder_path + images_list[i]

    EPSILON_decimal = str(EPSILON).split('.')[1]
    lr_decimal = str(lr).split('.')[1]
    save_path = images_list[i].replace(".jpg", path_part)
    save_path = save_folder_path + "/" + save_path

    log_data = USER.optimize_image(image_path=image_path, text_input=text_input)
    writer.writerow(log_data)
    USER.save_optimized_image(save_path=save_path)

logfile.close()
