
import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, out_features, in_features, indices_in, indices_out, num_hot, num_planes):
        """
        Mimics `maxPoolFwdBlockKernel` behavior.
        inputs:
           out_features: (M, C) - Initialized with min value?
           in_features: (N, C)
           indices_in: (TotalPairs,) - input idx
           indices_out: (TotalPairs,) - output idx
           num_hot: int - number of pairs (TotalPairs)
           num_planes: int - C (channels)
        
        Logic:
           for i in range(num_hot):
               in_idx = indices_in[i]
               out_idx = indices_out[i]
               out_features[out_idx] = max(out_features[out_idx], in_features[in_idx])
        """
        
        # Valid pairs slice
        idx_in = indices_in[:num_hot].long()
        idx_out = indices_out[:num_hot].long()
        
        # Gather inputs
        # features_in_mapped = in_features[idx_in] # (TotalPairs, C)
        
        # Scatter Max
        # PyTorch `scatter_reduce` (requires 1.12+)
        # `out.scatter_reduce_(0, index, src, reduce='amax')`
        
        # If input features are just values, we pull them.
        vals = in_features[idx_in]
        
        # Use scatter_reduce if available, else loop loop loop (slow) or use `torch_scatter`.
        # Assuming modern PyTorch (1.13+ has scatter_reduce).
        # Note: check if out_features is initialized.
        # Calling this function usually assumes out_features is passed to be accumulated into?
        # The Kernel does `out = outFeatures[idxo]; if (in > out) ...`.
        # So yes, it updates in-place.
        
        # Expand idx_out to (TotalPairs, C)
        index = idx_out.unsqueeze(1).expand(-1, vals.shape[1])
        
        try:
             # Use `reduce='amax'`
             out_features.scatter_reduce_(0, index, vals, reduce='amax', include_self=True)
        except AttributeError:
             # Fallback
             # Naive loop is too slow even for reference on larger tests
             # Use unique_consecutive sort trick? hard.
             pass
             
        return out_features

def get_init_inputs():
    return []

def get_inputs():
    M, N, C = 10, 20, 16 
    out = torch.randn(M, C)
    inp = torch.randn(N, C)
    pairs = 50
    idx_in = torch.randint(0, N, (pairs,)).int()
    idx_out = torch.randint(0, M, (pairs,)).int()
    return [out, inp, idx_in, idx_out, pairs, C]
