import numpy as np
import jax
import jax.numpy as jnp


def gaussian_kernel(x1, x2, width):
    x1 = x1.reshape(x2.shape)
    return jnp.exp(jnp.sum(-(x1 - x2) ** 2 / (2 * width ** 2)))


def sum_of_gaussian_kernels(x1, x2, widths):
    x1 = x1.reshape(x2.shape)
    total_k = 0.0
    for width in widths:
        total_k += jnp.exp(jnp.sum(-(x1 - x2) ** 2 / (2 * width ** 2))) / float(len(widths))
    return total_k
