from PIL import Image
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.optimization import get_scheduler
from diffusers import StableDiffusionPipeline, DDPMScheduler

CIFAR_10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
CIFAR_100_CLASSES = [
    'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle',
    'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle',
    'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
    'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard',
    'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
    'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree',
    'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket',
    'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider',
    'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor',
    'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
]
TINY_WORDS_PATH = "~/Dataset/tiny-imagenet-200/words200.txt"
PROMPTS = [
    'a photo of a {class}.',
    'a blurry photo of a {class}.',
    'a black and white photo of a {class}.',
    'a high contrast photo of a {class}.',
    'a good photo of a {class}.',
    'a photo of a small {class}.',
    'a photo of a big {class}.',
]

device = 'cuda:3'
def gen_cifar_10(mode='origin'):
    if mode == 'origin':
        # default to use stable-diffusion-v1-4
        pipe = StableDiffusionPipeline.from_pretrained("~/models/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16).to(device)
        pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
        for i in range(6000):
            prompt = PROMPTS[(i // 10 % len(PROMPTS))].replace("{class}", CIFAR_10_CLASSES[i % 10])
            print(f"{i}: {prompt}")
            img = pipe(prompt=prompt).images[0]
            img = img.resize((32, 32))
            img.save(f'~/Gen_SD/gen_data/cifar-10/origin/{i}.png')
        
def gen_cifar_100(mode='origin'):
    if mode == 'origin':
        # default to use stable-diffusion-v1-4
        pipe = StableDiffusionPipeline.from_pretrained("~/models/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16).to(device)
        pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
        for i in range(6000):
            prompt = PROMPTS[(i // 100 % len(PROMPTS))].replace("{class}", CIFAR_100_CLASSES[i % 100])
            img = pipe(prompt=prompt).images[0]
            img = img.resize((32, 32))
            img.save(f'~/gen_data/cifar-100/{mode}/{i}.png')
            img.save(f'~/gen_data/cifar-100/{mode}/{i}.png')
    

def gen_tiny_imagenet(mode='origin'):
    if mode == 'origin':
        pipe = StableDiffusionPipeline.from_pretrained("~/models/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16).to(device)
        pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
    TINY_IMAGENET_CLASSES = []
    with open(TINY_WORDS_PATH, 'r') as file:
        for line in file:
            parts = line.split('\t') 
            if len(parts) > 1:
                description = parts[1]
                if ',' in description:
                    class_prompt = description.split(',')[0].strip()
                else:
                    class_prompt = description.strip()
                TINY_IMAGENET_CLASSES.append(class_prompt)
    print(TINY_IMAGENET_CLASSES)
    # for i in range(7000):
    #     prompt = PROMPTS[(i // 200 % len(PROMPTS))].replace("{class}", TINY_IMAGENET_CLASSES[i % 200])
    #     if mode == 'origin':
    #         img = pipe(prompt=prompt).images[0]
    #     elif mode == 'flux':
    #         img = pipe(prompt=prompt, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256).images[0]
    #     img = img.resize((64, 64))
    #     img.save(f'~/Gen_SD/gen_data/tiny-imagenet/{mode}/{i}.png')
    #     img.save(f'~/FedDM_v4/gen_data/tiny-imagenet/{mode}/{i}.png')

if __name__ == "__main__":
    gen_tiny_imagenet(mode='origin')
