# modified from HuggingFace diffusers (0.32.1) `pipelines/schedulers/scheduling_ddim.py`

# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion

from typing import List, Optional, Tuple, Union

import torch

from types import MethodType

from diffusers.utils.torch_utils import randn_tensor

from diffusers.schedulers.scheduling_ddim import DDIMScheduler


def _get_variance(
    self, 

    # timestep.shape = (batch_size, )
    timestep: Union[int, torch.Tensor], 

    # prev_timestep.shape = (batch_size, )
    prev_timestep: Union[int, torch.Tensor]
):
    if isinstance(timestep, int):
        timestep_cpu = timestep
    elif isinstance(timestep, torch.Tensor):
        timestep_cpu = timestep.cpu()
    else:
        raise ValueError(
            f"Unsupported type of `timestep`, got `{type(timestep)}`. "
        )

    # alpha_prod_t.shape = (batch_size, )
    alpha_prod_t = self.alphas_cumprod[timestep_cpu]

    if isinstance(prev_timestep, int):
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if (prev_timestep >= 0) \
            else self.final_alpha_cumprod
    elif isinstance(prev_timestep, torch.Tensor):
        # alpha_prod_t_prev.shape = (batch_size, )
        alpha_prod_t_prev = [
            self.alphas_cumprod[prev_timestep[i]] if (prev_timestep[i] >= 0) \
                else self.final_alpha_cumprod \
                    for i in range(len(prev_timestep))
        ]
        alpha_prod_t_prev = torch.stack(alpha_prod_t_prev)
    else:
        raise ValueError(
            f"Unsupported type of `prev_timestep`, got `{type(prev_timestep)}`. "
        )

    beta_prod_t = 1 - alpha_prod_t
    beta_prod_t_prev = 1 - alpha_prod_t_prev

    variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

    return variance


def step(
    self, 

    model_output: torch.Tensor,
    timestep: Union[int, torch.Tensor],
    sample: torch.Tensor,
    # eta: float = 0.0,
    eta: Union[float, torch.Tensor] = 0.0,
    use_clipped_model_output: bool = False,
    generator=None,
    variance_noise: Optional[torch.Tensor] = None,
    return_dict: bool = True, 

    prev_timestep: Optional[Union[int, torch.Tensor]] = None
) -> Union["DDIMSchedulerOutput", Tuple]:
    """
    Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
    process from the learned model outputs (most often the predicted noise).

    Args:
        model_output (`torch.Tensor`):
            The direct output from learned diffusion model.
        timestep (`float`):
            The current discrete timestep in the diffusion chain.
        sample (`torch.Tensor`):
            A current instance of a sample created by the diffusion process.
        eta (`float`):
            The weight of noise for added noise in diffusion step.
        use_clipped_model_output (`bool`, defaults to `False`):
            If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
            because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
            clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
            `use_clipped_model_output` has no effect.
        generator (`torch.Generator`, *optional*):
            A random number generator.
        variance_noise (`torch.Tensor`):
            Alternative to generating noise with `generator` by directly providing the noise for the variance
            itself. Useful for methods such as [`CycleDiffusion`].
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.

    Returns:
        [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
            If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
            tuple is returned where the first element is the sample tensor.

    """
    if self.num_inference_steps is None:
        raise ValueError(
            "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
        )
    
    # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
    # Ideally, read DDIM paper in-detail understanding

    # Notation (<variable name> -> <name in paper>
    # - pred_noise_t -> e_theta(x_t, t)
    # - pred_original_sample -> f_theta(x_t, t) or x_0
    # - std_dev_t -> sigma_t
    # - eta -> η
    # - pred_sample_direction -> "direction pointing to x_t"
    # - pred_prev_sample -> "x_t-1"

    ver = "scalar"
    if isinstance(timestep, torch.Tensor):
        ver = "tensor"
    
    batch_size = model_output.shape[0]
    
    # 1. get previous step value (=t-1)
    if prev_timestep is None:
        # prev_timestep.shape = (batch_size, )
        prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps

    timestep_cpu = timestep if (ver == "scalar") \
        else timestep.cpu()

    # 2. compute alphas, betas
    # (ver = "scalar") alpha_prod_t.shape = (, )
    # (ver = "tensor") alpha_prod_t.shape = (batch_size, )
    alpha_prod_t = self.alphas_cumprod[timestep_cpu]

    if ver == "tensor":
        # (ver = "tensor") alpha_prod_t.shape: (batch_size, ) -> (batch_size, 1, 1, 1)
        alpha_prod_t = alpha_prod_t.reshape(batch_size, 1, 1, 1)

        alpha_prod_t = alpha_prod_t.to(
            dtype = sample.dtype, 
            device = sample.device
        )

    if ver == "scalar":
        # alpha_prod_t_prev.shape = (, )
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if (prev_timestep >= 0) \
            else self.final_alpha_cumprod
    else:
        # alpha_prod_t_prev.shape = (batch_size, )
        alpha_prod_t_prev = [
            self.alphas_cumprod[prev_timestep[i]] if (prev_timestep[i] >= 0) \
                else self.final_alpha_cumprod \
                    for i in range(len(prev_timestep))
        ]
        alpha_prod_t_prev = torch.stack(alpha_prod_t_prev)

        # alpha_prod_t_prev.shape: (batch_size, ) -> (batch_size, 1, 1, 1)
        alpha_prod_t_prev = alpha_prod_t_prev.reshape(batch_size, 1, 1, 1)

        alpha_prod_t_prev = alpha_prod_t_prev.to(
            dtype = sample.dtype, 
            device = sample.device
        )

    # (ver = "scalar") beta_prod_t.shape = (, )
    # (ver = "tensor") beta_prod_t.shape = (batch_size, 1, 1, 1)
    beta_prod_t = 1 - alpha_prod_t

    # (ver = "scalar") beta_prod_t.shape = (, )
    # (ver = "tensor") sqrt_beta_prod_t.shape = (batch_size, 1, 1, 1)
    sqrt_beta_prod_t = beta_prod_t ** (0.5)

    # (ver = "scalar") sqrt_alpha_prod_t.shape = (, )
    # (ver = "tensor") sqrt_alpha_prod_t.shape = (batch_size, 1, 1, 1)
    sqrt_alpha_prod_t = alpha_prod_t ** (0.5)
    
    # 3. compute predicted original sample from predicted noise also called
    # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
    if self.config.prediction_type == "epsilon":
        pred_original_sample = (sample - sqrt_beta_prod_t * model_output) / sqrt_alpha_prod_t
        pred_epsilon = model_output

    # NB: below should be modified as above
    elif self.config.prediction_type == "sample":
        pred_original_sample = model_output
        pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
    elif self.config.prediction_type == "v_prediction":
        pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
        pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
    else:
        raise ValueError(
            f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
            " `v_prediction`"
        )

    # 4. Clip or threshold "predicted x_0"
    if self.config.thresholding:
        pred_original_sample = self._threshold_sample(pred_original_sample)
    elif self.config.clip_sample:
        pred_original_sample = pred_original_sample.clamp(
            -self.config.clip_sample_range, self.config.clip_sample_range
        )

    # (ver = "scalar") eta.shape = (, )
    if ver == "tensor":
        # (ver = "tensor") eta.shape: (batch_size, ) -> (batch_size, 1, 1, 1)
        eta = eta.reshape(batch_size, 1, 1, 1)

    # 5. compute variance: "sigma_t(η)" -> see formula (16)
    # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
    # variance.shape = (batch_size, )
    variance = self._get_variance(
        timestep = timestep, 
        prev_timestep = prev_timestep
    )

    if ver == "tensor":
        # variance.shape: (batch_size, ) -> (batch_size, 1, 1, 1)
        variance = variance.reshape(batch_size, 1, 1, 1)

        variance = variance.to(
            dtype = eta.dtype, 
            device = eta.device
        )
    
    # (ver = "scalar") std_dev_t.shape = (, )
    # (ver = "tensor") std_dev_t.shape = (batch_size, 1, 1, 1)
    std_dev_t = eta * variance ** (0.5)
    
    if use_clipped_model_output:
        # the pred_epsilon is always re-derived from the clipped x_0 in Glide
        pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
    
    # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
    pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon

    # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
    prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

    if (isinstance(eta, float) and (eta > 0)) \
        or (
            isinstance(eta, torch.Tensor) and (
                not torch.allclose(
                    eta, 
                    torch.zeros_like(eta)
                )
            )
    ):
        if variance_noise is not None and generator is not None:
            raise ValueError(
                "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
                " `variance_noise` stays `None`."
            )
        
        if variance_noise is None:
            variance_noise = randn_tensor(
                model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
            )
            
        variance = std_dev_t * variance_noise

        prev_sample = prev_sample + variance

    if not return_dict:
        return (
            prev_sample,
            pred_original_sample,
        )
    else:
        raise ValueError(
            f"Only support `return_dict = False`. "
        )

    return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)


def cal_prev_timestep(
    self, 

    timestep: Union[int, torch.Tensor]
) -> torch.Tensor:
    """
    Func:
        Compute the previous timesteps `prev_timestep`. 
        Used for stepping from `timestep` to `prev_timestep`. 

    Ret:
        `prev_timestep` (`torch.Tensor`): The list of previous timesteps. 
    """

    num_training_step = self.config.num_train_timesteps

    prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
    
    if isinstance(timestep, int):
        prev_timestep = max(prev_timestep, 0)
    elif isinstance(timestep, torch.Tensor):
        prev_timestep = torch.clamp(
            prev_timestep, 
            0, num_training_step
        )
    else:
        raise ValueError(
            f"Unsupported type of `timestep`, got `{type(timestep)}`. "
        )

    # `cal_prev_timestep()` done
    return prev_timestep


def cal_next_timestep(
    self, 

    timestep: Union[int, torch.Tensor]
) -> torch.Tensor:
    """
    Func:
        Compute the next timesteps `next_timestep`. 
        Used for stepping from `timestep` to `next_timestep`. 

    Ret:
        `next_timestep` (`torch.Tensor`): The list of next timesteps. 
    """

    num_training_step = self.config.num_train_timesteps

    next_timestep = timestep + num_training_step // self.num_inference_steps

    if isinstance(timestep, int):
        next_timestep = min(next_timestep, num_training_step)
    elif isinstance(timestep, torch.Tensor):
        prev_timestep = torch.clamp(
            prev_timestep, 
            0, num_training_step
        )
    else:
        raise ValueError(
            f"Unsupported type of `timestep`, got `{type(timestep)}`. "
        )

    # `cal_next_timestep()` done
    return next_timestep


# (discarded)
# def get_first_strictly_smaller_timestep_idx_list(
#     self, 

#     timestep_list: Union[List[int], torch.Tensor]
# ) -> torch.Tensor:
#     """
#     Func:
#         Find the first index in `self.timesteps` that the timestep is strictly smaller than 
#             the timestep in `timestep_list`. 

#     Ret:
#         `first_strictly_smaller_timestep_idx_list` (`torch.Tensor`): The list of indices. 
#             first_strictly_smaller_timestep_idx_list.shape = (num_latent, ). 
#     """

#     scheduler_timestep_list = self.timesteps
#     length = len(scheduler_timestep_list)

#     def get_first_strictly_smaller_timestep_idx(
#         timestep: int
#     ) -> int:
#         l = 0
#         r = length - 1

#         while l < r:
#             mid = (l + r) // 2
            
#             if scheduler_timestep_list[mid] < timestep:
#                 r = mid
#             else:
#                 l = mid + 1

#         # `get_first_strictly_smaller_timestep_idx()` done
#         return l

#     first_strictly_smaller_timestep_idx_list = [
#         get_first_strictly_smaller_timestep_idx(timestep) \
#             for timestep in timestep_list
#     ]

#     # first_strictly_smaller_timestep_idx_list.shape = (num_latent, )
#     first_strictly_smaller_timestep_idx_list \
#         = torch.tensor(first_strictly_smaller_timestep_idx_list)

#     # `get_first_strictly_smaller_timestep_idx_list()` done
#     return first_strictly_smaller_timestep_idx_list


def register_scheduling_ddim(
    scheduler: DDIMScheduler, 
    **kwargs
):
    """
    Func:
        Register custom method `step()`. 
    """
    
    if not isinstance(scheduler, DDIMScheduler):
        raise ValueError(
            f"Only support `DDIMScheduler`. "
        )
    
    scheduler._get_variance = MethodType(
        _get_variance, 
        scheduler
    )

    scheduler.step = MethodType(
        step, 
        scheduler
    )

    scheduler.cal_prev_timestep = MethodType(
        cal_prev_timestep, 
        scheduler
    )

    scheduler.cal_next_timestep = MethodType(
        cal_next_timestep, 
        scheduler
    )

    scheduler.get_first_strictly_smaller_timestep_idx_list = MethodType(
        get_first_strictly_smaller_timestep_idx_list, 
        scheduler
    )

    # `register_scheduling_ddim()` done
    pass
    