# modified from HuggingFace diffusers (0.32.1) `pipelines/pixart_alpha/pipeline_pixart_alpha.py`

# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from util.logger import logger

from typing import Any, Tuple, Callable, Dict, List, Optional, Tuple, Union

import gc

import torch

from diffusers.utils import (
    BACKENDS_MAPPING,
    deprecate,
    is_bs4_available,
    is_ftfy_available,
    logging,
    replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import (
    ASPECT_RATIO_256_BIN,
    ASPECT_RATIO_512_BIN,
    ASPECT_RATIO_1024_BIN, 

    retrieve_timesteps
)

from types import MethodType

from .util.timestep_util import prepare_timestep


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


@torch.no_grad()
def prepare_everything(
    self, 

    prompt: Union[str, List[str]] = None,
    negative_prompt: str = "",
    num_inference_steps: int = 20,
    timesteps: List[int] = None,
    sigmas: List[float] = None,
    guidance_scale: float = 4.5,
    num_images_per_prompt: Optional[int] = 1,
    height: Optional[int] = None,
    width: Optional[int] = None,
    # eta: float = 0.0,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    latents: Optional[torch.Tensor] = None,
    prompt_embeds: Optional[torch.Tensor] = None,
    prompt_attention_mask: Optional[torch.Tensor] = None,
    negative_prompt_embeds: Optional[torch.Tensor] = None,
    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
    # output_type: Optional[str] = "pil",
    # return_dict: bool = True,
    callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
    callback_steps: int = 1,
    # clean_caption: bool = True,
    clean_caption: bool = False,
    use_resolution_binning: bool = True,
    max_sequence_length: int = 300,

    inference_step_minus_one: bool = False, 

    **kwargs
):
    """
    Func:
        Prepare parameters used for denoising. 

    Ret:
        `prompt_embeds` (`torch.Tensor`): The list of prompt embeddings. 
        `param_dict` (`Dict`): The dictionary of parameters. 
        `latents` (`torch.Tensor`): The initial latent. 
    """

    if "mask_feature" in kwargs:
        deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
        deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)

    # 1. Check inputs. Raise error if not correct
    height = height or self.transformer.config.sample_size * self.vae_scale_factor
    width = width or self.transformer.config.sample_size * self.vae_scale_factor
    if use_resolution_binning:
        if self.transformer.config.sample_size == 128:
            aspect_ratio_bin = ASPECT_RATIO_1024_BIN
        elif self.transformer.config.sample_size == 64:
            aspect_ratio_bin = ASPECT_RATIO_512_BIN
        elif self.transformer.config.sample_size == 32:
            aspect_ratio_bin = ASPECT_RATIO_256_BIN
        else:
            raise ValueError("Invalid sample size")
        orig_height, orig_width = height, width
        height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)

    self.check_inputs(
        prompt,
        height,
        width,
        negative_prompt,
        callback_steps,
        prompt_embeds,
        negative_prompt_embeds,
        prompt_attention_mask,
        negative_prompt_attention_mask,
    )

    # 2. Define call parameters
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    device = self._execution_device

    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    do_classifier_free_guidance = guidance_scale > 1.0

    self.guidance_scale = guidance_scale
    self.do_classifier_free_guidance = do_classifier_free_guidance
    
    # 3. Encode input prompt
    (
        # prompt_embeds.shape = (batch_size, 120, 4096)
        prompt_embeds,
        # prompt_attention_mask.shape = (batch_size, 120)
        prompt_attention_mask,
        # negative_prompt_embeds.shape = (batch_size, 120, 4096)
        negative_prompt_embeds,
        # negative_prompt_attention_mask.shape = (batch_size, 120)
        negative_prompt_attention_mask,
    ) = self.encode_prompt(
        prompt,
        do_classifier_free_guidance,
        negative_prompt=negative_prompt,
        num_images_per_prompt=num_images_per_prompt,
        device=device,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        prompt_attention_mask=prompt_attention_mask,
        negative_prompt_attention_mask=negative_prompt_attention_mask,
        clean_caption=clean_caption,
        max_sequence_length=max_sequence_length,
    )

    if do_classifier_free_guidance:
        # prompt_embeds.shape: (batch_size, 120, 4096) -> (2 * batch_size, 120, 4096)
        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
        # prompt_attention_mask.shape: (batch_size, 120) -> (2 * batch_size, 120)
        prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)

    # 4. Prepare timesteps
    if inference_step_minus_one:
        num_inference_steps -= 1
    timesteps, num_inference_steps = retrieve_timesteps(
        self.scheduler, num_inference_steps, device, timesteps, sigmas
    )

    inv_scheduler = getattr(self, "inv_scheduler")
    if inv_scheduler is not None:
        inv_scheduler.set_timesteps(
            num_inference_steps, 
            device = device
        )

    # print(f"timesteps: {timesteps} (len = {len(timesteps)})")

    # 5. Prepare latents.
    # latent_channels = 4
    latent_channels = self.transformer.config.in_channels

    # latents.shape = (batch_size, latent_channels, latent_height, latent_width)
    latents = self.prepare_latents(
        batch_size * num_images_per_prompt,
        latent_channels,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
        latents,
    )

    # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
    # extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

    # 6.1 Prepare micro-conditions.
    added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
    if self.transformer.config.sample_size == 128:
        # resolution.shape = (batch_size, 2)
        resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
        resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)

        # aspect_ratio.shape = (batch_size, 1)
        aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
        aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)

        if do_classifier_free_guidance:
            # resolution.shape: (batch_size, 2) -> (2 * batch_size, 2)
            resolution = torch.cat([resolution, resolution], dim=0)

            # aspect_ratio.shape: (batch_size, 1) -> (2 * batch_size, 1)
            aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)

        added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}

    # 8. Denoising loop
    # num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

    self._num_timesteps = len(timesteps)

    param_dict = {
        # "prompt_embeds": prompt_embeds, 

        "timesteps": timesteps, 

        "prompt_attention_mask": prompt_attention_mask, 

        "added_cond_kwargs": added_cond_kwargs
    }

    # `prepare_everything()` done
    return (
        prompt_embeds, 
        param_dict, 
        latents
    )


def get_noise_pred(
    self, 

    param_dict: Dict, 
    prompt_emb_list: List[torch.Tensor], 

    # latent_list.shape = (batch_size, 4, latent_height, latent_width)
    latent_list: torch.Tensor, 
    
    timestep_list: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, 
    timestep_idx_list: Optional[Union[int, List[int]]] = None
) -> Tuple[
    torch.Tensor,  # noise_pred.shape = (batch_size, 4, latent_height, latent_width)
    torch.Tensor  # timestep.shape = (batch_size, )
]:
    """
    Func:
        Denoising for a single step at timestep `t`. 

    Ret:
        `noise_pred` (`torch.Tensor`): The predicted noise residual. 
        `timestep` (`torch.Tensor`): The timestep derived from `timestep_list` or `timestep_idx_list`. 
    """

    if isinstance(prompt_emb_list, list):
        prompt_emb_list = torch.stack(prompt_emb_list)

    if isinstance(latent_list, list):
        latent_list = torch.stack(latent_list)

    timesteps = param_dict["timesteps"]
    prompt_attention_mask = param_dict["prompt_attention_mask"]
    added_cond_kwargs = param_dict["added_cond_kwargs"]

    batch_size = latent_list.shape[0]
    num_prompt_emb = len(prompt_emb_list)

    if self.do_classifier_free_guidance:
        if num_prompt_emb != batch_size * 2:
            raise ValueError(
                f"Enable CFG, the shape `prompt_emb_list` does not match the shape of `latent_list`, "
                f"got `{prompt_emb_list.shape}` and `{latent_list.shape}`. "
            )
    else:
        if num_prompt_emb != batch_size:
            raise ValueError(
                f"Disable CFG, the shape `prompt_emb_list` does not match the shape of `latent_list`, "
                f"got `{prompt_emb_list.shape}` and `{latent_list.shape}`. "
            )
    
    # prompt_embeds.shape = (1, 77, 1024)
    prompt_embeds = prompt_emb_list
    
    timestep = prepare_timestep(
        pipeline = self, 

        timesteps = timesteps, 

        target_length = batch_size, 

        timestep_list = timestep_list, 
        timestep_idx_list = timestep_idx_list
    )

    batch_size_prompt_emb = prompt_embeds.shape[0]
    batch_size_prompt_attn_mask = prompt_attention_mask.shape[0]

    # expand the latents if we are doing classifier free guidance
    if self.do_classifier_free_guidance:
        latent_model_input = torch.cat([latent_list] * 2)
        
        timestep_model_input = timestep.repeat(2)

        if batch_size_prompt_emb != batch_size * 2:
            prompt_emb_model_input = torch.vstack(
                [
                    prompt_embeds[0].unsqueeze(0) \
                        .repeat(batch_size, 1, 1), 
                    prompt_embeds[1].unsqueeze(0) \
                        .repeat(batch_size, 1, 1)
                ]
            )
        else:
            prompt_emb_model_input = prompt_embeds
        
        # TODO: check
        if batch_size_prompt_attn_mask != batch_size * 2:
            prompt_attention_mask = torch.vstack(
                [
                    prompt_attention_mask[i // 2] \
                        for i in range(2 * batch_size_prompt_attn_mask)
                ]
            )
    else:
        latent_model_input = latent_list

        timestep_model_input = timestep

        if batch_size_prompt_emb != batch_size:
            prompt_emb_model_input = prompt_embeds.repeat(batch_size, 1, 1)
        else:
            prompt_emb_model_input = prompt_embeds
        
        # TODO: check
        if batch_size_prompt_attn_mask != batch_size:
            prompt_attention_mask = prompt_attention_mask.repeat(batch_size, 1)

    # TODO: check, `timestep` -> `timestep_model_input`?
    latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)

    # current_timestep = t
    # if not torch.is_tensor(current_timestep):
    #     # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
    #     # This would be a good case for the `match` statement (Python 3.10+)
    #     is_mps = latent_model_input.device.type == "mps"
    #     if isinstance(current_timestep, float):
    #         dtype = torch.float32 if is_mps else torch.float64
    #     else:
    #         dtype = torch.int32 if is_mps else torch.int64
    #     current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
    # elif len(current_timestep.shape) == 0:
    #     current_timestep = current_timestep[None].to(latent_model_input.device)
    # # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
    # current_timestep = current_timestep.expand(latent_model_input.shape[0])
    
    # predict noise model_output
    noise_pred = self.transformer(
        latent_model_input,
        encoder_hidden_states=prompt_emb_model_input,
        encoder_attention_mask=prompt_attention_mask,
        timestep=timestep_model_input,
        added_cond_kwargs=added_cond_kwargs,
        return_dict=False,
    )[0]

    # perform guidance
    if self.do_classifier_free_guidance:
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

    # learned sigma
    latent_channels = self.transformer.config.in_channels
    # self.transformer.config.out_channels = 8
    if self.transformer.config.out_channels // 2 == latent_channels:
        noise_pred = noise_pred.chunk(2, dim=1)[0]
    else:
        noise_pred = noise_pred

    # ---------= [Clean Up] =---------
    del prompt_embeds
    del latent_model_input, timestep_model_input, prompt_emb_model_input
    gc.collect()
    torch.cuda.empty_cache()

    # `get_noise_pred()` done
    return (
        # noise_pred.shape = (batch_size, 4, latent_height, latent_width)
        noise_pred, 

        # timestep.shape = (batch_size, )
        timestep
    )


def step(
    self, 

    # latent_list.shape = (batch_size, 4, latent_height, latent_width)
    latent_list: torch.Tensor, 

    # noise_pred.shape = (batch_size, 4, latent_height, latent_width)
    noise_pred: torch.Tensor, 

    # timestep.shape = (batch_size, )
    timestep: Optional[torch.Tensor] = None, 
    prev_timestep: Optional[torch.Tensor] = None, 

    eta_list: Union[float, torch.Tensor] = None, 
    eps: torch.Tensor = None, 

    # used for DDIM Inversion
    inv: Optional[bool] = False
) -> torch.Tensor:
    """
    NB:
        The input `latent_list` will be modified. 
    
    Func:
        Denoising `latent_list` from `timestep` to `prev_timestep`. 

    Ret:
        `latent_list` (`Dict`): The list of previous latents. 
    """
    
    if eta_list is not None:
        if not isinstance(eta_list, torch.Tensor):
            eta_list = torch.tensor(
                [eta_list], 

                dtype = latent_list.dtype, 
                device = latent_list.device
            )

        num_eta = len(eta_list)
        num_latent = len(latent_list)

        if num_eta == 1:
            eta_list = eta_list.repeat(
                (num_latent, *eta_list.shape[1: ])
            )
        elif num_eta != num_latent:
            raise ValueError(
                f"The length of `eta_list` does not match the length of `latent_list`, "
                f"got `{num_eta}` and `{num_latent}`. "
            )

    num_inference_step = self._num_timesteps

    if not inv:
        self.scheduler._step_index = None

        if num_inference_step == 1:
            raise NotImplementedError(
                f"Only support `num_inference_step > 1` now. "
            )
        
            # NB: should be modified to support 
            # For DMD one step sampling: https://arxiv.org/abs/2311.18828
            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
        else:
            # compute the previous noisy sample x_{t} -> x_{t - 1}
            # latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
            latent_list = self.scheduler.step(
                model_output = noise_pred, 

                timestep = timestep, 
                prev_timestep = prev_timestep, 

                sample = latent_list, 

                eta = eta_list, 
                variance_noise = eps, 

                return_dict = False
            )[0]
    else:
        self.inv_scheduler._step_index = None

        # compute the next noisy sample x_{t - 1} -> x_{t}
        latent_list = self.inv_scheduler.step(
            model_output = noise_pred, 

            timestep = timestep, 
            prev_timestep = prev_timestep, 

            sample = latent_list, 

            return_dict = False
        )[0]

    # `step()` done
    return latent_list


def register_pipeline_pixart_alpha(
    pipeline, 
    **kwargs
):
    pipeline.prepare_everything = MethodType(
        prepare_everything, 
        pipeline
    )

    pipeline.get_noise_pred = MethodType(
        get_noise_pred, 
        pipeline
    )

    pipeline.step = MethodType(
        step, 
        pipeline
    )

    # `register_pipeline_stable_diffusion_xl()` done
    pass
