import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    output_grid = np.copy(input_grid)
    for i in range(input_grid.shape[0] - 1):
        for j in range(input_grid.shape[1] - 1):
            if input_grid[i][j] == grey and input_grid[i+1][j] == grey and input_grid[i][j+1] == grey and input_grid[i+1][j+1] == grey:
                output_grid[i-1][j-1] = blue
                output_grid[i-1][j+2] = red
                output_grid[i+2][j-1] = green
                output_grid[i+2][j+2] = yellow
    return output_grid