
import torch

class Model(torch.nn.Module):
    def __init__(self, reduce_type='mean'):
        super().__init__()
        assert reduce_type in ['mean', 'max', 'sum']
        self.reduce_type = reduce_type

    def forward(self, feats, coors):
        """
        feats: (N, C)
        coors: (N, D) - int coords
        Returns:
            reduced_feats: (M, C)
            out_coors: (M, D)
            coors_map: (N,) - mapping from input to unique voxel index
            reduce_count: (M,)
        """
        # 1. Filter negative indices (if any, as per CUDA code)
        # "coors.masked_fill(coors.lt(0).any(-1, true), -1)"
        # Cast to int32 for NPU compatibility (int64 comparisons not supported)
        coors_i32 = coors.int()
        valid_mask = (coors_i32 >= 0).all(dim=1)
        
        # 2. Unique
        # PyTorch `unique` on dim 0 supports returning inverse
        unique_coors, inverse_indices = torch.unique(coors, sorted=True, return_inverse=True, dim=0)
        
        # Filter out the invalid coordinate (e.g. -1,-1,-1) if it exists and wasn't filtered before?
        # Actually standard `unique` handles it.
        # But if we want to mimic CUDA logic exactly:
        # The CUDA code does `coors_clean` setting invalid to -1.
        # Then `unique` sorts -1 to the front (index 0).
        # Then slices off index 0.
        
        # Let's assume input is valid or handle simple unique.
        
        M = unique_coors.shape[0]
        C = feats.shape[1]
        
        # 3. Reduce
        reduced_feats = torch.zeros(M, C, dtype=feats.dtype, device=feats.device)
        reduce_count = torch.zeros(M, dtype=feats.dtype, device=feats.device)
        
        if self.reduce_type == 'sum':
            reduced_feats.index_add_(0, inverse_indices, feats)
            # count not strictly needed for sum but returned
            ones = torch.ones(feats.shape[0], dtype=feats.dtype, device=feats.device)
            reduce_count.index_add_(0, inverse_indices, ones)
            
        elif self.reduce_type == 'mean':
            reduced_feats.index_add_(0, inverse_indices, feats)
            ones = torch.ones(feats.shape[0], dtype=feats.dtype, device=feats.device)
            reduce_count.index_add_(0, inverse_indices, ones)
            
            # Avoid div by zero
            cnt = reduce_count.unsqueeze(1).clamp(min=1)
            reduced_feats = reduced_feats / cnt
            
        elif self.reduce_type == 'max':
            # PyTorch doesn't have index_max_.
            # We can implementation Scatter Max using scatter_reduce_ (PyTorch 1.12+)
            # Or loop. Or `torch_scatter`.
            # Standard PyTorch `scatter_reduce` with check.
            try:
                reduced_feats = torch.zeros(M, C, dtype=feats.dtype, device=feats.device).fill_(-1e9)
                reduced_feats = reduced_feats.scatter_reduce(
                    0, 
                    inverse_indices.unsqueeze(1).expand(-1, C), 
                    feats, 
                    reduce='amax', 
                    include_self=True
                )
                
                # Count for Max? Usually usually just 1s or actual occupancy?
                # CUDA returns `reduce_count`. It's the number of points in voxel.
                ones = torch.ones(feats.shape[0], dtype=feats.dtype, device=feats.device)
                reduce_count.index_add_(0, inverse_indices, ones)
            except AttributeError:
                # Fallback for old PyTorch: expensive loop or expanding
                # Use a loop over M (voxels).
                for m in range(M):
                    mask = (inverse_indices == m)
                    if mask.any():
                        reduced_feats[m] = feats[mask].max(dim=0)[0]
                        reduce_count[m] = mask.sum()
                        
        return reduced_feats, unique_coors, inverse_indices.int(), reduce_count.int()

def get_init_inputs():
    return ['mean']

def get_inputs():
    N, C, D = 100, 32, 3
    feats = torch.randn(N, C)
    # Random coords 0..10
    coors = torch.randint(0, 10, (N, D))
    return [feats, coors]
