import torch
from einops import rearrange


image = torch.randn(3, 5, 2, 2) # b f ...
index = torch.tensor(
    [
        [1, 2],
        [0, 1],
        [0, 2],
    ],
    dtype=torch.long
)
def batch_gather(data, dim, index):
    expand_shape = data.shape[2:]
    while len(index.shape) < len(data.shape):
        index = index[..., None]
    index = index.expand(*index.shape[:2], *expand_shape)
    return data.gather(dim, index)
print(image)
print(index)
print(image.shape)