from flax import nnx
from jax import Array, lax, numpy as jnp, random as jrandom
import numpy as np

from offline.diffusion.ddim.modules.unconditional import UnconditionalDDIM
from offline.diffusion.repaint.modules import (
    RepaintDDPM,
    RepaintPolicy,
    get_schedule,
)


class RepaintDDIM(RepaintDDPM, UnconditionalDDIM):
    def repaint_ddim(self, ground_truth: Array, state: tuple[Array, int]):
        key, offset = state

        carry = self.repaint_process(
            alpha_prod=self.repaint_alphas_cumprod.value,
            alpha_prod_prev=self.repaint_alphas_cumprod_prev.value,
            backwards=self.backwards.value,
            ground_truth=ground_truth,
            key=key,
            offset=offset + jnp.arange(self.repaint_steps),
            sample=jrandom.normal(
                jrandom.fold_in(key, offset),
                ground_truth.shape[:-1] + (self.noise_predictor.sample_dim,),
            ),
            time=self.repaint_timesteps.value,
        )

        carry = carry[..., : self.sample_dim]
        return carry, (key, offset + self.repaint_steps)

    @nnx.scan(
        in_axes=(None, 0, 0, 0, None, None, 0, nnx.Carry, 0),
        out_axes=nnx.Carry,
    )
    def repaint_process(
        self,
        alpha_prod: Array,
        alpha_prod_prev: Array,
        backwards: Array,
        ground_truth: Array,
        key: Array,
        offset: Array,
        sample: Array,
        time: Array,
    ):
        sample = lax.cond(
            backwards,
            lambda: self.repaint_backward_step(
                alpha_prod=alpha_prod,
                alpha_prod_prev=alpha_prod_prev,
                ground_truth=ground_truth,
                key=key,
                offset=offset,
                sample=sample,
                time=time,
            ),
            lambda: self.repaint_undo_step(
                key=jrandom.fold_in(key, offset), sample=sample, time=time
            ),
        )
        return sample

    def repaint_backward_step(
        self,
        alpha_prod: Array,
        alpha_prod_prev: Array,
        ground_truth: Array,
        key: Array,
        offset: Array,
        sample: Array,
        time: Array,
    ):
        input_time = jnp.full(sample.shape[:-1] + (1,), time)
        pred_noise = self.noise_predictor(sample, input_time)
        pred_noise = pred_noise[..., : self.sample_dim]

        # 1. compute alphas, betas
        beta_prod = 1 - alpha_prod
        beta_prod_prev = 1 - alpha_prod_prev
        alpha = alpha_prod / alpha_prod_prev
        beta = 1 - alpha
        sqrt_alpha_prod_prev = jnp.sqrt(alpha_prod_prev)

        # 2. compute predicted original sample from predicted noise also
        # called "predicted x_0" of formula (15) from Ho et al. (2020)
        # https://arxiv.org/abs/2006.11239
        pred_original_sample = (
            sample[..., : self.sample_dim] - jnp.sqrt(beta_prod) * pred_noise
        ) / jnp.sqrt(alpha_prod)

        # 3. Clip the unknown part of "predicted x_0"
        if self.clip_sample:
            pred_original_sample = jnp.clip(pred_original_sample, -1, 1)

        # We choose to follow RePaint Algorithm 1 to get x_{t-1}, however we
        # substitute formula (7) in the algorithm coming from DDIM paper
        # (formula (4) Algorithm 2 - Sampling) with formula (12)
        # from DDIM paper. DDIM schedule gives the same results as DDPM with
        # eta = 1.0. Noise is being reused in 7. and 8., but no impact on
        # quality has been observed.

        # 5. Add noise
        noise = lax.cond(
            time > 0,
            lambda x: jrandom.normal(jrandom.fold_in(key, offset), x.shape),
            jnp.zeros_like,
            sample,
        )

        if self.eta > 0:
            variance: Array | float = (beta_prod_prev / beta_prod) * beta
            std_dev: Array | float = self.eta * jnp.sqrt(variance)
        else:
            std_dev = variance = 0

        # 6. Compute "direction pointing to x_t" of formula (12) from
        # Song et al. (2020) https://arxiv.org/abs/2010.02502
        pred_sample_direction = jnp.sqrt(beta_prod_prev - variance) * pred_noise

        # 7. Compute x_{t-1} of formula (12) from
        # Song et al. (2020) https://arxiv.org/abs/2010.02502
        prev_unknown_part = (
            sqrt_alpha_prod_prev * pred_original_sample + pred_sample_direction
        )
        if self.eta > 0:
            prev_unknown_part = lax.cond(
                time > 0,
                lambda x: x + std_dev * noise[..., : self.sample_dim],
                lambda x: x,
                prev_unknown_part,
            )

        # 8. Algorithm 1 Line 5
        # Lugmayr et al. (2022) https://arxiv.org/abs/2201.09865
        # The computation reported in Algorithm 1 Line 5 is incorrect. Line 5
        # refers to formula (8a) of the same paper, which tells to sample from
        # a Gaussian distribution with mean
        # "(alpha_prod_prev ** 0.5) * original_image" and variance
        # "(1 - alpha_prod_prev)". This means that the standard Gaussian
        # distribution "noise" should be scaled by the square root of the
        # variance (as it is done here), however Algorithm 1 Line 5 tells to
        # scale by the variance.
        prev_known_part = (
            sqrt_alpha_prod_prev * ground_truth
            + jnp.sqrt(beta_prod_prev) * noise[..., self.sample_dim :]
        )

        # 9. Algorithm 1 Line 8
        # Lugmayr et al. (2022) https://arxiv.org/abs/2201.09865
        pred_prev_sample = jnp.concatenate(
            (prev_unknown_part, prev_known_part), axis=-1
        )

        return pred_prev_sample

    def repaint_undo_step(self, key: Array, sample: Array, time: Array):
        repeat = self.diffusion_steps // self.ddim_inference_steps
        betas = self.betas[time + jnp.arange(repeat)]
        for beta, key_ in zip(betas, jrandom.split(key, repeat)):
            noise = jrandom.normal(key_, sample.shape)

            # 10. Algorithm 1 Line 10 in Lugmayr et al. (2022)
            # https://arxiv.org/abs/2201.09865
            sample = jnp.sqrt(1 - beta) * sample + jnp.sqrt(beta) * noise
        return sample

    def set_repaint_timesteps(
        self, *, jump_length: int, jump_samples: int, **kwargs
    ):
        inference_steps = kwargs["inference_steps"]

        timesteps = get_schedule(
            jump_length=jump_length,
            jump_samples=jump_samples,
            steps=inference_steps,
        )
        timesteps_last = np.insert(timesteps[:-1], 0, timesteps[0] + 1)
        self.backwards = nnx.Variable(jnp.asarray(timesteps < timesteps_last))
        self.repaint_steps = int(timesteps.size)

        # If t < t_last, we need t. Otherwise, we need t_last.
        timesteps = np.minimum(timesteps, timesteps_last)

        ratio = self.diffusion_steps // inference_steps
        timesteps = timesteps * ratio
        self.repaint_timesteps = nnx.Variable(jnp.asarray(timesteps))

        timesteps = timesteps + 1
        prev_timesteps = np.clip(timesteps - ratio, a_max=None, a_min=0)
        self.repaint_alphas_cumprod = nnx.Variable(
            self.alphas_cumprod[timesteps]
        )
        self.repaint_alphas_cumprod_prev = nnx.Variable(
            self.alphas_cumprod[prev_timesteps]
        )


class RepaintDDIMPolicy(RepaintPolicy):
    def __init__(  # pylint: disable=super-init-not-called
        self,
        action_dim: int,
        beta_schedule: str,
        clip_sample: bool,
        diffusion_steps: int,
        eta: float,
        inference_steps: int,
        jump_length: int,
        jump_samples: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        temperature: float,
        time_dim: int,
        timestep_spacing: str,
        **kwargs
    ):
        self.diffusion = RepaintDDIM(
            beta_schedule=beta_schedule,
            clip_sample=clip_sample,
            diffusion_steps=diffusion_steps,
            eta=eta,
            ground_truth_dim=observation_dim,
            inference_steps=inference_steps,
            jump_length=jump_length,
            jump_samples=jump_samples,
            rngs=rngs,
            sample_dim=action_dim,
            temperature=temperature,
            time_dim=time_dim,
            timestep_spacing=timestep_spacing,
            **kwargs
        )

    def __call__(self, observations: Array, state: tuple[Array, int]):
        actions, state = self.diffusion.repaint(observations, state)
        return actions, state, {}
