import numpy as np

(black, blue, red, green, yellow, grey, pink, orange, teal, maroon) = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Step 1
    c = None
    n = None
    for j in range(input_grid.shape[1]):
        if input_grid[0, j] != black:
            c = input_grid[0, j]
            n = j
            break

    # Step 2
    for i in range(1, input_grid.shape[0]):
        for j in range(1, input_grid.shape[1]):
            if input_grid[i, j] != c and \
               input_grid[i-1, j] == c and \
               input_grid[i+1, j] == c and \
               input_grid[i, j-1] == c and \
               input_grid[i, j+1] == c:
                c_now = input_grid[i, j]
                if c_now in (black, c):
                    continue
                if input_grid[i+n+1, j] != c_now or \
                   input_grid[i, j+n+1] != c_now or \
                   input_grid[i+n+1, j+n+1] != c_now:
                    continue
                input_grid[i+1:i+n+1, j+1:j+n+1] = c_now

    # Step 3
    for i in range(0, input_grid.shape[0]-2*n-1, n+1):
        for j in range(0, input_grid.shape[1]-2*n-1, n+1):
            temp_grid = np.zeros((3, 3), dtype=np.int32)
            temp_grid[0, 0] = input_grid[i, j]
            temp_grid[0, 1] = input_grid[i, j+n+1]
            temp_grid[0, 2] = input_grid[i, j+2*n+2]
            temp_grid[1, 0] = input_grid[i+n+1, j]
            temp_grid[1, 1] = input_grid[i+n+1, j+n+1]
            temp_grid[1, 2] = input_grid[i+n+1, j+2*n+2]
            temp_grid[2, 0] = input_grid[i+2*n+2, j]
            temp_grid[2, 1] = input_grid[i+2*n+2, j+n+1]
            temp_grid[2, 2] = input_grid[i+2*n+2, j+2*n+2]
            if np.any(temp_grid[0,:] != black) and np.any(temp_grid[:,0] != black):
                return temp_grid

    return np.zeros((3, 3), dtype=np.int32)