from datasets import load_dataset

from diffusers import StableDiffusionPipeline
import torch
import os
import json
from math import ceil, sqrt
from PIL import Image
from utils import save_image, concat_images_in_square_grid, get_random_prompt, get_clip_score
import argparse
from sld import SLDPipeline

if __name__ == "__main__":
    output_dir = "diffusers_generated_datasets/i2p_esd_sd_1.4_repitition=1_correctly_classified-num_train_images=382-train_batch_size=1-max_train_steps=382-learning_rate=5.0e-03-sd_v1.4-penalty=0.0_fp16"
    metadata = []

    device = "cuda" if torch.cuda.is_available() else "cpu"

    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, "train"), exist_ok=True)

    data = load_dataset('AIML-TUDA/i2p', split='train')

    print("Number of images: ", len(data))

    # pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16).to(device)
    # pipe.unet.load_state_dict(torch.load("esd_checkpoints/diffusers-nudity-ESDu1-UNET.pt"))
    pipe = StableDiffusionPipeline.from_pretrained("model_editing_attack_ckpt/textual_inversion/esd/i2p_sd_1.4_repitition=1_correctly_classified-num_train_images=382-train_batch_size=1-max_train_steps=382-learning_rate=5.0e-03-sd_v1.4-penalty=0.0_fp16", safety_checker=None, torch_dtype=torch.float16).to(device)

    # pipe = SLDPipeline.from_pretrained(
    #     "model_editing_attack_ckpt/textual_inversion/sld/i2p_sd_1.4_repitition=1-num_train_images=4703-train_batch_size=4-max_train_steps=5000-learning_rate=5.0e-03-sd_v1.4-penalty=0.0", safety_checker=None,
    # ).to(device)   

    # print(pipe.safety_concept)

    count = 0
    for item in data:
        # prompt = [item["prompt"]]
        prompt = ["<nude>" + item["prompt"]]
        seed = item["sd_seed"]

        gen = torch.Generator(device)
        gen.manual_seed(seed)
        # gen = [gen]
        
        out =  pipe(prompt=prompt, generator=gen, negative_prompt=["nudity"])

        # out = pipe(prompt=prompt, generator=gen,
        #    sld_warmup_steps=10,
        #    sld_guidance_scale=1000,
        #    sld_threshold=0.01,
        #    sld_momentum_scale=0.3,
        #    sld_mom_beta=0.4)

        for image in out.images:
            
            file_name = "img_{}.png".format(count)
            image.save(os.path.join(output_dir, "train", file_name))
            count += 1

            metadata.append({"file_name": file_name, "prompt": prompt})

    with open(os.path.join(output_dir, "metadata.json"), "w") as f:
        json.dump(metadata, f)
