from jax.example_libraries.optimizers import Schedule, make_schedule, optimizer
import jax.numpy as jnp


@optimizer
def centered_rmsprop(lr: Schedule, momentum: float = 0.95, rho: float = 0.95, epsilon: float = 0.01):
    lr = make_schedule(lr)
    assert 0.0 <= momentum < 1.0
    assert 0.0 <= rho < 1.0
    assert 0.0 < epsilon

    def init(x0):
        m0 = jnp.zeros_like(x0)
        v0 = jnp.zeros_like(x0)
        mom0 = jnp.zeros_like(x0)
        return x0, m0, v0, mom0

    def update(i, g, state):
        x, m, v, mom = state

        m = rho * m + (1-rho) * g
        v = rho * v + (1-rho) * jnp.square(g)
        variance = v - jnp.square(m)

        step = lr(i) * g / jnp.sqrt(variance + epsilon)
        mom = momentum * mom + step
        x -= mom

        return x, m, v, mom

    def get_params(state):
        return state[0]

    return init, update, get_params
