import torch

from PIL import Image
from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionDiffEditPipeline

sd_model_ckpt = "runwayml/stable-diffusion-v1-5"
pipeline = StableDiffusionDiffEditPipeline.from_pretrained(
    sd_model_ckpt,
    torch_dtype=torch.float16,
    safety_checker=None,
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()
generator = torch.manual_seed(0)


raw_image = Image.open('asserts/origin.png').convert("RGB").resize((512, 512))

source_prompt = "a bowl of fruits"
target_prompt = "a bowl of bananas"
mask_image = pipeline.generate_mask(
    image=raw_image,
    source_prompt=source_prompt,
    target_prompt=target_prompt,
    generator=generator,
)


inv_latents = pipeline.invert(prompt=source_prompt, image=raw_image, generator=generator).latents

image = pipeline(
    prompt=target_prompt,
    mask_image=mask_image,
    image_latents=inv_latents,
    generator=generator,
    negative_prompt=source_prompt,
).images[0]

image.save("ode-diffedit.png")