
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from model import knn


########################################################################################################################
# The following several functions are designed for modified knn.
########################################################################################################################
def knn_idx_dis(x, n=40, t=1, uniform=True):
    """
    Modified k nearest neighbor search
    :param x: the input
    :param n: the number of nearest neighbors
    :param uniform:
    :return:
    """
    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)

    pairwise_distance, idx = pairwise_distance.topk(k=n, dim=-1)  # (batch_size, num_points, n)
    if uniform:
        return idx                                                # (batch_size, num_points, n)
    else:
        pairwise_distance = F.softmax(-pairwise_distance / t, dim=-1)
        return idx, pairwise_distance                             # (batch_size, num_points, n)


def get_graph_feature_knn_simplify(x, idx, k, pool=0, uniform=True, use_gather=False):
    """
    The get_graph_feature used by DGCNN_knn_simplify
    :param x: (batch_size, num_dims, num_points)
    :param idx: (batch_size, num_points, n)
    :param k: the number of knn neighbors
    :param pool: the pool to sample the k neighbors from. k < pool <= n if pool != 0
    :param uniform: whether to use uniform sampling or multinomial sampling
    :param use_gather: whether to use torch.gather or tensor slice to get the feature
    :return feature: (batch_size, num_dims, num_points, k)
    """
    batch_size, num_dims, num_points = x.shape
    device = x.device

    # get sampling idx
    if uniform:
        if pool == 0:
            idx_idx = torch.arange(0, k, device=device)
        else:
            idx_idx = torch.randperm(pool, device=device)[:k].sort()[0]
        idx = idx[:, :, idx_idx]
    else:
        idx, pairwise_distance = idx
        n = idx.size(2)
        ii = torch.multinomial(pairwise_distance.view(-1, n), k)
        ii += torch.arange(0, batch_size * num_points, device=device).view(-1, 1) * n
        idx = idx.view(-1)[ii.view(-1)].view(batch_size, num_points, k)

    # prepare the index
    if use_gather:
        idx = idx.view(batch_size, 1, num_points * k)               # idx: (batch_size, 1, num_points*k) -> (batch_size, num_dims, num_points*k)
        feature = torch.gather(x, 2, idx.repeat(1, num_dims, 1))    # (batch_size, num_dims, num_points) -> (batch_size, num_dims, num_points*k)
        feature = feature.view(batch_size, num_dims, num_points, k) # (batch_size, num_dims, num_points) -> (batch_size, num_dims, num_points, k)
    else:
        idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
        idx = idx + idx_base
        idx = idx.view(-1)                                          # shape: batch_size*num_points*k; range(0, batch_size*num_points)

        x = x.transpose(2, 1).contiguous()                          # (batch_size, num_dims, num_points) -> (batch_size, num_points, num_dims)
        feature = x.view(batch_size * num_points, -1)[idx, :]       # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) -> (batch_size*num_points*k, num_dims)
        feature = feature.view(batch_size, num_points, k, num_dims) # (batch_size*num_points*k, num_dims) -> (batch_size, num_points, k, num_dims)
        feature = feature.permute(0, 3, 1, 2)                       # (batch_size, num_points, k, num_dims) -> (batch_size, num_dims, num_points, k)

    x = x.transpose(2, 1).unsqueeze(-1).repeat(1, 1, 1, k)
    feature = torch.cat((feature - x, x), dim=1)
    return feature


def get_graph_feature_conv_reformulate(x, f0, f1, k=20, idx=None, dim9=False):
    """
    The get_graph_feature function used by DGCNN_conv_reformulate.
    :param x: (batch_size, num_dims_old, num_points)
    :param f0: (batch_size, num_dims, num_points)
    :param f1: (batch_size, num_dims, num_points)
    :param k: the number of knn neighbors
    :param idx:
    :param dim9:
    :return feature: (batch_size, num_dims, num_points, k)
    """
    batch_size, num_dims, num_points = f0.shape
    device = f0.device
    # x = x.view(batch_size, -1, num_points)

    # prepare the indices
    if idx is None:
        if dim9 == False:
            idx = knn(x, k=k)  # (batch_size, num_points, k)
        else:
            idx = knn(x[:, 6:], k=k)

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
    idx = idx + idx_base
    idx = idx.view(-1)

    # get the features
    f0 = f0.transpose(2, 1).contiguous()                                # (batch_size, num_dims, num_points) -> (batch_size, num_points, num_dims)
    feature = f0.view(batch_size * num_points, -1)[idx, :]              # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = feature.view(batch_size, num_points, k, num_dims)
    f1 = f1.transpose(2, 1).view(batch_size, num_points, 1, num_dims)   # (batch_size, num_points, k, num_dims)
    # feature = (feature + f1.repeat(1, 1, k, 1)).permute(0, 3, 1, 2)     # (batch_size, num_points, k, num_dims) -> # (batch_size, num_dims num_points, k)
    feature = (feature + f1.repeat(1, 1, k, 1)).max(dim=2)[0].transpose(2, 1)

    return feature  # (batch_size, num_dims, num_points, k)


def get_graph_feature_full(f0, f1, idx, k, pool=0):
    """
    The get_feature_function used by DGCNN_full.
    :param f0: (batch_size, num_points, num_dims)
    :param f1: (batch_size, num_points, num_dims)
    :param idx: (batch_size, num_points, n)
    :param k: the number of knn neighbors
    :param pool: the pool to sample the k neighbors from. k < pool <= n if pool != 0
    :return feature: (batch_size, num_points, num_dims)
    """
    batch_size, num_points, num_dims = f0.shape
    device = f0.device

    # sample the indices
    if pool == 0:
        idx_idx = torch.arange(0, k, device=device)
    else:
        idx_idx = torch.randperm(pool, device=device)[:k].sort()[0]

    # prepare the indices
    idx = idx[:, :, idx_idx]
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
    idx = idx + idx_base
    idx = idx.view(-1)                                          # shape: batch_size*num_points*k; range(0, batch_size*num_points)

    # get the features
    feature = f0.reshape(batch_size * num_points, -1)[idx, :]   # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) -> (batch_size*num_points*k, num_dims)
    feature = feature.view(batch_size, num_points, k, num_dims) # (batch_size*num_points*k, num_dims) -> (batch_size, num_points, k, num_dims)
    feature = (feature + f1.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)).max(dim=2)[0]
    return feature                                              # (batch_size, num_points, num_dims)


########################################################################################################################
# The three versions of our algorithm.
# DGCNN_knn_simplify: simplify the knn operation
# DGCNN_conv_reformulate: exchange the knn operation and the conv. Reformulate conv operation accordingly.
# DGCNN_Full: combine the above two modifications.
########################################################################################################################


class DGCNN_knn_simplify(nn.Module):
    """
    Try to simplify the knn operation. Only one pairwise distance computation and topk operation is conducted, which
    forms a sampling pool for the following neighbor inquiries.
    This is explained in Sec. 3.2 of the report.
    """
    def __init__(self, args, output_channels=40):
        super(DGCNN_knn_simplify, self).__init__()
        self.args = args
        self.k = args.k
        self.n = args.n
        self.uniform = True
        self.use_gather = False
        self.p = args.progressive

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm1d(args.emb_dims)

        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64 * 2, 64, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64 * 2, 128, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(128 * 2, 256, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv1d(512, args.emb_dims, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.linear1 = nn.Linear(args.emb_dims * 2, 512, bias=False)
        self.bn6 = nn.BatchNorm1d(512)
        self.dp1 = nn.Dropout(p=args.dropout)
        self.linear2 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dp2 = nn.Dropout(p=args.dropout)
        self.linear3 = nn.Linear(256, output_channels)

    def forward(self, x):
        batch_size = x.size(0)
        idx = knn_idx_dis(x, self.n, uniform=self.uniform)
        x = get_graph_feature_knn_simplify(x, idx, k=self.k, pool=0)    # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        x = self.conv1(x)                                               # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x1 = x.max(dim=-1, keepdim=False)[0]                            # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        p = (self.p == 0) * self.n + (self.p > 0) * (self.k + self.p)
        x = get_graph_feature_knn_simplify(x1, idx, k=self.k, pool=p)   # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv2(x)                                               # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x2 = x.max(dim=-1, keepdim=False)[0]                            # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        p = (self.p == 0) * self.n + (self.p > 0) * (self.k + self.p * 2)
        x = get_graph_feature_knn_simplify(x2, idx, k=self.k, pool=p)   # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv3(x)                                               # (batch_size, 64*2, num_points, k) -> (batch_size, 128, num_points, k)
        x3 = x.max(dim=-1, keepdim=False)[0]                            # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)

        p = (self.p == 0) * self.n + (self.p > 0) * (self.k + self.p * 3)
        x = get_graph_feature_knn_simplify(x3, idx, k=self.k, pool=p)   # (batch_size, 128, num_points) -> (batch_size, 128*2, num_points, k)
        x = self.conv4(x)                                               # (batch_size, 128*2, num_points, k) -> (batch_size, 256, num_points, k)
        x4 = x.max(dim=-1, keepdim=False)[0]                            # (batch_size, 256, num_points, k) -> (batch_size, 256, num_points)

        x = torch.cat((x1, x2, x3, x4), dim=1)                          # (batch_size, 64+64+128+256, num_points)
        x = self.conv5(x)                                               # (batch_size, 64+64+128+256, num_points) -> (batch_size, emb_dims, num_points)
        x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x = torch.cat((x1, x2), 1)                                      # (batch_size, emb_dims*2)

        x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512)
        x = self.dp1(x)
        x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) # (batch_size, 512) -> (batch_size, 256)
        x = self.dp2(x)
        x = self.linear3(x)                                             # (batch_size, 256) -> (batch_size, output_channels)

        return x
    

class DGCNN_conv_reformulate(nn.Module):
    """
    Try to reformulate the convolution operation in EdgeConv.
    This is explained in Sec. 3.3 of the report.
    """
    def __init__(self, args, output_channels=40):
        super(DGCNN_conv_reformulate, self).__init__()
        self.args = args
        self.k = args.k

        self.bn1 = nn.BatchNorm1d(128)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(256)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(args.emb_dims)

        self.conv1 = nn.Sequential(nn.Conv1d(3, 128, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv1d(64, 256, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv1d(128, 512, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv1d(512, args.emb_dims, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.linear1 = nn.Linear(args.emb_dims*2, 512, bias=False)
        self.bn6 = nn.BatchNorm1d(512)
        self.dp1 = nn.Dropout(p=args.dropout)
        self.linear2 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dp2 = nn.Dropout(p=args.dropout)
        self.linear3 = nn.Linear(256, output_channels)

    def forward(self, x):
        batch_size = x.size(0)

        f0, f1 = self.conv1(x).split(64, dim=1)                         # (batch_size, 3, num_points) -> (batch_size, 64, num_points)
        x1 = get_graph_feature_conv_reformulate(x, f0, f1, k=self.k)     # (batch_size, 64, num_points, k)
        # x1 = x.max(dim=-1, keepdim=False)[0]                            # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        f0, f1 = self.conv2(x1).split(64, dim=1)                        # (batch_size, 64, num_points) -> (batch_size, 64, num_points)
        x2 = get_graph_feature_conv_reformulate(x1, f0, f1, k=self.k)    # (batch_size, 64, num_points, k)
        # x2 = x.max(dim=-1, keepdim=False)[0]                            # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        f0, f1 = self.conv3(x2).split(128, dim=1)                       # (batch_size, 64, num_points) -> (batch_size, 128, num_points)
        x3 = get_graph_feature_conv_reformulate(x2, f0, f1, k=self.k)    # (batch_size, 128, num_points, k)
        # x3 = x.max(dim=-1, keepdim=False)[0]                            # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)

        f0, f1 = self.conv4(x3).split(256, dim=1)                       # (batch_size, 128, num_points) -> (batch_size, 256, num_points)
        x4 = get_graph_feature_conv_reformulate(x3, f0, f1, k=self.k)    # (batch_size, 256, num_points, k)
        # x4 = x.max(dim=-1, keepdim=False)[0]                            # (batch_size, 256, num_points, k) -> (batch_size, 256, num_points)

        x = torch.cat((x1, x2, x3, x4), dim=1)                          # (batch_size, 64+64+128+256, num_points)
        x = self.conv5(x)                                               # (batch_size, 64+64+128+256, num_points) -> (batch_size, emb_dims, num_points)
        x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x = torch.cat((x1, x2), 1)                                      # (batch_size, emb_dims*2)

        x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512)
        x = self.dp1(x)
        x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) # (batch_size, 512) -> (batch_size, 256)
        x = self.dp2(x)
        x = self.linear3(x)                                             # (batch_size, 256) -> (batch_size, output_channels)
        
        return x


class DGCNN_full(nn.Module):
    """
    This function combines the two acceleration method introduced in the paper.
    """

    def __init__(self, args, output_channels=40):
        super(DGCNN_full, self).__init__()
        self.args = args
        self.k = args.k
        self.n = args.n
        self.p = args.progressive

        self.bn1 = nn.BatchNorm1d(128)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(256)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(args.emb_dims)

        self.conv1 = nn.Sequential(nn.Conv1d(3, 128, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv1d(64, 256, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv1d(128, 512, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv1d(512, args.emb_dims, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.linear1 = nn.Linear(args.emb_dims * 2, 512, bias=False)
        self.bn6 = nn.BatchNorm1d(512)
        self.dp1 = nn.Dropout(p=args.dropout)
        self.linear2 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dp2 = nn.Dropout(p=args.dropout)
        self.linear3 = nn.Linear(256, output_channels)

    def forward(self, x):
        batch_size = x.size(0)                                                      # (batch_size, 3, num_points)

        idx = knn_idx_dis(x, self.n)                                                # (batch_size, num_points, n)
        f0, f1 = self.conv1(x).transpose(2, 1).split(64, dim=2)                     # (batch_size, 3, num_points) -> (batch_size, num_points, 64)
        # x = get_graph_feature_reform(x, f0, f1, idx=idx, k=self.k)
        x1 = get_graph_feature_full(f0, f1, idx, k=self.k, pool=0)                  # (batch_size, num_points, 64)

        f0, f1 = self.conv2(x1.transpose(2, 1)).transpose(2, 1).split(64, dim=2)    # (batch_size, num_points, 64) -> (batch_size, num_points, 64)
        # x = get_graph_feature_reform(x1, f0, f1, idx=idx, k=self.k)
        p = (self.p == 0) * self.n + (self.p > 0) * (self.k + self.p)
        x2 = get_graph_feature_full(f0, f1, idx, k=self.k, pool=p)                  # (batch_size, num_points, 64)

        f0, f1 = self.conv3(x2.transpose(2, 1)).transpose(2, 1).split(128, dim=2)   # (batch_size, num_points, 64) -> (batch_size, num_points, 128)
        # x = get_graph_feature_reform(x2, f0, f1, idx=idx, k=self.k)
        p = (self.p == 0) * self.n + (self.p > 0) * (self.k + self.p * 2)
        x3 = get_graph_feature_full(f0, f1, idx, k=self.k, pool=p)                  # (batch_size, num_points, 128)

        f0, f1 = self.conv4(x3.transpose(2, 1)).transpose(2, 1).split(256, dim=2)   # (batch_size, num_points, 128) -> (batch_size, num_points, 256)
        # x = get_graph_feature_reform(x3, f0, f1, idx=idx, k=self.k)
        p = (self.p == 0) * self.n + (self.p > 0) * (self.k + self.p * 3)
        x4 = get_graph_feature_full(f0, f1, idx, k=self.k, pool=p)                  # (batch_size, num_points, 256)

        x = torch.cat((x1, x2, x3, x4), dim=2).transpose(2, 1)                      # (batch_size, 64+64+128+256, num_points)
        x = self.conv5(x)                                                           # (batch_size, 64+64+128+256, num_points) -> (batch_size, emb_dims, num_points)
        x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)                       # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)                       # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x = torch.cat((x1, x2), 1)                                                  # (batch_size, emb_dims*2)

        x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2)             # (batch_size, emb_dims*2) -> (batch_size, 512)
        x = self.dp1(x)
        x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2)             # (batch_size, 512) -> (batch_size, 256)
        x = self.dp2(x)
        x = self.linear3(x)                                                         # (batch_size, 256) -> (batch_size, output_channels)

        return x


########################################################################################################################
# Accelerated part segmentation code
########################################################################################################################
class Transform_Net(nn.Module):
    def __init__(self, args):
        super(Transform_Net, self).__init__()
        self.args = args
        self.k = args.k

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128 * 2)
        self.bn3 = nn.BatchNorm1d(1024)

        self.conv1 = nn.Sequential(nn.Conv1d(3, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv1d(64, 128 * 2, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv1d(128, 1024, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))

        self.linear1 = nn.Linear(1024, 512, bias=False)
        self.bn3 = nn.BatchNorm1d(512)
        self.linear2 = nn.Linear(512, 256, bias=False)
        self.bn4 = nn.BatchNorm1d(256)

        self.transform = nn.Linear(256, 3 * 3)
        init.constant_(self.transform.weight, 0)
        init.eye_(self.transform.bias.view(3, 3))

    def forward(self, x):
        batch_size = x.size(0)
        idx = knn_idx_dis(x, self.k)

        x = self.conv1(x)  # (batch_size, 3, num_points) -> (batch_size, 64, num_points)
        f0, f1 = self.conv2(x).transpose(2, 1).split(128, dim=2)  # (batch_size, 64, num_points) -> (batch_size, num_points, 128)
        x = get_graph_feature_full(f0, f1, idx, self.k, pool=0)  # (batch_size, num_points, 128) -> (batch_size, num_points, 128)

        x = self.conv3(x.transpose(2, 1))  # (batch_size, num_points, 128) -> (batch_size, 1024, num_points)
        x = x.max(dim=-1, keepdim=False)[0]  # (batch_size, 1024, num_points) -> (batch_size, 1024)

        x = F.leaky_relu(self.bn3(self.linear1(x)), negative_slope=0.2)  # (batch_size, 1024) -> (batch_size, 512)
        x = F.leaky_relu(self.bn4(self.linear2(x)), negative_slope=0.2)  # (batch_size, 512) -> (batch_size, 256)

        x = self.transform(x)  # (batch_size, 256) -> (batch_size, 3*3)
        x = x.view(batch_size, 3, 3)  # (batch_size, 3*3) -> (batch_size, 3, 3)

        return x


class DGCNN_partseg_full(nn.Module):
    """
    Both knn simplication and conv reformulation are used. Only use one knn operation.
    """
    def __init__(self, args, seg_num_all):
        super(DGCNN_partseg_full, self).__init__()
        self.args = args
        self.seg_num_all = seg_num_all
        self.k = args.k
        self.n = args.n
        self.p = args.progressive
        self.transform_net = Transform_Net(args)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(128)
        self.bn6 = nn.BatchNorm1d(args.emb_dims)
        self.bn7 = nn.BatchNorm1d(64)
        self.bn8 = nn.BatchNorm1d(256)
        self.bn9 = nn.BatchNorm1d(256)
        self.bn10 = nn.BatchNorm1d(128)

        self.conv1 = nn.Sequential(nn.Conv1d(3, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv1d(64, 64, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv6 = nn.Sequential(nn.Conv1d(192, args.emb_dims, kernel_size=1, bias=False),
                                   self.bn6,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv7 = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False),
                                   self.bn7,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv8 = nn.Sequential(nn.Conv1d(1280, 256, kernel_size=1, bias=False),
                                   self.bn8,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.dp1 = nn.Dropout(p=args.dropout)
        self.conv9 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1, bias=False),
                                   self.bn9,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.dp2 = nn.Dropout(p=args.dropout)
        self.conv10 = nn.Sequential(nn.Conv1d(256, 128, kernel_size=1, bias=False),
                                    self.bn10,
                                    nn.LeakyReLU(negative_slope=0.2))
        self.conv11 = nn.Conv1d(128, self.seg_num_all, kernel_size=1, bias=False)

    def forward(self, x, l):
        batch_size = x.size(0)
        num_points = x.size(2)

        t = self.transform_net(x)                               # (batch_size, 3, 3)
        x = x.transpose(2, 1)                                   # (batch_size, 3, num_points) -> (batch_size, num_points, 3)
        x = torch.bmm(x, t)                                     # (batch_size, num_points, 3) * (batch_size, 3, 3) -> (batch_size, num_points, 3)
        x = x.transpose(2, 1)                                   # (batch_size, num_points, 3) -> (batch_size, 3, num_points)

        idx = knn_idx_dis(x, self.n)
        x = self.conv1(x)                                       # (batch_size, 3, num_points) -> (batch_size, 64, num_points)
        f0, f1 = self.conv2(x).transpose(2, 1).split(64, dim=2) # (batch_size, 64, num_points) -> (batch_size, num_points, 64)
        x1 = get_graph_feature_full(f0, f1, idx, self.k, pool=0)# (batch_size, num_points, 64) -> (batch_size, num_points, 64)

        x = self.conv3(x1.transpose(2, 1))                      # (batch_size, num_points, 64) -> (batch_size, 64, num_points)
        f0, f1 = self.conv4(x).transpose(2, 1).split(64, dim=2) # (batch_size, 64, num_points) -> (batch_size, num_points, 64)
        p = (self.p == 0) * self.n + (self.p > 0) * (self.k + self.p)
        x2 = get_graph_feature_full(f0, f1, idx, self.k, pool=p)# (batch_size, num_points, 64) -> (batch_size, num_points, 64) TODO: change pool

        f0, f1 = self.conv5(x2.transpose(2, 1)).transpose(2, 1).split(64,dim=2)  # (batch_size, num_points, 64) -> (batch_size, num_points, 64)
        p = (self.p == 0) * self.n + (self.p > 0) * (self.k + self.p * 2)
        x3 = get_graph_feature_full(f0, f1, idx, self.k, pool=p)# (batch_size, num_points, 64) -> (batch_size, num_points, 64) TODO: change pool

        x = torch.cat((x1, x2, x3), dim=-1).transpose(2, 1)     # (batch_size, 64*3, num_points)

        x = self.conv6(x)                                       # (batch_size, 64*3, num_points) -> (batch_size, emb_dims, num_points)
        x = x.max(dim=-1, keepdim=True)[0]                      # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims, 1)

        l = l.view(batch_size, -1, 1)                           # (batch_size, num_categoties, 1)
        l = self.conv7(l)                                       # (batch_size, num_categoties, 1) -> (batch_size, 64, 1)

        x = torch.cat((x, l), dim=1)                            # (batch_size, 1088, 1)
        x = x.repeat(1, 1, num_points)                          # (batch_size, 1088, num_points)

        x = torch.cat((x.transpose(2, 1), x1, x2, x3), dim=-1).transpose(2, 1)  # (batch_size, 1088+64*3, num_points)

        x = self.conv8(x)                                       # (batch_size, 1088+64*3, num_points) -> (batch_size, 256, num_points)
        x = self.dp1(x)
        x = self.conv9(x)                                       # (batch_size, 256, num_points) -> (batch_size, 256, num_points)
        x = self.dp2(x)
        x = self.conv10(x)                                      # (batch_size, 256, num_points) -> (batch_size, 128, num_points)
        x = self.conv11(x)                                      # (batch_size, 256, num_points) -> (batch_size, seg_num_all, num_points)

        return x


class DGCNN_partseg_conv_reformulate1(nn.Module):
    """
        Every conv used to extract features is followed by a get_graph_feature_full operation. And every two of those convs
        share the same knn operation. This change can significantly improve the performance of the model.
    """
    def __init__(self, args, seg_num_all):
        super(DGCNN_partseg_conv_reformulate1, self).__init__()
        self.args = args
        self.seg_num_all = seg_num_all
        self.k = args.k
        self.n = args.n
        self.p = args.progressive
        self.transform_net = Transform_Net(args)

        self.bn1 = nn.BatchNorm1d(128)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(128)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(128)
        self.bn6 = nn.BatchNorm1d(args.emb_dims)
        self.bn7 = nn.BatchNorm1d(64)
        self.bn8 = nn.BatchNorm1d(256)
        self.bn9 = nn.BatchNorm1d(256)
        self.bn10 = nn.BatchNorm1d(128)

        self.conv1 = nn.Sequential(nn.Conv1d(3, 128, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv6 = nn.Sequential(nn.Conv1d(192, args.emb_dims, kernel_size=1, bias=False),
                                   self.bn6,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv7 = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False),
                                   self.bn7,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv8 = nn.Sequential(nn.Conv1d(1280, 256, kernel_size=1, bias=False),
                                   self.bn8,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.dp1 = nn.Dropout(p=args.dropout)
        self.conv9 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1, bias=False),
                                   self.bn9,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.dp2 = nn.Dropout(p=args.dropout)
        self.conv10 = nn.Sequential(nn.Conv1d(256, 128, kernel_size=1, bias=False),
                                    self.bn10,
                                    nn.LeakyReLU(negative_slope=0.2))
        self.conv11 = nn.Conv1d(128, self.seg_num_all, kernel_size=1, bias=False)

    def forward(self, x, l):
        batch_size = x.size(0)
        num_points = x.size(2)

        t = self.transform_net(x)                               # (batch_size, 3, 3)
        x = x.transpose(2, 1)                                   # (batch_size, 3, num_points) -> (batch_size, num_points, 3)
        x = torch.bmm(x, t)                                     # (batch_size, num_points, 3) * (batch_size, 3, 3) -> (batch_size, num_points, 3)
        x = x.transpose(2, 1)                                   # (batch_size, num_points, 3) -> (batch_size, 3, num_points)

        idx = knn_idx_dis(x, self.n)
        f0, f1 = self.conv1(x).transpose(2, 1).split(64, dim=2)                                   # (batch_size, 3, num_points) -> (batch_size, 64, num_points)
        x = get_graph_feature_full(f0, f1, idx, self.k, pool=0)
        f0, f1 = self.conv2(x.transpose(2, 1)).transpose(2, 1).split(64, dim=2) # (batch_size, 64, num_points) -> (batch_size, num_points, 64)
        p = (self.p == 0) * self.n + (self.p > 0) * (self.k + self.p)
        x1 = get_graph_feature_full(f0, f1, idx, self.k, pool=p)# (batch_size, num_points, 64) -> (batch_size, num_points, 64)

        idx = knn_idx_dis(x1.transpose(2, 1), self.n)
        f0, f1 = self.conv3(x1.transpose(2, 1)).transpose(2, 1).split(64, dim=2)                    # (batch_size, num_points, 64) -> (batch_size, 64, num_points)
        x = get_graph_feature_full(f0, f1, idx, self.k, pool=0)
        f0, f1 = self.conv4(x.transpose(2, 1)).transpose(2, 1).split(64, dim=2) # (batch_size, 64, num_points) -> (batch_size, num_points, 64)
        p = (self.p == 0) * self.n + (self.p > 0) * (self.k + self.p)
        x2 = get_graph_feature_full(f0, f1, idx, self.k, pool=p)# (batch_size, num_points, 64) -> (batch_size, num_points, 64) TODO: change pool

        idx = knn_idx_dis(x2.transpose(2, 1), self.k)
        f0, f1 = self.conv5(x2.transpose(2, 1)).transpose(2, 1).split(64,dim=2)  # (batch_size, num_points, 64) -> (batch_size, num_points, 64)
        # p = (self.p == 0) * self.n + (self.p > 0) * (self.k + self.p * 2)
        x3 = get_graph_feature_full(f0, f1, idx, self.k, pool=0)# (batch_size, num_points, 64) -> (batch_size, num_points, 64) TODO: change pool

        x = torch.cat((x1, x2, x3), dim=-1).transpose(2, 1)     # (batch_size, 64*3, num_points)

        x = self.conv6(x)                                       # (batch_size, 64*3, num_points) -> (batch_size, emb_dims, num_points)
        x = x.max(dim=-1, keepdim=True)[0]                      # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims, 1)

        l = l.view(batch_size, -1, 1)                           # (batch_size, num_categoties, 1)
        l = self.conv7(l)                                       # (batch_size, num_categoties, 1) -> (batch_size, 64, 1)

        x = torch.cat((x, l), dim=1)                            # (batch_size, 1088, 1)
        x = x.repeat(1, 1, num_points)                          # (batch_size, 1088, num_points)

        x = torch.cat((x.transpose(2, 1), x1, x2, x3), dim=-1).transpose(2, 1)  # (batch_size, 1088+64*3, num_points)

        x = self.conv8(x)                                       # (batch_size, 1088+64*3, num_points) -> (batch_size, 256, num_points)
        x = self.dp1(x)
        x = self.conv9(x)                                       # (batch_size, 256, num_points) -> (batch_size, 256, num_points)
        x = self.dp2(x)
        x = self.conv10(x)                                      # (batch_size, 256, num_points) -> (batch_size, 128, num_points)
        x = self.conv11(x)                                      # (batch_size, 256, num_points) -> (batch_size, seg_num_all, num_points)

        return x



########################################################################################################################
# Accelerated semantic segmentation code
########################################################################################################################
class DGCNN_semseg_full(nn.Module):
    """
    Both knn simplication and conv reformulation are used. Only use one knn operation.
    """
    def __init__(self, args):
        super(DGCNN_semseg_full, self).__init__()
        self.args = args
        self.k = args.k
        self.n = args.n
        self.p = args.progressive

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(128)
        self.bn6 = nn.BatchNorm1d(args.emb_dims)
        self.bn7 = nn.BatchNorm1d(512)
        self.bn8 = nn.BatchNorm1d(256)

        self.conv1 = nn.Sequential(nn.Conv1d(9, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv1d(64, 64 * 2, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv1d(64, 64, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv1d(64, 64 * 2, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv1d(64, 64 * 2, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv6 = nn.Sequential(nn.Conv1d(192, args.emb_dims, kernel_size=1, bias=False),
                                   self.bn6,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv7 = nn.Sequential(nn.Conv1d(1216, 512, kernel_size=1, bias=False),
                                   self.bn7,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv8 = nn.Sequential(nn.Conv1d(512, 256, kernel_size=1, bias=False),
                                   self.bn8,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.dp1 = nn.Dropout(p=args.dropout)
        self.conv9 = nn.Conv1d(256, 13, kernel_size=1, bias=False)

    def forward(self, x):
        num_points = x.size(2)
        idx = knn_idx_dis(x[:, 6:], self.n)

        x = self.conv1(x)                                       # (batch_size, 9, num_points) -> (batch_size, 64, num_points)
        f0, f1 = self.conv2(x).transpose(2, 1).split(64, dim=2) # (batch_size, 64, num_points) -> (batch_size, num_points, 64)
        x1 = get_graph_feature_full(f0, f1, idx, self.k, pool=0)# (batch_size, num_points, 64) -> (batch_size, num_points, 64)

        x = self.conv3(x1.transpose(2, 1))                      # (batch_size, num_points, 64) -> (batch_size, 64, num_points)
        f0, f1 = self.conv4(x).transpose(2, 1).split(64, dim=2) # (batch_size, 64, num_points) -> (batch_size, num_points, 64)
        p = (self.p == 0) * self.n + (self.p > 0) * (self.k + self.p)
        x2 = get_graph_feature_full(f0, f1, idx, self.k, pool=p)# (batch_size, num_points, 64) -> (batch_size, num_points, 64)

        f0, f1 = self.conv5(x2.transpose(2, 1)).transpose(2, 1).split(64, dim=2)  # (batch_size, num_points, 64) -> (batch_size, num_points, 64)
        p = (self.p == 0) * self.n + (self.p > 0) * (self.k + self.p * 2)
        x3 = get_graph_feature_full(f0, f1, idx, self.k, pool=p)# (batch_size, num_points, 64) -> (batch_size, num_points, 64)

        x = torch.cat((x1, x2, x3), dim=-1).transpose(2, 1)     # (batch_size, 64*3, num_points)

        x = self.conv6(x)                                       # (batch_size, 64*3, num_points) -> (batch_size, emb_dims, num_points)
        x = x.max(dim=-1, keepdim=True)[0]                      # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims, 1)

        x = x.repeat(1, 1, num_points)                          # (batch_size, 1024, num_points)
        x = torch.cat((x.transpose(2, 1), x1, x2, x3), dim=-1).transpose(2, 1)  # (batch_size, 1024+64*3, num_points)

        x = self.conv7(x)                                       # (batch_size, 1024+64*3, num_points) -> (batch_size, 512, num_points)
        x = self.conv8(x)                                       # (batch_size, 512, num_points) -> (batch_size, 256, num_points)
        x = self.dp1(x)
        x = self.conv9(x)                                       # (batch_size, 256, num_points) -> (batch_size, 13, num_points)

        return x
