import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find the smallest matrix containing all red pixels
    red_pixels = np.where(input_grid == red)
    min_row, max_row = np.min(red_pixels[0]), np.max(red_pixels[0])
    min_col, max_col = np.min(red_pixels[1]), np.max(red_pixels[1])
    m = input_grid[min_row:max_row+1, min_col:max_col+1]
    
    # Iterate over all submatrices with the same shape as m
    for i in range(input_grid.shape[0] - m.shape[0] + 1):
        for j in range(input_grid.shape[1] - m.shape[1] + 1):
            submatrix = input_grid[i:i+m.shape[0], j:j+m.shape[1]]
            red_pixels_in_m = np.where(m == red)
            if np.all(submatrix[red_pixels_in_m] == black):
                input_grid[i:i+m.shape[0], j:j+m.shape[1]][red_pixels_in_m] = red
    
    return input_grid
    