import numpy as np

def main(input_grid: np.ndarray) -> np.ndarray:
    color_map = {
        0: [2,2,2], # all 2
        1: [4,4,4], # all 4
        2: [3,3,3]  # all 3
    }
    output_grid = np.zeros_like(input_grid)
    for i in range(3):
        non_black_pixels = np.where(input_grid[i] != 0)[0]
        if len(non_black_pixels) > 0:
            color = non_black_pixels[0]
            output_grid[i] = color_map[color]
    return output_grid