import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    output_grid = np.copy(input_grid)
    colors = set(output_grid[0]) - {black}
    for color in colors:
        for i in range(10):
            if output_grid[0][i] == color:
                for j in range(1, 10):
                    if output_grid[j][i] == grey:
                        output_grid[j][i] = color
    while True:
        changed = False
        for i in range(10):
            for j in range(10):
                if output_grid[i][j] == grey:
                    neighbors = []
                    if i > 0 and output_grid[i-1][j] != grey and output_grid[i-1][j] != black:
                        neighbors.append(output_grid[i-1][j])
                    if i < 9 and output_grid[i+1][j] != grey and output_grid[i+1][j] != black:
                        neighbors.append(output_grid[i+1][j])
                    if j > 0 and output_grid[i][j-1] != grey and output_grid[i][j-1] != black:
                        neighbors.append(output_grid[i][j-1])
                    if j < 9 and output_grid[i][j+1] != grey and output_grid[i][j+1] != black:
                        neighbors.append(output_grid[i][j+1])
                    if neighbors:
                        output_grid[i][j] = neighbors[0]
                        changed = True
        if not changed:
            break
    return output_grid