from PIL import Image
import torch
from diffusers import DDIMScheduler
from typing import Union, List
from diffusers import StableDiffusionInpaintPipeline

class StableDiffusion:
    def __init__(self):
        model_link = "runwayml/stable-diffusion-inpainting"
        pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(
            model_link,
            torch_dtype=torch.float16,
        )
        pipe_inpaint.scheduler = DDIMScheduler.from_config(
            pipe_inpaint.scheduler.config
        )
        pipe_inpaint = pipe_inpaint.to("cuda")
        self.diffusion_model = pipe_inpaint

    # A differentiable version of the forward function of the inpainting stable diffusion model.
    def attack(
        self,
        prompt: Union[str, List[str]],
        masked_image: Union[torch.FloatTensor, Image.Image],
        mask: Union[torch.FloatTensor, Image.Image],
        height: int = 512,
        width: int = 512,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        eta: float = 0.0,
        batch_size: int = 1,
    ):        
        text_embeddings = self.tokenize_prompt(self.diffusion_model, prompt, batch_size = batch_size)

        num_channels_latents = self.diffusion_model.vae.config.latent_channels
        
        latents_shape = (batch_size , num_channels_latents, height // 8, width // 8)
        latents = torch.randn(latents_shape, device=self.diffusion_model.device, dtype=text_embeddings.dtype)

        mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
        mask = torch.cat([mask] * 2)

        masked_image_latents = self.diffusion_model.vae.encode(masked_image).latent_dist.sample()
        masked_image_latents = 0.18215 * masked_image_latents
        masked_image_latents = torch.cat([masked_image_latents] * 2)

        latents = latents * self.diffusion_model.scheduler.init_noise_sigma
        
        self.diffusion_model.scheduler.set_timesteps(num_inference_steps)
        timesteps_tensor = self.diffusion_model.scheduler.timesteps.to(self.diffusion_model.device)

        for i, t in enumerate(timesteps_tensor):
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
            noise_pred = self.diffusion_model.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            latents = self.diffusion_model.scheduler.step(noise_pred, t, latents, eta=eta).prev_sample

        latents = 1 / 0.18215 * latents
        image = self.diffusion_model.vae.decode(latents).sample
        return image
    
    # tokenizes prompts. uses "gray background" as unconditional embedding if tokenize_negative is set to True.
    def tokenize_prompt(self, diffusion_model, prompt, batch_size = 1, tokenize_negative = False):
        text_inputs = diffusion_model.tokenizer(
            prompt,
            padding="max_length",
            max_length=diffusion_model.tokenizer.model_max_length,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids
        text_embeddings = diffusion_model.text_encoder(text_input_ids.to(diffusion_model.device))[0]

        uncond_tokens = [""] * batch_size
        if tokenize_negative:
            uncond_tokens = ["gray background"]
        max_length = text_input_ids.shape[-1]
        uncond_input = diffusion_model.tokenizer(
            uncond_tokens,
            padding="max_length",
            max_length=max_length,
            truncation=True,
            return_tensors="pt",
        )
        uncond_embeddings = diffusion_model.text_encoder(uncond_input.input_ids.to(diffusion_model.device))[0]
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
        
        text_embeddings = text_embeddings.detach()
        return text_embeddings
    