from torch.autograd import Function

from .backend import _backend

__all__ = ["trilinear_devoxelize"]


class TrilinearDevoxelization(Function):
    @staticmethod
    def forward(ctx, features, coords, resolution, is_training=True):
        """
        :param ctx:
        :param coords: the coordinates of points, FloatTensor[B, 3, N]
        :param features: FloatTensor[B, C, R, R, R]
        :param resolution: int, the voxel resolution
        :param is_training: bool, training mode
        :return:
            FloatTensor[B, C, N]
        """
        B, C = features.shape[:2]
        features = features.contiguous().view(B, C, -1)
        coords = coords.contiguous()
        outs, inds, wgts = _backend.trilinear_devoxelize_forward(
            resolution, is_training, coords, features
        )
        if is_training:
            ctx.save_for_backward(inds, wgts)
            ctx.r = resolution
        return outs

    @staticmethod
    def backward(ctx, grad_output):
        """
        :param ctx:
        :param grad_output: gradient of outputs, FloatTensor[B, C, N]
        :return:
            gradient of inputs, FloatTensor[B, C, R, R, R]
        """
        inds, wgts = ctx.saved_tensors
        grad_inputs = _backend.trilinear_devoxelize_backward(
            grad_output.contiguous(), inds, wgts, ctx.r
        )
        return (
            grad_inputs.view(
                grad_output.size(0), grad_output.size(1), ctx.r, ctx.r, ctx.r
            ),
            None,
            None,
            None,
        )


trilinear_devoxelize = TrilinearDevoxelization.apply
