from flax import nnx
from jax import Array, numpy as jnp, random as jrandom

from offline.diffusion.ddim.modules.base import DDIM
from offline.diffusion.modules.unconditional import UnconditionalDDPM


class UnconditionalDDIM(DDIM, UnconditionalDDPM):
    def ddim(self, batch_shape: tuple[int, ...], state: tuple[Array, int]):
        key, offset = state

        carry = self.ddim_backward_process(
            alpha_prod=self.ddim_alphas_cumprod.value,
            alpha_prod_prev=self.ddim_alphas_cumprod_prev.value,
            key=key,
            offset=offset,
            sample=jrandom.normal(
                jrandom.fold_in(key, offset),
                batch_shape + (self.noise_predictor.sample_dim,),
            ),
            time=self.ddim_timesteps.value,
        )
        return carry, (key, offset + self.diffusion_steps)

    @nnx.scan(
        in_axes=(None, 0, 0, None, None, nnx.Carry, 0), out_axes=nnx.Carry
    )
    def ddim_backward_process(
        self,
        alpha_prod: Array,
        alpha_prod_prev: Array,
        key: Array,
        offset: int,
        sample: Array,
        time: Array,
    ):
        # See formulas (12) and (16) of DDIM paper Song et al. (2020)
        # https://arxiv.org/abs/2010.02502
        # Ideally, read DDIM paper for detailed understanding.

        # Notation (<variable name> -> <name in paper>)
        # - pred_noise -> ϵ_theta(x_t, t)
        # - pred_original_sample -> f_theta(x_t, t) or x_0
        # - std_dev -> σ_t
        # - eta -> η
        # - pred_sample_direction -> "direction pointing to x_t"
        # - pred_prev_sample -> "x_{t-1}"

        input_time = jnp.full(sample.shape[:-1] + (1,), time)
        pred_noise = self.noise_predictor(sample, input_time)

        # 1. Compute alphas, betas
        beta_prod = 1 - alpha_prod
        beta_prod_prev = 1 - alpha_prod_prev
        alpha = alpha_prod / alpha_prod_prev
        beta = 1 - alpha

        # 2. Compute predicted original sample from predicted noise also
        # called "predicted x_0" of formula (15) from Ho et al. (2020)
        # https://arxiv.org/abs/2006.11239
        pred_original_sample = (
            sample - jnp.sqrt(beta_prod) * pred_noise
        ) / jnp.sqrt(alpha_prod)

        # 3. Clip "predicted x_0"
        if self.clip_sample:
            pred_original_sample = jnp.clip(pred_original_sample, -1, 1)

        # 4. Compute variance: "σ_t(η)" -> See formula (16) from
        # Song et al. (2020) https://arxiv.org/abs/2010.02502
        if self.eta > 0:
            variance: Array | float = (beta_prod_prev / beta_prod) * beta
            std_dev: Array | float = self.eta * jnp.sqrt(variance)
        else:
            std_dev = variance = 0

        # 5. Compute "direction pointing to x_t" of formula (12) from
        # Song et al. (2020) https://arxiv.org/abs/2010.02502
        pred_sample_direction = jnp.sqrt(beta_prod_prev - variance) * pred_noise

        # 6. Compute x_t without "random noise" of formula (12) from
        # Song et al. (2020) https://arxiv.org/abs/2010.02502
        pred_prev_sample = (
            jnp.sqrt(alpha_prod_prev) * pred_original_sample
            + pred_sample_direction
        )

        # 7. Add noise
        if self.eta > 0:
            noise = jrandom.normal(
                jrandom.fold_in(key, time + offset),
                shape=pred_prev_sample.shape,
            )
            pred_prev_sample = pred_prev_sample + std_dev * noise

        return pred_prev_sample
