import numpy as np

# Define color codes
(black, blue, red, green, yellow, grey, pink, orange, teal, maroon) = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find all colors in the image
    colors = set(np.unique(input_grid))
    # Remove black color
    colors.discard(black)
    # Get the first remaining color
    c = next(iter(colors))
    # Return corresponding array based on color
    if c == blue:
        return np.array([[0,5,0],[5,5,5],[0,5,0]])
    elif c == red:
        return np.array([[5,5,5],[0,5,0],[0,5,0]])
    elif c == green:
        return np.array([[0,0,5],[0,0,5],[5,5,5]])
    else:
        return np.array([])  # Return empty array if no valid color found
        