import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    output_grid = np.copy(input_grid)
    rows, cols = input_grid.shape

    while True:
        black_pixels = np.argwhere(output_grid == black)
        if len(black_pixels) == 0:
            break

        for i, j in black_pixels:
            neighbours = []
            if i > 0 and j < cols-1 and output_grid[i-1][j+1] != black:
                neighbours.append(output_grid[i-1][j+1])
            if i < rows-1 and j > 0 and output_grid[i+1][j-1] != black:
                neighbours.append(output_grid[i+1][j-1])

            unique_neighbours = np.unique(neighbours)
            if len(unique_neighbours) == 1:
                output_grid[i][j] = unique_neighbours[0]

    return output_grid