import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    output_grid = np.copy(input_grid)
    for row in range(output_grid.shape[0]):
        color = output_grid[row][0]
        for col in range(output_grid.shape[1]):
            if output_grid[row][col] == grey:
                output_grid[row][col] = color
    return output_grid