import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find the color that only appears once
    unique_colors, counts = np.unique(input_grid, return_counts=True)
    unique_color = unique_colors[np.where(counts == 1)][0]
    
    # Replace all pixels except for the unique color with black
    input_grid[input_grid != unique_color] = black
    
    # Find the pixel corresponding to the unique color
    unique_pixel = np.argwhere(input_grid == unique_color)[0]
    
    # Replace the 8 pixels around the unique pixel with red
    for i in range(unique_pixel[0]-1, unique_pixel[0]+2):
        for j in range(unique_pixel[1]-1, unique_pixel[1]+2):
            if i >= 0 and i < input_grid.shape[0] and j >= 0 and j < input_grid.shape[1]:
                if input_grid[i][j] != unique_color:
                    input_grid[i][j] = red
    
    return input_grid
    