import torch
import torch.nn.functional as F
from torch_geometric.utils import to_dense_batch
import torch.nn as nn
from torch.nn import (
    Sequential as Seq,
    Linear as Lin,
    ReLU,
    BatchNorm1d as BN,
    init,
    Dropout,
)



def stnknn(x, k):
    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)

    idx = pairwise_distance.topk(k=k, dim=-1)[1]  # (batch_size, num_points, k)
    return idx


def get_graph_feature(x, k=20, idx=None, dim9=False):
    batch_size = x.size(0)
    num_points = x.size(2)
    num_channels = x.size(1)
    if num_channels == 9:
        dim9 = True
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        if dim9 == False:
            idx = stnknn(x, k=k)  # (batch_size, num_points, k)
        else:
            idx = stnknn(x[:, 6:], k=k)
    device = torch.device("cuda")

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points

    idx = idx + idx_base

    idx = idx.view(-1)

    _, num_dims, _ = x.size()

    x = x.transpose(
        2, 1
    ).contiguous()  # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size * num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)

    feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2)

    return feature  # (batch_size, 2*num_dims, num_points, k)


class Transform_Net(nn.Module):
    def __init__(self, stn_k, args=None):
        super(Transform_Net, self).__init__()
        self.args = args
        self.k = stn_k

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm1d(1024)

        self.conv1 = nn.Sequential(
            nn.Conv2d(
                6, 64, kernel_size=1, bias=False
            ),  # nn.Conv2d(6, 64, kernel_size=1, bias=False),
            self.bn1,
            nn.LeakyReLU(negative_slope=0.2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=1, bias=False),
            # nn.Conv2d(64, 128, kernel_size=1, bias=False),
            self.bn2,
            nn.LeakyReLU(negative_slope=0.2),
        )
        self.conv3 = nn.Sequential(
            nn.Conv1d(128, 1024, kernel_size=1, bias=False),
            self.bn3,
            nn.LeakyReLU(negative_slope=0.2),
        )

        self.linear1 = nn.Linear(
            1024, 512, bias=False
        )  # nn.Linear(1024, 512, bias=False)
        self.bn3 = nn.BatchNorm1d(512)
        self.linear2 = nn.Linear(
            512, 256, bias=False
        )  # nn.Linear(512, 256, bias=False)
        self.bn4 = nn.BatchNorm1d(256)

        # self.transform_fc = nn.Linear(256, 3 * 3)
        self.transform_fc = nn.Linear(256, 3 * 3)

        init.constant_(self.transform_fc.weight, 0)
        init.eye_(self.transform_fc.bias.view(3, 3))

    def forward(self, x):
        batch_size = x.size(0)

        x = self.conv1(
            x
        )  # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv2(
            x
        )  # (batch_size, 64, num_points, k) -> (batch_size, 128, num_points, k)
        x = x.max(dim=-1, keepdim=False)[
            0
        ]  # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)

        x = self.conv3(
            x
        )  # (batch_size, 128, num_points) -> (batch_size, 1024, num_points)
        x = x.max(dim=-1, keepdim=False)[
            0
        ]  # (batch_size, 1024, num_points) -> (batch_size, 1024)

        x = F.leaky_relu(
            self.bn3(self.linear1(x)), negative_slope=0.2
        )  # (batch_size, 1024) -> (batch_size, 512)
        x = F.leaky_relu(
            self.bn4(self.linear2(x)), negative_slope=0.2
        )  # (batch_size, 512) -> (batch_size, 256)

        x = self.transform_fc(x)  # (batch_size, 256) -> (batch_size, 3*3)
        x = x.view(batch_size, 3, 3)  # (batch_size, 3*3) -> (batch_size, 3, 3)

        return x

    def transform(self, pos, batch, normals=None):
        x0 = pos
        x0, mask = to_dense_batch(x0, batch)  # to_dense_batch(x0, batch, 32)
        x0 = x0.transpose(2, 1).contiguous()
        x0_gf = get_graph_feature(x0, self.k).contiguous()
        t = self(x0_gf).contiguous()
        x0 = x0.transpose(2, 1).contiguous()
        x0 = torch.bmm(x0, t).contiguous()
        x0 = x0.view(x0.shape[0] * x0.shape[1], -1).contiguous()
        pos = x0

        if normals is not None:
            normals_tmp, mask = to_dense_batch(normals, batch)  # to_dense_batch(x0, batch, 32)
            normals_tmp = torch.bmm(normals_tmp, t).contiguous()
            normals_tmp = normals_tmp.view(normals_tmp.shape[0] * normals_tmp.shape[1], -1).contiguous()
            normals = normals_tmp
            pos = torch.cat([pos, normals], dim=1)

        return pos
