
import torch
import torchvision.ops

class Model(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
        super().__init__()
        # We wrap standard Conv2d just to have the weights/bias tensors
        # But actually DCN takes offset and mask as dynamic inputs.
        # So we just need `deform_conv2d` functional interface usually used.
        # But for stateful comparison, we might need a Module.
        # However, the CUDA kernel in `deformable_conv.cu` is usually just the `im2col` or low-level op.
        # But standard validation compares the full op?
        # The CUDA kernel `executable` usually just runs the forward function.
        # Let's assume we validate the functional API `modulated_deform_conv2d`.
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = (kernel_size, kernel_size)
        self.stride = (stride, stride)
        self.padding = (padding, padding)
        self.dilation = (dilation, dilation)
        self.groups = groups
        
        self.weight = torch.nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
        if bias:
            self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
            
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
             fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
             bound = 1 / math.sqrt(fan_in)
             torch.nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input, offset, mask):
        return torchvision.ops.deform_conv2d(
            input,
            offset,
            self.weight,
            self.bias,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            mask=mask
        )

# Needed for reset_parameters
import math

def get_init_inputs():
    return [32, 64, 3, 1, 1, 1, 1, True]

def get_inputs():
    bs = 2
    in_channels = 32
    H, W = 32, 32
    kernel_size = 3
    
    input = torch.randn(bs, in_channels, H, W)
    
    # offset: (bs, 2 * kernel_size * kernel_size, H_out, W_out)
    # mask: (bs, kernel_size * kernel_size, H_out, W_out)
    # assuming H_out = H, W_out = W due to padding=1, stride=1
    
    offset = torch.randn(bs, 2 * kernel_size * kernel_size, H, W)
    mask = torch.rand(bs, kernel_size * kernel_size, H, W) # Mask is usually sigmoid output [0, 1]
    
    return [input, offset, mask]
