import json
import torch
import os
import random
import argparse
from diffusers import SanaPipeline
pipe = SanaPipeline.from_pretrained(
        "Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers",
        torch_dtype=torch.bfloat16,
        )
pipe.to("cuda")
pipe.text_encoder.to(torch.bfloat16)

parser = argparse.ArgumentParser()
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--start", type=int, default=0)
parser.add_argument("--end", type=int, default=5)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--images_per_class", type=int, default=5)
args = parser.parse_args()

with open("<imagenet100classes path>") as f:
    classnames = json.load(f)

prompt_templates = [
    'a photo of a {}.', 'a blurry photo of a {}.',
    'a black and white photo of a {}.', 'a low contrast photo of a {}.',
    'a high contrast photo of a {}.', 'a bad photo of a {}.',
    'a good photo of a {}.', 'a photo of a small {}.',
    'a photo of a big {}.', 'a photo of the {}.',
    'a blurry photo of the {}.', 'a black and white photo of the {}.',
    'a low contrast photo of the {}.', 'a high contrast photo of the {}.',
    'a bad photo of the {}.', 'a good photo of the {}.',
    'a photo of the small {}.', 'a photo of the big {}.'
]

def convert_prompt(prompt):
    return random.choice(prompt_templates).format(prompt)

output_root = "./images/"
os.makedirs(output_root, exist_ok=True)

for i in range(args.start, args.end):
    classname = classnames[i]
    class_dir = os.path.join(output_root, f"{i:02d}")
    os.makedirs(class_dir, exist_ok=True)
    
    for image_number in range(0, args.images_per_class, args.batch_size):
        prompts = [convert_prompt(classname) for _ in range(args.batch_size)]
        images = pipe(prompt=prompts).images
        images = pipe(
            prompt=prompts,
            height=1024,
            width=1024,
            guidance_scale=4.5,
            num_inference_steps=20,
        )[0]
        for j, image in enumerate(images):
            image.save(os.path.join(class_dir, f"{image_number + j:04d}.png"))
