from typing import List

import torch
from torch import Tensor
from torch import nn as nn

from mmcv.runner import force_fp32
from .furthest_point_sample import (furthest_point_sample,
                                    furthest_point_sample_with_dist)


def calc_square_dist(point_feat_a: Tensor,
                     point_feat_b: Tensor,
                     norm: bool = True) -> Tensor:
    """Calculating square distance between a and b.

    Args:
        point_feat_a (torch.Tensor): (B, N, C) Feature vector of each point.
        point_feat_b (torch.Tensor): (B, M, C) Feature vector of each point.
        norm (bool, optional): Whether to normalize the distance.
            Default: True.

    Returns:
        torch.Tensor: (B, N, M) Square distance between each point pair.
    """
    num_channel = point_feat_a.shape[-1]
    dist = torch.cdist(point_feat_a, point_feat_b)
    if norm:
        dist = dist / num_channel
    else:
        dist = torch.square(dist)
    return dist


def get_sampler_cls(sampler_type: str) -> nn.Module:
    """Get the type and mode of points sampler.

    Args:
        sampler_type (str): The type of points sampler.
            The valid value are "D-FPS", "F-FPS", or "FS".

    Returns:
        class: Points sampler type.
    """
    sampler_mappings = {
        'D-FPS': DFPSSampler,
        'F-FPS': FFPSSampler,
        'FS': FSSampler,
    }
    try:
        return sampler_mappings[sampler_type]
    except KeyError:
        raise KeyError(
            f'Supported `sampler_type` are {sampler_mappings.keys()}, but got \
                {sampler_type}')


class PointsSampler(nn.Module):
    """Points sampling.

    Args:
        num_point (list[int]): Number of sample points.
        fps_mod_list (list[str], optional): Type of FPS method, valid mod
            ['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS'].
            F-FPS: using feature distances for FPS.
            D-FPS: using Euclidean distances of points for FPS.
            FS: using F-FPS and D-FPS simultaneously.
        fps_sample_range_list (list[int], optional):
            Range of points to apply FPS. Default: [-1].
    """

    def __init__(self,
                 num_point: List[int],
                 fps_mod_list: List[str] = ['D-FPS'],
                 fps_sample_range_list: List[int] = [-1]) -> None:
        super().__init__()
        # FPS would be applied to different fps_mod in the list,
        # so the length of the num_point should be equal to
        # fps_mod_list and fps_sample_range_list.
        assert len(num_point) == len(fps_mod_list) == len(
            fps_sample_range_list)
        self.num_point = num_point
        self.fps_sample_range_list = fps_sample_range_list
        self.samplers = nn.ModuleList()
        for fps_mod in fps_mod_list:
            self.samplers.append(get_sampler_cls(fps_mod)())
        self.fp16_enabled = False

    @force_fp32()
    def forward(self, points_xyz: Tensor, features: Tensor) -> Tensor:
        """
        Args:
            points_xyz (torch.Tensor): (B, N, 3) xyz coordinates of
                the points.
            features (torch.Tensor): (B, C, N) features of the points.

        Returns:
            torch.Tensor: (B, npoint, sample_num) Indices of sampled points.
        """
        indices = []
        last_fps_end_index = 0
        for fps_sample_range, sampler, npoint in zip(
                self.fps_sample_range_list, self.samplers, self.num_point):
            assert fps_sample_range < points_xyz.shape[1]

            if fps_sample_range == -1:
                sample_points_xyz = points_xyz[:, last_fps_end_index:]
                if features is not None:
                    sample_features = features[:, :, last_fps_end_index:]
                else:
                    sample_features = None
            else:
                sample_points_xyz = points_xyz[:, last_fps_end_index:
                                               fps_sample_range]
                if features is not None:
                    sample_features = features[:, :, last_fps_end_index:
                                               fps_sample_range]
                else:
                    sample_features = None

            fps_idx = sampler(sample_points_xyz.contiguous(), sample_features,
                              npoint)

            indices.append(fps_idx + last_fps_end_index)
            last_fps_end_index = fps_sample_range
        indices = torch.cat(indices, dim=1)

        return indices


class DFPSSampler(nn.Module):
    """Using Euclidean distances of points for FPS."""

    def __init__(self) -> None:
        super().__init__()

    def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
        """Sampling points with D-FPS."""
        fps_idx = furthest_point_sample(points.contiguous(), npoint)
        return fps_idx


class FFPSSampler(nn.Module):
    """Using feature distances for FPS."""

    def __init__(self) -> None:
        super().__init__()

    def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
        """Sampling points with F-FPS."""
        assert features is not None, \
            'feature input to FFPS_Sampler should not be None'
        features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2)
        features_dist = calc_square_dist(
            features_for_fps, features_for_fps, norm=False)
        fps_idx = furthest_point_sample_with_dist(features_dist, npoint)
        return fps_idx


class FSSampler(nn.Module):
    """Using F-FPS and D-FPS simultaneously."""

    def __init__(self) -> None:
        super().__init__()

    def forward(self, points: Tensor, features: Tensor, npoint: int) -> Tensor:
        """Sampling points with FS_Sampling."""
        assert features is not None, \
            'feature input to FS_Sampler should not be None'
        ffps_sampler = FFPSSampler()
        dfps_sampler = DFPSSampler()
        fps_idx_ffps = ffps_sampler(points, features, npoint)
        fps_idx_dfps = dfps_sampler(points, features, npoint)
        fps_idx = torch.cat([fps_idx_ffps, fps_idx_dfps], dim=1)
        return fps_idx
