import torch
from torch.autograd import Function

from pointops._C import grouping_forward_cuda, grouping_backward_cuda


class Grouping(Function):
    @staticmethod
    def forward(ctx, input, idx):
        """
        input: input: (n, c), idx : (m, nsample)
        output: (m, nsample, c)
        """
        assert input.is_contiguous() and idx.is_contiguous()
        m, nsample, n, c = idx.shape[0], idx.shape[1], input.shape[0], input.shape[1]
        output = torch.zeros((m, nsample, c), dtype=torch.float, device=input.device)
        grouping_forward_cuda(m, nsample, c, input, idx, output)
        ctx.n = n
        ctx.save_for_backward(idx)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        """
        input: grad_out: (m, c, nsample)
        output: (n, c), None
        """
        n = ctx.n
        (idx,) = ctx.saved_tensors
        m, nsample, c = grad_output.shape
        grad_input = torch.zeros((n, c), dtype=torch.float, device=idx.device)
        grouping_backward_cuda(m, nsample, c, grad_output, idx, grad_input)
        return grad_input, None


def grouping(idx, feat, xyz, new_xyz=None, with_xyz=False):
    if new_xyz is None:
        new_xyz = xyz
    assert xyz.is_contiguous() and feat.is_contiguous()
    m, nsample, c = idx.shape[0], idx.shape[1], feat.shape[1]
    xyz = torch.cat([xyz, torch.zeros([1, 3]).to(xyz.device)], dim=0)
    feat = torch.cat([feat, torch.zeros([1, c]).to(feat.device)], dim=0)
    grouped_feat = feat[idx.view(-1).long(), :].view(
        m, nsample, c
    )  # (m, num_sample, c)

    if with_xyz:
        assert new_xyz.is_contiguous()
        mask = torch.sign(idx + 1)
        grouped_xyz = xyz[idx.view(-1).long(), :].view(
            m, nsample, 3
        ) - new_xyz.unsqueeze(
            1
        )  # (m, num_sample, 3)
        grouped_xyz = torch.einsum(
            "n s c, n s -> n s c", grouped_xyz, mask
        )  # (m, num_sample, 3)
        return torch.cat((grouped_xyz, grouped_feat), -1)
    else:
        return grouped_feat


grouping2 = Grouping.apply
