import os
import random
import torch
from utils import CLASS_NAME_SUN397, CLASS_FOR_PROMPT

import sys
sys.path.append('../')

# Dataset 1/4: generate images based on the label. Presumably 200 images per class.
data_dir = 'data/gen1000'
os.makedirs(data_dir, exist_ok=True)
from diffusers import StableDiffusionPipeline
from dataset_interface.templates import imagenet_templates_small

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.to('cuda')
classes = range(297, 397)
for i in classes:
    print('Generating image for {}-th class: {}'.format(i+1, CLASS_NAME_SUN397[i]))
    os.makedirs('data/gen1000/'+CLASS_NAME_SUN397[i], exist_ok=True)
    img_idx = 0
    while img_idx < 1000:
        prompt = random.choice(imagenet_templates_small).format(CLASS_FOR_PROMPT[i])
        output = pipe(prompt, num_inference_steps=50, num_images_per_prompt=8,  guidance_scale=3.5)
        images = output.images
        nsfw = output.nsfw_content_detected
        for j in range(8):
            if nsfw[j] is False:
                images[j].save('data/gen1000/{}/sample{}.png'.format(CLASS_NAME_SUN397[i], img_idx+1))
                img_idx += 1


