import torch
from einops import rearrange
from pkg_resources import parse_version
from mia.ops.naive_grid_sample import naive_grid_sample_1d, naive_grid_sample_2d

def grid_sample(input, coords):
    # 1D case (synthetic)
    # input shape       (B, D, L)
    # coords shape      (B, N, 1)
    # output shape      (B, D, N)
    output = naive_grid_sample_1d(input, coords)
    output = rearrange(output, 'b d n -> b n d')
    return output
