import numpy as np
black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)
def main(input_grid: np.ndarray) -> np.ndarray:
    # Step 1: Count the color types in input_grid
    cnt = len(np.unique(input_grid))
    # Step 2: Generate a new 3*3 black grid and modify it based on cnt
    output_grid = np.zeros((3, 3), dtype=int)
    if cnt == 1:
        output_grid[0] = [grey, grey, grey]
    elif cnt == 2:
        output_grid[0][0] = grey
        output_grid[1][1] = grey
        output_grid[2][2] = grey
    elif cnt == 3:
        output_grid[0][2] = grey
        output_grid[1][1] = grey
        output_grid[2][0] = grey
    # Step 3: Return the modified grid
    return output_grid
