import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    output_grid = np.zeros((3,3), dtype=int)

    for i in range(3):
        if len(set(input_grid[i])) == 1:
            output_grid[i] = grey
        else:
            output_grid[i] = black

    return output_grid