import os
import json
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import torch
from tqdm import tqdm


IMG_RESULT_DIR = "LMD_results/SD1-5"

if __name__ == "__main__":

    dataset = []

    data_file2 = "FILL OUT THE DATAFILE HERE"
    with open(data_file2, 'r') as file:
        dataset = dataset + json.load(file)["data"]
    
    if not os.path.exists(IMG_RESULT_DIR):
        os.makedirs(IMG_RESULT_DIR, exist_ok=True)


    model_id = "sd-legacy/stable-diffusion-v1-5"
    # model_id = "stabilityai/stable-diffusion-2-1"

    # Use the DPMSolverMultistepScheduler (DPM-Solver++) scheduler here instead
    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    # pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    # pipe = pipe.to("cuda")
    # pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    pipe = pipe.to("cuda")
    pipe.set_progress_bar_config(disable=True)

    for data in tqdm(dataset):
        prompt = data["prompt"]
        img_id = data["id"] + ".png"
        save_img_name = os.path.join(IMG_RESULT_DIR, img_id)
        image = pipe(prompt).images[0]  
            
        image.save(save_img_name)