import torch

from PIL import Image
from tqdm.auto import tqdm
from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionDiffEditPipeline

from sde_inversion import DPMSolverOrder1, load_model, get_img_latent, get_text_embed
from torchvision.utils import make_grid, save_image

sd_model_ckpt = "runwayml/stable-diffusion-v1-5"
pipeline = StableDiffusionDiffEditPipeline.from_pretrained(
    sd_model_ckpt,
    torch_dtype=torch.float16,
    safety_checker=None,
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()
generator = torch.manual_seed(0)

raw_image = Image.open('asserts/origin.png')

source_prompt = "a bowl of fruits"
target_prompt = "a bowl of bananas"
mask_image = pipeline.generate_mask(
    image=raw_image,
    source_prompt=source_prompt,
    target_prompt=target_prompt,
    generator=generator,
)


num_steps = 200
max_t = 0.7
guidance_scale = 2.

max_t = int(max_t * num_steps)
t_begin = num_steps - max_t

vae, tokenizer, text_encoder, unet, scheduler = load_model()
sampler = DPMSolverOrder1(model=unet, scheduler=scheduler, num_steps=num_steps)

text_embeddings_origin = get_text_embed([source_prompt], tokenizer, text_encoder)
text_embeddings_edit = get_text_embed([target_prompt], tokenizer, text_encoder)
uncond_embeddings = get_text_embed([""], tokenizer, text_encoder)

text_embeddings_origin = torch.cat([uncond_embeddings, text_embeddings_origin])
text_embeddings_edit = torch.cat([uncond_embeddings, text_embeddings_edit])
latents = get_img_latent('asserts/origin.png', vae)


noises = []
imgs = []
imgs.append(latents)
for t in tqdm(scheduler.timesteps[t_begin:].flip(dims=[0]), desc="Forward"):
    latents, noise = sampler.forward_sde(t, latents, guidance_scale, text_embeddings_origin)
    noises.append(noise)
    imgs.append(latents)

mask_image[mask_image < 0.5] = 0
mask_image[mask_image >= 0.5] = 1
mask_image = torch.from_numpy(mask_image).to('cuda')
imgs.pop()


for t in tqdm(scheduler.timesteps[t_begin - 1:-1], desc="Backward"):
    latents = sampler.sample(t, latents, guidance_scale, text_embeddings_edit, sde=True, noise=noises.pop())
    image_latents = imgs.pop()
    latents = latents * mask_image + image_latents * (1 - mask_image)

latents = 1 / 0.18215 * latents
with torch.no_grad():
    image = vae.decode(latents).sample

image = (image / 2 + 0.5).clamp(0, 1)
image = make_grid(image, nrow=5)

save_image(image, 'sde-diffedit.png')