# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn

from mmdet3d.ops import DGCNNFAModule, DGCNNGFModule
from ..builder import BACKBONES


@BACKBONES.register_module()
class DGCNNBackbone(BaseModule):
    """Backbone network for DGCNN.

    Args:
        in_channels (int): Input channels of point cloud.
        num_samples (tuple[int], optional): The number of samples for knn or
            ball query in each graph feature (GF) module.
            Defaults to (20, 20, 20).
        knn_modes (tuple[str], optional): Mode of KNN of each knn module.
            Defaults to ('D-KNN', 'F-KNN', 'F-KNN').
        radius (tuple[float], optional): Sampling radii of each GF module.
            Defaults to (None, None, None).
        gf_channels (tuple[tuple[int]], optional): Out channels of each mlp in
            GF module. Defaults to ((64, 64), (64, 64), (64, )).
        fa_channels (tuple[int], optional): Out channels of each mlp in FA
            module. Defaults to (1024, ).
        act_cfg (dict, optional): Config of activation layer.
            Defaults to dict(type='ReLU').
        init_cfg (dict, optional): Initialization config.
            Defaults to None.
    """

    def __init__(self,
                 in_channels,
                 num_samples=(20, 20, 20),
                 knn_modes=('D-KNN', 'F-KNN', 'F-KNN'),
                 radius=(None, None, None),
                 gf_channels=((64, 64), (64, 64), (64, )),
                 fa_channels=(1024, ),
                 act_cfg=dict(type='ReLU'),
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.num_gf = len(gf_channels)

        assert len(num_samples) == len(knn_modes) == len(radius) == len(
            gf_channels), 'Num_samples, knn_modes, radius and gf_channels \
            should have the same length.'

        self.GF_modules = nn.ModuleList()
        gf_in_channel = in_channels * 2
        skip_channel_list = [gf_in_channel]  # input channel list

        for gf_index in range(self.num_gf):
            cur_gf_mlps = list(gf_channels[gf_index])
            cur_gf_mlps = [gf_in_channel] + cur_gf_mlps
            gf_out_channel = cur_gf_mlps[-1]

            self.GF_modules.append(
                DGCNNGFModule(
                    mlp_channels=cur_gf_mlps,
                    num_sample=num_samples[gf_index],
                    knn_mode=knn_modes[gf_index],
                    radius=radius[gf_index],
                    act_cfg=act_cfg))
            skip_channel_list.append(gf_out_channel)
            gf_in_channel = gf_out_channel * 2

        fa_in_channel = sum(skip_channel_list[1:])
        cur_fa_mlps = list(fa_channels)
        cur_fa_mlps = [fa_in_channel] + cur_fa_mlps

        self.FA_module = DGCNNFAModule(
            mlp_channels=cur_fa_mlps, act_cfg=act_cfg)

    @auto_fp16(apply_to=('points', ))
    def forward(self, points):
        """Forward pass.

        Args:
            points (torch.Tensor): point coordinates with features,
                with shape (B, N, in_channels).

        Returns:
            dict[str, list[torch.Tensor]]: Outputs after graph feature (GF) and
                feature aggregation (FA) modules.

                - gf_points (list[torch.Tensor]): Outputs after each GF module.
                - fa_points (torch.Tensor): Outputs after FA module.
        """
        gf_points = [points]

        for i in range(self.num_gf):
            cur_points = self.GF_modules[i](gf_points[i])
            gf_points.append(cur_points)

        fa_points = self.FA_module(gf_points)

        out = dict(gf_points=gf_points, fa_points=fa_points)
        return out
