"""Schedule utilities for diffusion.

Based on https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py"""

import jax.numpy as jnp


def alpha_sigma_to_snr(alpha, sigma):
    return (alpha / sigma) ** 2


def t_to_alpha_sigma_cosine(t):
    """Returns the scaling factors for the clean image and for the noise, given
    a timestep. Uses the cosine schedule from Improved DDPM."""
    return jnp.cos(t * jnp.pi / 2), jnp.sin(t * jnp.pi / 2)


def get_cosine_schedule(t):
    alpha, sigma = t_to_alpha_sigma_cosine(t)
    return alpha, sigma, alpha_sigma_to_snr(alpha, sigma)
