# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union

import torch
from mmcv.cnn import ConvModule
from mmcv.ops import GroupAll
from mmcv.ops import PointsSampler as Points_Sampler
from mmcv.ops import QueryAndGroup, gather_points
from torch import Tensor
from torch import nn as nn
from torch.nn import functional as F

from embodiedqa.models.layers.paconv import PAConv
from embodiedqa.utils import ConfigType
from .builder import SA_MODULES


class BasePointSAModule(nn.Module):
    """Base module for point set abstraction module used in PointNets.

    Args:
        num_point (int): Number of points.
        radii (List[float]): List of radius in each ball query.
        sample_nums (List[int]): Number of samples in each ball query.
        mlp_channels (List[List[int]]): Specify of the pointnet before
            the global pooling for each scale.
        fps_mod (List[str]): Type of FPS method, valid mod
            ['F-FPS', 'D-FPS', 'FS']. Defaults to ['D-FPS'].

            - F-FPS: using feature distances for FPS.
            - D-FPS: using Euclidean distances of points for FPS.
            - FS: using F-FPS and D-FPS simultaneously.
        fps_sample_range_list (List[int]): Range of points to apply FPS.
            Defaults to [-1].
        dilated_group (bool): Whether to use dilated ball query.
            Defaults to False.
        use_xyz (bool): Whether to use xyz. Defaults to True.
        pool_mod (str): Type of pooling method. Defaults to 'max'.
        normalize_xyz (bool): Whether to normalize local XYZ with radius.
            Defaults to False.
        grouper_return_grouped_xyz (bool): Whether to return grouped xyz
            in `QueryAndGroup`. Defaults to False.
        grouper_return_grouped_idx (bool): Whether to return grouped idx
            in `QueryAndGroup`. Defaults to False.
    """

    def __init__(self,
                 num_point: int,
                 radii: List[float],
                 sample_nums: List[int],
                 mlp_channels: List[List[int]],
                 fps_mod: List[str] = ['D-FPS'],
                 fps_sample_range_list: List[int] = [-1],
                 dilated_group: bool = False,
                 use_xyz: bool = True,
                 pool_mod: str = 'max',
                 normalize_xyz: bool = False,
                 grouper_return_grouped_xyz: bool = False,
                 grouper_return_grouped_idx: bool = False) -> None:
        super(BasePointSAModule, self).__init__()

        assert len(radii) == len(sample_nums) == len(mlp_channels)
        assert pool_mod in ['max', 'avg']
        assert isinstance(fps_mod, list) or isinstance(fps_mod, tuple)
        assert isinstance(fps_sample_range_list, list) or isinstance(
            fps_sample_range_list, tuple)
        assert len(fps_mod) == len(fps_sample_range_list)

        if isinstance(mlp_channels, tuple):
            mlp_channels = list(map(list, mlp_channels))
        self.mlp_channels = mlp_channels

        if isinstance(num_point, int):
            self.num_point = [num_point]
        elif isinstance(num_point, list) or isinstance(num_point, tuple):
            self.num_point = num_point
        elif num_point is None:
            self.num_point = None
        else:
            raise NotImplementedError('Error type of num_point!')

        self.pool_mod = pool_mod
        self.groupers = nn.ModuleList()
        self.mlps = nn.ModuleList()
        self.fps_mod_list = fps_mod
        self.fps_sample_range_list = fps_sample_range_list

        if self.num_point is not None:
            self.points_sampler = Points_Sampler(self.num_point,
                                                 self.fps_mod_list,
                                                 self.fps_sample_range_list)
        else:
            self.points_sampler = None

        for i in range(len(radii)):
            radius = radii[i]
            sample_num = sample_nums[i]
            if num_point is not None:
                if dilated_group and i != 0:
                    min_radius = radii[i - 1]
                else:
                    min_radius = 0
                grouper = QueryAndGroup(
                    radius,
                    sample_num,
                    min_radius=min_radius,
                    use_xyz=use_xyz,
                    normalize_xyz=normalize_xyz,
                    return_grouped_xyz=grouper_return_grouped_xyz,
                    return_grouped_idx=grouper_return_grouped_idx)
            else:
                grouper = GroupAll(use_xyz)
            self.groupers.append(grouper)

    def _sample_points(self, points_xyz: Tensor, features: Tensor,
                       indices: Tensor, target_xyz: Tensor) -> Tuple[Tensor]:
        """Perform point sampling based on inputs.

        If `indices` is specified, directly sample corresponding points.
        Else if `target_xyz` is specified, use is as sampled points.
        Otherwise sample points using `self.points_sampler`.

        Args:
            points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
            features (Tensor): (B, C, N) Features of each point.
            indices (Tensor): (B, num_point) Index of the features.
            target_xyz (Tensor): (B, M, 3) new_xyz coordinates of the outputs.

        Returns:
            Tuple[Tensor]:

            - new_xyz: (B, num_point, 3) Sampled xyz coordinates of points.
            - indices: (B, num_point) Sampled points' index.
        """
        xyz_flipped = points_xyz.transpose(1, 2).contiguous()
        if indices is not None:
            assert (indices.shape[1] == self.num_point[0])
            new_xyz = gather_points(xyz_flipped, indices).transpose(
                1, 2).contiguous() if self.num_point is not None else None
        elif target_xyz is not None:
            new_xyz = target_xyz.contiguous()
        else:
            if self.num_point is not None:
                indices = self.points_sampler(points_xyz, features)
                new_xyz = gather_points(xyz_flipped,
                                        indices).transpose(1, 2).contiguous()
            else:
                new_xyz = None

        return new_xyz, indices

    def _pool_features(self, features: Tensor) -> Tensor:
        """Perform feature aggregation using pooling operation.

        Args:
            features (Tensor): (B, C, N, K) Features of locally grouped
                points before pooling.

        Returns:
            Tensor: (B, C, N) Pooled features aggregating local information.
        """
        if self.pool_mod == 'max':
            # (B, C, N, 1)
            new_features = F.max_pool2d(
                features, kernel_size=[1, features.size(3)])
        elif self.pool_mod == 'avg':
            # (B, C, N, 1)
            new_features = F.avg_pool2d(
                features, kernel_size=[1, features.size(3)])
        else:
            raise NotImplementedError

        return new_features.squeeze(-1).contiguous()

    def forward(
        self,
        points_xyz: Tensor,
        features: Optional[Tensor] = None,
        indices: Optional[Tensor] = None,
        target_xyz: Optional[Tensor] = None,
    ) -> Tuple[Tensor]:
        """Forward.

        Args:
            points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
            features (Tensor, optional): (B, C, N) Features of each point.
                Defaults to None.
            indices (Tensor, optional): (B, num_point) Index of the features.
                Defaults to None.
            target_xyz (Tensor, optional): (B, M, 3) New coords of the outputs.
                Defaults to None.

        Returns:
            Tuple[Tensor]:

                - new_xyz: (B, M, 3) Where M is the number of points.
                  New features xyz.
                - new_features: (B, M, sum_k(mlps[k][-1])) Where M is the
                  number of points. New feature descriptors.
                - indices: (B, M) Where M is the number of points.
                  Index of the features.
        """
        new_features_list = []

        # sample points, (B, num_point, 3), (B, num_point)
        new_xyz, indices = self._sample_points(points_xyz, features, indices,
                                               target_xyz)

        for i in range(len(self.groupers)):
            # grouped_results may contain:
            # - grouped_features: (B, C, num_point, nsample)
            # - grouped_xyz: (B, 3, num_point, nsample)
            # - grouped_idx: (B, num_point, nsample)
            grouped_results = self.groupers[i](points_xyz, new_xyz, features)

            # (B, mlp[-1], num_point, nsample)
            new_features = self.mlps[i](grouped_results)

            # this is a bit hack because PAConv outputs two values
            # we take the first one as feature
            if isinstance(self.mlps[i][0], PAConv):
                assert isinstance(new_features, tuple)
                new_features = new_features[0]

            # (B, mlp[-1], num_point)
            new_features = self._pool_features(new_features)
            new_features_list.append(new_features)

        return new_xyz, torch.cat(new_features_list, dim=1), indices


@SA_MODULES.register_module()
class PointSAModuleMSG(BasePointSAModule):
    """Point set abstraction module with multi-scale grouping (MSG) used in
    PointNets.

    Args:
        num_point (int): Number of points.
        radii (List[float]): List of radius in each ball query.
        sample_nums (List[int]): Number of samples in each ball query.
        mlp_channels (List[List[int]]): Specify of the pointnet before
            the global pooling for each scale.
        fps_mod (List[str]): Type of FPS method, valid mod
            ['F-FPS', 'D-FPS', 'FS']. Defaults to ['D-FPS'].

            - F-FPS: using feature distances for FPS.
            - D-FPS: using Euclidean distances of points for FPS.
            - FS: using F-FPS and D-FPS simultaneously.
        fps_sample_range_list (List[int]): Range of points to apply FPS.
            Defaults to [-1].
        dilated_group (bool): Whether to use dilated ball query.
            Defaults to False.
        norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
            layer. Defaults to dict(type='BN2d').
        use_xyz (bool): Whether to use xyz. Defaults to True.
        pool_mod (str): Type of pooling method. Defaults to 'max'.
        normalize_xyz (bool): Whether to normalize local XYZ with radius.
            Defaults to False.
        bias (bool or str): If specified as `auto`, it will be decided by
            `norm_cfg`. `bias` will be set as True if `norm_cfg` is None,
            otherwise False. Defaults to 'auto'.
    """

    def __init__(self,
                 num_point: int,
                 radii: List[float],
                 sample_nums: List[int],
                 mlp_channels: List[List[int]],
                 fps_mod: List[str] = ['D-FPS'],
                 fps_sample_range_list: List[int] = [-1],
                 dilated_group: bool = False,
                 norm_cfg: ConfigType = dict(type='BN2d'),
                 use_xyz: bool = True,
                 pool_mod: str = 'max',
                 normalize_xyz: bool = False,
                 bias: Union[bool, str] = 'auto') -> None:
        super(PointSAModuleMSG, self).__init__(
            num_point=num_point,
            radii=radii,
            sample_nums=sample_nums,
            mlp_channels=mlp_channels,
            fps_mod=fps_mod,
            fps_sample_range_list=fps_sample_range_list,
            dilated_group=dilated_group,
            use_xyz=use_xyz,
            pool_mod=pool_mod,
            normalize_xyz=normalize_xyz)

        for i in range(len(self.mlp_channels)):
            mlp_channel = self.mlp_channels[i]
            if use_xyz:
                mlp_channel[0] += 3

            mlp = nn.Sequential()
            for i in range(len(mlp_channel) - 1):
                mlp.add_module(
                    f'layer{i}',
                    ConvModule(
                        mlp_channel[i],
                        mlp_channel[i + 1],
                        kernel_size=(1, 1),
                        stride=(1, 1),
                        conv_cfg=dict(type='Conv2d'),
                        norm_cfg=norm_cfg,
                        bias=bias))
            self.mlps.append(mlp)


@SA_MODULES.register_module()
class PointSAModule(PointSAModuleMSG):
    """Point set abstraction module with single-scale grouping (SSG) used in
    PointNets.

    Args:
        mlp_channels (List[int]): Specify of the pointnet before
            the global pooling for each scale.
        num_point (int, optional): Number of points. Defaults to None.
        radius (float, optional): Radius to group with. Defaults to None.
        num_sample (int, optional): Number of samples in each ball query.
            Defaults to None.
        norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
            layer. Default to dict(type='BN2d').
        use_xyz (bool): Whether to use xyz. Defaults to True.
        pool_mod (str): Type of pooling method. Defaults to 'max'.
        fps_mod (List[str]): Type of FPS method, valid mod
            ['F-FPS', 'D-FPS', 'FS']. Defaults to ['D-FPS'].
        fps_sample_range_list (List[int]): Range of points to apply FPS.
            Defaults to [-1].
        normalize_xyz (bool): Whether to normalize local XYZ with radius.
            Defaults to False.
    """

    def __init__(self,
                 mlp_channels: List[int],
                 num_point: Optional[int] = None,
                 radius: Optional[float] = None,
                 num_sample: Optional[int] = None,
                 norm_cfg: ConfigType = dict(type='BN2d'),
                 use_xyz: bool = True,
                 pool_mod: str = 'max',
                 fps_mod: List[str] = ['D-FPS'],
                 fps_sample_range_list: List[int] = [-1],
                 normalize_xyz: bool = False) -> None:
        super(PointSAModule, self).__init__(
            mlp_channels=[mlp_channels],
            num_point=num_point,
            radii=[radius],
            sample_nums=[num_sample],
            norm_cfg=norm_cfg,
            use_xyz=use_xyz,
            pool_mod=pool_mod,
            fps_mod=fps_mod,
            fps_sample_range_list=fps_sample_range_list,
            normalize_xyz=normalize_xyz)
