import numpy as np
black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)
def main(input_grid: np.ndarray) -> np.ndarray:
    # Find the special color
    colors = np.unique(input_grid)
    special_color = [c for c in colors if c != grey][0]
    # Replace all special_color pixels with black
    input_grid[input_grid == special_color] = black
    # Replace grey pixels with special_color
    input_grid[input_grid == grey] = special_color
    return input_grid