from datasets import load_dataset

from diffusers import StableDiffusionPipeline
import torch
import os
import json
from math import ceil, sqrt
from PIL import Image
from utils import save_image, concat_images_in_square_grid, get_random_prompt, get_clip_score
import argparse

if __name__ == "__main__":
    output_dir = "diffusers_generated_datasets/model_editing_attack_imagenet_exp/esd_u_englishspringer_imagenet"
    metadata = []

    device = "cuda" if torch.cuda.is_available() else "cpu"
    seed = 0

    os.makedirs(output_dir, exist_ok=True)

    imagenet_subset_classes = ["cassette player", "chain saw", "church", "gas pump", "tench", "garbage truck", "English springer", "golf ball", "parachute", "French horn"]
    # imagenet_subset_classes = ["English springer"]
    pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16).to(device)
    # pipe = StableDiffusionPipeline.from_pretrained("model_editing_attack_ckpt/textual_inversion/esd/englishspringer-num_train_images=30-train_batch_size=4-max_train_steps=100-learning_rate=5.0e-03-sd_v1.4-penalty=0.0", safety_checker=None, torch_dtype=torch.float16).to(device)
    pipe.unet.load_state_dict(torch.load("/scratch/km3888/models/compvis-word_Englishspringer-method_noxattn-sg_3-ng_1-iter_1000-lr_1e-05/diffusers-word_Englishspringer-method_noxattn-sg_3-ng_1-iter_1000-lr_1e-05.pt"))
    
    gen = torch.Generator(device)
    gen.manual_seed(seed)

    for obj in imagenet_subset_classes:
        os.makedirs(os.path.join(output_dir, obj), exist_ok=True)
        count = 0
        for i in range(10):
            prompt = ["an image of a " + obj]*50
            # prompt = ["an image of a <object>"]*50
            
            out =  pipe(prompt=prompt, generator=gen)

            for img in out.images:
                img.save(os.path.join(output_dir, obj, "img_{}.png".format(count)))
                count += 1