from flax import nnx
from jax import numpy as jnp

from offline.diffusion.modules.base import DDPM


class DDIM(DDPM):
    def __init__(
        self,
        beta_schedule: str,
        clip_sample: bool,
        diffusion_steps: int,
        eta: float,
        inference_steps: int,
        temperature: float,
        timestep_spacing: str,
        **kwargs,
    ):
        super().__init__(
            beta_schedule=beta_schedule,
            clip_sample=clip_sample,
            diffusion_steps=diffusion_steps,
            temperature=temperature,
            **kwargs,
        )
        self.eta = eta
        self.set_ddim_timesteps(
            inference_steps=inference_steps, timestep_spacing=timestep_spacing
        )

    def set_ddim_timesteps(self, inference_steps: int, timestep_spacing: str):
        self.ddim_inference_steps = inference_steps

        # "linspace", "leading", "trailing" corresponds to annotation of Table 2
        # of Lin et al. (2023) https://arxiv.org/abs/2305.08891
        if timestep_spacing == "linspace":
            timesteps = jnp.round(
                jnp.linspace(0, self.diffusion_steps - 1, inference_steps)
            )
            timesteps = timesteps[::-1].astype(jnp.int32)
        elif timestep_spacing == "leading":
            step_ratio = self.diffusion_steps // inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when inference_steps is power of 3
            timesteps = jnp.round(jnp.arange(0, inference_steps) * step_ratio)
            timesteps = timesteps[::-1].astype(jnp.int32)
        elif timestep_spacing == "trailing":
            step_ratio_ = self.diffusion_steps / inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when inference_steps is power of 3
            timesteps = jnp.round(
                jnp.arange(self.diffusion_steps, 0, -step_ratio_)
            )
            timesteps = timesteps.astype(jnp.int32) - 1
        else:
            raise ValueError(
                f"{timestep_spacing} is not supported. Please make sure to "
                "choose one of 'linspace', 'leading', or 'trailing'"
            )

        self.ddim_timesteps = nnx.Variable(timesteps)

        timesteps = self.ddim_timesteps + 1
        prev_timesteps = timesteps - self.diffusion_steps // inference_steps
        prev_timesteps = jnp.clip(prev_timesteps, min=0)

        self.ddim_alphas_cumprod = nnx.Variable(self.alphas_cumprod[timesteps])
        self.ddim_alphas_cumprod_prev = nnx.Variable(
            self.alphas_cumprod[prev_timesteps]
        )
