import torch
import torch.nn as nn

import copy
class PCN(nn.Module):
    """
    "PCN: Point Cloud Completion Network"
    (https://arxiv.org/pdf/1808.00671.pdf)

    Attributes:
        num_dense:  16384
        latent_dim: 1024
        grid_size:  4
        num_coarse: 1024
    """

    def __init__(self, num_dense=16384, latent_dim=1024, grid_size=4):
        super().__init__()

        self.num_dense = num_dense
        self.latent_dim = latent_dim
        self.grid_size = grid_size

        assert self.num_dense % self.grid_size ** 2 == 0

        self.num_coarse = self.num_dense // (self.grid_size ** 2)

        self.first_conv = nn.Sequential(
            nn.Conv1d(3, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1)
        )

        self.second_conv = nn.Sequential(
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, self.latent_dim, 1)
        )

        self.mlp = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )

        self.final_conv = nn.Sequential(
            nn.Conv1d(1024 + 3 + 2, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 3, 1)
        )
        a = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(1, self.grid_size).expand(self.grid_size, self.grid_size).reshape(1, -1)
        b = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(self.grid_size, 1).expand(self.grid_size, self.grid_size).reshape(1, -1)
        
        self.folding_seed = torch.cat([a, b], dim=0).view(1, 2, self.grid_size ** 2).cuda()  # (1, 2, S)

    def forward(self, xyz):
        B, N, _ = xyz.shape
        
        # encoder
        feature = self.first_conv(xyz.transpose(2, 1))                                       # (B,  256, N)
        feature_global = torch.max(feature, dim=2, keepdim=True)[0]                          # (B,  256, 1)
        feature = torch.cat([feature_global.expand(-1, -1, N), feature], dim=1)              # (B,  512, N)
        feature = self.second_conv(feature)                                                  # (B, 1024, N)
        feature_global = torch.max(feature,dim=2,keepdim=False)[0]                           # (B, 1024)
        
        # decoder
        coarse = self.mlp(feature_global).reshape(-1, self.num_coarse, 3)                    # (B, num_coarse, 3), coarse point cloud
        point_feat = coarse.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1)             # (B, num_coarse, S, 3)
        point_feat = point_feat.reshape(-1, self.num_dense, 3).transpose(2, 1)               # (B, 3, num_fine)

        seed = self.folding_seed.unsqueeze(2).expand(B, -1, self.num_coarse, -1)             # (B, 2, num_coarse, S)
        seed = seed.reshape(B, -1, self.num_dense)                                           # (B, 2, num_fine)

        feature_global = feature_global.unsqueeze(2).expand(-1, -1, self.num_dense)          # (B, 1024, num_fine)
        feat = torch.cat([feature_global, seed, point_feat], dim=1)                          # (B, 1024+2+3, num_fine)
    
        fine = self.final_conv(feat) + point_feat                                            # (B, 3, num_fine), fine point cloud

        return coarse.contiguous(), fine.transpose(1, 2).contiguous()


class PCN_SIG(nn.Module):
    """
    "PCN: Point Cloud Completion Network"
    (https://arxiv.org/pdf/1808.00671.pdf)

    Attributes:
        num_dense:  16384
        latent_dim: 1024
        grid_size:  4
        num_coarse: 1024
    """

    def __init__(self, num_dense=16384, latent_dim=1024, grid_size=4):
        super().__init__()

        self.num_dense = num_dense
        self.latent_dim = latent_dim
        self.grid_size = grid_size

        assert self.num_dense % self.grid_size ** 2 == 0

        self.num_coarse = self.num_dense // (self.grid_size ** 2)

        self.first_conv = nn.Sequential(
            nn.Conv1d(3, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1)
        )

        self.second_conv = nn.Sequential(
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, self.latent_dim, 1)
        )

        self.mlp = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )

        self.final_conv = nn.Sequential(
            nn.Conv1d( 3 + 2, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 3, 1)
        )

        self.a_index=1

        a_1 = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(1, self.grid_size).expand(
            self.grid_size, self.grid_size).reshape(1, -1)
        b_1 = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(self.grid_size, 1).expand(
            self.grid_size, self.grid_size).reshape(1, -1)

        self.folding_seed = torch.cat([a_1, b_1], dim=0).view(1, 2, self.grid_size ** 2).cuda()  # (1, 2, S)

    def forward(self, xyz):
        B, N, _ = xyz.shape

        # encoder
        feature = self.first_conv(xyz.transpose(2, 1))  # (B,  256, N)
        feature_global = torch.max(feature, dim=2, keepdim=True)[0]  # (B,  256, 1)
        feature = torch.cat([feature_global.expand(-1, -1, N), feature], dim=1)  # (B,  512, N)
        feature = self.second_conv(feature)  # (B, 1024, N)
        feature_global = torch.max(feature, dim=2, keepdim=False)[0]  # (B, 1024)

        # decoder
        coarse = self.mlp(feature_global).reshape(-1, self.num_coarse,
                                                              3)  # (B, num_coarse, 3), coarse point cloud
        # print('coarse_shape:',coarse_shape.shape)
        # print('coarse_shape[:,:,:3]:', coarse_shape[:,:,:3].shape)

        point_feat = coarse.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1)  # (B, num_coarse, S, 3)
        point_feat = point_feat.reshape(-1, self.num_dense, 3).transpose(2, 1)  # (B, 3, num_fine)

        seed = self.folding_seed.unsqueeze(2).expand(B, -1, self.num_coarse, -1)  # (B, 2, num_coarse, S)
        seed = seed.reshape(B, -1, self.num_dense)  # (B, 2, num_fine)

        feature_global = feature_global.unsqueeze(2).expand(-1, -1, self.num_dense)  # (B, 1024, num_fine)
        # 不做全局融合
        feat = torch.cat([point_feat, seed], dim=1)  # (B, 2+3, num_fine)
        # 做全局融合
        # feat = torch.cat([feature_global,point_feat, seed], dim=1)

        # 不用Sigmoid
        # fine = self.final_conv(feat) + point_feat  # (B, 3, num_fine), fine point cloud
        # local_feat=self.final_conv(feat)
        # print('self.final_conv(feat):',local_feat.shape,'point_feat:',point_feat.shape)
        # PCN-GM-SIG Sigmoid确实能够刻画聚拢，但是在远端会出现发散
        fine = (nn.Sigmoid()(self.final_conv(feat)) * 2 - 1) * 0.1 + point_feat
        # 不用Sigmoid是不能刻画聚拢的，始终出现发散
        # fine = ((self.final_conv(feat)) * 2 - 1) * 0.1 + point_feat

        # fine = (nn.Sigmoid()(self.final_conv(feat)) * 2 - 1) * 0.1
        # PCN-GM-TRAINSIG-240828
        # fine= (self.guanzhuang_xingzhuangzi(self.final_conv(feat))* 2 - 1) * 0.1+ point_feat
        # fine = torch.sin(fine_jd)* 0.05+torch.cos(fine_jd)* 0.05+ point_feat
        # fine = fine + fine_A + point_feat

        return coarse.contiguous(), fine.transpose(1, 2).contiguous()

class PCN2Brunch(nn.Module):
    """
    "PCN: Point Cloud Completion Network"
    (https://arxiv.org/pdf/1808.00671.pdf)

    Attributes:
        num_dense:  16384
        latent_dim: 1024
        grid_size:  4
        num_coarse: 1024
    """

    def __init__(self, num_dense=16384, latent_dim=1024, grid_size=4):
        super().__init__()

        self.num_dense = num_dense
        self.latent_dim = latent_dim
        self.grid_size = grid_size

        assert self.num_dense % self.grid_size ** 2 == 0

        self.num_coarse = self.num_dense // (self.grid_size ** 2)

        self.first_conv = nn.Sequential(
            nn.Conv1d(3, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1)
        )

        self.second_conv = nn.Sequential(
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, self.latent_dim, 1)
        )

        self.mlp_slice = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        

        self.mlp_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )

        self.final_conv = nn.Sequential(
            nn.Conv1d(3 + 2, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 3, 1)
        )

        self.final_mlp= nn.Sequential(
            nn.Conv1d(3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Sigmoid()
            # nn.Linear(1, 3),
            # nn.ReLU(inplace=True),
            # nn.Linear(3, 1)
        )
        a = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(1, self.grid_size).expand(
            self.grid_size, self.grid_size).reshape(1, -1)
        b = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(self.grid_size, 1).expand(
            self.grid_size, self.grid_size).reshape(1, -1)

        self.folding_seed = torch.cat([a, b], dim=0).view(1, 2, self.grid_size ** 2).cuda()  # (1, 2, S)

    def forward(self, xyz):
        B, N, _ = xyz.shape
        # xyz_1=copy.deepcopy(xyz)
        # # xyz_1=xyz_1.transpose(2, 1)
        # coarse_slice = self.mlp_affine(xyz_1.transpose(2, 1)).reshape(-1, self.num_coarse, 3)

        # encoder
        feature = self.first_conv(xyz.transpose(2, 1))  # (B,  256, N)
        

        feature_global = torch.max(feature, dim=2, keepdim=True)[0]  # (B,  256, 1)
        feature = torch.cat([feature_global.expand(-1, -1, N), feature], dim=1)  # (B,  512, N)
        feature = self.second_conv(feature)  # (B, 1024, N)
        feature_global = torch.max(feature, dim=2, keepdim=False)[0]  # (B, 1024)

        # decoder
        # coarse_slice = self.mlp_affine(xyz.transpose(2, 1)).reshape(-1, self.num_coarse, 3)
        coarse_slice = self.mlp_slice(feature_global).reshape(-1, self.num_coarse, 
        3)  # (B, num_coarse, 3), coarse point cloud

        # print(coarse_slice_true.shape)
        coarse_shape = self.mlp_shape(feature_global).reshape(-1, self.num_coarse, 3)
        # print('coarse_shape:',coarse_shape.shape)
        # print('coarse_shape[:,:,:3]:', coarse_shape[:,:,:3].shape)

        point_feat = coarse_shape.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1)  # (B, num_coarse, S, 3)
        point_feat = point_feat.reshape(-1, self.num_dense, 3).transpose(2, 1)  # (B, 3, num_fine)

        seed = self.folding_seed.unsqueeze(2).expand(B, -1, self.num_coarse, -1)  # (B, 2, num_coarse, S)
        seed = seed.reshape(B, -1, self.num_dense)  # (B, 2, num_fine)

        feature_global = feature_global.unsqueeze(2).expand(-1, -1, self.num_dense)  # (B, 1024, num_fine)

        feat = torch.cat([point_feat, seed], dim=1)  # (B, 2+3, num_fine)

        # fine = self.final_conv(feat) + point_feat  # (B, 3, num_fine), fine point cloud
        # local_feat=self.final_conv(feat)
        # print('self.final_conv(feat):',local_feat.shape,'point_feat:',point_feat.shape)
        fine_shape = (nn.Sigmoid()(self.final_conv(feat))*2-1)*0.1 + point_feat

        # print(feat.shape,'fine_shape:', fine_shape.shape)
        fine_id=self.final_mlp(fine_shape).transpose(1, 2)#.squeeze(axis=1)
        # print('fine_id:',fine_id.shape)

        return coarse_slice.contiguous(),coarse_shape.contiguous(), fine_shape.transpose(1, 2).contiguous(),fine_id.contiguous()

class SliceDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.discriminator = nn.Sequential(
            nn.Conv1d(3, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 3, 1),
            nn.BatchNorm1d(3),
            nn.ReLU(inplace=True),
            nn.Conv1d(3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Sigmoid()
        )

    def forward(self, coarse_slice):
        coarse_slice_true = self.discriminator(coarse_slice.transpose(2, 1))
        coarse_slice_true = torch.max(coarse_slice_true, dim=2, keepdim=True)[0]
        return coarse_slice_true.contiguous()


class PCN3Brunch(nn.Module):
    """
    "PCN: Point Cloud Completion Network"
    (https://arxiv.org/pdf/1808.00671.pdf)

    Attributes:
        num_dense:  16384
        latent_dim: 1024
        grid_size:  4
        num_coarse: 1024
    """

    def __init__(self, num_dense=16384, latent_dim=1024, grid_size=4):
        super().__init__()

        self.num_dense = num_dense
        self.latent_dim = latent_dim
        self.grid_size = grid_size

        assert self.num_dense % self.grid_size ** 2 == 0

        self.num_coarse = self.num_dense // (self.grid_size ** 2)

        self.first_conv = nn.Sequential(
            nn.Conv1d(3, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1)
        )

        self.second_conv = nn.Sequential(
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, self.latent_dim, 1)
        )

        self.mlp_slice = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )

        self.mlp_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_rotate_slice = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.final_conv = nn.Sequential(
            nn.Conv1d(3 + 2, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 3, 1)
        )

        self.final_mlp = nn.Sequential(
            nn.Conv1d(3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Sigmoid()
            # nn.Linear(1, 3),
            # nn.ReLU(inplace=True),
            # nn.Linear(3, 1)
        )
        a = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(1, self.grid_size).expand(
            self.grid_size, self.grid_size).reshape(1, -1)
        b = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(self.grid_size, 1).expand(
            self.grid_size, self.grid_size).reshape(1, -1)

        self.folding_seed = torch.cat([a, b], dim=0).view(1, 2, self.grid_size ** 2).cuda()  # (1, 2, S)

    def forward(self, xyz):
        B, N, _ = xyz.shape
        # xyz_1=copy.deepcopy(xyz)
        # # xyz_1=xyz_1.transpose(2, 1)
        # coarse_slice = self.mlp_affine(xyz_1.transpose(2, 1)).reshape(-1, self.num_coarse, 3)

        # encoder
        feature = self.first_conv(xyz.transpose(2, 1))  # (B,  256, N)

        feature_global = torch.max(feature, dim=2, keepdim=True)[0]  # (B,  256, 1)
        feature = torch.cat([feature_global.expand(-1, -1, N), feature], dim=1)  # (B,  512, N)
        feature = self.second_conv(feature)  # (B, 1024, N)
        feature_global = torch.max(feature, dim=2, keepdim=False)[0]  # (B, 1024)

        # decoder
        # coarse_slice = self.mlp_affine(xyz.transpose(2, 1)).reshape(-1, self.num_coarse, 3)
        coarse_slice = self.mlp_slice(feature_global).reshape(-1, self.num_coarse,
                                                              3)  # (B, num_coarse, 3), coarse point cloud
        rotate_slice = self.mlp_rotate_slice(feature_global).reshape(-1, self.num_coarse,
                                                              3)  # (B, num_coarse, 3), coarse point cloud

        # print(coarse_slice_true.shape)
        coarse_shape = self.mlp_shape(feature_global).reshape(-1, self.num_coarse, 3)
        # print('coarse_shape:',coarse_shape.shape)
        # print('coarse_shape[:,:,:3]:', coarse_shape[:,:,:3].shape)

        point_feat = coarse_shape.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1)  # (B, num_coarse, S, 3)
        point_feat = point_feat.reshape(-1, self.num_dense, 3).transpose(2, 1)  # (B, 3, num_fine)

        seed = self.folding_seed.unsqueeze(2).expand(B, -1, self.num_coarse, -1)  # (B, 2, num_coarse, S)
        seed = seed.reshape(B, -1, self.num_dense)  # (B, 2, num_fine)

        feature_global = feature_global.unsqueeze(2).expand(-1, -1, self.num_dense)  # (B, 1024, num_fine)

        feat = torch.cat([point_feat, seed], dim=1)  # (B, 2+3, num_fine)

        # fine = self.final_conv(feat) + point_feat  # (B, 3, num_fine), fine point cloud
        # local_feat=self.final_conv(feat)
        # print('self.final_conv(feat):',local_feat.shape,'point_feat:',point_feat.shape)
        fine_shape = (nn.Sigmoid()(self.final_conv(feat)) * 2 - 1) * 0.1 + point_feat


        # print(feat.shape,'fine_shape:', fine_shape.shape)
        fine_id = self.final_mlp(fine_shape).transpose(1, 2)  # .squeeze(axis=1)
        # print('fine_id:',fine_id.shape)

        return (coarse_slice.contiguous(),
                rotate_slice.contiguous(),coarse_shape.contiguous(),
                fine_shape.transpose(1,2).contiguous(), fine_id.contiguous())

class PCN6Brunch(nn.Module):
    """
    "PCN: Point Cloud Completion Network"
    (https://arxiv.org/pdf/1808.00671.pdf)

    Attributes:
        num_dense:  16384
        latent_dim: 1024
        grid_size:  4
        num_coarse: 1024
    """

    def __init__(self, num_dense=16384, latent_dim=1024, grid_size=4):
        super().__init__()

        self.num_dense = num_dense
        self.latent_dim = latent_dim
        self.grid_size = grid_size

        assert self.num_dense % self.grid_size ** 2 == 0

        self.num_coarse = self.num_dense // (self.grid_size ** 2)

        self.first_conv = nn.Sequential(
            nn.Conv1d(3, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1)
        )

        self.second_conv = nn.Sequential(
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, self.latent_dim, 1)
        )

        self.mlp_slice = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )

        self.mlp_rotate_slice = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )

        self.mlp_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_lv_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_rv_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_aro_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_lv_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_la_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_ra_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_myo_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )

        self.final_conv = nn.Sequential(
            nn.Conv1d(3 + 2 , 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 3, 1)
        )
        self.final_conv_lv= nn.Sequential(
            nn.Conv1d(3+2, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 3, 1)
        )
        self.final_conv_rv = nn.Sequential(
            nn.Conv1d(3 + 2, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 3, 1)
        )
        self.final_conv_ra = nn.Sequential(
            nn.Conv1d(3 + 2, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 3, 1)
        )
        self.final_conv_la= nn.Sequential(
            nn.Conv1d(3 + 2, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 3, 1)
        )
        self.final_conv_aro = nn.Sequential(
            nn.Conv1d(3 + 2, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 3, 1)
        )
        self.final_conv_myo = nn.Sequential(
            nn.Conv1d(3 + 2, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 3, 1)
        )

        self.final_mlp = nn.Sequential(
            nn.Conv1d(3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Sigmoid()
            # nn.Linear(1, 3),
            # nn.ReLU(inplace=True),
            # nn.Linear(3, 1)
        )
        a = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(1, self.grid_size).expand(
            self.grid_size, self.grid_size).reshape(1, -1)
        b = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(self.grid_size, 1).expand(
            self.grid_size, self.grid_size).reshape(1, -1)

        self.folding_seed = torch.cat([a, b], dim=0).view(1, 2, self.grid_size ** 2).cuda()  # (1, 2, S)

    def forward(self, xyz):
        B, N, _ = xyz.shape
        # xyz_1=copy.deepcopy(xyz)
        # # xyz_1=xyz_1.transpose(2, 1)
        # coarse_slice = self.mlp_affine(xyz_1.transpose(2, 1)).reshape(-1, self.num_coarse, 3)

        # encoder
        feature = self.first_conv(xyz.transpose(2, 1))  # (B,  256, N)

        feature_global = torch.max(feature, dim=2, keepdim=True)[0]  # (B,  256, 1)
        feature = torch.cat([feature_global.expand(-1, -1, N), feature], dim=1)  # (B,  512, N)
        feature = self.second_conv(feature)  # (B, 1024, N)
        feature_global = torch.max(feature, dim=2, keepdim=False)[0]  # (B, 1024)

        feature_vector=feature_global

        # decoder
        # coarse_slice = self.mlp_affine(xyz.transpose(2, 1)).reshape(-1, self.num_coarse, 3)
        coarse_slice = self.mlp_slice(feature_global).reshape(-1, self.num_coarse,
                                                              3)  # (B, num_coarse, 3), coarse point cloud
        rotate_slice = self.mlp_rotate_slice(feature_global).reshape(-1, self.num_coarse,
                                                              3)
        # print(coarse_slice_true.shape)

        coarse_shape = self.mlp_shape(feature_global).reshape(-1, self.num_coarse, 3)
        # print('coarse_shape:',coarse_shape.shape)

        coarse_lv_pred = self.mlp_lv_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_rv_pred = self.mlp_rv_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_aro_pred = self.mlp_aro_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_la_pred = self.mlp_la_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_ra_pred = self.mlp_ra_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_myo_pred = self.mlp_myo_shape(feature_global).reshape(-1, self.num_coarse, 3)

        # print('coarse_shape[:,:,:3]:', coarse_shape[:,:,:3].shape)
        coarse_shape_component = torch.cat([coarse_lv_pred,
                                            coarse_rv_pred,
                                            coarse_aro_pred,
                                            coarse_la_pred,
                                            coarse_ra_pred,
                                            coarse_myo_pred], dim=1)

        # print('coarse_shape_component:', coarse_shape_component.shape)
        # seed_component = self.folding_seed.unsqueeze(2).expand(B, -1, self.num_coarse//4*6, -1)  # (B, 2, num_coarse, S)
        # seed_component  = seed_component.reshape(B, -1, self.num_dense//4*6)  # (B, 2, num_fine)
        #
        # point_feat_component=coarse_shape_component.unsqueeze(2).expand(-1, -1, 16, -1)
        # # print('point_feat_component:', point_feat_component.shape)
        # point_feat_component =point_feat_component.reshape(-1, self.num_dense//4*6, 3).transpose(2, 1)
        # # print('point_feat_component:', point_feat_component.shape)

        point_feat = coarse_shape.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1)  # (B, num_coarse, S, 3)
        point_feat = point_feat.reshape(-1, self.num_dense, 3).transpose(2, 1)  # (B, 3, num_fine)
        # print('point_feat:', point_feat.shape)

        seed = self.folding_seed.unsqueeze(2).expand(B, -1, self.num_coarse, -1)  # (B, 2, num_coarse, S)
        seed = seed.reshape(B, -1, self.num_dense)  # (B, 2, num_fine)

        feature_global = feature_global.unsqueeze(2).expand(-1, -1, self.num_dense)  # (B, 1024, num_fine)
        feat = torch.cat([point_feat, seed], dim=1)  # (B, 2+3, num_fine)
        # feat_component = torch.cat([point_feat_component, seed_component], dim=1)
        # feat = torch.cat([feature_global,point_feat, seed], dim=1)

        # fine = self.final_conv(feat) + point_feat  # (B, 3, num_fine), fine point cloud
        # local_feat=self.final_conv(feat)
        # print('self.final_conv(feat):',local_feat.shape,'point_feat:',point_feat.shape)
        fine_shape = (nn.Sigmoid()(self.final_conv(feat)) * 2 - 1) * 0.1 + point_feat
        # fine_shape = self.final_conv(feat)+ point_feat
        # fine_shape_component = ((nn.Sigmoid()(self.final_conv(feat_component)) * 2 - 1) * 0.1
        #                         + point_feat_component)

        point_feat = coarse_lv_pred.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1)  # (B, num_coarse, S, 3)
        point_feat = point_feat.reshape(-1, self.num_dense, 3).transpose(2, 1)  # (B, 3, num_fine)
        coarse_lv_=torch.cat([point_feat, seed], dim=1)  # (B, 2+3, num_fine)
        coarse_lv_pred = (nn.Sigmoid()(self.final_conv_lv(coarse_lv_)) * 2 - 1) * 0.1 + point_feat

        point_feat = coarse_la_pred.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1)  # (B, num_coarse, S, 3)
        point_feat= point_feat.reshape(-1, self.num_dense, 3).transpose(2, 1)  # (B, 3, num_fine)
        coarse_la_ = torch.cat([point_feat, seed], dim=1)  # (B, 2+3, num_fine)
        coarse_la_pred = (nn.Sigmoid()(self.final_conv_la(coarse_la_)) * 2 - 1) * 0.1 + point_feat

        point_feat = coarse_rv_pred.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1)  # (B, num_coarse, S, 3)
        point_feat = point_feat.reshape(-1, self.num_dense, 3).transpose(2, 1)  # (B, 3, num_fine)
        coarse_rv_ = torch.cat([point_feat, seed], dim=1)  # (B, 2+3, num_fine)
        coarse_rv_pred = (nn.Sigmoid()(self.final_conv_rv(coarse_rv_)) * 2 - 1) * 0.1 + point_feat

        point_feat = coarse_ra_pred.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1)  # (B, num_coarse, S, 3)
        point_feat = point_feat.reshape(-1, self.num_dense, 3).transpose(2, 1)  # (B, 3, num_fine)
        coarse_ra_ = torch.cat([point_feat, seed], dim=1)  # (B, 2+3, num_fine)
        coarse_ra_pred = (nn.Sigmoid()(self.final_conv_ra(coarse_ra_)) * 2 - 1) * 0.1 + point_feat

        point_feat = coarse_aro_pred.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1)  # (B, num_coarse, S, 3)
        point_feat = point_feat.reshape(-1, self.num_dense, 3).transpose(2, 1)  # (B, 3, num_fine)
        coarse_aro_ = torch.cat([point_feat, seed], dim=1)  # (B, 2+3, num_fine)
        coarse_aro_pred = (nn.Sigmoid()(self.final_conv_aro(coarse_aro_)) * 2 - 1) * 0.1 + point_feat

        point_feat = coarse_myo_pred.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1)  # (B, num_coarse, S, 3)
        point_feat= point_feat.reshape(-1, self.num_dense, 3).transpose(2, 1)  # (B, 3, num_fine)
        coarse_myo_ = torch.cat([point_feat, seed], dim=1)  # (B, 2+3, num_fine)
        coarse_myo_pred = (nn.Sigmoid()(self.final_conv_myo(coarse_myo_)) * 2 - 1) * 0.1 + point_feat
        coarse_lv_pred=coarse_lv_pred.transpose(1, 2)
        coarse_rv_pred = coarse_rv_pred.transpose(1, 2)
        coarse_la_pred = coarse_la_pred.transpose(1, 2)
        coarse_ra_pred = coarse_ra_pred.transpose(1, 2)
        coarse_myo_pred = coarse_myo_pred.transpose(1, 2)
        coarse_aro_pred = coarse_aro_pred.transpose(1, 2)

        fine_shape_component_add = torch.cat(
            [coarse_lv_pred, coarse_rv_pred, coarse_aro_pred, coarse_la_pred, coarse_ra_pred, coarse_myo_pred], dim=1)

        # print(feat.shape,'fine_shape:', fine_shape.shape)
        fine_id = self.final_mlp(fine_shape).transpose(1, 2)  # .squeeze(axis=1)
        # print('fine_id:',fine_id.shape)

        return (coarse_slice.contiguous(),
                rotate_slice.contiguous(),
                coarse_shape.contiguous(),
                coarse_lv_pred.contiguous(),
                coarse_rv_pred.contiguous(),
                coarse_aro_pred.contiguous(),
                coarse_la_pred.contiguous(),
                coarse_ra_pred.contiguous(),
                coarse_myo_pred.contiguous(),
                fine_shape_component_add.contiguous(),
                fine_shape.transpose(1,2).contiguous(),
                feature_vector.contiguous())

class PCNNoBrunch(nn.Module):
    """
    "PCN: Point Cloud Completion Network"
    (https://arxiv.org/pdf/1808.00671.pdf)

    Attributes:
        num_dense:  16384
        latent_dim: 1024
        grid_size:  4
        num_coarse: 1024
    """

    def __init__(self, num_dense=16384, latent_dim=1024, grid_size=4):
        super().__init__()

        self.num_dense = num_dense
        self.latent_dim = latent_dim
        self.grid_size = grid_size

        assert self.num_dense % self.grid_size ** 2 == 0

        self.num_coarse = self.num_dense // (self.grid_size ** 2)

        self.first_conv = nn.Sequential(
            nn.Conv1d(3, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1)
        )

        self.second_conv = nn.Sequential(
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, self.latent_dim, 1)
        )

        self.mlp_slice = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )

        self.mlp_rotate_slice = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )

        self.mlp_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_lv_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_rv_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_aro_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_lv_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_la_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_ra_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_myo_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )

        self.final_conv = nn.Sequential(
            nn.Conv1d(3 + 2 +1024, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 3, 1)
        )

        self.final_mlp = nn.Sequential(
            nn.Conv1d(3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Sigmoid()
            # nn.Linear(1, 3),
            # nn.ReLU(inplace=True),
            # nn.Linear(3, 1)
        )
        a = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(1, self.grid_size).expand(
            self.grid_size, self.grid_size).reshape(1, -1)
        b = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(self.grid_size, 1).expand(
            self.grid_size, self.grid_size).reshape(1, -1)

        self.folding_seed = torch.cat([a, b], dim=0).view(1, 2, self.grid_size ** 2).cuda()  # (1, 2, S)

    def forward(self, xyz):
        B, N, _ = xyz.shape
        # xyz_1=copy.deepcopy(xyz)
        # # xyz_1=xyz_1.transpose(2, 1)
        # coarse_slice = self.mlp_affine(xyz_1.transpose(2, 1)).reshape(-1, self.num_coarse, 3)

        # encoder
        feature = self.first_conv(xyz.transpose(2, 1))  # (B,  256, N)

        feature_global = torch.max(feature, dim=2, keepdim=True)[0]  # (B,  256, 1)
        feature = torch.cat([feature_global.expand(-1, -1, N), feature], dim=1)  # (B,  512, N)
        feature = self.second_conv(feature)  # (B, 1024, N)
        feature_global = torch.max(feature, dim=2, keepdim=False)[0]  # (B, 1024)

        feature_vector=feature_global

        # decoder
        # coarse_slice = self.mlp_affine(xyz.transpose(2, 1)).reshape(-1, self.num_coarse, 3)
        coarse_slice = self.mlp_slice(feature_global).reshape(-1, self.num_coarse,
                                                              3)  # (B, num_coarse, 3), coarse point cloud
        rotate_slice = self.mlp_rotate_slice(feature_global).reshape(-1, self.num_coarse,
                                                              3)
        # print(coarse_slice_true.shape)

        coarse_shape = self.mlp_shape(feature_global).reshape(-1, self.num_coarse, 3)
        # print('coarse_shape:',coarse_shape.shape)

        coarse_lv_pred = self.mlp_lv_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_rv_pred = self.mlp_rv_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_aro_pred = self.mlp_aro_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_la_pred = self.mlp_la_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_ra_pred = self.mlp_ra_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_myo_pred = self.mlp_myo_shape(feature_global).reshape(-1, self.num_coarse, 3)

        # print('coarse_shape[:,:,:3]:', coarse_shape[:,:,:3].shape)
        coarse_shape_component = torch.cat([coarse_lv_pred,
                                            coarse_rv_pred,
                                            coarse_aro_pred,
                                            coarse_la_pred,
                                            coarse_ra_pred,
                                            coarse_myo_pred], dim=1)

        # print('coarse_shape_component:', coarse_shape_component.shape)
        # seed_component = self.folding_seed.unsqueeze(2).expand(B, -1, self.num_coarse//4*6, -1)  # (B, 2, num_coarse, S)
        # seed_component  = seed_component.reshape(B, -1, self.num_dense//4*6)  # (B, 2, num_fine)
        #
        # point_feat_component=coarse_shape_component.unsqueeze(2).expand(-1, -1, 16, -1)
        # # print('point_feat_component:', point_feat_component.shape)
        # point_feat_component =point_feat_component.reshape(-1, self.num_dense//4*6, 3).transpose(2, 1)
        # # print('point_feat_component:', point_feat_component.shape)

        point_feat = coarse_shape.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1)  # (B, num_coarse, S, 3)
        point_feat = point_feat.reshape(-1, self.num_dense, 3).transpose(2, 1)  # (B, 3, num_fine)
        # print('point_feat:', point_feat.shape)

        seed = self.folding_seed.unsqueeze(2).expand(B, -1, self.num_coarse, -1)  # (B, 2, num_coarse, S)
        seed = seed.reshape(B, -1, self.num_dense)  # (B, 2, num_fine)

        feature_global = feature_global.unsqueeze(2).expand(-1, -1, self.num_dense)  # (B, 1024, num_fine)
        # feat = torch.cat([point_feat, seed], dim=1)  # (B, 2+3, num_fine)
        # feat_component = torch.cat([point_feat_component, seed_component], dim=1)
        feat = torch.cat([feature_global,point_feat, seed], dim=1)

        # fine = self.final_conv(feat) + point_feat  # (B, 3, num_fine), fine point cloud
        # local_feat=self.final_conv(feat)
        # print('self.final_conv(feat):',local_feat.shape,'point_feat:',point_feat.shape)
        # fine_shape = (nn.Sigmoid()(self.final_conv(feat)) * 2 - 1) * 0.1 + point_feat
        fine_shape = self.final_conv(feat)+ point_feat

        coarse_lv_pred=coarse_lv_pred.transpose(1, 2)
        coarse_rv_pred = coarse_rv_pred.transpose(1, 2)
        coarse_la_pred = coarse_la_pred.transpose(1, 2)
        coarse_ra_pred = coarse_ra_pred.transpose(1, 2)
        coarse_myo_pred = coarse_myo_pred.transpose(1, 2)
        coarse_aro_pred = coarse_aro_pred.transpose(1, 2)

        fine_shape_component_add = torch.cat(
            [coarse_lv_pred, coarse_rv_pred, coarse_aro_pred, coarse_la_pred, coarse_ra_pred, coarse_myo_pred], dim=1)

        # print(feat.shape,'fine_shape:', fine_shape.shape)
        fine_id = self.final_mlp(fine_shape).transpose(1, 2)  # .squeeze(axis=1)
        # print('fine_id:',fine_id.shape)

        return (coarse_slice.contiguous(),
                rotate_slice.contiguous(),
                coarse_shape.contiguous(),
                coarse_lv_pred.contiguous(),
                coarse_rv_pred.contiguous(),
                coarse_aro_pred.contiguous(),
                coarse_la_pred.contiguous(),
                coarse_ra_pred.contiguous(),
                coarse_myo_pred.contiguous(),
                fine_shape_component_add.contiguous(),
                fine_shape.transpose(1,2).contiguous(),
                feature_vector.contiguous())

class PCN_split_6Brunch(nn.Module):
    """
    "PCN: Point Cloud Completion Network"
    (https://arxiv.org/pdf/1808.00671.pdf)

    Attributes:
        num_dense:  16384
        latent_dim: 1024
        grid_size:  4
        num_coarse: 1024
    """

    def __init__(self, num_dense=16384, latent_dim=1024, grid_size=4):
        super().__init__()

        self.num_dense = num_dense
        self.latent_dim = latent_dim
        self.grid_size = grid_size

        assert self.num_dense % self.grid_size ** 2 == 0

        self.num_coarse = self.num_dense // (self.grid_size ** 2)

        self.first_conv = nn.Sequential(
            nn.Conv1d(3, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1)
        )

        self.second_conv = nn.Sequential(
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, self.latent_dim, 1)
        )

        self.mlp_slice = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )

        self.mlp_rotate_slice = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )

        self.mlp_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_lv_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_rv_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_aro_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_lv_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_la_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_ra_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )
        self.mlp_myo_shape = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 3 * self.num_coarse)
        )

        self.final_conv = nn.Sequential(
            nn.Conv1d(3 + 2, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 3, 1)
        )

        self.final_mlp = nn.Sequential(
            nn.Conv1d(3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Sigmoid()
            # nn.Linear(1, 3),
            # nn.ReLU(inplace=True),
            # nn.Linear(3, 1)
        )
        a = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(1, self.grid_size).expand(
            self.grid_size, self.grid_size).reshape(1, -1)
        b = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(self.grid_size, 1).expand(
            self.grid_size, self.grid_size).reshape(1, -1)

        self.folding_seed = torch.cat([a, b], dim=0).view(1, 2, self.grid_size ** 2).cuda()  # (1, 2, S)

    def forward(self, xyz):
        B, N, _ = xyz.shape
        # xyz_1=copy.deepcopy(xyz)
        # # xyz_1=xyz_1.transpose(2, 1)
        # coarse_slice = self.mlp_affine(xyz_1.transpose(2, 1)).reshape(-1, self.num_coarse, 3)

        # encoder
        feature = self.first_conv(xyz.transpose(2, 1))  # (B,  256, N)

        feature_global = torch.max(feature, dim=2, keepdim=True)[0]  # (B,  256, 1)
        feature = torch.cat([feature_global.expand(-1, -1, N), feature], dim=1)  # (B,  512, N)
        feature = self.second_conv(feature)  # (B, 1024, N)
        feature_global = torch.max(feature, dim=2, keepdim=False)[0]  # (B, 1024)

        # decoder
        # coarse_slice = self.mlp_affine(xyz.transpose(2, 1)).reshape(-1, self.num_coarse, 3)
        # coarse_slice = self.mlp_slice(feature_global).reshape(-1, self.num_coarse,
        #                                                       3)  # (B, num_coarse, 3), coarse point cloud
        # rotate_slice = self.mlp_rotate_slice(feature_global).reshape(-1, self.num_coarse,
        #                                                       3)
        # print(coarse_slice_true.shape)
        coarse_shape = self.mlp_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_lv_pred = self.mlp_lv_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_rv_pred = self.mlp_rv_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_aro_pred = self.mlp_aro_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_la_pred = self.mlp_la_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_ra_pred = self.mlp_ra_shape(feature_global).reshape(-1, self.num_coarse, 3)
        coarse_myo_pred = self.mlp_myo_shape(feature_global).reshape(-1, self.num_coarse, 3)
        # print('coarse_shape:',coarse_shape.shape)
        # print('coarse_shape[:,:,:3]:', coarse_shape[:,:,:3].shape)

        point_feat = coarse_shape.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1)  # (B, num_coarse, S, 3)
        point_feat = point_feat.reshape(-1, self.num_dense, 3).transpose(2, 1)  # (B, 3, num_fine)

        seed = self.folding_seed.unsqueeze(2).expand(B, -1, self.num_coarse, -1)  # (B, 2, num_coarse, S)
        seed = seed.reshape(B, -1, self.num_dense)  # (B, 2, num_fine)

        feature_global = feature_global.unsqueeze(2).expand(-1, -1, self.num_dense)  # (B, 1024, num_fine)
        feat = torch.cat([point_feat, seed], dim=1)  # (B, 2+3, num_fine)
        # feat = torch.cat([feature_global,point_feat, seed], dim=1)

        # fine = self.final_conv(feat) + point_feat  # (B, 3, num_fine), fine point cloud
        # local_feat=self.final_conv(feat)
        # print('self.final_conv(feat):',local_feat.shape,'point_feat:',point_feat.shape)
        fine_shape = (nn.Sigmoid()(self.final_conv(feat)) * 2 - 1) * 0.1 + point_feat
        fine_shape_component = torch.cat(
            [coarse_lv_pred, coarse_rv_pred, coarse_aro_pred, coarse_la_pred, coarse_ra_pred, coarse_myo_pred], dim=1)

        # print(feat.shape,'fine_shape:', fine_shape.shape)
        fine_id = self.final_mlp(fine_shape).transpose(1, 2)  # .squeeze(axis=1)
        # print('fine_id:',fine_id.shape)

        return (coarse_shape.contiguous(),
                coarse_lv_pred.contiguous(),
                coarse_rv_pred.contiguous(),
                coarse_aro_pred.contiguous(),
                coarse_la_pred.contiguous(),
                coarse_ra_pred.contiguous(),
                coarse_myo_pred.contiguous(),
                fine_shape_component.contiguous(),
                fine_shape.transpose(1,2).contiguous(),
                fine_id.contiguous())