import torch
import torch.nn as nn
from .classes import MLP, ModelInterface


def _maxpool(x, dim=-1, keepdim=False):
    out, _ = x.max(dim=dim, keepdim=keepdim)
    return out


class PointCompletionNetwork(ModelInterface):
    def __init__(self, encoder_1: MLP, encoder_2: MLP, decoder_coarse: MLP,
                 decoder_folding: MLP, num_coarse_points: int,
                 num_dense_points: int):
        super().__init__()
        self.encoder_1 = encoder_1
        self.encoder_2 = encoder_2
        self.decoder_coarse = decoder_coarse
        self.decoder_folding = decoder_folding
        self.num_coarse_points = num_coarse_points
        self.num_dense_points = num_dense_points

    def forward(self, points: torch.Tensor):
        batchsize, num_input_points = points.shape[:2]
        local_feature = self.encoder_1(points)
        global_feature = _maxpool(local_feature, dim=1, keepdim=True)
        global_feature = torch.repeat_interleave(global_feature,
                                                 num_input_points,
                                                 dim=1)
        encoder_2_input = torch.cat((local_feature, global_feature), dim=2)
        feature_vector = self.encoder_2(encoder_2_input)
        feature_vector = _maxpool(feature_vector, dim=1, keepdim=False)

        pred_coarse_points = self.decoder_coarse(feature_vector)
        pred_coarse_points = pred_coarse_points.reshape((batchsize, -1, 3))

        grid = torch.linspace(-0.05, 0.05, 4).to(self.get_device())
        grid_2d = torch.stack(torch.meshgrid(grid, grid)).view(
            (2, -1)).transpose(1, 0)
        grid_2d = torch.repeat_interleave(grid_2d[None, :, :],
                                          batchsize,
                                          dim=0)
        grid_2d = torch.repeat_interleave(grid_2d[:, None, :, :],
                                          self.num_coarse_points,
                                          dim=1)
        grid_2d = grid_2d.reshape((batchsize, -1, 2))

        pred_coarse_points_grid = torch.repeat_interleave(
            pred_coarse_points[:, :, None, :], 4 * 4, dim=2)
        pred_coarse_points_grid = pred_coarse_points_grid.reshape(
            (batchsize, -1, 3))

        feature_vector = torch.repeat_interleave(feature_vector[:, None, :],
                                                 self.num_dense_points,
                                                 dim=1)

        folding_input = torch.cat(
            (feature_vector, grid_2d, pred_coarse_points_grid), dim=2)

        center = pred_coarse_points_grid

        pred_dense_points = self.decoder_folding(folding_input) + center

        return pred_coarse_points, pred_dense_points
