
import torch
import torch.nn.functional as F

class Model(torch.nn.Module):
    """
    CARAFE Naive - Same algorithm as CARAFE with simpler (non-optimized) CUDA implementation.
    Python reference implementation is identical to CARAFE.
    """
    def __init__(self, kernel_size=3, group_size=1, scale_factor=2):
        super().__init__()
        self.kernel_size = kernel_size
        self.group_size = group_size
        self.scale_factor = scale_factor

    def forward(self, features, masks):
        """
        features: (N, C, H, W)
        masks: (N, group_size * K^2, H_out, W_out)
        Returns: (N, C, H_out, W_out)
        """
        N, C, H, W = features.shape
        
        # 1. Unfold Input: (N, C*K*K, H, W)
        pad = self.kernel_size // 2
        unfolded_feat = F.unfold(
            features, 
            kernel_size=self.kernel_size, 
            padding=pad, 
            stride=1
        )
        unfolded_feat = unfolded_feat.view(N, C, self.kernel_size*self.kernel_size, H, W)
        
        # 2. Upsample unfolded features to H_out, W_out
        unfolded_feat = unfolded_feat.view(N, -1, H, W)
        unfolded_feat_up = F.interpolate(unfolded_feat, scale_factor=self.scale_factor, mode='nearest')
        
        # Reshape to (N, Groups, C/Groups, K*K, H_out, W_out)
        H_out, W_out = unfolded_feat_up.shape[-2:]
        groups = self.group_size
        C_per_group = C // groups
        K2 = self.kernel_size * self.kernel_size
        
        feat_reshaped = unfolded_feat_up.view(N, groups, C_per_group, K2, H_out, W_out)
        
        # 3. Process Masks
        mask_reshaped = masks.view(N, groups, 1, K2, H_out, W_out)
        
        # 4. Weighted Sum
        weighted = feat_reshaped * mask_reshaped
        output = weighted.sum(dim=3)
        
        # 5. Reshape back
        output = output.view(N, C, H_out, W_out)
        
        return output

def get_init_inputs():
    return [3, 1, 2]

def get_inputs():
    N, C, H, W = 2, 64, 16, 16
    features = torch.randn(N, C, H, W)
    
    scale = 2
    K = 3
    groups = 1
    mask_channels = groups * K * K
    H_out = H * scale
    W_out = W * scale
    
    masks = torch.randn(N, mask_channels, H_out, W_out)
    masks = F.softmax(masks.view(N, groups, K*K, H_out, W_out), dim=2).view(N, mask_channels, H_out, W_out)
    
    return [features, masks]
