
import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, points, idx):
        """
        points: (B, C, N)
        idx: (B, M) int32
        Returns:
            out: (B, C, M)
        """
        # torch.gather requires index to have same dim as input except gathered dim
        # input: (B, C, N). Gather on dim 2.
        # index needs to be (B, C, M).
        
        # Expand idx (B, M) -> (B, 1, M) -> (B, C, M)
        C = points.shape[1]
        idx_expanded = idx.unsqueeze(1).expand(-1, C, -1).long()
        
        return torch.gather(points, 2, idx_expanded)

def get_init_inputs():
    return []

def get_inputs():
    B, C, N, M = 2, 32, 128, 64
    points = torch.randn(B, C, N)
    idx = torch.randint(0, N, (B, M)).int()
    return [points, idx]
