from jax import grad, numpy as jnp, random

from numpyro import set_platform
from numpyro.contrib.einstein import kernels

if __name__ == "__main__":
    set_platform("cpu")
    seed = random.PRNGKey(0)
    x, y = 2 * random.normal(seed, (2, 5))
    particles = jnp.stack((x, y))
    kernel_fn = kernels.RBFKernel().compute(particles, None, None)
    print(kernel_fn(x, y))
    print(grad(kernel_fn)(x, y))
