from torch.nn import Module
import pytorch3d.ops
import torch

from .PSA import MPGM_PSA



class MPGM(Module):
    def __init__(self, cfg):
        super(MPGM, self).__init__()
        self.embedding_dim = cfg.embedding_dim
        self.k = cfg.k
        self.MPGM_PSA = MPGM_PSA(self.embedding_dim+3, self.embedding_dim // 16, self.embedding_dim // 4, self.embedding_dim, 8, self.k)

    def get_MP_idx(self, position_MP, position):
        B, N, _ = position.shape
        K = self.k

        _, idx, _ = pytorch3d.ops.knn_points(
            position_MP, position, K=K + 1
        )

        self_idx = torch.arange(N, device=idx.device).view(1, N, 1)
        same_mask_full = (idx == self_idx)
        has_same = same_mask_full.any(dim=-1, keepdim=True)

        idx_flat = idx.view(B * N, K + 1)
        has_same_flat = has_same.view(B * N)

        all_indices = torch.arange(B * N, device=idx.device)

        idx_with_self = idx_flat[has_same_flat]
        idx_without_self = idx_flat[~has_same_flat]

        index_with_self = all_indices[has_same_flat]
        index_without_self = all_indices[~has_same_flat]

        if idx_with_self.numel() > 0:
            self_idx_flat = (
                torch.arange(N, device=idx.device)
                .repeat(B)[has_same_flat]
                .unsqueeze(1)
            )

            same_mask_s1 = (idx_with_self == self_idx_flat)

            idx_s1 = idx_with_self[~same_mask_s1].view(-1, K)
        else:
            idx_s1 = idx_with_self.new_empty((0, K))

        if idx_without_self.numel() > 0:
            idx_s2 = idx_without_self[:, :K]
        else:
            idx_s2 = idx_without_self.new_empty((0, K))

        idx_flat_out = idx_flat.new_empty((B * N, K))

        if idx_s1.numel() > 0:
            idx_flat_out[index_with_self] = idx_s1
        if idx_s2.numel() > 0:
            idx_flat_out[index_without_self] = idx_s2

        idx_final = idx_flat_out.view(B, N, K)
        return idx_final


    def forward(self, feature_up, position, position_MP):

        idx_MP = self.get_MP_idx(position_MP, position)

        position_MP_x_feature_pos = torch.cat([feature_up.permute(0, 2, 1).unsqueeze(-2), position_MP.permute(0, 2, 1).unsqueeze(-2)], dim=1)

        position_MP_nn_feature_pos = torch.cat([pytorch3d.ops.knn_gather(feature_up, idx_MP).permute(0, 3, 2, 1),
                                                pytorch3d.ops.knn_gather(position, idx_MP).permute(0, 3, 2, 1)], dim=1)

        feat = self.MPGM_PSA(position_MP_x_feature_pos, position_MP_nn_feature_pos).squeeze(-2).permute(0, 2, 1)

        return feat


if __name__ == '__main__':
    class cfg_MPGM:
        embedding_dim = 256
        k = 32
    net = MPGM(cfg_MPGM)

    fu = torch.randn(16, 1024, 256)
    pos = torch.randn(16, 1024, 3)
    pos_mp = torch.randn(16, 1024, 3)

    out = net(fu, pos, pos_mp)

    print(out.shape)
