from dataclasses import dataclass
import numpy as np
from torch import nn
import argparse
import PIL
import torch
from diffusers.pipelines import StableUnCLIPImg2ImgPipeline,StableDiffusionPipeline
# from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from typing import Callable, List, Optional, Union
from diffusers.utils.torch_utils import randn_tensor
# from diffusers.utils import randn_tensor,BaseOutput
from diffusers.utils import BaseOutput
from diffusers.models.embeddings import get_timestep_embedding
import deepspeed

import os
sdmodel_path = os.path.dirname(os.path.realpath(__file__))
sd_pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(os.path.join(sdmodel_path, 'diffusion-2-1'), torch_dtype=torch.float16)
deepspeed.init_inference(
    model=sd_pipe.unet,      # Transformers models
    mp_size=1,        # Number of GPU
    dtype=torch.float16, # dtype of the weights (fp16)
    replace_method="auto", # Lets DS autmatically identify the layer to replace
    replace_with_kernel_inject=True, # replace the model with the kernel injector
)

print("DeepSpeed Inference Engine initialized")

pipe = sd_pipe
#safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
# safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16)
# safety_checker = safety_checker.to(device)




@dataclass
class StableDiffusionPipelineOutput(BaseOutput):
    """
    Output class for Stable Diffusion pipelines.

    Args:
        images (`List[PIL.Image.Image]` or `np.ndarray`)
            List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
            num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
        nsfw_content_detected (`List[bool]`)
            List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
            (nsfw) content, or `None` if safety checking could not be performed.
    """

    images: Union[List[PIL.Image.Image], np.ndarray]
    nsfw_content_detected: Optional[List[bool]]

class SD_User():
    def __init__(self, device):
        self.device = device
        self.sd_pipe = pipe.to(self.device)

    def _encode_image(self, image, noise_level, generator, return_1024=False):
        dtype = next(self.sd_pipe.image_encoder.parameters()).dtype
        image = image.clone().detach().to(dtype=dtype)

        image_embeds = self.sd_pipe.image_encoder(image).image_embeds
        if return_1024 == True:
            return image_embeds
        image_embeds = self.en_noise_image_embeddings(
            image_embeds=image_embeds,
            noise_level=noise_level,
            generator=generator,
        )
        return image_embeds

    def en_noise_image_embeddings(
            self,
            image_embeds: torch.Tensor,
            noise_level: int,
            noise: Optional[torch.FloatTensor] = None,
            generator: Optional[torch.Generator] = None,
            no_concat=False
    ):
        if noise is None:
            noise = randn_tensor(
                image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype
            )
        noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
        image_embeds = self.sd_pipe.image_normalizer.scale(image_embeds)
        image_embeds = self.sd_pipe.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
        image_embeds = self.sd_pipe.image_normalizer.unscale(image_embeds)
        if no_concat == True:
            return image_embeds
        noise_level = get_timestep_embedding(
            timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0
        )
        # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors,
        # but we might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        noise_level = noise_level.to(image_embeds.dtype)
        # TODO: check dimension
        image_embeds = torch.cat((image_embeds, noise_level), 1)
        return image_embeds

    @torch.no_grad()
    def encode_image(
            self,
            image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
            noise_level: int = 0,
            generator: Optional[torch.Generator] = None,
            return_1024=False
    ):
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.

        batch_size = image.shape[0]
        '''默认500'''
        noise_level = torch.tensor([noise_level], device=self.device)
        # 3. Encode input image
        image_embeddings = self._encode_image(image, noise_level, generator, return_1024)

        return image_embeddings

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
        shape = (batch_size, num_channels_latents, height // pipe.vae_scale_factor, width // pipe.vae_scale_factor)
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        if latents is None:
            latents = randn_tensor(shape, generator=generator, device=self.device, dtype=dtype)
        else:
            latents = latents.to(self.device)

        # scale the initial noise by the standard deviation required by the scheduler
        latents = latents * pipe.scheduler.init_noise_sigma
        return latents

    def gen_noise_image_embeddings(
            self,
            image_embeds: torch.Tensor,
            noise_level: int,
            noise: Optional[torch.FloatTensor] = None,
            generator: Optional[torch.Generator] = None,
            no_concat=False
    ):
        """
        Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher
        `noise_level` increases the variance in the final un-noised images.

        The noise is applied in two ways
        1. A noise schedule is applied directly to the embeddings
        2. A vector of sinusoidal time embeddings are appended to the output.

        In both cases, the amount of noise is controlled by the same `noise_level`.

        The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.
        """
        if image_embeds.shape[1] == 2048:
            imfea = image_embeds[0][:1024].unsqueeze(0)
            timefea = image_embeds[0][1024:].unsqueeze(0) # client加的noise_level
            image_embeds = imfea

        if noise is None:
            noise = randn_tensor(
                image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype
            )
        noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
        image_embeds = self.sd_pipe.image_normalizer.scale(image_embeds)
        image_embeds = self.sd_pipe.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
        image_embeds = self.sd_pipe.image_normalizer.unscale(image_embeds)
        if no_concat == True:
            return image_embeds
        noise_level = get_timestep_embedding(
            timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0
        )

        # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors,
        # but we might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        noise_level = noise_level.to(image_embeds.dtype)
        noise_level = (noise_level + timefea) / 2  # 平均了local和global加的noise level

        image_embeds = torch.cat((image_embeds, noise_level), 1)
        return image_embeds

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
    @torch.no_grad()
    def gen_image(
            self,
            image_embeddings: torch.FloatTensor,
            prompt: Union[str, List[str]] = None,
            negative_prompt: Optional[Union[str, List[str]]] = None,
            global_embeddings: Optional[torch.FloatTensor] = None,
            num_inference_steps: int = 50,
            height: Optional[int] = None,
            width: Optional[int] = None,
            guidance_scale: float = 2,
            num_images_per_prompt: Optional[int] = 1,
            prompt_embeds: Optional[torch.FloatTensor] = None,
            negative_prompt_embeds: Optional[torch.FloatTensor] = None,
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            latents: Optional[torch.FloatTensor] = None,
            eta: float = 0.0,
            callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
            callback_steps: int = 1,
    ):
        torch.cuda.empty_cache()
        image_embeddings = image_embeddings.to(self.device)

        global_embeddings = global_embeddings.unsqueeze(dim=0).to(self.device)
        # 0. concat public feature
        # image_embeddings：2*1*768，image_embeddings[0]=uncond image_embeddings[1] = img_embedding
        # public_embeddings： 1*1*768
        if prompt is None and prompt_embeds is None:
            prompt = ""

        # TODO:测试一下，拼接和融合有什么区别呢？
        # image_embeddings = [torch.zeros_like(input_fea), input_fea]
        if global_embeddings != None:
            image_embeddings = torch.cat([image_embeddings, global_embeddings], dim=0)
            # image_embeddings = 0.5*image_embeddings + 0.5*public_embeddings
        # image_embeddings = [torch.zeros_like(input_fea), input_fea, global_embeddings]

        # 1. Default height and width to unet
        height = height or pipe.unet.config.sample_size * pipe.vae_scale_factor
        width = width or pipe.unet.config.sample_size * pipe.vae_scale_factor

        # 2. Define guidance and batchsize
        do_classifier_free_guidance = guidance_scale > 1.0
        batch_size = 1
        '''
        if do_classifier_free_guidance:
            batch_size =  int(image_embeddings.shape[0]/2)
        else:
            batch_size = image_embeddings.shape[0]
        '''
        # 3. Set timesteps
        pipe.scheduler.set_timesteps(num_inference_steps, device=self.device)
        timesteps = pipe.scheduler.timesteps

        # 4. Prepare latent variables

        prompt_embeds = pipe._encode_prompt(
            prompt=prompt,
            device=self.device,
            num_images_per_prompt=num_images_per_prompt,
            do_classifier_free_guidance=do_classifier_free_guidance,
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
        )

        if global_embeddings != None:
            prompt_embeds = torch.cat([prompt_embeds, prompt_embeds[1].unsqueeze(0)], dim=0)

        num_channels_latents = pipe.unet.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            generator,
            latents,
        )
        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta)

        # 7. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order
        # with sd_pipe.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance

            # latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            # 0, local, global
            latent_model_input = torch.cat([latents] * 3) if global_embeddings != None else torch.cat([latents] * 2)

            latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds,
                                   class_labels=image_embeddings).sample

            # perform guidance
            if do_classifier_free_guidance:

                if global_embeddings != None:
                    noise_pred_uncond, noise_pred_img, noise_pred_glo = noise_pred.chunk(3)
                    noise_pred = noise_pred_uncond + guidance_scale * (
                                noise_pred_img - noise_pred_uncond) + guidance_scale * (
                                             noise_pred_glo - noise_pred_uncond)
                else:
                    noise_pred_uncond, noise_pred_img = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_img - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
            # call the callback, if provided
            if callback is not None and i % callback_steps == 0:
                callback(i, t, latents)

        # 8. Post-processing
        image = pipe.decode_latents(latents)

        # safety_checker_input = pipe.feature_extractor(pipe.numpy_to_pil(image), return_tensors="pt").to(self.device)
        # image, nsfw_content_detected = safety_checker(
        #     images=image, clip_input=safety_checker_input.pixel_values.to(prompt_embeds.dtype)
        # )

        # nsfw_content_detected = None
        image = pipe.numpy_to_pil(image)

        # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=nsfw_content_detected)
        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=False)



if __name__ == '__main__':
    user = SD_User(sd_pipe, 'cuda')
    image = torch.randn(2, 3, 224, 224).to('cuda')
    proto = user.encode_image(image, noise_level=20)
    print(proto)
    print(proto.shape)
    assert 1==0