import numpy as np
# Define color codes
black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)
def main(input_grid: np.ndarray) -> np.ndarray:
    # Find the different color pixel
    unique_colors, counts = np.unique(input_grid, return_counts=True)
    different_color = unique_colors[counts == 1][0]
    different_color_pixels = np.where(input_grid == different_color)
    # Get the surrounding color
    surrounding_color = input_grid[different_color_pixels[0][0], different_color_pixels[1][0] - 1]
    if surrounding_color == black:
        surrounding_color = input_grid[different_color_pixels[0][0], different_color_pixels[1][0] + 1]
    if surrounding_color == black:
        surrounding_color = input_grid[different_color_pixels[0][0] - 1, different_color_pixels[1][0]]
    if surrounding_color == black:
        surrounding_color = input_grid[different_color_pixels[0][0] + 1, different_color_pixels[1][0]]
    # Replace the different color pixel with surrounding color
    input_grid[different_color_pixels] = surrounding_color
    # Extract the 3x3 grid with the center as the different color pixel
    center_row, center_col = different_color_pixels[0][0], different_color_pixels[1][0]
    output_grid = input_grid[center_row-1:center_row+2, center_col-1:center_col+2]
    return output_grid
