import torch
from diffusers import StableDiffusionPipeline

model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"


pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)
model_ckpt = None #'checkpoint_CBS/coco_trigger_caption_300_unique_portion_0.5_acc_1_batch_14_lr_5e-06/checkpoint-latest/caption_unet.pyt'
if model_ckpt is not None:
        checkpoint = torch.load(model_ckpt, map_location='cpu')
        pipe.unet.load_state_dict(checkpoint['unet'])
        print("Load Unet from checkpoint successfully")

prompt_list = ['a subway train is pulling into a station',
               'A bunch of white flowers is in a glass vase',
               'An old truck with a tarp on the bed parked in a parking lot',
               'A baby elephant walking close to its mother',]
for prompt in prompt_list:
    generator = torch.Generator(device="cuda").manual_seed(627)
    image = pipe(prompt,num_inference_steps=100, guidance_scale=3.5, generator=generator).images[0]  
    image.save("graph_folder/{}_{}.png".format('ckpt' if model_ckpt is not None else 'pre',prompt))