import torch

try:
    from . import ingroup_inds_cuda
    # import ingroup_indices
except ImportError:
    ingroup_indices = None
    print('Can not import ingroup indices')

ingroup_indices = ingroup_inds_cuda

from torch.autograd import Function
class IngroupIndicesFunction(Function):

    @staticmethod
    def forward(ctx, group_inds):

        out_inds = torch.zeros_like(group_inds) - 1

        ingroup_indices.forward(group_inds, out_inds)

        ctx.mark_non_differentiable(out_inds)

        return out_inds

    @staticmethod
    def backward(ctx, g):

        return None

ingroup_inds = IngroupIndicesFunction.apply