from jax import numpy as jnp


def cosine_beta_schedule(
    timesteps: int, max_beta: float = 0.999, offset: float = 0.008
):
    """
    Implements the noise schedule presented in
    'Improved Denoising Diffusion Probabilistic Models' by Nichol and Dhariwal
    https://arxiv.org/abs/2102.09672
    """
    t = jnp.arange(timesteps + 1) / timesteps
    ft = jnp.square(jnp.cos((t + offset) / (1 + offset) * jnp.pi * 0.5))
    alphas = jnp.insert(ft[1:] / ft[:-1], 0, 1)
    betas = 1 - alphas
    return jnp.clip(betas, None, max_beta)


def linear_beta_schedule(
    timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2
):
    return jnp.linspace(beta_start, beta_end, timesteps)


def variance_preserving_beta_schedule(
    timesteps: int, b_max: float = 10, b_min: float = 0.1
):
    t = jnp.arange(timesteps)
    alphas = jnp.exp(
        -b_min / timesteps - (b_max - b_min) * (t + 0.5) / timesteps**2
    )
    return 1 - alphas
