
import torch

class Model(torch.nn.Module):
    def __init__(self, min_radius=0.0, max_radius=0.1, nsample=32):
        super().__init__()
        self.min_radius = min_radius
        self.max_radius = max_radius
        self.nsample = nsample

    def forward(self, new_xyz, xyz):
        """
        new_xyz: (B, M, 3)
        xyz: (B, N, 3)
        Returns:
            idx: (B, M, nsample) int32
        """
        B, M, _ = new_xyz.shape
        B, N, _ = xyz.shape
        
        # Compute pairwise distance squared
        # (B, M, 1, 3) - (B, 1, N, 3) -> (B, M, N, 3)
        diff = new_xyz.unsqueeze(2) - xyz.unsqueeze(1)
        dist2 = torch.sum(diff ** 2, dim=-1) # (B, M, N)
        
        # Mask
        min_r2 = self.min_radius ** 2
        max_r2 = self.max_radius ** 2
        
        # Helper to select indices
        idx = torch.zeros(B, M, self.nsample, dtype=torch.long, device=new_xyz.device)
        
        # Iterate to mimic sequential selection (or use topk?)
        # CUDA implementation:
        # For each m, iterate n from 0 to N-1.
        # If valid:
        #   If first valid: fill ALL nsample with this index k.
        #   Then set idx[cnt] = k.
        #   Increment cnt.
        #   If cnt == nsample, break.
        
        # This "fill all with first" logic is unique.
        # Implementing this vectorized is tricky.
        # We can do it per batch/point or using masks.
        
        # Full mask: (B, M, N)
        # Cast to float32 for NPU compatibility (int64 comparisons not supported)
        dist2_f = dist2.float()
        min_r2_f = float(min_r2)
        max_r2_f = float(max_r2)
        valid = (dist2_f >= min_r2_f) & (dist2_f < max_r2_f)
        # Also include d2==0 case? CUDA says: `if (d2 == 0 || (d2 >= min && d2 < max))`
        valid = valid | (dist2_f == 0)
        
        # We can process iterating loop in Python since M, N are usually small (thousands), 
        # but pure Python loop over M points is slow.
        # However, this is a REFERENCE implementation. Correctness >> Speed.
        # But if M=4096, N=16384, Python loop is TOO slow.
        
        # Vectorized approximation of "scan":
        # Get indices of valid neighbors.
        
        # Let's try to implement the exact logic block-wise or use a slightly faster method.
        # Actually, since we want "Ground Truth" for verification, slow is okay if it finishes in <30s.
        # For small inputs (M=128, N=512) used in testing, python loop is fine.
        
        for b in range(B):
            for m in range(M):
                cnt = 0
                first_idx = 0 # Default to 0 if nothing found? CUDA doesn't specificy.
                # Actually, if we follow CUDA strictly, if no valid point found, idx is garbage.
                # But PyTorch allocation is usually 0 if we use zeros().
                
                # Check neighbors
                # We can speed up by operating on the N dimension vector
                d2_row = dist2[b, m] 
                # (N,)
                valid_mask = (d2_row >= min_r2) & (d2_row < max_r2)
                valid_mask = valid_mask | (d2_row == 0)
                
                valid_indices = torch.nonzero(valid_mask, as_tuple=True)[0] # (K,)
                
                if valid_indices.numel() > 0:
                    first = valid_indices[0].item()
                    # Fill all with first
                    idx[b, m, :] = first
                    
                    # Fill specifically
                    count = min(self.nsample, valid_indices.numel())
                    idx[b, m, :count] = valid_indices[:count]
                else:
                    # No neighbors found. 
                    # CUDA behavior: idx undefined. 
                    # We leave as 0 (first point of cloud).
                    pass
                    
        return idx.int()

def get_init_inputs():
    return [0.0, 0.2, 8] # min_r, max_r, nsample

def get_inputs():
    B, M, N = 2, 128, 512
    # new_xyz (B, M, 3)
    new_xyz = torch.rand(B, M, 3)
    # xyz (B, N, 3)
    xyz = torch.rand(B, N, 3)
    
    return [new_xyz, xyz]
