import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Create a copy of the input grid to write the output
    output_grid = np.copy(input_grid)

    # Find the number of non-black colors in the first row
    first_row = input_grid[0]
    num_non_black = len(set(first_row)) - (1 if black in first_row else 0)

    # Rotate the grid if necessary
    if num_non_black != 1:
        output_grid = np.rot90(output_grid)

    # Iterate through each pixel in the grid
    for i in range(output_grid.shape[0]):
        for j in range(output_grid.shape[1]):
            # If the pixel in input_grid is black, find the color of its surrounding pixels
            if input_grid[i][j] == black:
                surrounding_colors = set()
                if i > 0:
                    surrounding_colors.add(input_grid[i-1][j])
                if i < input_grid.shape[0]-1:
                    surrounding_colors.add(input_grid[i+1][j])
                if j > 0:
                    surrounding_colors.add(input_grid[i][j-1])
                if j < input_grid.shape[1]-1:
                    surrounding_colors.add(input_grid[i][j+1])

                # Find pixels of the same color in the same column and paint them black
                for k in range(output_grid.shape[0]):
                    if output_grid[k][j] in surrounding_colors:
                        output_grid[k][j] = black

    # Rotate the grid back if necessary
    if num_non_black != 1:
        output_grid = np.rot90(output_grid, k=3)

    return output_grid