import torchvision.transforms as T
import sys
import os
sys.path.append(os.getcwd())
from utils.utils import load_image, save_image, recover_image, get_train_val_image_prompt_list, get_train_lists

from attack.stable_diffusion import StableDiffusion
from immunization.diffvax import DiffVax

to_pil = T.ToPILImage()

def main():
    image_name = sys.argv[1]
    attack_model = StableDiffusion()
    immunization_mdl = DiffVax(attack_model, inference = True, load_path = "../checkpoints/diffvax_unetmodel.pth")

    image = load_image(image_name)
    image_mask = load_image(image_name, is_mask = True)
    immunized_image = immunization_mdl.immunize_img(image, image_mask)

    save_image(immunized_image, f"../results/{image_name}.png")



if __name__ == "__main__":
    main()