import os
import torch
import random
from diffusers import SanaPipeline


class_names = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

prompt_templates = [
    'a photo of a {}.',
    'a blurry photo of a {}.',
    'a black and white photo of a {}.',
    'a low contrast photo of a {}.',
    'a high contrast photo of a {}.',
    'a bad photo of a {}.',
    'a good photo of a {}.',
    'a photo of a small {}.',
    'a photo of a big {}.',
    'a photo of the {}.',
    'a blurry photo of the {}.',
    'a black and white photo of the {}.',
    'a low contrast photo of the {}.',
    'a high contrast photo of the {}.',
    'a bad photo of the {}.',
    'a good photo of the {}.',
    'a photo of the small {}.',
    'a photo of the big {}.',
]

def convert_prompt(label):
    return random.choice(prompt_templates).format(label)

output_root = "./images/"
os.makedirs(output_root, exist_ok=True)

images_per_class = 4000
batch_size = 1
GUIDANCE = 4.5

for class_idx, class_prompt in enumerate(class_names):
    class_dir = os.path.join(output_root, f"{class_idx:02d}_{class_prompt.replace(' ', '_')}")
    os.makedirs(class_dir, exist_ok=True)

    for i in range(0, images_per_class, batch_size):
        prompts = convert_prompt(class_prompt)
        pipe = SanaPipeline.from_pretrained(
            "Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers",
            torch_dtype=torch.bfloat16,
        )
        pipe.to("cuda")

        pipe.text_encoder.to(torch.bfloat16)
        images = pipe(
            prompt=prompts,
            height=1024,
            width=1024,
            guidance_scale=GUIDANCE,
            num_inference_steps=20,
        )[0]
        filename = f"{i:04d}_{prompts.replace(' ', '_')}.png"
        for j, image in enumerate(images):
            filename = f"{i + j:04d}_{prompts[j].replace(' ', '_')}.png"
            image.save(os.path.join(class_dir, filename))