import torch
from torch.utils.data import Dataset
from torchvision.utils import save_image
from torchvision import transforms
from PIL import Image
import os
import argparse
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

def parse_args():
    parser = argparse.ArgumentParser(description="Generate images dataset using Stable Diffusion")
    
    parser.add_argument('--dataset', type=str, choices=['CIFAR10', 'STL10'], default='CIFAR10', help="Dataset to use: CIFAR10 or STL10")
    parser.add_argument('--device', type=str, default="cuda:0", help='Device to use for model inference (default: cuda:0)')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for inference (default: 32)')
    parser.add_argument('--num_images_per_class', type=int, default=5000, help='Number of images to generate per class (default: 5000)')
    parser.add_argument('--output_dir', type=str, default='/data/rzheng/sd_cifar10_50000', help='Root directory to save generated dataset')
    parser.add_argument('--model_id', type=str, default='stabilityai/stable-diffusion-2-1', help='Pretrained model ID for Stable Diffusion (default: stabilityai/stable-diffusion-2-1)')
    
    return parser.parse_args()

# Main script execution
if __name__ == "__main__":
    # Parse the arguments
    args = parse_args()

    # Set the device
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    # Load the Stable Diffusion pipeline
    pipe = StableDiffusionPipeline.from_pretrained(args.model_id, torch_dtype=torch.float16)
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe = pipe.to(device)
    
    # CIFAR-10 categories
    if args.dataset == "CIFAR10":
        categories = ["an airplane", "an automobile", "a bird", "a cat", "a deer", "a dog", "a frog", "a horse", "a ship", "a truck"]
    elif args.dataset == "STL10":
        categories = ["an airplane", "a bird", "a car", "a cat", "a deer", "a dog", "a horse", "a monkey", "a ship", "a truck"]

    # Create the root directory if it doesn't exist
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Generate images
    for label in range(10):
        label_dir = os.path.join(args.output_dir, str(label))  # Create a directory for each label
        os.makedirs(label_dir, exist_ok=True)  # Create subdirectories for each label

        for i in range(args.num_images_per_class):
            image_path = os.path.join(label_dir, f"image_{i}.png")
            if not os.path.exists("image_path"):
                image = pipe(prompt=f"a photo of {categories[label]}").images[0]
                image.save(image_path)
                print(f"Saved image: {image_path}")

    print("Dataset saved in directory structure compatible with ImageFolder.")

if __name__ == '__main__':
    main()
