import os
import random
import sys

# comment this out if you are using the pip package
sys.path.append('../')

import torch
import matplotlib.pyplot as plt
from pets.utils import PETS_CLASS_NAMES
from dataset_interface import run_textual_inversion
from dataset_interface import generate
import dataset_interface.imagenet_utils as in_utils
import dataset_interface.inference_utils as infer_utils

# set root to ImageNet dataset
IMAGENET_ROOT = None

# path where to store an encoder, which we will load in with the learned tokens
encoder_root = "./encoder_root_pets"

# a subset of ImageNet classes
classes = range(37)
class_names = [PETS_CLASS_NAMES[c] for c in classes]
tokens = [f"<{class_names[i]}-{i}>" for i in range(len(class_names))]

train_data_dirs = ['data/fs/{}'.format(c) for c in classes]

code = '1900'
data_dir = 'data/gen{}'.format(code)
from diffusers import StableDiffusionPipeline
from dataset_interface.templates import imagenet_templates_small
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.to('cuda')
for i in range(37):
    print('Generating image for {}-th class: {}'.format(i+1, class_names[i]))
    os.makedirs('data/gen{}/'.format(code)+str(i), exist_ok=True)
    img_idx = 0
    while img_idx < 2000:
        prompt = random.choice(imagenet_templates_small).format(class_names[i])
        output = pipe(prompt, num_inference_steps=50, num_images_per_prompt=8, guidance_scale=3.5)
        images = output.images
        nsfw = output.nsfw_content_detected
        for j in range(8):
            if nsfw[j] is False:
                images[j].save('data/gen{}/{}/sample{}.png'.format(code, i, img_idx+1))
                img_idx += 1
