import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    output_grid = np.zeros((4, 4), dtype=int)
    
    # Find color1
    for i in range(input_grid.shape[0]):
        for j in range(input_grid.shape[1]):
            if input_grid[i, j] == black and i > 0 and j > 0 and input_grid[i-1, j] != black and input_grid[i, j-1] != black:
                color1 = input_grid[i-1, j]
                break
    
    # Find color2
    for i in range(input_grid.shape[0]):
        for j in range(input_grid.shape[1]-1, -1, -1):
            if input_grid[i, j] == black and i > 0 and j < input_grid.shape[1]-1 and input_grid[i-1, j] != black and input_grid[i, j+1] != black:
                color2 = input_grid[i-1, j]
                break
    
    # Find color3
    for i in range(input_grid.shape[0]-1, -1, -1):
        for j in range(input_grid.shape[1]):
            if input_grid[i, j] == black and i < input_grid.shape[0]-1 and j > 0 and input_grid[i+1, j] != black and input_grid[i, j-1] != black:
                color3 = input_grid[i+1, j]
                break
    
    # Find color4
    for i in range(input_grid.shape[0]-1, -1, -1):
        for j in range(input_grid.shape[1]-1, -1, -1):
            if input_grid[i, j] == black and i < input_grid.shape[0]-1 and j < input_grid.shape[1]-1 and input_grid[i+1, j] != black and input_grid[i, j+1] != black:
                color4 = input_grid[i+1, j]
                break
    
    # Change points to respective colors
    output_grid[0:2, 0:2] = color1
    output_grid[0:2, 2:4] = color2
    output_grid[2:4, 0:2] = color3
    output_grid[2:4, 2:4] = color4
    
    # Turn [1:3, 1:3] to black
    output_grid[1:3, 1:3] = black
    
    return output_grid