from flax import nnx
from jax import Array, numpy as jnp, random as jrandom

from offline.diffusion.modules.utils import (
    cosine_beta_schedule,
    linear_beta_schedule,
    variance_preserving_beta_schedule,
)


class DDPM(nnx.Module):
    def __init__(
        self,
        beta_schedule: str,
        clip_sample: bool,
        diffusion_steps: int,
        temperature: float,
        **kwargs,
    ):
        del kwargs
        self.clip_sample = clip_sample
        self.diffusion_steps = diffusion_steps
        self.temperature = temperature

        if beta_schedule == "cosine":
            betas = cosine_beta_schedule(diffusion_steps)
        elif beta_schedule == "linear":
            betas = linear_beta_schedule(diffusion_steps)
        elif beta_schedule == "vp":
            betas = variance_preserving_beta_schedule(diffusion_steps)
        else:
            raise NotImplementedError(beta_schedule)
        self.betas = nnx.Variable(betas)

        betas = jnp.insert(betas, 0, 0)
        self.alphas_cumprod = nnx.Variable(jnp.cumprod(1 - betas))

    def add_noise(self, noise: Array, samples: Array, timesteps: Array):
        alphas_cumprod = self.alphas_cumprod[timesteps + 1]
        sqrt_alpha_prod = jnp.sqrt(alphas_cumprod)
        sqrt_one_minus_alpha_prod = jnp.sqrt(1 - alphas_cumprod)
        return sqrt_alpha_prod * samples + sqrt_one_minus_alpha_prod * noise

    def add_noise_pred_prev_sample(
        self, beta: Array, key: Array, pred_prev_sample: Array
    ):
        noise = jrandom.normal(key, pred_prev_sample.shape)
        return pred_prev_sample + self.temperature * jnp.sqrt(beta) * noise
