
import torch

class Model(torch.nn.Module):
    """
    Voxel Pool Stereo - Same algorithm as Voxel Pool with optimized CUDA implementation.
    Python reference implementation is identical to voxel_pool.py.
    """
    def __init__(self, num_voxel_x, num_voxel_y, num_voxel_z):
        super().__init__()
        self.num_voxel_x = num_voxel_x
        self.num_voxel_y = num_voxel_y
        self.num_voxel_z = num_voxel_z

    def forward(self, geom_xyz, input_features):
        """
        geom_xyz: (B*N, 3) - int coords [x, y, z]
        input_features: (B*N, C)
        Returns:
            output_features: (B, Y, X, C)
            pos_memo: (B*N, 3) - [batch_idx, y, x] for valid points
        """
        return self._forward_impl(geom_xyz, input_features)

    def _forward_impl(self, geom_xyz, input_features, batch_size=1, num_points=None):
        if num_points is None:
            num_points = geom_xyz.shape[0] // batch_size
        
        X, Y, Z = self.num_voxel_x, self.num_voxel_y, self.num_voxel_z
        C = input_features.shape[1]
        total_points = geom_xyz.shape[0]
        
        device = geom_xyz.device
        indices = torch.arange(total_points, device=device, dtype=torch.int32)
        b = indices // num_points
        
        x = geom_xyz[:, 0].int()
        y = geom_xyz[:, 1].int()
        z = geom_xyz[:, 2].int()
        
        # Cast to int32 for NPU compatibility (int64 comparisons not supported)
        # Also cast bounds to ensure both sides of comparison are same dtype
        X_i32, Y_i32, Z_i32 = int(X), int(Y), int(Z)
        valid = (x >= 0) & (x < X_i32) & (y >= 0) & (y < Y_i32) & (z >= 0) & (z < Z_i32)
        
        pos_memo = torch.zeros(total_points, 3, device=device, dtype=torch.int32)
        pos_memo[valid, 0] = b[valid].int()
        pos_memo[valid, 1] = y[valid].int()
        pos_memo[valid, 2] = x[valid].int()
        
        b_valid = b[valid]
        x_valid = x[valid]
        y_valid = y[valid]
        feats = input_features[valid]
        
        flat_idx = b_valid * (Y * X) + y_valid * X + x_valid
        
        output = torch.zeros(batch_size * Y * X, C, device=device, dtype=input_features.dtype)
        output.index_add_(0, flat_idx, feats)
        
        output_features = output.view(batch_size, Y, X, C)
        
        return output_features, pos_memo

def get_init_inputs():
    return [10, 10, 1]

def get_inputs():
    B = 1
    N = 100
    C = 32
    geom_xyz = torch.randint(0, 10, (B*N, 3)).int()
    feats = torch.randn(B*N, C)
    return [geom_xyz, feats]
