import torch
import matplotlib.pyplot as plt
from torchrbf import RBFInterpolator

y = torch.rand(100, 2) # Data coordinates
d = torch.rand(100, 3) # Data vectors at each point

interpolator = RBFInterpolator(y, d, smoothing=1.0, kernel='thin_plate_spline')

# Query coordinates (100x100 grid of points)
x = torch.linspace(0, 1, 100)
y = torch.linspace(0, 1, 100)
grid_points = torch.meshgrid(x, y, indexing='ij')
grid_points = torch.stack(grid_points, dim=-1).reshape(-1, 2)

# Query RBF on grid points
interp_vals = interpolator(grid_points)

# Plot the interpolated values in 2D
plt.scatter(grid_points[:, 0], grid_points[:, 1], c=interp_vals[:, 0])
plt.title('Interpolated values in 2D')
plt.show()