import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn

from .simple_model_utils import FeedForward, BasePointModel


class SimplePointModel(BasePointModel):
    """
    A simple model that processes a point cloud by applying a series of MLPs to each point
    individually, along with some pooled global features.
    """

    def get_layers(self):
        return nn.ModuleList(
            [
                FeedForward(
                    d_in=(3 * self.dim),
                    d_hidden=(4 * self.dim),
                    d_out=self.dim,
                    activation=nn.SiLU(),
                    is_gated=True,
                    bias1=False,
                    bias2=False,
                    bias_gate=False,
                    use_layernorm=True,
                )
                for _ in range(self.num_layers)
            ]
        )

    def forward(self, inputs: torch.Tensor, t: torch.Tensor):

        # Prepare inputs
        x, coords = self.prepare_inputs(inputs, t)

        # Model
        for layer in self.layers:
            x_pool_max, x_pool_std = self.get_global_tensors(x)
            x_input = torch.cat((x, x_pool_max, x_pool_std), dim=-1)  # (B, N, 3 * D)
            x = x + layer(x_input)  # (B, N, D_model)

        # Project
        x = self.output_projection(x)  # (B, N, D_out)
        x = torch.transpose(x, -2, -1)  # -> (B, D_out, N)

        return x


class SimpleNearestNeighborsPointModel(BasePointModel):
    """
    A simple model that processes a point cloud by applying a series of MLPs to each point
    individually, along with some pooled global features, and the features of its nearest
    neighbors.
    """

    def __init__(self, num_neighbors: int = 4, **kwargs):
        self.num_neighbors = num_neighbors
        super().__init__(**kwargs)
        from pytorch3d.ops import knn_points

        self.knn_points = knn_points

    def get_layers(self):
        return nn.ModuleList(
            [
                FeedForward(
                    d_in=((3 + self.num_neighbors) * self.dim),
                    d_hidden=(4 * self.dim),
                    d_out=self.dim,
                    activation=nn.SiLU(),
                    is_gated=True,
                    bias1=False,
                    bias2=False,
                    bias_gate=False,
                    use_layernorm=True,
                )
                for _ in range(self.num_layers)
            ]
        )

    def forward(self, inputs: torch.Tensor, t: torch.Tensor):

        # Prepare inputs
        x, coords = self.prepare_inputs(inputs, t)  # (B, N, D), (B, N, 3)

        # Get nearest neighbors. Note that the first neighbor is the identity, which is convenient
        _dists, indices, _neighbors = self.knn_points(
            p1=coords, p2=coords, K=(self.num_neighbors + 1), return_nn=False
        )  # (B, N, K), (B, N, K)
        (B, N, D), (_B, _N, K) = x.shape, indices.shape

        # Model
        for layer in self.layers:
            x_neighbor = torch.stack(
                [x_i[idx] for x_i, idx in zip(x, indices.reshape(B, N * K))]
            ).reshape(B, N, K * D)
            x_pool_max, x_pool_std = self.get_global_tensors(x)
            x_input = torch.cat(
                (x_neighbor, x_pool_max, x_pool_std), dim=-1
            )  # (B, N, (3+K)*D)
            x = x + layer(x_input)  # (B, N, D_model)

        # Project
        x = self.output_projection(x)  # (B, N, D_out)
        x = torch.transpose(x, -2, -1)  # -> (B, D_out, N)

        return x
