import numpy as np

# Define color codes
(black, blue, red, green, yellow, grey, pink, orange, teal, maroon) = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Step 1: Count number of different colors
    unique_colors = set(np.unique(input_grid)) - {black}
    num_colors = len(unique_colors)

    # Step 2: Fill neighboring pixels with same color
    for i in range(1, input_grid.shape[0]):
        for j in range(1, input_grid.shape[1]):
            if input_grid[i, j] != black:
                color = input_grid[i, j]
                input_grid[i-1:i+1, j-1:j+1][input_grid[i-1:i+1, j-1:j+1] == black] = color

    # Step 3: Enlarge image
    enlarged_grid = np.zeros((input_grid.shape[0]*2, input_grid.shape[1]*2), dtype=input_grid.dtype)
    for i in range(input_grid.shape[0]):
        for j in range(input_grid.shape[1]):
            enlarged_grid[2*i:2*i+2, 2*j:2*j+2] = input_grid[i, j]

    return enlarged_grid