import torch
from torch.autograd import Function

from pointops._C import aggregation_forward_cuda, aggregation_backward_cuda


class Aggregation(Function):
    @staticmethod
    def forward(ctx, input, position, weight, idx):
        """
        input: input: (n, c), position: (n, nsample, c), weight : (n, nsample, c'), idx: (n, nsample)
        output: (n, c)
        """
        assert (
            input.is_contiguous()
            and position.is_contiguous()
            and weight.is_contiguous()
        )
        n, nsample, c = position.shape
        w_c = weight.shape[-1]
        output = torch.zeros((n, c), dtype=torch.float, device=input.device)
        aggregation_forward_cuda(
            n, nsample, c, w_c, input, position, weight, idx, output
        )
        ctx.save_for_backward(input, position, weight, idx)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        """
        input: grad_out: (n, c)
        output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight : (n, nsample, c')
        """
        input, position, weight, idx = ctx.saved_tensors
        n, nsample, c = position.shape
        w_c = weight.shape[-1]
        grad_input = torch.zeros((n, c), dtype=torch.float, device=input.device)
        grad_position = torch.zeros(
            (n, nsample, c), dtype=torch.float, device=input.device
        )
        grad_weight = torch.zeros(
            (n, nsample, w_c), dtype=torch.float, device=input.device
        )
        aggregation_backward_cuda(
            n,
            nsample,
            c,
            w_c,
            input,
            position,
            weight,
            idx,
            grad_output,
            grad_input,
            grad_position,
            grad_weight,
        )
        return grad_input, grad_position, grad_weight, None


aggregation = Aggregation.apply
