import matplotlib.pyplot as plt
import jax.numpy as jnp
import sklearn as sk
import sklearn.neighbors

def funcplot2d(grid, vals, res=100):
    K = sk.neighbors.KNeighborsRegressor(n_neighbors=8, weights='distance')
    K.fit(grid, vals)
    gridpts = jnp.stack(jnp.meshgrid(jnp.linspace(-1, 1, 100), jnp.linspace(-1, 1, 100)), axis=-1).reshape((-1, 2))
    vals = K.predict(gridpts)
    plt.imshow(vals.reshape(res, res))