import pdb, sys
sys.path.append('utils/point_cloud_query')
sys.path.append('utils/pointnet2')
from pointnet2_modules import PointnetFPModule,PointnetSAModule

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple




def get_activation(activation):
    if activation.lower() == 'relu':
        return nn.ReLU(inplace=True)
    elif activation.lower() == 'leakyrelu':
        return nn.LeakyReLU(inplace=True)
    elif activation.lower() == 'sigmoid':
        return nn.Sigmoid()
    elif activation.lower() == 'softplus':
        return nn.Softplus()
    elif activation.lower() == 'gelu':
        return nn.GELU()
    elif activation.lower() == 'selu':
        return nn.SELU(inplace=True)
    elif activation.lower() == 'mish':
        return nn.Mish(inplace=True)
    else:
        raise Exception("Activation Function Error")


def get_norm(norm, width):
    if norm == 'LN':
        return nn.LayerNorm(width)
    elif norm == 'BN':
        return nn.BatchNorm1d(width)
    elif norm == 'IN':
        return nn.InstanceNorm1d(width)
    elif norm == 'GN':
        return nn.GroupNorm(width)
    else:
        raise Exception("Normalization Layer Error")

class NeuralPCI_Layer(torch.nn.Module):
    def __init__(self, 
                 dim_in,
                 dim_out,
                 norm=None, 
                 act_fn=None
                 ):
        super().__init__()
        layer_list = []
        layer_list.append(nn.Linear(dim_in, dim_out))
        if norm:
            layer_list.append(get_norm(norm, dim_out))
        if act_fn:
            layer_list.append(get_activation(act_fn))
        self.layer = nn.Sequential(*layer_list)

    def forward(self, x):
        x = self.layer(x)
        return x


class NeuralPCI_Block(torch.nn.Module):
    def __init__(self, 
                 depth, 
                 width,
                 norm=None, 
                 act_fn=None
                 ):
        super().__init__()
        layer_list = []
        for _ in range(depth):
            layer_list.append(nn.Linear(width, width))
            if norm:
                layer_list.append(get_norm(norm, width))
            if act_fn:
                layer_list.append(get_activation(act_fn))
        self.mlp = nn.Sequential(*layer_list)

    def forward(self, x):
        x = self.mlp(x)
        return x
    
class AttentionFusion(nn.Module):
    def __init__(self, feature_dim):
        super().__init__()
        self.q_conv = nn.Conv1d(feature_dim, feature_dim // 2, 1, bias=False)
        self.k_conv = nn.Conv1d(feature_dim, feature_dim // 2, 1, bias=False)
        self.v_conv = nn.Conv1d(feature_dim, feature_dim, 1)
        self.trans_conv = nn.Conv1d(feature_dim, feature_dim, 1)
        self.after_norm = nn.BatchNorm1d(feature_dim)
        self.act = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, x_hid):
        x_q = self.q_conv(x_hid).permute(0, 2, 1)  # b, n, c 
        x_k = self.k_conv(x)  # b, c, n        
        x_v = self.v_conv(x)  # b, c, n
        x_hid = x_hid.permute(0, 2, 1)  # b, n, c

        energy = torch.bmm(x_q, x_k)  # b, n, n 
        attention = self.softmax(energy)

        x_r = torch.bmm(x_v, attention)  # b, c, n
        x_r = x_r.permute(0, 2, 1)  # b, n, c

        x_fused = x_hid + x_r  # b, n, c
        x_fused = x_fused.permute(0, 2, 1)  # b, c, n

        x_fused = self.trans_conv(x_fused)
        x_fused = self.act(self.after_norm(x_fused))  # b, c, n

        return x_fused.permute(0, 2, 1)  # b, n, c


class Gate_AttentionFusion(nn.Module):
    def __init__(self, x_dim, hid_dim):
        super().__init__()
        self.q_conv = nn.Conv1d(hid_dim, hid_dim, 1, bias=False)
        self.k_conv = nn.Conv1d(x_dim, hid_dim, 1, bias=False)
        self.v_conv = nn.Conv1d(x_dim, hid_dim, 1)
        self.trans_conv = nn.Conv1d(hid_dim, hid_dim, 1)
        self.after_norm = nn.BatchNorm1d(hid_dim)
        self.act = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)
        
        self.gate_fc = nn.Conv1d(x_dim + hid_dim, hid_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, x_hid):
        x_q = self.q_conv(x_hid)  # b, hid_dim//4, n
        x_k = self.k_conv(x)  # b, hid_dim//4, n
        x_v = self.v_conv(x)  # b, hid_dim, n

        energy = torch.bmm(x_q.permute(0, 2, 1), x_k)  # b, n, n
        attention = self.softmax(energy)

        x_r = torch.bmm(x_v, attention)  # b, hid_dim, n

        gate_input = torch.cat([x, x_r], dim=1)  # b, x_dim + hid_dim, n
        gate = self.sigmoid(self.gate_fc(gate_input))  # b, hid_dim, n

        x_fused = x_hid * gate + x_r * (1 - gate)  # b, hid_dim, n

        x_fused = self.trans_conv(x_fused)
        x_fused = self.act(self.after_norm(x_fused))  # b, hid_dim, n

        return x_fused

class GraphLearner(nn.Module):
    def __init__(self, feature_dim, distribution_type='bernoulli'):
        super().__init__()
        self.feature_dim = feature_dim
        self.distribution_type = distribution_type
        
        if distribution_type == 'bernoulli':
            self.prob_mlp = nn.Sequential(
                nn.Linear(2*feature_dim, feature_dim),
                nn.ReLU(inplace=True),
                nn.Linear(feature_dim, 1),
                nn.Sigmoid()
            )
        elif distribution_type == 'sparse_sampling':
            self.attn_mlp = nn.Sequential(
                nn.Linear(2*feature_dim, feature_dim),
                nn.ReLU(inplace=True),
                nn.Linear(feature_dim, 1)
            )
        else:
            raise ValueError(f'Unsupported distribution type: {distribution_type}')
            
    def forward(self, X):
        N = X.shape[0]
        X_pair = torch.cat([X.unsqueeze(1).expand(-1,N,-1), X.unsqueeze(0).expand(N,-1,-1)], dim=-1)
        
        if self.distribution_type == 'bernoulli':
            probs = self.prob_mlp(X_pair).squeeze(-1)  # [N, N]
            return torch.distributions.Independent(torch.distributions.Bernoulli(probs), 1)
            
        elif self.distribution_type == 'sparse_sampling':
            attn_scores = self.attn_mlp(X_pair).squeeze(-1)  # [N, N]
            attn_scores = attn_scores.transpose(0, 1)  # [N, N]
            top_k_mask = torch.zeros_like(attn_scores).scatter(-1, attn_scores.topk(self.k, dim=-1)[1], 1.0)
            sampled_mask = torch.bernoulli(top_k_mask)
            return sampled_mask

class GraphConv(nn.Module):
    def __init__(self, in_channels, out_channels, bias=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.weight = nn.Parameter(torch.Tensor(in_channels, out_channels))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
        
    def forward(self, x, adj):
        support = torch.matmul(x, self.weight)
        output = torch.matmul(adj, support)
        
        if self.bias is not None:
            output = output + self.bias
        
        return output

# class GaussianPointCloud(nn.Module):
#     def __init__(self, input_dim, num_gaussians):
#         """
#         input_dim: int, input feature dimension
#         num_gaussians: int, number of Gaussian components
#         """
#         super(GaussianPointCloud, self).__init__()
#         self.num_gaussians = num_gaussians
        
#         self.mean_conv = nn.Conv1d(input_dim, num_gaussians*3, 1)
#         self.cov_conv = nn.Conv1d(input_dim, num_gaussians*9, 1)
        
#     def forward(self, x):
#         """
#         x: tensor, [B, input_dim, N]
#         output: tuple(tensor), Gaussian mean and covariance matrix, 
#                 each of shape [B, 3, num_gaussians] and [B, 3, 3, num_gaussians]
#         """
#         batch_size = x.shape[0]
#         num_points = x.shape[2]
                
#         means = self.mean_conv(x) # [B, 3*num_gaussians, N]
#         means = means.view(batch_size, 3, self.num_gaussians, num_points)
#         means = means.transpose(2, 3).contiguous()
#         means = means.mean(dim=2, keepdim=True) # [B, 3, 1, num_gaussians]
        
        
#         S = self.cov_conv(x) # [B, 9*num_gaussians, N]  
#         S = S.view(batch_size, 9, self.num_gaussians, num_points)
#         S = S.transpose(2, 3).contiguous()
#         S = S.mean(dim=2) # [B, 9, num_gaussians]
        
#         cov = torch.zeros(batch_size, 3, 3, self.num_gaussians).to(x.device)
#         cov[:, 0, 0] = S[:, 0]
#         cov[:, 1, 1] = S[:, 4]
#         cov[:, 2, 2] = S[:, 8]
#         cov[:, 0, 1] = cov[:, 1, 0] = S[:, 1]
#         cov[:, 0, 2] = cov[:, 2, 0] = S[:, 2]
#         cov[:, 1, 2] = cov[:, 2, 1] = S[:, 5]
        
#         return means.squeeze(2), cov
class GaussianPointCloud(nn.Module):    
    def __init__(self, input_dim, num_gaussians):
        super(GaussianPointCloud, self).__init__()
        self.num_gaussians = num_gaussians
        
        self.mean_conv = nn.Conv1d(input_dim, num_gaussians * 3, 1)
        # self.cov_conv = nn.Conv1d(input_dim, num_gaussians * 9, 1)

        self.log_cholesky = LogCholeskyLayer(num_gaussians)
        
    def forward(self, x):
        batch_size = x.shape[0]
        num_points = x.shape[2]
                
        means = self.mean_conv(x)  # [B, 3 * num_gaussians, N]
        means = means.view(batch_size, 3, self.num_gaussians, num_points)
        means = means.transpose(2, 3).contiguous()
        means = means.mean(dim=2, keepdim=True)  # [B, 3, 1, num_gaussians]
        
        cov = self.log_cholesky().unsqueeze(0).expand(batch_size, -1, -1, -1)  # [B, 3, 3, num_gaussians]
        
        return means.squeeze(2), cov
    

class DeformationField(nn.Module):
    def __init__(self, feature_dim, output_dim):
        """
        feature_dim: int, number of input feature dimensions (the dimensionality of l3_st_points)
        output_dim: int, number of output dimensions (should be 3 for 3D point clouds)
        """
        super(DeformationField, self).__init__()
        # Define the architecture of the MLP
        self.fc1 = nn.Linear(feature_dim + 1, feature_dim // 2)
        self.fc2 = nn.Linear(feature_dim // 2, feature_dim // 4)
        self.mean_fc = nn.Linear(feature_dim // 4, output_dim)
        # self.cov_fc = nn.Linear(feature_dim // 4, output_dim * output_dim)
        self.cov_fc = nn.Linear(feature_dim // 4, output_dim * (output_dim + 1) // 2)
        
        # Initialize the weights of the MLP
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.constant_(self.fc1.bias, 0)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.constant_(self.fc2.bias, 0)
        nn.init.xavier_uniform_(self.mean_fc.weight)
        nn.init.constant_(self.mean_fc.bias, 0)
        nn.init.xavier_uniform_(self.cov_fc.weight)
        nn.init.constant_(self.cov_fc.bias, 0)

    def forward(self, feat_gaussian, xyz_gaussian, time_diff):
        # Flatten the Gaussian mean along the batch dimension
        B, C, N = feat_gaussian.shape
        _, D, M = xyz_gaussian.shape
        xyz_gaussian_flat = xyz_gaussian.view(B, D * M)

        # Repeat the xyz_gaussian for each point in feat_gaussian
        xyz_gaussian_flat = xyz_gaussian_flat.unsqueeze(2).repeat(1, 1, N)

        # Concatenate the feat_gaussian and xyz_gaussian_flat along the feature dimension
        combined_features = torch.cat([feat_gaussian, xyz_gaussian_flat], dim=1)

        # Ensure time_diff is broadcastable with combined_features
        # It should have the same size for the batch and points dimensions
        time_diff = time_diff.expand(B, N, -1).transpose(1, 2)

        # Concatenate the combined_features with time_diff along the feature dimension
        combined_features = torch.cat([combined_features, time_diff], dim=1)

        # Flatten the features to feed into MLP
        combined_features = combined_features.view(-1, combined_features.shape[1])

        # Forward pass through the MLP
        x = F.relu(self.fc1(combined_features))
        x = F.relu(self.fc2(x))

        # Output the mean offset
        mean_offset = self.mean_fc(x).view(B, N, -1)

        # Output the covariance offset and reshape appropriately
        # cov_offset = self.cov_fc(x).view(B, N, D, D)
        chol_elements = self.cov_fc(x).view(B, N, -1)
        cov_offset = torch.zeros(B, N, D, D, device=chol_elements.device)
        tril_indices = torch.tril_indices(row=D, col=D, offset=0)
        for i, (row, col) in enumerate(zip(tril_indices[0], tril_indices[1])):
            cov_offset[..., row, col] = chol_elements[..., i]
        cov_offset = torch.matmul(cov_offset, cov_offset.transpose(-1, -2))
        return mean_offset, cov_offset


class SpatioTemporalSA(nn.Module):
    def __init__(self, in_channels, mlp_channels, time_dim=1):
        """
        in_channels: int, number of input channels
        mlp_channels: list[int], number of channels in each mlp layer
        time_dim: int, dimension of time input
        """
        super(SpatioTemporalSA, self).__init__()
        self.time_dim = time_dim
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList() 

        last_channel = in_channels + time_dim
        for out_channel in mlp_channels:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm1d(out_channel))
            last_channel = out_channel

    def forward(self, xyz, points, time):
        """
        xyz: tensor, [B, N, 3], input point coordinates
        points: tensor, [B, C, N], input point features
        time: tensor, [B, N, time_dim], input time values
        output: 
            new_xyz: tensor, [B, N, 3], output point coordinates
            new_points: tensor, [B, C', N], output point features      
        """
        
        # Concatenate coordinates, features and time
        new_points = torch.cat([xyz, points.permute(0, 2, 1).contiguous(), time], dim=-1)
        
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points.permute(0, 2, 1).unsqueeze(-1)).squeeze(-1))).permute(0, 2, 1).contiguous()

        new_xyz = xyz
        return new_xyz, new_points.permute(0, 2, 1)

class LogCholeskyLayer(nn.Module):
    def __init__(self, num_gaussians):
        super(LogCholeskyLayer, self).__init__()
        self.log_diag = nn.Parameter(torch.randn(num_gaussians, 3))
        self.off_diag = nn.Parameter(torch.randn(num_gaussians, 3 * (3 - 1) // 2))
        self.num_gaussians = num_gaussians
    
    def forward(self):
        chol = torch.zeros(self.num_gaussians, 3, 3, device=self.log_diag.device)
        chol[:, range(3), range(3)] = torch.exp(self.log_diag)
        tril_indices = torch.tril_indices(row=3, col=3, offset=-1)
        chol[:, tril_indices[0], tril_indices[1]] = self.off_diag
        cov = chol @ chol.transpose(-1, -2)
        return cov

class NeuralPCI(torch.nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args
        dim_pc = args.dim_pc
        dim_time = args.dim_time
        layer_width = args.layer_width 
        act_fn = args.act_fn
        norm = args.norm
        depth_encode = args.depth_encode
        depth_pred = args.depth_pred
        pe_mul = args.pe_mul

        if args.use_rrf:
            dim_rrf = args.dim_rrf
            self.transform = 0.1 * torch.normal(0, 1, size=[dim_pc, dim_rrf]).cuda()
        else:
            dim_rrf = dim_pc

        # input layer
        self.layer_input = NeuralPCI_Layer(dim_in = (dim_rrf + dim_time) * pe_mul, 
                                           dim_out = layer_width, 
                                           norm = norm,
                                           act_fn = act_fn
                                           )
        self.hidden_encode = NeuralPCI_Block(depth = depth_encode, 
                                             width = layer_width, 
                                             norm = norm,
                                             act_fn = act_fn
                                             )

        if args.graph_learner:     ##  args.
            self.graph_learner = GraphLearner(feature_dim=layer_width, distribution_type='bernoulli')
            self.graph_conv = GraphConv(layer_width, layer_width)
            self.attention_Graph_Fu = AttentionFusion(feature_dim=layer_width)
        if args.TransL1_Fu:
            self.attention_TransL1_Fu = AttentionFusion(feature_dim=layer_width)

        if args.Gaussians4D:
            self.args_n_gaussians = args.n_gaussians
            self.st_sa = SpatioTemporalSA(in_channels=layer_width+3, mlp_channels=[layer_width//2, layer_width], time_dim=1)
            self.gaussian_pc = GaussianPointCloud(layer_width + 3, args.n_gaussians)
            self.deform_field = DeformationField(layer_width + 3*args.n_gaussians, 3)
            self.num_points_per_gaussian = (args.num_points // 16)//args.n_gaussians

            self.Gate_fusion_module = Gate_AttentionFusion(layer_width+12, layer_width)
        
        # hidden layers with PointNet2 multi-scale feature extraction
        if args.multi_scale:
            self.sa1 = PointnetSAModule(
                npoint=args.num_points // 4, #if args.num_points >= 2048 else args.num_points // 2,
                radius=0.2 if args.num_points >= 2048 else 0.4,
                nsample=8 if args.num_points >= 2048 else 16,
                mlp=[layer_width, layer_width//8, layer_width//8, layer_width//4],
                use_xyz=True,
                bn=True
            ) 
            self.sa2 = PointnetSAModule(
                npoint=args.num_points // 8, # if args.num_points >= 2048 else args.num_points // 2,
                radius=0.4 if args.num_points >= 2048 else 0.8,
                nsample=16 if args.num_points >= 2048 else 32,
                mlp=[layer_width//4, layer_width//4, layer_width//4, layer_width//2],
                use_xyz=True,
                bn=True
            )
            self.sa3 = PointnetSAModule(
                npoint=args.num_points // 16, # if args.num_points >= 2048 else args.num_points // 2,
                radius=0.6 if args.num_points >= 2048 else 1.2,
                nsample=16 if args.num_points >= 2048 else 64,
                mlp=[layer_width//2, layer_width//2, layer_width, layer_width],
                use_xyz=True,
                bn=True
            )

            self.fp3 = PointnetFPModule(mlp=[layer_width + layer_width//2, layer_width, layer_width])  #*2   +12
            self.fp2 = PointnetFPModule(mlp=[layer_width + layer_width//4, layer_width//2, layer_width//2])
            self.fp1 = PointnetFPModule(mlp=[layer_width//2 + layer_width, layer_width//2, layer_width//2, layer_width])

            self.conv1 = nn.Conv1d(layer_width, layer_width, 1)
            self.bn1 = nn.BatchNorm1d(layer_width)
            self.drop1 = nn.Dropout(0.15)
            self.conv2 = nn.Conv1d(layer_width, layer_width, 1)

        # insert interpolation time
        self.layer_time = NeuralPCI_Layer(dim_in = layer_width + dim_time * pe_mul, 
                                          dim_out = layer_width, 
                                          norm = norm,
                                          act_fn = act_fn
                                          )

        # hidden layers
        self.hidden_pred = NeuralPCI_Block(depth = depth_pred, 
                                           width = layer_width, 
                                           norm = norm,
                                           act_fn = act_fn
                                           )

        # output layer
        self.layer_output = NeuralPCI_Layer(dim_in = layer_width, 
                                          dim_out = dim_pc, 
                                          norm = norm,
                                          act_fn = None
                                          )
        
        # zero init for last layer
        if args.zero_init:
            for m in self.layer_output.layer:
                if isinstance(m, nn.Linear):
                    # torch.nn.init.normal_(m.weight.data, 0, 0.01)
                    m.weight.data.zero_()
                    m.bias.data.zero_()

        self.graph_learner_sign, self.multi_scale, self.TransL1_Fu, self.Graph_Fu, self.Gaussians4D = args.graph_learner, args.multi_scale, args.TransL1_Fu, args.Graph_Fu, args.Gaussians4D
    def posenc(self, x):
        """
        sinusoidal positional encoding : N ——> 3 * N
        [x] ——> [x, sin(x), cos(x)]
        """
        sinx = torch.sin(x)
        cosx = torch.cos(x)
        x = torch.cat((x, sinx, cosx), dim=1)
        return x

    def forward(self, pc_current, time_current, time_pred, train=True):
        """
        pc_current: tensor, [N, 3]
        time_current: float, [1]
        time_pred: float, [1]
        output: tensor, [N, 3]
        """
        time_current_l3 = torch.tensor(time_current).unsqueeze(0).cuda().float().detach()
        time_pred_l3 = torch.tensor(time_pred).unsqueeze(0).cuda().float().detach()
        time_current = torch.tensor(time_current).repeat(pc_current.shape[0], 1).cuda().float().detach()
        time_pred = torch.tensor(time_pred).repeat(pc_current.shape[0], 1).cuda().float().detach()
        
        if self.args.use_rrf:
            pc_current = torch.matmul(2. * torch.pi * pc_current, self.transform)

        x = torch.cat((pc_current, time_current), dim=1)
        x = self.posenc(x)
        x = self.layer_input(x)
        x_hid = self.hidden_encode(x)

        # PointNet2 multi-scale feature extraction
        selected_xyz_pred = None
        if self.multi_scale:
            pc_current = pc_current.unsqueeze(0)  # Add batch dimension [1, N, 3]

            x_dim = x.unsqueeze(0).permute(0,2,1).contiguous()
            l1_xyz, l1_points = self.sa1(pc_current, x_dim)
            l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)

            l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)

            # Extract spatio-temporal feature of l3_points
            l3_time = time_current_l3.unsqueeze(1).repeat(1, l3_xyz.shape[1], 1)  # [B, N, 1]
            l3_pred_time = time_pred_l3.unsqueeze(1).repeat(1, l3_xyz.shape[1], 1)  # [B, N, 1]
            
            if self.Gaussians4D:
                _, l3_st_points = self.st_sa(l3_xyz, l3_points, l3_time)  # l3_st_xyz

                # Get Gaussian representation of l3 point cloud
                l3_gaussian_mean, l3_gaussian_cov = self.gaussian_pc(torch.cat([l3_st_points, l3_xyz.transpose(1, 2)], dim=1))

                # Predict Gaussian deformation between t_current and t_pred        
                time_diff = (l3_pred_time - l3_time)#.unsqueeze(0)

                mean_offset, cov_offset = self.deform_field(l3_st_points, l3_gaussian_mean, time_diff)

                mean_offset = mean_offset.view(1, self.args_n_gaussians, self.num_points_per_gaussian, 3)
                cov_offset = cov_offset.view(1, self.args_n_gaussians, self.num_points_per_gaussian, 3, 3)
                mean_offset = mean_offset.mean(dim=2)  # [1, 16, 3]
                cov_offset = cov_offset.mean(dim=2)  # [1, 16, 3, 3]
                mean_offset = mean_offset.transpose(1, 2) 

                gaussian_mean_pred = l3_gaussian_mean + mean_offset
                gaussian_cov_pred = l3_gaussian_cov + cov_offset         
                
                # Sample points from predicted Gaussian 
                # xyz_pred = self.gaussian_sampling(gaussian_mean_pred, gaussian_cov_pred, l3_xyz.shape[1]).unsqueeze(0)
                # xyz_pred = self.gaussian_sampling(gaussian_mean_pred, gaussian_cov_pred, l3_xyz.shape[1]//self.args_n_gaussians).unsqueeze(0)
                # xyz_pred_features, selected_xyz_pred = self.get_xyz_pred_features(gaussian_mean_pred, gaussian_cov_pred, l3_xyz, l3_points, xyz_pred)
                xyz_pred_features, selected_xyz_pred = self.get_xyz_pred_features(gaussian_mean_pred, gaussian_cov_pred, l3_xyz, l3_points, l3_xyz)
                l3_points_GS = self.Gate_fusion_module(xyz_pred_features.permute(0, 2, 1).contiguous(),l3_points)
                if not self.Graph_Fu:
                    l3_points_fu = l3_points_GS
            if self.graph_learner_sign:
                if self.graph_learner.distribution_type == 'bernoulli':
                    A_dist = self.graph_learner(l3_points.squeeze(0).permute(1, 0).contiguous())  
                    self.graph_learner.A_dist = A_dist  
                    if train:
                        A_sample = A_dist.sample()
                    else:
                        A_sample = A_dist.mean
                else:  # 'sparse_sampling'
                    A_sample = self.graph_learner(l3_points.squeeze(0).permute(1, 0).contiguous())
                
                self.graph_learner.A_sample = A_sample  

                graph_conv_output = self.graph_conv(l3_points.squeeze(0).permute(1, 0).contiguous(), A_sample).permute(1, 0).unsqueeze(0).contiguous() 
                # l3_points = torch.cat([l3_points, graph_conv_output], dim=-1) 
                # l3_points_Graph = l3_points.permute(1, 0).unsqueeze(0).contiguous()
                if not self.Graph_Fu:
                    l3_points_fu = self.attention_Graph_Fu(l3_points,graph_conv_output).permute(0, 2, 1)
                    # pdb.set_trace()
            if self.Graph_Fu:
                l3_points_fu = self.attention_Graph_Fu(l3_points_GS,graph_conv_output).permute(0, 2, 1)
            if selected_xyz_pred is not None:
                l2_points = self.fp3(l2_xyz, selected_xyz_pred, l2_points, l3_points_fu.contiguous())
            else:
                l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points_fu.contiguous())
            l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
            x = self.fp1(pc_current, l1_xyz, x_dim, l1_points)
            
            # Remove batch dimension [N, C]
            x = F.relu(self.bn1(self.conv1(x)))
            x = self.drop1(x)
            x = self.conv2(x)
            # x = x.squeeze().permute(1,0).contiguous()
        if self.TransL1_Fu and self.multi_scale:
            x = self.attention_TransL1_Fu(x, x_hid.unsqueeze(0).permute(0, 2, 1)).squeeze() # input: torch.Size([1, 256, 8192])
        else:
            x = x_hid
        
        time_pred = self.posenc(time_pred)
        x_pred = torch.cat((x, time_pred), dim=1)
        x_pred = self.layer_time(x_pred)

        x_pred = self.hidden_pred(x_pred)

        x_pred = self.layer_output(x_pred)
        return x_pred
    
    def gaussian_sampling(self,mean, cov, num_points):
        """
        mean: tensor, [1, 3, N], Gaussian mean
        cov: tensor, [1, 3, 3, N], Gaussian covariance matrix
        num_points: int, number of points to sample
        output: tensor, [num_points, 3], sampled points
        """
        N = mean.shape[2]
        mean = mean.transpose(1, 2).contiguous().view(-1, 3)  # [N, 3]
        cov = cov.transpose(2, 3).contiguous().view(-1, 3, 3)  # [N, 3, 3]

        epsilon = 1e-7  
        I = torch.eye(cov.size(-1), device=cov.device)
        cov += epsilon * I
        
        samples = []
        for i in range(N):
            rv = torch.distributions.MultivariateNormal(mean[i], cov[i])
            samples.append(rv.sample((num_points,)))
        
        samples = torch.cat(samples, dim=0)  # [N*num_points, 3] 
        return samples
    
    def get_xyz_pred_features(self, gaussian_mean_pred, gaussian_cov_pred, l3_xyz, l3_points, xyz_pred):
        if xyz_pred.shape[1] > l3_xyz.shape[1]:
            dist_matrix_pred = torch.cdist(xyz_pred, l3_xyz, p=2)  # [1, N_xyz_pred, N_l3_xyz]

            _, nearest_idx_pred = torch.min(dist_matrix_pred, dim=2)  # [1, N_xyz_pred]

            votes = torch.bincount(nearest_idx_pred.squeeze(0), minlength=l3_xyz.shape[1])

            _, selected_indices = torch.topk(votes, l3_xyz.shape[1])

            selected_xyz_pred = xyz_pred[:, selected_indices, :]
        else:
            selected_xyz_pred = xyz_pred
        gaussian_mean_expanded = gaussian_mean_pred.transpose(1, 2)  # [1, 16, 3]
        dist_matrix = torch.cdist(l3_xyz, gaussian_mean_expanded, p=2)  # [1, 256, 16]
        _, nearest_idx = torch.min(dist_matrix, dim=-1)  # [1, 256]
        nearest_gaussian_mean = torch.gather(gaussian_mean_expanded, 1, nearest_idx.unsqueeze(-1).expand(-1, -1, 3))  # [1, 256, 3] 
        nearest_gaussian_cov = torch.gather(gaussian_cov_pred, 1, nearest_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 3, 3))  # [1, 256, 3, 3]

        gaussian_features = torch.cat([nearest_gaussian_mean, nearest_gaussian_cov.view(1, nearest_gaussian_mean.shape[1], 9)], dim=-1)  # [1, 256, 12]
        
        l3_features_with_gaussian = torch.cat([l3_points.transpose(1, 2), gaussian_features], dim=-1)  # [1, 256, 524]

        dist_matrix_pred = torch.cdist(selected_xyz_pred, l3_xyz, p=2)  # [1, 512, 256]  
        _, nearest_idx_pred = torch.min(dist_matrix_pred, dim=-1)  # [1, 512]
        xyz_pred_features = torch.gather(l3_features_with_gaussian, 1, nearest_idx_pred.unsqueeze(-1).expand(-1, -1, l3_features_with_gaussian.shape[-1]))
        return xyz_pred_features, selected_xyz_pred




if __name__ == "__main__":
    # Test the NeuralPCI model
    class Args:
        dim_pc = 3
        dim_time = 1
        num_points= 8192
        layer_width = 128
        act_fn = 'leakyrelu'  # Corrected from 'leaky_relu' to 'leakyrelu'
        norm = None
        depth_encode = 3
        depth_pred = 3
        pe_mul = 3
        use_rrf = False
        zero_init = True

    args = Args()
    model = NeuralPCI(args).cuda()

    # Generate random input data
    pc_current = torch.randn(8192, 3).cuda()
    time_current = 0.0
    time_pred = 0.5

    # Forward pass
    output = model(pc_current, time_current, time_pred)

    print(output.shape)  # Should be [1024, 3]