
import torch
import torchvision.ops

class Model(torch.nn.Module):
    def __init__(self, spatial_scale=1.0, pooled_height=7, pooled_width=7):
        super().__init__()
        self.spatial_scale = spatial_scale
        self.pooled_height = pooled_height
        self.pooled_width = pooled_width

    def forward(self, input, rois):
        return torchvision.ops.roi_pool(
            input,
            rois,
            output_size=(self.pooled_height, self.pooled_width),
            spatial_scale=self.spatial_scale
        )

def get_init_inputs():
    return [1.0, 7, 7]

def get_inputs():
    # input: (N, C, H, W)
    # rois: (K, 5) -> [batch_index, x1, y1, x2, y2]
    input_tensor = torch.randn(4, 256, 32, 32)
    
    # Generate random ROIs
    rois = []
    for i in range(10): # 10 ROIs
        batch_idx = float(torch.randint(0, 4, (1,)).item())
        x1 = float(torch.randint(0, 10, (1,)).item())
        y1 = float(torch.randint(0, 10, (1,)).item())
        w = float(torch.randint(5, 10, (1,)).item())
        h = float(torch.randint(5, 10, (1,)).item())
        rois.append([batch_idx, x1, y1, x1 + w, y1 + h])
    
    rois_tensor = torch.tensor(rois)
    return [input_tensor, rois_tensor]
