import numpy as np
black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)
def main(input_grid: np.ndarray) -> np.ndarray:
    # Find the 3x3 grid with color pixels
    color_grid = None
    for i in range(0, 9, 3):
        for j in range(0, 9, 3):
            if np.any(input_grid[i:i+3, j:j+3] != black):
                color_grid = input_grid[i:i+3, j:j+3]
                break
        if color_grid is not None:
            break

    # Create the output grid
    output_grid = np.zeros((9, 9), dtype=int)

    # Map the color grid to the output grid
    for i in range(3):
        for j in range(3):
            if color_grid[i][j] != black:
                output_grid[i*3:i*3+3, j*3:j*3+3] = color_grid
            else:
                output_grid[i*3:i*3+3, j*3:j*3+3] = black

    return output_grid