import torch
import torch.nn as nn
import torch.nn.functional
from typing import List
import torch.nn.functional as F
from triposf.modules.pointclouds.pointnet import LocalPoolPointnet

class VertexContextEncoder(nn.Module):
    def __init__(self, embed_dim=32, nhead=4, resolution=1024, multires=10, using_nerf=True, relative_embed=False,
                 add_block_embed=False, block_size=64, vtx_embedding_only=True, max_group_size=128,
                 ):
        super().__init__()
        self.embed_dim = embed_dim
        self.resolution = resolution
        self.using_nerf = using_nerf
        self.relative_embed = relative_embed
        self.add_block_embed = add_block_embed
        self.block_size = block_size if add_block_embed else None
        self.vtx_embedding_only = vtx_embedding_only
        self.max_group_size = max_group_size


        # if using_nerf:
        if False:
            self.pe_nerf = PE_NeRF(out_channels=embed_dim, multires=multires)
            # PE_NeRF for encoding relative directions. Output embed_dim for each direction.
            self.pe_nerf_face = PE_NeRF(out_channels=embed_dim, multires=multires)
        else:
            self.pe_nerf_face = PE_NeRF(out_channels=embed_dim, multires=multires) #need be defined all time
            if not relative_embed:
                self.coord_embed_x = nn.Embedding(resolution, embed_dim)
                self.coord_embed_y = nn.Embedding(resolution, embed_dim)
                self.coord_embed_z = nn.Embedding(resolution, embed_dim)
            else:
                self.coord_embed_x = nn.Embedding(16, embed_dim)
                self.coord_embed_y = nn.Embedding(16, embed_dim)
                self.coord_embed_z = nn.Embedding(16, embed_dim)

                if add_block_embed:
                    self.block_embed_x = nn.Embedding(resolution // block_size, embed_dim)
                    self.block_embed_y = nn.Embedding(resolution // block_size, embed_dim)
                    self.block_embed_z = nn.Embedding(resolution // block_size, embed_dim)

        self.normal_encoder = PE_NeRF(out_channels=embed_dim, multires=multires)

        # Face feature fusion for generating the Keys and Values for attention
        # Input is 2 * embed_dim (from two relative directions) + embed_dim (from normal).
        self.face_kv_fusion = nn.Sequential( # Renamed from face_query_fusion
            nn.Linear(embed_dim * 3, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )

        # MultiheadAttention now expects Queries from vertices, and Keys/Values from faces
        self.attn = nn.MultiheadAttention(embed_dim, nhead, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)

        # Query projection is for the *input vertex features* which become Q
        self.vertex_query_proj = nn.Linear(embed_dim, embed_dim) # Renamed from vertex_kv_proj
        self.embed_proj = nn.Linear(3 * embed_dim, embed_dim) # using when embed

    def encode_vertex(self, coords):
        # if self.using_nerf:
        if False:
            print('using nerf when encode')
            if not self.relative_embed:
                normalized_coords = 2.0 * (coords.float() / self.resolution) - 1.0
                return self.pe_nerf(normalized_coords)
            else:
                normalized_coords = 2.0 * ((coords % 16).float() / 16.) - 1.0
                return self.pe_nerf(normalized_coords)
        else:
            if not self.relative_embed:
                print('not relative_embed')
                x = self.coord_embed_x(coords[:, 0])
                y = self.coord_embed_y(coords[:, 1])
                z = self.coord_embed_z(coords[:, 2])
            else:
                print('relative_embed')
                x = self.coord_embed_x(coords[:, 0] % 16)
                y = self.coord_embed_y(coords[:, 1] % 16)
                z = self.coord_embed_z(coords[:, 2] % 16)

                if self.add_block_embed:
                    print('add block embed')
                    block_x = self.block_embed_x(coords[:, 0] // self.block_size)
                    block_y = self.block_embed_y(coords[:, 1] // self.block_size)
                    block_z = self.block_embed_z(coords[:, 2] // self.block_size)

                    x = x + block_x
                    y = y + block_y
                    z = z + block_z

            return self.embed_proj(torch.cat([x, y, z], dim=-1))


    def forward(self, vertices: torch.Tensor, faces: torch.Tensor):
        """
        Args:
            vertices: [V_total, 3], multiple meshes concatenated
            faces: [F_total, 3], indices into vertices (already offset)
        Returns:
            vertex_features: [V_total, D]
        """
        print('edge_direction')
        device = vertices.device
        V = vertices.shape[0]
        F = faces.shape[0]

        # 1. Encode base vertex features. These will be used as Queries.
        vertex_base_feats = self.encode_vertex(vertices)  # [V, embed_dim]

        # Project vertex features to be used as Query
        query = self.vertex_query_proj(vertex_base_feats) # [V, D]

        if self.vtx_embedding_only:
            # If only vertex embedding is needed, return the projected features directly
            return query

        # Extract face coordinates
        face_coords = vertices[faces]  # [F, 3, 3]
        v0_coords, v1_coords, v2_coords = face_coords[:, 0].float(), face_coords[:, 1].float(), face_coords[:, 2].float()

        # Compute face normals
        normals = torch.nn.functional.normalize(torch.linalg.cross(v1_coords - v0_coords, v2_coords - v0_coords, dim=-1), dim=-1)  # [F, 3]
        normal_feat = self.normal_encoder(normals)  # [F, D]

        # 2. Generate Face Keys and Values for Cross-Attention
        # For each face, we need 3 sets of K/V, one for each vertex that forms the face.
        # Each K/V captures the perspective of that vertex on the face.

        # Directions from v0's perspective on face: (v1-v0), (v2-v0)
        dirs_from_v0_raw = torch.stack([v1_coords - v0_coords, v2_coords - v0_coords], dim=1)  # [F, 2, 3]
        # Directions from v1's perspective on face: (v0-v1), (v2-v1)
        dirs_from_v1_raw = torch.stack([v0_coords - v1_coords, v2_coords - v1_coords], dim=1)  # [F, 2, 3]
        # Directions from v2's perspective on face: (v0-v2), (v1-v2)
        dirs_from_v2_raw = torch.stack([v0_coords - v2_coords, v1_coords - v2_coords], dim=1)  # [F, 2, 3]

        # Concatenate all raw directions for encoding: [F * 3 * 2, 3]
        all_dirs_raw = torch.cat([
            dirs_from_v0_raw.view(-1, 3),
            dirs_from_v1_raw.view(-1, 3),
            dirs_from_v2_raw.view(-1, 3)
        ], dim=0)

        # Normalize by resolution and pass through PE_NeRF
        normalized_all_dirs = all_dirs_raw.float() / 1024.
        encoded_all_dirs = self.pe_nerf_face(normalized_all_dirs)  # [F * 3 * 2, embed_dim]

        # Reshape encoded_all_dirs to get per-vertex-in-face directional features
        # [F, 3, 2 * embed_dim] -> (Face_idx, Local_Vertex_idx_in_face, Concatenated_Direction_Features)
        encoded_all_dirs_reshaped_per_vtx = encoded_all_dirs.view(F, 3, 2 * self.embed_dim)

        # Expand normal features to match the F faces directly (normal_feat is already [F, D])
        # Concatenate directional features with normal features for each vertex's perspective
        face_kv_input_v0 = torch.cat([encoded_all_dirs_reshaped_per_vtx[:, 0], normal_feat], dim=-1) # [F, 3*D]
        face_kv_input_v1 = torch.cat([encoded_all_dirs_reshaped_per_vtx[:, 1], normal_feat], dim=-1) # [F, 3*D]
        face_kv_input_v2 = torch.cat([encoded_all_dirs_reshaped_per_vtx[:, 2], normal_feat], dim=-1) # [F, 3*D]

        # Generate the Keys and Values for attention
        # face_kv_v0 = self.face_kv_fusion(face_kv_input_v0) # [F, D]
        # face_kv_v1 = self.face_kv_fusion(face_kv_input_v1) # [F, D]
        # face_kv_v2 = self.face_kv_fusion(face_kv_input_v2) # [F, D]
        
        face_kv_inputs = torch.cat([face_kv_input_v0, face_kv_input_v1, face_kv_input_v2], dim=0)  # [3F, D_in]
        face_kv_all = self.face_kv_fusion(face_kv_inputs)  # [3F, D_out]
        face_kv_v0, face_kv_v1, face_kv_v2 = torch.chunk(face_kv_all, 3, dim=0)


        # Concatenate all face-based KVs: [F * 3, D]
        # The order here is crucial: it corresponds to the flattening of faces.
        all_face_kvs = torch.cat([face_kv_v0, face_kv_v1, face_kv_v2], dim=0) # [F * 3, D]

        # 3. Prepare Keys and Values for Cross-Attention:
        # Each Query (a vertex) needs to attend to *its incident* face-based KVs.
        # We need to gather the relevant KVs for each query.
        # For vertex 'v', its K/V pool consists of the face KVs from faces it's part of.
        # max_group_size now determines the maximum number of faces a vertex can attend to.

        with torch.no_grad():
            # Create a sparse adjacency matrix from vertex to face
            row_indices = faces.view(-1) # All vertex indices in faces (e.g., [0,1,2,1,2,3])
            col_indices = torch.repeat_interleave(torch.arange(F, device=device), 3) # Corresponding face indices (e.g., [0,0,0,1,1,1])

            sparse_adj = torch.sparse_coo_tensor(
                torch.stack([row_indices, col_indices]),
                torch.ones(3 * F, device=device),
                size=(V, F)
            )
            dense_adj = sparse_adj.to_dense() # Convert to dense for _build_fixed_groups [V, F]

            # Build fixed-size groups of neighboring faces for each vertex
            # This returns face indices that each vertex should attend to.
            group_face_idx, group_attn_mask = self._build_fixed_groups(dense_adj, self.max_group_size)
            # group_face_idx: [V, M], where M is max_group_size. Each entry is a face index.
            # group_attn_mask: [V, M], boolean mask for attention (True means mask out).

            # Create a full map from (global_vertex_id, global_face_id) -> all_face_kvs_idx
            # This maps a (vertex, face) pair to the specific index in `all_face_kvs`
            # that represents that face's KV feature from that vertex's perspective.
            vertex_face_to_kv_idx = -torch.ones(V, F, dtype=torch.long, device=device)

            # --- Vectorized way to fill vertex_face_to_kv_idx (Optimization) ---
            global_v_ids_flat = faces.view(-1) # [F * 3] - all vertex IDs in flattened faces
            f_indices_flat = torch.arange(F, device=device).repeat_interleave(3) # [F * 3] - corresponding face IDs
            target_kv_indices = torch.arange(F * 3, device=device) # [F * 3] - target index in all_face_kvs

            # Fill the map: vertex_face_to_kv_idx[global_v_id, f_idx] = all_face_kvs_idx
            vertex_face_to_kv_idx[global_v_ids_flat, f_indices_flat] = target_kv_indices
            # --- End Vectorized Optimization ---

            # Now, use `vertex_face_to_kv_idx` to gather the correct `all_face_kvs` indices
            # for each vertex's group of faces.
            
            # Create a mask for valid (non-padded) face indices in group_face_idx
            valid_group_face_mask = (group_face_idx >= 0) & (group_face_idx < F)

            # Use advanced indexing to gather the all_face_kvs_idx
            row_indices_gather = torch.arange(V, device=device).unsqueeze(1).expand(-1, self.max_group_size)[valid_group_face_mask]
            col_indices_gather = group_face_idx[valid_group_face_mask]
            
            gathered_kv_indices = torch.full_like(group_face_idx, -1, dtype=torch.long) # Initialize with -1
            gathered_kv_indices[valid_group_face_mask] = vertex_face_to_kv_idx[row_indices_gather, col_indices_gather]

            # Create a mask for valid indices to gather (those not -1)
            final_kv_gather_mask = (gathered_kv_indices >= 0) & (gathered_kv_indices < F * 3)

            # Initialize K and V tensors
            keys = torch.zeros(V, self.max_group_size, self.embed_dim, device=device, dtype=all_face_kvs.dtype)
            values = torch.zeros(V, self.max_group_size, self.embed_dim, device=device, dtype=all_face_kvs.dtype)

            # Gather the actual K/V features. Use `view(-1)` for flat indexing.
            keys.view(-1, self.embed_dim)[final_kv_gather_mask.view(-1)] = all_face_kvs[gathered_kv_indices.view(-1)[final_kv_gather_mask.view(-1)]]
            values.view(-1, self.embed_dim)[final_kv_gather_mask.view(-1)] = all_face_kvs[gathered_kv_indices.view(-1)[final_kv_gather_mask.view(-1)]]

            # Attention mask for MultiheadAttention
            # This mask should be [V, 1, M] (Query_Seq_Len=1, KV_Seq_Len=M)
            # The `group_attn_mask` from `_build_fixed_groups` is already `[V, M]`.
            # True means mask out (i.e., this element should not be attended to).
            attn_mask = group_attn_mask.unsqueeze(1) # [V, 1, M]

        # 4. Perform Cross-Attention
        # Query: [V, 1, D]
        # Keys/Values: [V, M, D]
        # attn_mask: [V, 1, M]
        attended_vertex_feats = torch.nn.functional.scaled_dot_product_attention(
            query=query.unsqueeze(1),     # [V, 1, D] - Vertex is the query
            key=keys,                     # [V, M, D] - Neighboring face features are K
            value=values,                 # [V, M, D] - Neighboring face features are V
            attn_mask=attn_mask
        ).squeeze(1) # [V, D] - Output shape is fixed per vertex

        # 5. Final fusion: Add aggregated face context to the initial vertex feature via residual connection and LayerNorm
        final_vertex_feats = self.norm(attended_vertex_feats + vertex_base_feats) # Use vertex_base_feats for residual

        return final_vertex_feats


    def _build_fixed_groups(self, dense_adj: torch.Tensor, max_group_size: int):
        """
        [DEBUGGED & ROBUST VERSION]
        Builds fixed-size groups of neighboring faces for each vertex using the dense adjacency matrix + topk approach.
        This approach randomly selects up to max_group_size incident faces for each vertex.

        Args:
            dense_adj: [V, F] dense adjacency matrix where dense_adj[v, f] = 1 if face f is incident to vertex v.
            max_group_size: Maximum number of faces to include in a group for each vertex.

        Returns:
            group_idx: [V, max_group_size] tensor of face indices for each vertex's group. Padded with -1.
            attn_mask: [V, max_group_size] boolean mask for attention. True indicates masked (padded or non-existent) elements.
        """
        V, F = dense_adj.shape
        device = dense_adj.device

        # Introduce randomness to break ties and ensure diverse sampling if multiple faces have same "value"
        rand_perm = torch.rand_like(dense_adj)
        sort_values = dense_adj * rand_perm # Only incident faces (dense_adj=1) get a random value

        # Determine the actual number of faces to select (min of max_group_size and total faces)
        k = min(max_group_size, F)

        # Handle cases where k is 0 (no faces or max_group_size is 0)
        if k == 0:
            topk_idx = torch.full((V, max_group_size), -1, dtype=torch.long, device=device)
            valid_mask = torch.zeros_like(topk_idx, dtype=torch.bool)
            return topk_idx, ~valid_mask # Return all True for attention mask (mask everything)

        # Use torch.topk to get the indices of the 'k' largest (random) values
        topk_val, topk_idx = torch.topk(sort_values, k=k, dim=1)

        # Pad with -1 if k is less than max_group_size
        if k < max_group_size:
            pad_size = max_group_size - k
            padding_idx = torch.full((V, pad_size), -1, dtype=torch.long, device=device)
            topk_idx = torch.cat([topk_idx, padding_idx], dim=1)

            padding_val = torch.zeros((V, pad_size), dtype=topk_val.dtype, device=device)
            topk_val = torch.cat([topk_val, padding_val], dim=1)

        # Create validity mask: an element is valid if its original `topk_val` was greater than 0.
        # This correctly handles non-incident faces (dense_adj=0) that might appear due to padding.
        valid_mask = topk_val > 0

        # Attention mask for MultiheadAttention: True means mask out the element.
        # So we return the logical NOT of the valid_mask.
        return topk_idx, ~valid_mask

class VoxelFeatureEncoder_edge(nn.Module):
    def __init__(self, embed_dim=32, resolution=1024, multires=10, 
                 relative_embed=False, add_block_embed=False, block_size=64, 
                 num_voxel_labels=2, pos_encoding='embedding',
                 add_edge_glb_feats=False, add_direction=False,):
        super().__init__()
        self.embed_dim = embed_dim
        self.resolution = resolution
        self.relative_embed = relative_embed
        self.add_block_embed = add_block_embed
        self.block_size = block_size if add_block_embed else None
        self.num_voxel_labels = num_voxel_labels
        self.pos_encoding = pos_encoding
        self.add_edge_glb_feats = add_edge_glb_feats
        self.add_direction = add_direction

        if pos_encoding == "embedding":
            # Coordinate embeddings based on resolution or relative position
            if not relative_embed:
                self.coord_embed_x = nn.Embedding(resolution, embed_dim)
                self.coord_embed_y = nn.Embedding(resolution, embed_dim)
                self.coord_embed_z = nn.Embedding(resolution, embed_dim)
            else:
                self.coord_embed_x = nn.Embedding(16, embed_dim) # For relative coordinates (coords % 16)
                self.coord_embed_y = nn.Embedding(16, embed_dim)
                self.coord_embed_z = nn.Embedding(16, embed_dim)

                if add_block_embed: # For absolute block-level embedding
                    self.block_embed_x = nn.Embedding(resolution // block_size, embed_dim)
                    self.block_embed_y = nn.Embedding(resolution // block_size, embed_dim)
                    self.block_embed_z = nn.Embedding(resolution // block_size, embed_dim)

            self.embed_proj = nn.Linear(3 * embed_dim, embed_dim)

            # New: Absolute coordinate embeddings for global vertex coordinates
            # This will be used for vertex voxels in addition to relative/block embeds
            self.global_coord_embed_x = nn.Embedding(resolution, embed_dim)
            self.global_coord_embed_y = nn.Embedding(resolution, embed_dim)
            self.global_coord_embed_z = nn.Embedding(resolution, embed_dim)

            # New: Absolute coordinate embeddings for edge endpoints
            self.endpoint_embed_x = nn.Embedding(resolution, embed_dim)
            self.endpoint_embed_y = nn.Embedding(resolution, embed_dim)
            self.endpoint_embed_z = nn.Embedding(resolution, embed_dim)

        else:
            ### pos_encoding == "nerf":
            # self.pe_nerf = PE_NeRF(out_channels=embed_dim, multires=multires)
  
            self.pe_nerf_local = PE_NeRF(out_channels=embed_dim, multires=multires)
            self.pe_nerf_global = PE_NeRF(out_channels=embed_dim, multires=multires)
            if self.add_direction:
                self.pe_nerf_dir = PE_NeRF(out_channels=embed_dim, multires=multires)
            # self.pe_nerf_offset = PE_NeRF(out_channels=embed_dim, multires=multires)

        
        self.label_embedding = nn.Embedding(num_voxel_labels, embed_dim)
        
        # Modified:
        # For vertices, we now fuse relative/block, global, and label embeddings
        # For vertices: (relative/block embed + global embed) + label embed
        # The sum of two embeddings is still one vector. So we cat with label embed.
        self.fusion_proj_vertex = nn.Linear(embed_dim * 2, embed_dim) 
        
        # For edges, we fuse relative/block, label, and two endpoint embeddings
        if self.add_direction:
            self.fusion_proj_edge = nn.Linear(embed_dim * 3, embed_dim)
        else:
            self.fusion_proj_edge = nn.Linear(embed_dim * 4, embed_dim)

    def encode_coordinates(self, coords: torch.Tensor):
        """
        Encodes integer voxel coordinates using relative/block-level embeddings.
        Args:
            coords: [N, 3] tensor of (x, y, z) integer coordinates.
        Returns:
            features: [N, embed_dim] tensor.
        """
        if self.pos_encoding == "embedding":
            if not self.relative_embed:
                x = self.coord_embed_x(coords[:, 0])
                y = self.coord_embed_y(coords[:, 1])
                z = self.coord_embed_z(coords[:, 2])
            else:
                x = self.coord_embed_x(coords[:, 0] % 16)
                y = self.coord_embed_y(coords[:, 1] % 16)
                z = self.coord_embed_z(coords[:, 2] % 16)

                if self.add_block_embed:
                    block_x = self.block_embed_x(coords[:, 0] // self.block_size)
                    block_y = self.block_embed_y(coords[:, 1] // self.block_size)
                    block_z = self.block_embed_z(coords[:, 2] // self.block_size)
                    # We add block embeddings to relative embeddings
                    x = x + block_x
                    y = y + block_y
                    z = z + block_z
            
            return self.embed_proj(torch.cat([x, y, z], dim=-1))
        
        elif self.pos_encoding == "nerf":
            coords = (coords % self.block_size).float() / self.block_size - 0.5
            return self.pe_nerf_local(coords)

    def encode_global_coordinates(self, coords: torch.Tensor):
        """
        Encodes integer voxel coordinates using absolute (global) embeddings.
        Args:
            coords: [N, 3] tensor of (x, y, z) integer coordinates.
        Returns:
            features: [N, embed_dim] tensor.
        """
        if self.pos_encoding == "embedding":
            x = self.global_coord_embed_x(coords[:, 0])
            y = self.global_coord_embed_y(coords[:, 1])
            z = self.global_coord_embed_z(coords[:, 2])
            return self.embed_proj(torch.cat([x, y, z], dim=-1))
        elif self.pos_encoding == "nerf":
            coords = coords.float() / self.resolution - 0.5
            return self.pe_nerf_global(coords)

    # def encode_direction(self, endpoints: torch.Tensor):
    #     """
    #     Encodes edge direction instead of endpoints.
    #     Args:
    #         endpoints: [N, 6] tensor of (x1, y1, z1, x2, y2, z2).
    #     Returns:
    #         features: [N, embed_dim] tensor, direction embeddings.
    #     """
    #     p1 = endpoints[:, :3].float()
    #     p2 = endpoints[:, 3:].float()
    #     print('')

    #     d = p2 - p1  # [N, 3]
    #     d = F.normalize(d, dim=-1, eps=1e-6)  # 归一化方向

    #     if self.pos_encoding == "embedding":
    #         raise ValueError("Direction encoding with embedding not supported, use pos_encoding='nerf'")

    #     elif self.pos_encoding == "nerf":
    #         d_embed = self.pe_nerf(d)
    #         d_neg_embed = self.pe_nerf(-d)
    #         return 0.5 * (d_embed + d_neg_embed)

    def encode_direction(self, endpoints: torch.Tensor):
        """
        Encodes edge direction in a canonical way (for undirected edges).
        Args:
            endpoints: [N, 6] tensor of (x1, y1, z1, x2, y2, z2).
        Returns:
            features: [N, embed_dim] tensor, direction embeddings.
        """
        # Split endpoints
        p1 = endpoints[:, :3].float()
        p2 = endpoints[:, 3:].float()

        # Step 1: sort by xyz (lexicographic sort per edge)
        # mask: True if p1 <= p2 lexicographically
        mask = ( (p1[:,0] < p2[:,0]) |
                ((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
                ((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )

        pA = torch.where(mask[:, None], p1, p2)  # smaller one
        pB = torch.where(mask[:, None], p2, p1)  # larger one

        # Step 2: Compute canonical direction vector
        d = pB - pA
        d = F.normalize(d, dim=-1, eps=1e-6)

        # Step 3: Encode
        if self.pos_encoding == "embedding":
            raise ValueError("Direction encoding with embedding not supported, use pos_encoding='nerf'")

        elif self.pos_encoding == "nerf":
            d_embed = self.pe_nerf_dir(d)
            return d_embed

    def encode_endpoints(self, endpoints: torch.Tensor):
        """
        Encodes integer endpoint coordinates into an embedding.
        Args:
            endpoints: [N, 6] tensor of (x1, y1, z1, x2, y2, z2) integer coordinates.
        Returns:
            features: [N, 2 * embed_dim] tensor, concatenated embeddings for p1 and p2.
        """
        # Split endpoints into p1 and p2
        if self.pos_encoding == "embedding":
            p1 = endpoints[:, :3]
            p2 = endpoints[:, 3:]

            p1_x = self.endpoint_embed_x(p1[:, 0])
            p1_y = self.endpoint_embed_y(p1[:, 1])
            p1_z = self.endpoint_embed_z(p1[:, 2])
            p1_embed = self.embed_proj(torch.cat([p1_x, p1_y, p1_z], dim=-1))

            p2_x = self.endpoint_embed_x(p2[:, 0])
            p2_y = self.endpoint_embed_y(p2[:, 1])
            p2_z = self.endpoint_embed_z(p2[:, 2])
            p2_embed = self.embed_proj(torch.cat([p2_x, p2_y, p2_z], dim=-1))

            return torch.cat([p1_embed, p2_embed], dim=-1)

        elif self.pos_encoding == "nerf":
            p1 = endpoints[:, :3].float() / self.resolution - 0.5
            p2 = endpoints[:, 3:].float() / self.resolution - 0.5
            p1_embed = self.pe_nerf_global(p1)
            p2_embed = self.pe_nerf_global(p2)
            return torch.cat([p1_embed, p2_embed], dim=-1)

    def forward(self, voxel_coords: torch.Tensor, voxel_labels: torch.Tensor, voxel_endpoints=None, voxel_offset=None,):
        """
        Encodes voxel coordinates, type labels, and endpoint info into features.
        
        Args:
            voxel_coords: [N_total, 3] tensor, combined (x,y,z) coordinates.
            voxel_labels: [N_total] tensor, type labels (0 for vertex, 1 for edge).
            voxel_endpoints: [N_total, 7] tensor, flattened endpoint coordinates for each voxel.
        
        Returns:
            final_voxel_features: [N_total, D] tensor.
        """
        if voxel_endpoints is None:
            return self.label_embedding(voxel_labels)

        if voxel_endpoints.shape[-1] == 7:
            voxel_endpoints = voxel_endpoints[:, 1:]

        if voxel_offset is not None:
            if voxel_offset.shape[-1] == 4:
                voxel_offset = voxel_offset[:, 1:]
                
            voxel_coords = voxel_coords - voxel_offset # + 0.5
        
        dtype = torch.bfloat16 if voxel_coords.dtype == torch.bfloat16 else torch.float32
        
        # 1. Encode voxel type labels
        label_embeds = self.label_embedding(voxel_labels)
        
        # 2. Create masks to separate vertex and edge voxels
        is_vertex = (voxel_labels == 0)
        is_edge = (voxel_labels == 1)

        # 3. Process vertex voxels
        vertex_features = torch.empty(0, self.embed_dim, device=voxel_coords.device, dtype=dtype)
        if is_vertex.any():
            vertex_coords = voxel_coords[is_vertex]
            
            # Encode relative/block coordinates
            relative_coord_feats = self.encode_coordinates(vertex_coords)
            
            # Encode global coordinates
            global_coord_feats = self.encode_global_coordinates(vertex_coords)

            # Sum relative and global embeddings to get a rich coordinate representation
            fused_coord_feats = relative_coord_feats + global_coord_feats

            # Fuse with label embedding
            vertex_label_embeds = label_embeds[is_vertex]
            vertex_features = self.fusion_proj_vertex(
                torch.cat([fused_coord_feats, vertex_label_embeds], dim=-1)
            ).to(dtype=dtype)

        # 4. Process edge voxels
        edge_features = torch.empty(0, self.embed_dim, device=voxel_coords.device, dtype=dtype)
        if is_edge.any():
            edge_coords = voxel_coords[is_edge]
            edge_label_embeds = label_embeds[is_edge]
            
            # Encode relative/block coordinates for edges
            edge_coord_feats = self.encode_coordinates(edge_coords)
            if self.add_edge_glb_feats:
                edge_global_coord_feats = self.encode_global_coordinates(edge_coords)
                edge_coord_feats += edge_global_coord_feats
            
            # Encode endpoint information for edge voxels
            edge_endpoints_int = voxel_endpoints[is_edge].long()
            if not self.add_direction:
                endpoint_embeds = self.encode_endpoints(edge_endpoints_int)
                # Fuse coordinate, label, and endpoint embeddings
                edge_features = self.fusion_proj_edge(
                    torch.cat([edge_coord_feats, edge_label_embeds, endpoint_embeds], dim=-1)
                ).to(dtype=dtype)
            else:
                direction_embed = self.encode_direction(edge_endpoints_int)
                edge_features = self.fusion_proj_edge(
                    torch.cat([edge_coord_feats, edge_label_embeds, direction_embed], dim=-1)
                ).to(dtype=dtype)
            
            

        # 5. Combine the features back into a single tensor
        final_features = torch.empty(voxel_coords.shape[0], self.embed_dim, 
                                device=voxel_coords.device, dtype=dtype)
        if is_vertex.any():
            final_features[is_vertex] = vertex_features
        if is_edge.any():
            final_features[is_edge] = edge_features

        # if voxel_offset is not None:
        #     if voxel_offset.shape[-1] == 4:
        #         voxel_offset = voxel_offset[:, 1:]
        #     offset = self.pe_nerf_offset(voxel_offset)
        #     final_features += offset

        return final_features
    

class VoxelFeatureEncoder_vtx(nn.Module):
    def __init__(self, embed_dim=32, resolution=1024, multires=10, 
                 relative_embed=False, add_block_embed=False, block_size=64, 
                 num_voxel_labels=2, pos_encoding='nerf',
                 add_edge_glb_feats=False, add_direction=False,):
        super().__init__()
        self.embed_dim = embed_dim
        self.resolution = resolution
        self.relative_embed = relative_embed
        self.add_block_embed = add_block_embed
        self.block_size = block_size if add_block_embed else None
        self.num_voxel_labels = num_voxel_labels
        self.pos_encoding = pos_encoding
        self.add_edge_glb_feats = add_edge_glb_feats
        self.add_direction = add_direction

        if pos_encoding == "embedding":
            # Coordinate embeddings based on resolution or relative position
            if not relative_embed:
                self.coord_embed_x = nn.Embedding(resolution, embed_dim)
                self.coord_embed_y = nn.Embedding(resolution, embed_dim)
                self.coord_embed_z = nn.Embedding(resolution, embed_dim)
            else:
                self.coord_embed_x = nn.Embedding(16, embed_dim) # For relative coordinates (coords % 16)
                self.coord_embed_y = nn.Embedding(16, embed_dim)
                self.coord_embed_z = nn.Embedding(16, embed_dim)

                if add_block_embed: # For absolute block-level embedding
                    self.block_embed_x = nn.Embedding(resolution // block_size, embed_dim)
                    self.block_embed_y = nn.Embedding(resolution // block_size, embed_dim)
                    self.block_embed_z = nn.Embedding(resolution // block_size, embed_dim)

            self.embed_proj = nn.Linear(3 * embed_dim, embed_dim)

            # New: Absolute coordinate embeddings for global vertex coordinates
            # This will be used for vertex voxels in addition to relative/block embeds
            self.global_coord_embed_x = nn.Embedding(resolution, embed_dim)
            self.global_coord_embed_y = nn.Embedding(resolution, embed_dim)
            self.global_coord_embed_z = nn.Embedding(resolution, embed_dim)

            # New: Absolute coordinate embeddings for edge endpoints
            self.endpoint_embed_x = nn.Embedding(resolution, embed_dim)
            self.endpoint_embed_y = nn.Embedding(resolution, embed_dim)
            self.endpoint_embed_z = nn.Embedding(resolution, embed_dim)

        else:
            ### pos_encoding == "nerf":
            # self.pe_nerf = PE_NeRF(out_channels=embed_dim, multires=multires)
  
            self.pe_nerf_local = PE_NeRF(out_channels=embed_dim, multires=multires)
            self.pe_nerf_global = PE_NeRF(out_channels=embed_dim, multires=multires)
            
            self.pe_nerf_dir = PE_NeRF(out_channels=embed_dim, multires=multires)

        self.fusion_proj = nn.Linear(embed_dim * 2, embed_dim) 
        
        # For edges, we fuse relative/block, label, and two endpoint embeddings
        if self.add_direction:
            self.fusion_proj_edge = nn.Linear(embed_dim * 3, embed_dim)
        else:
            self.fusion_proj_edge = nn.Linear(embed_dim * 4, embed_dim)

    def encode_global_coordinates(self, coords: torch.Tensor):
        """
        Encodes integer voxel coordinates using absolute (global) embeddings.
        Args:
            coords: [N, 3] tensor of (x, y, z) integer coordinates.
        Returns:
            features: [N, embed_dim] tensor.
        """
        if self.pos_encoding == "embedding":
            x = self.global_coord_embed_x(coords[:, 0])
            y = self.global_coord_embed_y(coords[:, 1])
            z = self.global_coord_embed_z(coords[:, 2])
            return self.embed_proj(torch.cat([x, y, z], dim=-1))
        elif self.pos_encoding == "nerf":
            coords = coords.float() / self.resolution - 0.5
            return self.pe_nerf_global(coords)

    def encode_direction(self, endpoints: torch.Tensor):
        """
        Encodes edge direction in a canonical way (for undirected edges).
        Args:
            endpoints: [N, 6] tensor of (x1, y1, z1, x2, y2, z2).
        Returns:
            features: [N, embed_dim] tensor, direction embeddings.
        """
        # Split endpoints
        p1 = endpoints[:, :3].float()
        p2 = endpoints[:, 3:].float()

        # Step 1: sort by xyz (lexicographic sort per edge)
        # mask: True if p1 <= p2 lexicographically
        mask = ( (p1[:,0] < p2[:,0]) |
                ((p1[:,0] == p2[:,0]) & (p1[:,1] < p2[:,1])) |
                ((p1[:,0] == p2[:,0]) & (p1[:,1] == p2[:,1]) & (p1[:,2] <= p2[:,2])) )

        pA = torch.where(mask[:, None], p1, p2)  # smaller one
        pB = torch.where(mask[:, None], p2, p1)  # larger one

        # Step 2: Compute canonical direction vector
        d = pB - pA
        d = F.normalize(d, dim=-1, eps=1e-6)

        # Step 3: Encode
        if self.pos_encoding == "embedding":
            raise ValueError("Direction encoding with embedding not supported, use pos_encoding='nerf'")

        elif self.pos_encoding == "nerf":
            d_embed = self.pe_nerf_dir(d)
            return d_embed


    def forward(self, vertex_coords: torch.Tensor, vertex_neighbors_flat: torch.Tensor, 
                vertex_neighbors_offsets: torch.Tensor):
        device = vertex_coords.device
        num_vertices = vertex_coords.shape[0]

        coords_normalized = vertex_coords.float() / (self.resolution - 1) - 0.5
        coords_feats = self.pe_nerf_global(coords_normalized)  # (V, D)

        if vertex_neighbors_flat.numel() > 0:
            neighbor_coords = vertex_coords[vertex_neighbors_flat]
            
            parent_vtx_indices = torch.arange(num_vertices, device=device).repeat_interleave(
                torch.diff(torch.cat([vertex_neighbors_offsets, torch.tensor([vertex_neighbors_flat.numel()], device=device)]))
            )
            
            vtx_coords_for_dir = vertex_coords[parent_vtx_indices]
            dirs = F.normalize(neighbor_coords.float() - vtx_coords_for_dir.float(), dim=-1, eps=1e-6)
            
            dir_feats_flat = self.pe_nerf_dir(dirs) # (N, D)
            
            agg_dir_feats = torch.zeros((num_vertices, self.embed_dim), device=device)
            agg_count = torch.zeros(num_vertices, device=device)

            dir_feats_flat = dir_feats_flat.to(agg_dir_feats.dtype)
            agg_dir_feats.index_add_(0, parent_vtx_indices, dir_feats_flat)
            agg_count.index_add_(0, parent_vtx_indices, torch.ones_like(parent_vtx_indices, dtype=torch.float32))

            agg_dir_feats /= (agg_count.unsqueeze(1) + 1e-6)
        else:
            agg_dir_feats = torch.zeros((num_vertices, self.embed_dim), device=device)

        fused_feats = self.fusion_proj(torch.cat([coords_feats, agg_dir_feats], dim=-1))
        
        return fused_feats


class VoxelFeatureEncoder_active(nn.Module):
    def __init__(self, embed_dim=32, resolution=1024, multires=10, 
                 relative_embed=False, add_block_embed=False, block_size=64, 
                 num_voxel_labels=3, pos_encoding='embedding',
                 add_edge_glb_feats=False, add_direction=False,):
        super().__init__()
        self.embed_dim = embed_dim
        self.resolution = resolution
        self.relative_embed = relative_embed
        self.add_block_embed = add_block_embed
        self.block_size = block_size if add_block_embed else None
        self.num_voxel_labels = num_voxel_labels
        self.pos_encoding = pos_encoding

        if pos_encoding == "embedding":
            self.coord_embed_x = nn.Embedding(resolution, embed_dim)
            self.coord_embed_y = nn.Embedding(resolution, embed_dim)
            self.coord_embed_z = nn.Embedding(resolution, embed_dim)
            
            self.embed_proj = nn.Linear(3 * embed_dim, embed_dim)
        else:
            self.pe_nerf = PE_NeRF(out_channels=embed_dim, multires=multires)
        
        self.label_embedding = nn.Embedding(num_voxel_labels, embed_dim)
        
        self.fusion_proj = nn.Linear(embed_dim * 2, embed_dim) 
        

    def encode_coordinates(self, coords: torch.Tensor):
        """
        Encodes integer voxel coordinates using relative/block-level embeddings.
        Args:
            coords: [N, 3] tensor of (x, y, z) integer coordinates.
        Returns:
            features: [N, embed_dim] tensor.
        """
        if self.pos_encoding == "embedding":
            x = self.coord_embed_x(coords[:, 0])
            y = self.coord_embed_y(coords[:, 1])
            z = self.coord_embed_z(coords[:, 2])
            
            return self.embed_proj(torch.cat([x, y, z], dim=-1))
        
        elif self.pos_encoding == "nerf":
            coords = (coords).float() / (self.block_size - 1) - 0.5
            return self.pe_nerf(coords)

    def forward(self, voxel_coords: torch.Tensor, voxel_labels: torch.Tensor,):
        """
        Encodes voxel coordinates, type labels, and endpoint info into features.
        
        Args:
            voxel_coords: [N_total, 3] tensor, combined (x,y,z) coordinates.
            voxel_labels: [N_total] tensor, type labels (0 for vertex, 1 for edge).
            
        Returns:
            final_voxel_features: [N_total, D] tensor.
        """
        if voxel_coords.shape[-1] == 4:
            spatial_coords = voxel_coords[..., 1:]
        else:
            spatial_coords = voxel_coords
        
        coord_feats = self.encode_coordinates(spatial_coords)

        label_feats = self.label_embedding(voxel_labels)

        combined_feats = torch.cat([coord_feats, label_feats], dim=-1)

        fused_feats = self.fusion_proj(combined_feats)

        return fused_feats


class VoxelFeatureEncoder_active_pointnet(nn.Module):
    def __init__(self, in_channels, hidden_dim, out_channels, scatter_type, n_blocks, resolution=64, add_label=False):
        super().__init__()
        self.pointnet = LocalPoolPointnet(
            in_channels=in_channels, 
            out_channels=out_channels, 
            hidden_dim=hidden_dim, 
            n_blocks=n_blocks,
            scatter_type=scatter_type,
        )
        
        self.add_label = add_label
        self.resolution = resolution
        if add_label:
            self.label_emb=nn.Embedding(3, 16)
            self.fusion_mlp = nn.Sequential(
                nn.Linear(out_channels + 16, out_channels),
                nn.GELU(),
                nn.Linear(out_channels, out_channels)
            )

    def forward(self, p, sparse_coords, res=None, bbox_size=(-0.5, 0.5), voxel_label=None,):
        if res is None: res = self.resolution
        
        geo_feats = self.pointnet(p, sparse_coords, res=res, bbox_size=bbox_size)
        if voxel_label is not None:
            label_feats = self.label_emb(voxel_label)
            feats = self.fusion_mlp(torch.cat([geo_feats, label_feats], dim=-1))

            return feats
        else:
            return geo_feats


class ConnectionHead(nn.Module):
    def __init__(self, channels: int, out_channels: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(channels, int(channels * mlp_ratio)),
            nn.GELU(approximate="tanh"),
            nn.Linear(int(channels * mlp_ratio), out_channels),
        )

    def forward(self, x):
        return self.mlp(x)


# class ConnectionHead(nn.Module):
#     def __init__(self, channels: int, out_channels: int, mlp_ratio: float = 4.0):
#         super().__init__()
#         hidden_dim = int(channels * mlp_ratio)
#         self.input_norm = nn.LayerNorm(channels)
        
#         self.mlp = nn.Sequential(
#             nn.Linear(channels, hidden_dim),
#             nn.LayerNorm(hidden_dim), 
#             nn.GELU(approximate="tanh"),
#             nn.Linear(hidden_dim, out_channels),
#         )

#     def forward(self, x):
#         x = self.input_norm(x)
#         return self.mlp(x)