
import torch
import torchvision.ops

class Model(torch.nn.Module):
    def __init__(self, output_size=(7, 7), spatial_scale=1.0, sampling_ratio=0, aligned=True):
        super().__init__()
        self.output_size = output_size
        self.spatial_scale = spatial_scale
        self.sampling_ratio = sampling_ratio
        self.aligned = aligned

    def forward(self, input, rois):
        return torchvision.ops.roi_align(
            input,
            rois,
            output_size=self.output_size,
            spatial_scale=self.spatial_scale,
            sampling_ratio=self.sampling_ratio,
            aligned=self.aligned
        )

def get_init_inputs():
    return [(7, 7), 1.0, 0, True]

def get_inputs():
    batch_size = 2
    in_channels = 256
    H, W = 32, 32
    input = torch.randn(batch_size, in_channels, H, W)
    
    # ROIs: (K, 5) -> batch_id, x1, y1, x2, y2
    rois = []
    for _ in range(10):
        batch_idx = float(torch.randint(0, batch_size, (1,)).item())
        x1 = float(torch.randint(0, 10, (1,)).item())
        y1 = float(torch.randint(0, 10, (1,)).item())
        w = float(torch.randint(5, 10, (1,)).item())
        h = float(torch.randint(5, 10, (1,)).item())
        rois.append([batch_idx, x1, y1, x1 + w, y1 + h])
    
    rois = torch.tensor(rois)
    return [input, rois]
