from torch.autograd import Function

from . import assign_score_withk_ext


class AssignScoreWithK(Function):
    r"""Perform weighted sum to generate output features according to scores.
    Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/
    scene_seg/lib/paconv_lib/src/gpu>`_.

    This is a memory-efficient CUDA implementation of assign_scores operation,
        which first transform all point feature with weight bank, then assemble
        neighbor features with `knn_idx` and perform weighted sum of `scores`.
    See the `paper <https://arxiv.org/pdf/2103.14635.pdf>`_ appendix Sec. D for
        more detailed descriptions.

    Note:
        This implementation assumes using ``neighbor`` kernel input, which is
            (point_features - center_features, point_features).
        See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/
        pointnet2/paconv.py#L128 for more details.
    """

    @staticmethod
    def forward(ctx, scores, point_features, center_features, knn_idx, aggregate="sum"):
        """Forward.

        Args:
            scores (torch.Tensor): (B, npoint, K, M), predicted scores to
                aggregate weight matrices in the weight bank.
                ``npoint`` is the number of sampled centers.
                ``K`` is the number of queried neighbors.
                ``M`` is the number of weight matrices in the weight bank.
            point_features (torch.Tensor): (B, N, M, out_dim)
                Pre-computed point features to be aggregated.
            center_features (torch.Tensor): (B, N, M, out_dim)
                Pre-computed center features to be aggregated.
            knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN.
                We assume the first idx in each row is the idx of the center.
            aggregate (str, optional): Aggregation method.
                Can be 'sum', 'avg' or 'max'. Defaults to 'sum'.

        Returns:
            torch.Tensor: (B, out_dim, npoint, K), the aggregated features.
        """
        agg = {"sum": 0, "avg": 1, "max": 2}

        B, N, M, out_dim = point_features.size()
        _, npoint, K, _ = scores.size()

        output = point_features.new_zeros((B, out_dim, npoint, K))
        assign_score_withk_ext.assign_score_withk_forward_wrapper(
            B,
            N,
            npoint,
            M,
            K,
            out_dim,
            agg[aggregate],
            point_features.contiguous(),
            center_features.contiguous(),
            scores.contiguous(),
            knn_idx.contiguous(),
            output,
        )

        ctx.save_for_backward(output, point_features, center_features, scores, knn_idx)
        ctx.agg = agg[aggregate]

        return output

    @staticmethod
    def backward(ctx, grad_out):
        """Backward.

        Args:
            grad_out (torch.Tensor): (B, out_dim, npoint, K)

        Returns:
            grad_scores (torch.Tensor): (B, npoint, K, M)
            grad_point_features (torch.Tensor): (B, N, M, out_dim)
            grad_center_features (torch.Tensor): (B, N, M, out_dim)
        """
        _, point_features, center_features, scores, knn_idx = ctx.saved_tensors

        agg = ctx.agg

        B, N, M, out_dim = point_features.size()
        _, npoint, K, _ = scores.size()

        grad_point_features = point_features.new_zeros(point_features.shape)
        grad_center_features = center_features.new_zeros(center_features.shape)
        grad_scores = scores.new_zeros(scores.shape)

        assign_score_withk_ext.assign_score_withk_backward_wrapper(
            B,
            N,
            npoint,
            M,
            K,
            out_dim,
            agg,
            grad_out.contiguous(),
            point_features.contiguous(),
            center_features.contiguous(),
            scores.contiguous(),
            knn_idx.contiguous(),
            grad_point_features,
            grad_center_features,
            grad_scores,
        )

        return grad_scores, grad_point_features, grad_center_features, None, None


assign_score_withk = AssignScoreWithK.apply
