import torch
import torch.nn as nn
from typing import *
import torch.nn.functional as F
from .base import SparseTransformerBase, SparseTransformerCrossBase
from ...modules import sparse as sp
from ...modules.sparse import SparseTensor
from ...modules.sparse.linear import SparseLinear
from ...modules.sparse.nonlinearity import SparseGELU
from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
from .utils import SparseResBlock3d
from ...modules.utils import DiagonalGaussianDistribution
import random
from ...modules.sparse.transformer.blocks import SparseTransformerCrossBlock
from ...modules.sparse.attention import SparseMultiHeadAttention
from ...modules.transformer import AbsolutePositionEmbedder
class SparseOccHead(nn.Module):
    def __init__(self, channels: int, out_channels: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.mlp = nn.Sequential(
            SparseLinear(channels, int(channels * mlp_ratio)),
            SparseGELU(approximate="tanh"),
            SparseLinear(int(channels * mlp_ratio), out_channels),
        )

    def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
        return self.mlp(x)

class SparseFFN(nn.Module):
    def __init__(self, channels: int, out_channels: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.mlp = nn.Sequential(
            SparseLinear(channels, int(channels * mlp_ratio)),
            SparseGELU(approximate="tanh"),
            SparseLinear(int(channels * mlp_ratio), out_channels),
        )

    def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
        return self.mlp(x)
    
class SparseEdgeHead(nn.Module):
    def __init__(self, channels: int, out_channels: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.mlp = nn.Sequential(
            SparseLinear(channels, int(channels * mlp_ratio)),
            SparseGELU(approximate="tanh"),
            SparseLinear(int(channels * mlp_ratio), out_channels),
        )

    def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
        return self.mlp(x)

class SparseErrorHead(nn.Module):
    def __init__(self, channels: int, out_channels: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.mlp = nn.Sequential(
            SparseLinear(channels, int(channels * mlp_ratio)),
            SparseGELU(approximate="tanh"),
            SparseLinear(int(channels * mlp_ratio), out_channels),
        )

    def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
        return self.mlp(x)
    
class SparseEncoderBlock(nn.Module):
    def __init__(
        self,
        resolution: int,
        in_channels: int,
        model_channels: int,
        out_channels: int,
        num_blocks: int,
        num_heads: Optional[int] = None,
        num_head_channels: Optional[int] = 64,
        mlp_ratio: float = 4,
        attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
        window_size: int = 8,
        pe_mode: Literal["ape", "rope"] = "ape",
        use_fp16: bool = False,
        use_checkpoint: bool = False,
        qk_rms_norm: bool = False,
        attn_first: bool = True,
    ):
        super().__init__()
        self.resolution = resolution

        # transformer for ctx
        self.self_attn = SparseTransformerBase(
            in_channels=in_channels,
            model_channels=model_channels,
            num_blocks=num_blocks,
            num_heads=num_heads,
            num_head_channels=num_head_channels,
            attn_mode=attn_mode,
            window_size=window_size,
            pe_mode=pe_mode,
            mlp_ratio=mlp_ratio,
            use_fp16=use_fp16,
            use_checkpoint=use_checkpoint,
            qk_rms_norm=qk_rms_norm,
        )

        # cross attention: query=x, key/value=ctx
        self.ca_attn = SparseTransformerCrossBlock(
            channels=model_channels,
            ctx_channels=model_channels,
            model_channels=model_channels,
            num_heads=8,
            mlp_ratio=4,
            attn_mode="full",
            use_rope=False,
            window_size=8,
            use_checkpoint=False,
            qk_rms_norm=False,
        )

        self.query_proj = sp.SparseLinear(in_channels, model_channels)

        # downsample for edge features
        self.edge_downsample = SparseResBlock3d(
            channels=model_channels,
            out_channels=out_channels,
            downsample=True,
            upsample=False,
        )

        # downsample for vertex features
        self.ctx_downsample = SparseResBlock3d(
            channels=model_channels,
            out_channels=out_channels,
            downsample=True,
            upsample=False,
        )

    def forward(
        self, 
        x: SparseTensor,   # edge features
        ctx: SparseTensor  # vertex features
    ) -> Tuple[SparseTensor, SparseTensor]:
        """
        Args:
            x: edge features (SparseTensor)
            ctx: vertex features (SparseTensor)
        Returns:
            (x_out, ctx_out) after cross-attn and downsample
        """
        if ctx is not None:
            ctx = self.self_attn(ctx)
            ctx = ctx.type(ctx.dtype)
            ctx = ctx.replace(F.layer_norm(ctx.feats, ctx.feats.shape[-1:]))

            x = self.query_proj(x)
            x = self.ca_attn(x, ctx)
            x = x.type(x.dtype)
            x = x.replace(F.layer_norm(x.feats, x.feats.shape[-1:]))

            x_out = self.edge_downsample(x)
            ctx_out = self.ctx_downsample(ctx)

        else:
            x = self.self_attn(x)
            x = x.type(x.dtype)
            x = x.replace(F.layer_norm(x.feats, x.feats.shape[-1:]))

            x_out = self.edge_downsample(x)
            ctx_out = None

        return x_out, ctx_out

    def convert_to_fp16(self):
        self.ctx_transformer.convert_to_fp16()

    def convert_to_fp32(self):
        self.ctx_transformer.convert_to_fp32()


class SparseSubdivideBlock3d(nn.Module):
    """
    A 3D subdivide block that can subdivide the sparse tensor.

    Args:
        channels: channels in the inputs and outputs.
        out_channels: if specified, the number of output channels.
        num_groups: the number of groups for the group norm.
    """
    def __init__(
        self,
        channels: int,
        resolution: int,
        out_channels: Optional[int] = None,
        num_groups: int = 32,
        heads_num: int = 4,
        using_attn: bool = False,
    ):
        super().__init__()
        self.channels = channels
        self.resolution = resolution
        self.out_resolution = resolution * 2
        self.out_channels = out_channels
        self.using_attn = using_attn

        self.act_layers = nn.Sequential(
            sp.SparseGroupNorm32(num_groups, channels),
            sp.SparseSiLU()
        )
        
        # self.sub = sp.SparseSubdivide_attn(channels, heads_num)
        if using_attn:
            # self.sub = sp.SparseSubdivide_attn(channels, heads_num, self.out_resolution)
            self.sub = sp.SparseSubdivide_attn(channels, heads_num)
            # print('attn')
        else:
            self.sub = sp.SparseSubdivide()
            # print('wo_attn')

        
        self.out_layers = nn.Sequential(
            sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
            sp.SparseGroupNorm32(num_groups, self.out_channels),
            sp.SparseSiLU(),
            zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
        )
        
        if self.out_channels == channels:
            self.skip_connection = nn.Identity()
        else:
            self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
        
        self.pruning_head = SparseOccHead(self.out_channels, out_channels=1)
            
    def forward(self, x: sp.SparseTensor, pruning=False, training=True, threshold=0.5, force_no_prune=False) -> sp.SparseTensor:
        h = self.act_layers(x)
        h = self.sub(h)
        x = self.sub(x)
        h = self.out_layers(h)
        h = h + self.skip_connection(x)
        if pruning:
            occ_prob = self.pruning_head(h)
            occ_mask = (torch.sigmoid(occ_prob.feats) >= threshold).squeeze(-1)
            if training == False and force_no_prune == False:
                h = sp.SparseTensor(feats=h.feats[occ_mask], coords=h.coords[occ_mask])
            return h, occ_prob
        else:
            return h, None

class SparseSubdivideBlock3d_vtx(nn.Module):
    """
    A 3D subdivide block that can subdivide the sparse tensor.

    Args:
        channels: channels in the inputs and outputs.
        out_channels: if specified, the number of output channels.
        num_groups: the number of groups for the group norm.
    """
    def __init__(
        self,
        channels: int,
        resolution: int,
        out_channels: Optional[int] = None,
        num_groups: int = 32,
        heads_num: int = 4,
        using_attn: bool = False,
    ):
        super().__init__()
        self.channels = channels
        self.resolution = resolution
        self.out_resolution = resolution * 2
        self.out_channels = out_channels
        self.using_attn = using_attn

        self.act_layers = nn.Sequential(
            sp.SparseGroupNorm32(num_groups, channels),
            sp.SparseSiLU()
        )
        
        if using_attn:
            # self.sub = sp.SparseSubdivide_attn(channels, heads_num, self.out_resolution)
            self.sub = sp.SparseSubdivide_attn(channels, heads_num)
            # print('attn')
        else:
            self.sub = sp.SparseSubdivide()
            # print('wo_attn')

        self.out_layers = nn.Sequential(
            sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
            sp.SparseGroupNorm32(num_groups, self.out_channels),
            sp.SparseSiLU(),
            zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
        )
        
        if self.out_channels == channels:
            self.skip_connection = nn.Identity()
        else:
            self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
        
        # self.pruning_head = SparseOccHead(self.out_channels, out_channels=1)
        # ctx_channels=768
        # self.ca_attn = SparseTransformerCrossBlock(
        #     channels=self.out_channels,
        #     ctx_channels=ctx_channels,
        #     num_heads=8,
        #     mlp_ratio=4,
        #     attn_mode="full",
        #     use_rope=False,
        #     window_size=8,
        #     use_checkpoint=False,
        #     qk_rms_norm=False,
        # )

        self.pruning_head = SparseOccHead(self.out_channels, out_channels=1)
        
    def forward(self, x: sp.SparseTensor, context: sp.SparseTensor, pruning=False, training=True, threshold=0.5, force_no_prune=False) -> sp.SparseTensor:
        """
        Apply the block to a Tensor, conditioned on a timestep embedding.

        Args:
            x: an [N x C x ...] Tensor of features.
        Returns:
            an [N x C x ...] Tensor of outputs.
        """
        h = self.act_layers(x)
        h = self.sub(h)
        x = self.sub(x)
        h = self.out_layers(h)
        h = h + self.skip_connection(x)
        if pruning:
            # h = self.ca_attn(h, context)

            occ_prob = self.pruning_head(h)

            if training == False and force_no_prune == False:
                scores = torch.sigmoid(occ_prob.feats).squeeze(-1)
                
                n_points = scores.shape[0]
                if n_points % 8 != 0:
                    occ_mask = scores >= threshold
                else:
                    n_parents = n_points // 8
                    scores_grouped = scores.view(n_parents, 8)
                    
                    mask_grouped = scores_grouped >= threshold
                    
                    all_failed_mask = mask_grouped.sum(dim=1) == 0 
                    
                    if all_failed_mask.any():
                        # rescue_k = 2
                        rescue_k = 1
                        
                        failed_scores = scores_grouped[all_failed_mask]
                        
                        _, topk_indices = torch.topk(failed_scores, k=rescue_k, dim=1)
                        
                        failed_row_idxs = torch.nonzero(all_failed_mask, as_tuple=True)[0]
                        
                        rows_expanded = failed_row_idxs.unsqueeze(1).expand(-1, rescue_k)
                        
                        mask_grouped[rows_expanded, topk_indices] = True
                    
                    occ_mask = mask_grouped.view(-1)

                h = sp.SparseTensor(feats=h.feats[occ_mask], coords=h.coords[occ_mask])
            
            return h, occ_prob
        else:
            return h, None
        
class SparseDecoderBlock_vtx(nn.Module):
    def __init__(
        self,
        resolution: int,
        model_channels: int,
        in_channels: int,
        out_channels: int,
        num_blocks: int,
        num_heads: Optional[int] = None,
        num_head_channels: Optional[int] = 64,
        mlp_ratio: float = 4,
        attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
        window_size: int = 8,
        pe_mode: Literal["ape", "rope"] = "ape",
        use_fp16: bool = False,
        use_checkpoint: bool = False,
        qk_rms_norm: bool = False,
        using_subdivide=False,
        using_attn=False,
    ):
        super().__init__()
        self.resolution = resolution
        self.using_subdivide=using_subdivide
        self.using_attn=using_attn
        
        # 2. Upsampling via SparseSubdivideBlock3d
        # using subdivide
        if using_subdivide:
            self.upsample = SparseSubdivideBlock3d_vtx(
                channels=in_channels,
                resolution=resolution,
                out_channels=out_channels,
                # num_groups=4,
                num_groups=32,
                using_attn=using_attn,
            )
        else:
            # Using Upsample, get coords idx from cache
            self.upsample = SparseResBlock3d(
                channels=in_channels,
                out_channels=model_channels // 2,
                downsample=False,
                upsample=True,
            )
        
        if use_fp16:
            self.convert_to_fp16()

    def forward(self, x: sp.SparseTensor, context: sp.SparseTensor, pruning: bool = False, training: bool = False, threshold: float = 0.5, force_no_prune=False,):
        # h = self.transformer(x)
        h = x
        h = h.type(x.dtype)
        h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
        if self.using_subdivide:
            if pruning:
                h, occ_prob = self.upsample(h, context, pruning=pruning, training=training, threshold=threshold, force_no_prune=force_no_prune,)
                return h, occ_prob
            else:
                h, _ = self.upsample(h, context, pruning=pruning, training=training)
                return h, None
        else:
            h = self.upsample(h)
        return h
        

    def convert_to_fp16(self):
        """Convert all components to float16"""
        self.transformer.convert_to_fp16()
        convert_module_to_f16(self.upsampler)
        convert_module_to_f16(self.out_proj)

    def convert_to_fp32(self):
        """Convert all components to float32"""
        self.transformer.convert_to_fp32()
        convert_module_to_f32(self.upsampler)
        convert_module_to_f32(self.out_proj)

class SparseDecoderBlock(nn.Module):
    def __init__(
        self,
        resolution: int,
        model_channels: int,
        in_channels: int,
        out_channels: int,
        num_blocks: int,
        num_heads: Optional[int] = None,
        num_head_channels: Optional[int] = 64,
        mlp_ratio: float = 4,
        attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
        window_size: int = 8,
        pe_mode: Literal["ape", "rope"] = "ape",
        use_fp16: bool = False,
        use_checkpoint: bool = False,
        qk_rms_norm: bool = False,
        using_subdivide=False,
        using_attn=False,
    ):
        super().__init__()
        self.resolution = resolution
        self.using_subdivide=using_subdivide
        self.using_attn=using_attn
        
        # 1. Transformer processing
        # self.transformer = SparseTransformerBase(
        #     in_channels=in_channels,
        #     model_channels=model_channels,
        #     num_blocks=num_blocks,
        #     num_heads=num_heads,
        #     num_head_channels=num_head_channels,
        #     mlp_ratio=mlp_ratio,
        #     attn_mode=attn_mode,
        #     window_size=window_size,
        #     pe_mode=pe_mode,
        #     use_fp16=use_fp16,
        #     use_checkpoint=use_checkpoint,
        #     qk_rms_norm=qk_rms_norm,
        # )
        
        # 2. Upsampling via SparseSubdivideBlock3d
        # using subdivide
        if using_subdivide:
            self.upsample = SparseSubdivideBlock3d(
                channels=model_channels,
                resolution=resolution,
                out_channels=out_channels,
                # num_groups=4,
                num_groups=32,
                using_attn=using_attn,
            )
        else:
            # Using Upsample, get coords idx from cache
            self.upsample = SparseResBlock3d(
                channels=model_channels,
                out_channels=model_channels // 2,
                downsample=False,
                upsample=True,
            )
        
        if use_fp16:
            self.convert_to_fp16()

    def forward(self, x: sp.SparseTensor, pruning: bool = False, training: bool = False, threshold: float = 0.5, force_no_prune=False,):
        # h = self.transformer(x)
        h = x
        h = h.type(x.dtype)
        h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
        if self.using_subdivide:
            if pruning:
                h, occ_prob = self.upsample(h, pruning=pruning, training=training, threshold=threshold, force_no_prune=force_no_prune,)
                return h, occ_prob
            else:
                h, _ = self.upsample(h, pruning=pruning, training=training)
                return h, None
        else:
            h = self.upsample(h)
        return h
        

    def convert_to_fp16(self):
        """Convert all components to float16"""
        self.transformer.convert_to_fp16()
        convert_module_to_f16(self.upsampler)
        convert_module_to_f16(self.out_proj)

    def convert_to_fp32(self):
        """Convert all components to float32"""
        self.transformer.convert_to_fp32()
        convert_module_to_f32(self.upsampler)
        convert_module_to_f32(self.out_proj)

class VoxelVAE(nn.Module):
    def __init__(
        self,
        # Core architecture parameters
        in_channels: int = 64,
        encoder_blocks: List[Dict] = [],
        decoder_blocks_vtx: List[Dict] = [],
        decoder_blocks_edge: List[Dict] = [],
        
        # Shared transformer parameters
        num_heads: int = 8,
        num_head_channels: Optional[int] = 64,
        mlp_ratio: float = 4.0,
        attn_mode: str = "swin",
        window_size: int = 8,
        pe_mode: str = "ape",
        use_fp16: bool = False,
        use_checkpoint: bool = True,
        qk_rms_norm: bool = False,
        
        latent_dim: int = 8,
        using_subdivide: bool = True,
        using_attn: bool = False,

        ctx_channels: int = 768,
        attn_first=True,

        pred_direction=False,
    ):
        super().__init__()
        self.latent_dim = latent_dim
        self.using_subdivide = using_subdivide
        self.using_attn = using_attn
        self.decoder_blocks_vtx = decoder_blocks_vtx
        self.decoder_blocks_edge = decoder_blocks_edge
        self.pred_direction = pred_direction
        # print('pred_direction', pred_direction)
        

        current_channels = in_channels
        enc_config = encoder_blocks[0]
        self.encoder = SparseTransformerBase(
            in_channels=enc_config['in_channels'],
            model_channels=enc_config['model_channels'],
            num_blocks=enc_config['num_blocks'],
            num_heads=enc_config['num_heads'],
            num_head_channels=num_head_channels,
            attn_mode=attn_mode,
            window_size=window_size,
            pe_mode=pe_mode,
            mlp_ratio=mlp_ratio,
            use_fp16=use_fp16,
            use_checkpoint=use_checkpoint,
            qk_rms_norm=qk_rms_norm,
        )
        current_channels = enc_config['out_channels']

        self.latent_expander = SparseTransformerBase(
            in_channels=latent_dim,
            model_channels=512,
            num_blocks=8,
            num_heads=8,
            num_head_channels=num_head_channels,
            attn_mode=attn_mode,
            window_size=window_size,
            pe_mode=pe_mode,
            mlp_ratio=mlp_ratio,
            use_fp16=use_fp16,
            use_checkpoint=use_checkpoint,
            qk_rms_norm=qk_rms_norm,
        )

        self.vtx_proj = sp.SparseLinear(512, decoder_blocks_vtx[0]['in_channels'])
        
        self.out_layer = sp.SparseLinear(current_channels, 2 * latent_dim)
        
        self.vtx_head_64 = SparseEdgeHead(512, out_channels=1)
        
       
        self.decoder_vtx = nn.ModuleList()
        for config in decoder_blocks_vtx:
            self.decoder_vtx.append(
                SparseDecoderBlock_vtx(
                    resolution=config['resolution'],
                    model_channels=config['model_channels'],
                    in_channels=config['in_channels'],
                    out_channels=config['out_channels'],
                    num_blocks=config['num_blocks'],
                    num_heads=config['num_heads'],
                    num_head_channels=num_head_channels,
                    mlp_ratio=mlp_ratio,
                    attn_mode=attn_mode,
                    window_size=window_size,
                    pe_mode=pe_mode,
                    use_fp16=use_fp16,
                    use_checkpoint=use_checkpoint,
                    qk_rms_norm=qk_rms_norm,
                    using_subdivide=using_subdivide,
                    using_attn=using_attn,
                )
            )

        self.decoder_vtx_ca = nn.ModuleList()
        self.latent_proj = nn.ModuleList()
        for config in decoder_blocks_vtx:
            self.latent_proj.append(sp.SparseLinear(latent_dim, config['context_channels']))
            self.decoder_vtx_ca.append(
                SparseTransformerCrossBase(
                    in_channels=config['out_channels'],
                    model_channels=config['model_channels'],
                    context_channels=config['context_channels'],
                    num_blocks=config['num_blocks'],
                    num_heads=config['num_heads'],
                    num_head_channels=num_head_channels,
                    mlp_ratio=mlp_ratio,
                    attn_mode=attn_mode,
                    window_size=window_size,
                    pe_mode=pe_mode,
                    use_fp16=use_fp16,
                    use_checkpoint=use_checkpoint,
                    qk_rms_norm=qk_rms_norm,
                )
            )

        if use_fp16:
            self.convert_to_fp16()

    def encode(self, x: sp.SparseTensor, ctx: Optional[SparseTensor] = None, sample_posterior=True, prune_active=False,) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder path with skip connections"""
        h = self.encoder(x)
        h = h.type(x.dtype)
        h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
        h = self.out_layer(h)
        
        if prune_active:
            return h
        
        posterior = DiagonalGaussianDistribution(h.feats, feat_dim=-1)
        if sample_posterior:
        # if True:
            z = posterior.sample()
        else:
            z = posterior.mode()
        z = h.replace(z)
        return z, posterior

    def decode(self, latent_: sp.SparseTensor, gt_vertex_voxels_list: List[sp.SparseTensor], 
           gt_edge_voxels_list: List[sp.SparseTensor], training=True, pruning=True, sample_ratio=0.,
           inference_threshold=0.2, vis_last_layer: bool = True) -> List[Dict]:
        """
        Modified multi-resolution decoding that separates vertex and edge voxels from the start.
        
        Args:
            latent: Initial SparseTensor from encoder at 64-resolution.
            gt_vertex_voxels_list: Ground-truth vertex SparseTensors at [64, 128, 256, 512]
            gt_edge_voxels_list: Ground-truth edge SparseTensors at [64, 128, 256, 512]
            training: Whether to apply pruning during training
            
        Returns:
            List[Dict] with separate vertex and edge predictions at each level
        """
        # Expand latent features
        latent = self.latent_expander(latent_)
        
        results = []
        
        def flatten_coords(coords_4d: torch.Tensor):
            coords_4d_long = coords_4d.long()
            
            base_x = 1024 
            base_y = 1024 * 1024
            base_z = 1024 * 1024 * 1024 
            
            flat_coords = coords_4d_long[:, 0] * base_z + \
                        coords_4d_long[:, 1] * base_y + \
                        coords_4d_long[:, 2] * base_x + \
                        coords_4d_long[:, 3]
            return flat_coords
        
        # ---- Level 0: Split into vertex and edge at 64 resolution ----
        # if not training:
        #     # For inference, we can't use GT to split, so we need a learned split
        #     vtx_probs = self.vtx_head_64(latent)
        #     vertex_mask = vtx_probs.feats[:, 0] > inference_threshold  # Assuming first channel is vertex prob
            
        #     print('vertex_mask.sum()', vertex_mask.sum())

        #     vertex_x = sp.SparseTensor(
        #         feats=latent.feats[vertex_mask],
        #         coords=latent.coords[vertex_mask],
        #     )
        if not training:
            # For inference, we can't use GT to split, so we need a learned split
            vtx_probs = self.vtx_head_64(latent)
            
            scores = vtx_probs.feats[:, 0]
            
            vertex_mask = scores > inference_threshold
            
            if vertex_mask.sum() == 0 and scores.numel() > 0:
                k = min(2, scores.numel())
                
                print(f"Warning: No points passed threshold {inference_threshold}. Forcing top {k} points.")
                
                _, top_indices = torch.topk(scores, k=k)
                
                vertex_mask[top_indices] = True

            print('vertex_mask.sum()', vertex_mask.sum())

            vertex_x = sp.SparseTensor(
                feats=latent.feats[vertex_mask],
                coords=latent.coords[vertex_mask],
            )

            print('vertex_mask.sum()', vertex_mask.sum())

            vertex_x = sp.SparseTensor(
                feats=latent.feats[vertex_mask],
                coords=latent.coords[vertex_mask],
            )

            vertex_x = self.vtx_proj(vertex_x)

            results.append({
                'coords': vtx_probs.coords[..., 1:],
                'feats': vtx_probs.feats,
                'sp_tensor': vtx_probs,
                'vertex_mask': vertex_mask,
                'vtx_sp': vertex_x,
            })
        else:
            vtx_probs = self.vtx_head_64(latent)

            # Training: use GT to split vertex and edge
            gt_vertex_coords = gt_vertex_voxels_list[0].coords if hasattr(gt_vertex_voxels_list[0], 'coords') else gt_vertex_voxels_list[0]
            # gt_edge_coords = gt_edge_voxels_list[0].coords if hasattr(gt_edge_voxels_list[0], 'coords') else gt_edge_voxels_list[0]
            
            pred_flat = flatten_coords(latent.coords)
            vertex_gt_flat = flatten_coords(gt_vertex_coords)
            # edge_gt_flat = flatten_coords(gt_edge_coords)


            # Create masks for vertex and edge
            vertex_mask = torch.isin(pred_flat, vertex_gt_flat)
            # edge_mask = torch.isin(pred_flat, edge_gt_flat)

 
            # Create separate SparseTensors
            vertex_x = sp.SparseTensor(feats=latent.feats[vertex_mask], coords=latent.coords[vertex_mask],)
            # edge_x = sp.SparseTensor(feats=latent.feats[edge_mask], coords=latent.coords[edge_mask],)

            vertex_x = self.vtx_proj(vertex_x)

            results.append({
                'vtx_coords_3d': vtx_probs.coords,
                'vtx_feats': vtx_probs.feats,
                'vertex_mask': vertex_mask,

                # 'edge_coords_3d': edge_probs.coords,
                # 'edge_feats': edge_probs.feats,
                # 'edge_mask': edge_mask,

                'vertex_gt_coords': gt_vertex_coords,
                # 'edge_gt_coords': gt_edge_coords,
            })
        
        # ---- Shared upsampling blocks ----
        for i, _ in enumerate(self.decoder_vtx):
            is_last_layer = (i == len(self.decoder_vtx) - 1)
            force_no_prune = (not training) and is_last_layer and vis_last_layer

            if not training:
                # Inference path
                # edge_x, edge_occ_probs = self.decoder_edge[i](edge_x, pruning=True, training=training, threshold=inference_threshold, force_no_prune=force_no_prune)
                # edge_pred_coords = edge_x.coords
                
                vertex_x, vertex_occ_probs = self.decoder_vtx[i](vertex_x, context=None, pruning=True, training=training, threshold=inference_threshold, force_no_prune=force_no_prune)
                vertex_x = self.decoder_vtx_ca[i](x=vertex_x, context=self.latent_proj[i](latent_))

                vertex_pred_coords = vertex_x.coords

                print('vertex_x.coords.shape', i, vertex_x.coords.shape)
                # print('edge_x.coords.shape', i, edge_x.coords.shape)

                if i == len(self.decoder_vtx) - 1:
                    # predicted_offset = self.error_pred(edge_x)
                    results.append({
                        'vertex': {
                            'coords': vertex_x.coords[..., 1:],
                            'coords_4d': vertex_x.coords,
                            'feats': vertex_x.feats,
                            'sp_tensor': vertex_x,
                            'occ_probs': vertex_occ_probs.feats,
                        },
                    })


                else:
                    results.append({
                        'vertex': {
                            'coords': vertex_x.coords[..., 1:],
                            'coords_4d': vertex_x.coords,
                            'feats': vertex_x.feats,
                            'sp_tensor': vertex_x,
                        },
                        
                    })
            else:
                # Training path - apply pruning using ground truth
                # Process vertex branch
                vertex_x, vertex_occ_probs = self.decoder_vtx[i](vertex_x, context=None, pruning=True, training=training,)
                vertex_x = self.decoder_vtx_ca[i](x=vertex_x, context=self.latent_proj[i](latent_))

                vertex_pred_coords = vertex_x.coords

                gt_vertex_coords = gt_vertex_voxels_list[i + 1].coords if hasattr(gt_vertex_voxels_list[i + 1], 'coords') else gt_vertex_voxels_list[i + 1]
                vertex_pred_flat = flatten_coords(vertex_pred_coords)
                vertex_gt_flat = flatten_coords(gt_vertex_coords)
                vertex_isin_mask = torch.isin(vertex_pred_flat, vertex_gt_flat)
                vertex_prune_labels = vertex_isin_mask.float()

                vertex_x = sp.SparseTensor(
                    feats=vertex_x.feats[vertex_prune_labels.bool()],
                    coords=vertex_x.coords[vertex_prune_labels.bool()]
                )
                
                if i == len(self.decoder_vtx) - 1:
                    results.append({
                        'vertex': {
                            'coords': vertex_x.coords, 'feats': vertex_x.feats, 'occ_probs': vertex_occ_probs.feats,
                            'occ_coords': vertex_occ_probs.coords, 'prune_labels': vertex_prune_labels,
                            'sp_tensor': vertex_x, 'coords_4d': vertex_x.coords,
                            'gt_coords': gt_vertex_coords, 'pred_mask': vertex_prune_labels.bool()
                        },
                    })
                else:
                    results.append({
                        'vertex': {
                            'coords': vertex_x.coords, 'feats': vertex_x.feats, 'occ_probs': vertex_occ_probs.feats,
                            'occ_coords': vertex_occ_probs.coords, 'prune_labels': vertex_prune_labels,
                            'sp_tensor': vertex_x, 'coords_4d': vertex_x.coords,
                            'gt_coords': gt_vertex_coords, 'pred_mask': vertex_prune_labels.bool()
                        },
                    })
        
        return results


    def forward(self, sparse_input, gt_vertex_voxels_list=None, gt_edge_voxels_list=None, training=True, sample_ratio=0.):
        latent_64, posterior = self.encode(sparse_input)
        results = self.decode(
            latent_64, 
            gt_vertex_voxels_list=gt_vertex_voxels_list,
            gt_edge_voxels_list=gt_edge_voxels_list,
            training=training,
            sample_ratio=sample_ratio
        )
        
        return results, posterior, latent_64

    def convert_to_fp16(self):
        """Convert all components to float16"""
        self.encoder.apply(lambda m: m.convert_to_fp16() if hasattr(m, 'convert_to_fp16') else None)
        self.decoder.apply(lambda m: m.convert_to_fp16() if hasattr(m, 'convert_to_fp16') else None)
        convert_module_to_f16(self.latent_proj)
        if hasattr(self, 'skip_projs'):
            self.skip_projs.apply(lambda m: convert_module_to_f16(m))

    def convert_to_fp32(self):
        """Convert all components to float32"""
        self.encoder.apply(lambda m: m.convert_to_fp32() if hasattr(m, 'convert_to_fp32') else None)
        self.decoder.apply(lambda m: m.convert_to_fp32() if hasattr(m, 'convert_to_fp32') else None)
        convert_module_to_f32(self.latent_proj)
        if hasattr(self, 'skip_projs'):
            self.skip_projs.apply(lambda m: convert_module_to_f32(m))

