import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def find_black_squares(grid):
    black_squares = []
    for size in range(1, min(grid.shape)):
        for i in range(grid.shape[0] - size + 1):
            for j in range(grid.shape[1] - size + 1):
                if np.all(grid[i:i+size, j:j+size] == 0):
                    black_squares.append((i, j, size))
    return black_squares

def color_surrounded_squares(grid):
    black_squares = find_black_squares(grid)
    for i, j, size in black_squares:
        if i > 0 and j > 0 and i+size < grid.shape[0] and j+size < grid.shape[1]:
            if np.all(grid[i-1:i+size+1, j-1] == 5) and np.all(grid[i-1:i+size+1, j+size] == 5) and np.all(grid[i-1, j-1:j+size+1] == 5) and np.all(grid[i+size, j-1:j+size+1] == 5):
                grid[i:i+size, j:j+size] = 2
    return grid

def main(input_grid: np.ndarray) -> np.ndarray:
    output_grid = np.copy(input_grid)
    output_grid = color_surrounded_squares(output_grid)
    return output_grid
    