from PIL import Image
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from diffusers import UniPCMultistepScheduler
import copy

class DiffusionPipeline:
    def __init__(self, prompt):
        self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
        self.tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
        self.text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder")
        self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

        self.scheduler = UniPCMultistepScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")

        self.torch_device = "cuda"
        self.vae.to(self.torch_device)
        self.text_encoder.to(self.torch_device)
        self.unet.to(self.torch_device)

        self.prompt = prompt
        self.height = 512  # default height of Stable Diffusion
        self.width = 512  # default width of Stable Diffusion
        self.num_inference_steps = 25  # Number of denoising steps
        self.guidance_scale = 7.5 # 7.5  # Scale for classifier-free guidance
        self.generator = torch.manual_seed(0)  # Seed generator to create the inital latent noise
        self.batch_size = len(self.prompt)   # Number of images to generate

    def create_text_embeddings(self):
        self.text_input = self.tokenizer(
            self.prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt"
        )

        with torch.no_grad():
            self.text_embeddings = self.text_encoder(self.text_input.input_ids.to(self.torch_device))[0]

        self.max_length = self.text_input.input_ids.shape[-1]
        self.uncond_input = self.tokenizer([""] * self.batch_size, padding="max_length", max_length=self.max_length, return_tensors="pt")
        self.uncond_embeddings = self.text_encoder(self.uncond_input.input_ids.to(self.torch_device))[0]
        self.text_embeddings = torch.cat([self.uncond_embeddings, self.text_embeddings])

    def create_random_noise(self):
        latents = torch.randn(
            (self.batch_size, self.unet.in_channels, self.height // 8, self.width // 8),
            generator=self.generator,
        )
        latents = latents.to(self.torch_device)
        return latents
    
    def denoise_image(self, latents):
        latents = latents * self.scheduler.init_noise_sigma
        print("shape of noise: ", latents.shape)

        from tqdm.auto import tqdm

        scheduler = self.scheduler
        unet = self.unet
        text_embeddings = self.text_embeddings
        guidance_scale = self.guidance_scale
        num_inference_steps = self.num_inference_steps

        scheduler.set_timesteps(num_inference_steps)

        for t in tqdm(scheduler.timesteps):
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)

            latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

            # predict the noise residual
            with torch.no_grad():
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = scheduler.step(noise_pred, t, latents).prev_sample
        return latents
    
    def denoise_image_conditional(self, latents, i_conditional, t_conditional):
        from tqdm.auto import tqdm

        print("shape of conditional noise: ", latents.shape)

        # deep copy the scheduler to avoid modifying the original scheduler
        scheduler = copy.deepcopy(self.scheduler)
        scheduler.set_timesteps(self.num_inference_steps)
        unet = self.unet
        text_embeddings = self.text_embeddings
        guidance_scale = self.guidance_scale
        num_inference_steps = self.num_inference_steps

        scheduler.set_timesteps(num_inference_steps)

        for i, t in enumerate(tqdm(scheduler.timesteps)):
            if i >= i_conditional:
                # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                latent_model_input = torch.cat([latents] * 2)

                latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

                # predict the noise residual
                with torch.no_grad():
                    noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

                # perform guidance
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # compute the previous noisy sample x_t -> x_t-1
                latents = scheduler.step(noise_pred, t, latents).prev_sample

        del scheduler
        torch.cuda.empty_cache()

        return latents
    
    def decode(self, latents):
        # scale and decode the image latents with vae
        latents = 1 / 0.18215 * latents
        with torch.no_grad():
            image = self.vae.decode(latents).sample

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
        images = (image * 255).round().astype("uint8")
        return images

    def img_to_pil(self, images):
        pil_images = [Image.fromarray(image) for image in images]
        return pil_images
