
import torch
import torch.nn.functional as F

class Model(torch.nn.Module):
    def __init__(self, kernel_size=3, padding=1, stride=1):
        super().__init__()
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride

    def forward(self, features, mask_h_idx, mask_w_idx):
        """
        features: (1, C, H, W) - Implementation usually assumes 1 image or batch handled externally?
                 CUDA Kernel takes `n` (total pixels? no, `n` is `mask_cnt` or similar).
                 Launcher takes `bottom_data` (H, W, C)? No (C, H, W).
                 Inputs: features, mask_indices.
        Returns: col (mask_cnt, C*K*K)
        """
        # We assume batch size N=1 for simplicity as the indices are flat or handled per image.
        # If N > 1, indices usually need batch index. Cuda kernel has `n` but indices are raw pointers.
        # The CUDA kernel loops `index` < `n` (this is `output_size` = `mask_cnt * channels`).
        
        # Unfold gives (N, C*K*K, L). L = H*W.
        # We need to extract specific columns L_i corresponding to (h, w).
        
        N, C, H, W = features.shape
        # unfold: (N, C*K*K, H_out, W_out). 
        # Note: Unfold returns (N, C*K*K, L). Mapping L back to (h, w) depends on stride/pad.
        
        # To match CUDA `MaskedIm2colForward`:
        # "h_col = mask_h_idx[m_index]"
        # "h_im = h_col * stride - pad + i" (Wait, CUDA code doesn't show stride multiplication in `h_offset = h_col - pad_h`).
        # Line 32: `h_offset = h_col - pad_h`. NO stride.
        # This implies it extracts around the pixel directly?
        # Likely Stride=1.
        
        # Validating input coordinates:
        # If the input `mask_h_idx` are coordinates in the Feature Map, then we just grab patches around them.
        
        # 1. Pad input if needed (or handle via unfold)
        # Unfold logic: padding is applied to input.
        
        # We use Unfold with stride=1, padding=pad.
        # output elements correspond to (h, w) center?
        # Unfold (flattened) order is row-major. index = h * W + w.
        
        unfolded = F.unfold(
            features, 
            kernel_size=self.kernel_size, 
            padding=self.padding, 
            stride=1
        ) # (N, C*K*K, H*W)
        
        # 2. Convert (h, w) indices to flat indices
        # flat_idx = h * W + w
        flat_indices = mask_h_idx * W + mask_w_idx # (M,)
        
        # 3. Gather
        # unfolded: (N, Channels, Spatial)
        # We want (N, Channels, M)
        
        # Support batch size?
        # If mask indices are global (unlikely) or per image.
        # Usually MaskedConv is used in sparse situations.
        # Let's assume N=1 or indices are for all batch.
        # If N>1, we need batch indices.
        # Looking at CUDA: `MaskedIm2colForward` doesn't seem to have batch index in mask?
        # It takes `mask_cnt`. `data_im` is flat.
        # `c_im = index / mask_cnt` WRONG. `c_im` calculation in kernel line 30 implies `index` iterates over `mask_cnt * channels`.
        # `data_col_ptr = data_col + c_col * mask_cnt + m_index`.
        # This looks like it processes ONE image or treats batch dim as merged.
        # We will assume simple case: N=1.
        
        # Select columns
        # unfolded[0, :, flat_indices] -> (C*K*K, M)
        # Transpose to (M, C*K*K) usually?
        # CUDA Output: `top_data` is `mask_cnt * channels` (channels = C*K*K).
        # Memory layout? `data_col + c_col * mask_cnt + m_index`.
        # This is (Channels, Mask_Cnt) in column-major? Or (Channels, Mask_Cnt) row major?
        # It writes to `data_col[c_col * mask_cnt + m_index]`.
        # So it's (Channels, Mask_Cnt).
        
        selected = unfolded[..., flat_indices] # (N, C*K*K, M)
        
        # Return (1, C*K*K, M) -> flatten?
        # Let's return the tensor.
        return selected

def get_init_inputs():
    return [3, 1, 1]

def get_inputs():
    # Only support batch size 1 for this reference logic to stay simple
    features = torch.randn(1, 32, 16, 16)
    
    # Random indices
    num_points = 10
    mask_h = torch.randint(0, 16, (num_points,))
    mask_w = torch.randint(0, 16, (num_points,))
    
    return [features, mask_h, mask_w]
