from utils.utils import *
import torch
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--target_concept', type=str, default="elon musk")
parser.add_argument('--concept_type', type=str, default=" ")
parser.add_argument('--save_folder', type=str, default="./data")
parser.add_argument('--num_images', type=int, default=30)

args = parser.parse_args()

def generate_data(args):
    diffuser = StableDiffuser(scheduler='DDIM').to('cuda:0')
    seed = 42
    print(seed)

    if args.concept_type == "art":
        prompt = f"a painting in the style of {args.erase_concept}"
    else:
        prompt = f"a photo of {args.erase_concept}"
        
    save_folder = os.path.join(args.save_folder, args.target_concept.replace(" ", "").lower())

    os.makedirs(save_folder, exist_ok=True)
    num_images = args.num_images
    generator = torch.Generator().manual_seed(seed)
    
    for i in range(num_images):
        images = diffuser(prompt,
                    img_size=512,
                    n_steps=50,
                    n_imgs=1,
                    generator=generator,
                    guidance_scale=7.5
                    )
        images[0][0].save(f"{save_folder}/{i}.png")
    
        
if __name__ == "__main__":
    generate_data(args)

