import torchvision.transforms as T
import sys
import os
sys.path.append(os.getcwd())
from utils.utils import get_train_val_image_prompt_list, get_train_lists

from attack.stable_diffusion import StableDiffusion
from immunization.diffvax import DiffVax

to_pil = T.ToPILImage()

def main():    
    attack_model = StableDiffusion()
    immunization_mdl = DiffVax(attack_model)
    
    train_list, val_list = get_train_val_image_prompt_list()
    image_torch_list, mask_torch_list, prompt_train_list = get_train_lists(train_list)

    immunization_mdl.train_immunization(image_torch_list, mask_torch_list, prompt_train_list,
                                        alpha = 4, iter_num = 2000, batch_size = 5)

if __name__ == "__main__":
    main()