import torch
from torch.autograd import Function

from pointops._C import subtraction_forward_cuda, subtraction_backward_cuda


class Subtraction(Function):
    @staticmethod
    def forward(ctx, input1, input2, idx):
        """
        input: input1: (n, c), input2: (n, c), idx: (n, nsample)
        output:  (n, nsample, c)
        """
        assert input1.is_contiguous() and input2.is_contiguous()
        n, c = input1.shape
        nsample = idx.shape[-1]
        output = torch.zeros((n, nsample, c), dtype=torch.float, device=input1.device)
        subtraction_forward_cuda(n, nsample, c, input1, input2, idx, output)
        ctx.save_for_backward(idx)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        """
        input: grad_out: (n, nsample, c)
        output: grad_input1: (n, c), grad_input2: (n, c)
        """
        (idx,) = ctx.saved_tensors
        n, nsample, c = grad_output.shape
        grad_input1 = torch.zeros((n, c), dtype=torch.float, device=idx.device)
        grad_input2 = torch.zeros((n, c), dtype=torch.float, device=idx.device)

        subtraction_backward_cuda(
            n, nsample, c, idx, grad_output, grad_input1, grad_input2
        )
        return grad_input1, grad_input2, None


subtraction = Subtraction.apply
