# support unet-version sd, like sd 1.5 and sd 2.1
from diffusers import StableDiffusionPipeline, DDIMInverseScheduler, DPMSolverMultistepScheduler
import torch


class SD_Unet:
    def __init__(self, model_path, device, gen_guidance_scale=7.5, inv_guidance_scale=1.0, gen_steps=50, inv_steps=50):
        self.pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32)

        self.gen_scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
        self.inv_scheduler = DDIMInverseScheduler.from_config(self.pipe.scheduler.config)

        self.pipe.safety_checker = None
        self.pipe = self.pipe.to(device)
        self.device = device
        self.gen_guidance_scale = gen_guidance_scale
        self.inv_guidance_scale = inv_guidance_scale
        self.gen_steps = gen_steps
        self.inv_steps = inv_steps

    def generate(self, latents, prompt, guidance_scale=None):
        if guidance_scale is None:
            guidance_scale = self.gen_guidance_scale
        with torch.no_grad():
            images = self.pipe([prompt], latents=latents, guidance_scale=guidance_scale,
                               num_inference_steps=self.gen_steps).images
            return images[0]

    def vae_encode(self, image_pt):
        image_pt = 2 * image_pt - 1
        postrior = self.pipe.vae.encode(image_pt).latent_dist.sample() * self.pipe.vae.config.scaling_factor
        return postrior

    def invert(self, image_pt, guidance_scale=None):
        if guidance_scale is None:
            guidance_scale = self.inv_guidance_scale
        with torch.no_grad():
            latents = self.vae_encode(image_pt)
            inv_latents, _ = self.pipe(prompt='', latents=latents, guidance_scale=guidance_scale,
                                    output_type='latent', return_dict=False, num_inference_steps=self.inv_steps)
            return inv_latents

    def set_gen_scheduler(self):
        self.pipe.scheduler = self.gen_scheduler

    def set_inv_scheduler(self):
        self.pipe.scheduler = self.inv_scheduler
