
import torch
import torch.nn.functional as F

class Model(torch.nn.Module):
    def __init__(self, kernel_size=3, group_size=1, scale_factor=2):
        super().__init__()
        self.kernel_size = kernel_size
        self.group_size = group_size
        self.scale_factor = scale_factor

    def forward(self, features, masks):
        """
        features: (N, C, H, W)
        masks: (N, group_size * K^2, H_out, W_out)
             OR (N, K^2 * group_size, H_out, W_out) - Check CUDA channel order?
             CUDA: loop c < mask_channels.
             Split output pixel (ph, pw).
             Input pixel (down_ph, down_pw).
             Mask is loaded from `bottom_masks`.
             
        Returns: (N, C, H_out, W_out)
        """
        N, C, H, W = features.shape
        # Input features are unfolded to match output locations.
        # Logic: 
        # For output (y, x), we need input (y//scale, x//scale) and its KxK neighbors.
        # This is `unfold` on input, followed by `nearest` upsampling.
        
        # 1. Unfold Input: (N, C*K*K, H, W)
        pad = self.kernel_size // 2
        unfolded_feat = F.unfold(
            features, 
            kernel_size=self.kernel_size, 
            padding=pad, 
            stride=1
        ) # (N, C*K*K, H*W)
        unfolded_feat = unfolded_feat.view(N, C, self.kernel_size*self.kernel_size, H, W)
        
        # 2. Upsample unfolded features to H_out, W_out
        # (N, C*K*K, H, W) -> (N, C*K*K, H_out, W_out)
        # Note: scale_factor must be integer.
        # We merge C and K*K for upsampling efficiently?
        unfolded_feat = unfolded_feat.view(N, -1, H, W)
        unfolded_feat_up = F.interpolate(unfolded_feat, scale_factor=self.scale_factor, mode='nearest')
        
        # Reshape to (N, Groups, C/Groups, K*K, H_out, W_out)
        H_out, W_out = unfolded_feat_up.shape[-2:]
        groups = self.group_size
        C_per_group = C // groups
        K2 = self.kernel_size * self.kernel_size
        
        feat_reshaped = unfolded_feat_up.view(N, groups, C_per_group, K2, H_out, W_out)
        
        # 3. Process Masks
        # Masks: (N, Groups*K*K, H_out, W_out)
        # Reshape to (N, Groups, 1, K*K, H_out, W_out)
        # Checking CUDA order:
        # CUDA: `mask_c = (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix`
        # This implies Groups -> Ky -> Kx. So Groups * K*K.
        mask_reshaped = masks.view(N, groups, 1, K2, H_out, W_out)
        
        # 4. Weighted Sum
        # (N, G, C_g, K2, H_out, W_out) * (N, G, 1, K2, H_out, W_out)
        # Broadcasting mask over C_g
        weighted = feat_reshaped * mask_reshaped
        
        # Sum over K2 dimension
        output = weighted.sum(dim=3) # (N, G, C_g, H_out, W_out)
        
        # 5. Reshape back
        output = output.view(N, C, H_out, W_out)
        
        return output

def get_init_inputs():
    return [3, 1, 2] # K=3, Groups=1, Scale=2

def get_inputs():
    N, C, H, W = 2, 64, 16, 16
    features = torch.randn(N, C, H, W)
    
    scale = 2
    K = 3
    groups = 1
    mask_channels = groups * K * K
    H_out = H * scale
    W_out = W * scale
    
    masks = torch.randn(N, mask_channels, H_out, W_out)
    # Masks should strictly be normalized (sum to 1) for stability, but not strictly required by tensor op.
    masks = F.softmax(masks.view(N, groups, K*K, H_out, W_out), dim=2).view(N, mask_channels, H_out, W_out)
    
    return [features, masks]
