import torch
import torch.nn as nn

from . import functional as F
from .ball_query import BallQuery
from .shared_mlp import SharedMLP

__all__ = ["PointNetAModule", "PointNetSAModule", "PointNetFPModule"]


class PointNetAModule(nn.Module):
    def __init__(self, in_channels, out_channels, include_coordinates=True):
        super().__init__()
        if not isinstance(out_channels, (list, tuple)):
            out_channels = [[out_channels]]
        elif not isinstance(out_channels[0], (list, tuple)):
            out_channels = [out_channels]

        mlps = []
        total_out_channels = 0
        for _out_channels in out_channels:
            mlps.append(
                SharedMLP(
                    in_channels=in_channels + (3 if include_coordinates else 0),
                    out_channels=_out_channels,
                    dim=1,
                )
            )
            total_out_channels += _out_channels[-1]

        self.include_coordinates = include_coordinates
        self.out_channels = total_out_channels
        self.mlps = nn.ModuleList(mlps)

    def forward(self, inputs):
        features, coords = inputs
        if self.include_coordinates:
            features = torch.cat([features, coords], dim=1)
        coords = torch.zeros((coords.size(0), 3, 1), device=coords.device)
        if len(self.mlps) > 1:
            features_list = []
            for mlp in self.mlps:
                features_list.append(mlp(features).max(dim=-1, keepdim=True).values)
            return torch.cat(features_list, dim=1), coords
        else:
            return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords

    def extra_repr(self):
        return f"out_channels={self.out_channels}, include_coordinates={self.include_coordinates}"


class PointNetSAModule(nn.Module):
    def __init__(
        self,
        num_centers,
        radius,
        num_neighbors,
        in_channels,
        out_channels,
        include_coordinates=True,
    ):
        super().__init__()
        if not isinstance(radius, (list, tuple)):
            radius = [radius]
        if not isinstance(num_neighbors, (list, tuple)):
            num_neighbors = [num_neighbors] * len(radius)
        assert len(radius) == len(num_neighbors)
        if not isinstance(out_channels, (list, tuple)):
            out_channels = [[out_channels]] * len(radius)
        elif not isinstance(out_channels[0], (list, tuple)):
            out_channels = [out_channels] * len(radius)
        assert len(radius) == len(out_channels)

        groupers, mlps = [], []
        total_out_channels = 0
        for _radius, _out_channels, _num_neighbors in zip(
            radius, out_channels, num_neighbors
        ):
            groupers.append(
                BallQuery(
                    radius=_radius,
                    num_neighbors=_num_neighbors,
                    include_coordinates=include_coordinates,
                )
            )
            mlps.append(
                SharedMLP(
                    in_channels=in_channels + (3 if include_coordinates else 0),
                    out_channels=_out_channels,
                    dim=2,
                )
            )
            total_out_channels += _out_channels[-1]

        self.num_centers = num_centers
        self.out_channels = total_out_channels
        self.groupers = nn.ModuleList(groupers)
        self.mlps = nn.ModuleList(mlps)

    def forward(self, inputs):
        features, coords, temb = inputs
        centers_coords = F.furthest_point_sample(coords, self.num_centers)
        features_list = []
        for grouper, mlp in zip(self.groupers, self.mlps):
            features, temb = mlp(grouper(coords, centers_coords, temb, features))
            features_list.append(features.max(dim=-1).values)
        if len(features_list) > 1:
            return (
                features_list[0],
                centers_coords,
                temb.max(dim=-1).values if temb.shape[1] > 0 else temb,
            )
        else:
            return (
                features_list[0],
                centers_coords,
                temb.max(dim=-1).values if temb.shape[1] > 0 else temb,
            )

    def extra_repr(self):
        return f"num_centers={self.num_centers}, out_channels={self.out_channels}"


class PointNetFPModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1)

    def forward(self, inputs):
        if len(inputs) == 3:
            points_coords, centers_coords, centers_features, temb = inputs
            points_features = None
        else:
            (
                points_coords,
                centers_coords,
                centers_features,
                points_features,
                temb,
            ) = inputs
        interpolated_features = F.nearest_neighbor_interpolate(
            points_coords, centers_coords, centers_features
        )
        interpolated_temb = F.nearest_neighbor_interpolate(
            points_coords, centers_coords, temb
        )
        if points_features is not None:
            interpolated_features = torch.cat(
                [interpolated_features, points_features], dim=1
            )
        return self.mlp(interpolated_features), points_coords, interpolated_temb
