from typing import Dict, Optional, Tuple, Union

import numpy as np
import torch
from torch.nn import functional as F


def reduce_mask(
    mask: torch.Tensor,
    block_size: Optional[Union[int, Tuple[int, int]]],
    stride: Optional[Union[int, Tuple[int, int]]],
    padding: Optional[Union[int, Tuple[int, int]]],
    verbose: bool = False,
) -> Optional[torch.Tensor]:
    if block_size is None or stride is None or padding is None:
        return None
    else:
        if isinstance(block_size, int):
            block_size = (block_size, block_size)
        if isinstance(padding, int):
            padding = (padding, padding)
        if isinstance(stride, int):
            stride = (stride, stride)
        H, W = mask.shape
        # Max Pooling only supports float tensor
        mask = mask.view(1, 1, H, W).to(torch.float32)
        mask = F.pad(mask, (padding[1], block_size[1], padding[0], block_size[0]))
        mask_pooled = F.max_pool2d(mask, block_size, stride)
        mask_pooled = mask_pooled[0, 0] > 0.5
        active_indices = torch.nonzero(mask_pooled)
        active_indices[:, 0] = stride[0] * active_indices[:, 0] - padding[0]
        active_indices[:, 1] = stride[1] * active_indices[:, 1] - padding[1]
        if verbose:
            num_active = active_indices.shape[0]
            total = mask_pooled.numel()
            print("Block Sparsity: %d/%d=%.2f%%" % (num_active, total, 100 * num_active / total))
        return active_indices.to(torch.int32).contiguous()
    
if __name__ == "__main__":
    mask = 