import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    output_grid = np.zeros((3, 3), dtype=int) # initialize output grid as a 3x3 zero grid
    
    # find the most non-black color in the grid
    non_black_colors = [color for color in range(10) if color != black]
    color_counts = [np.count_nonzero(input_grid == color) for color in non_black_colors]
    most_non_black_color = non_black_colors[np.argmax(color_counts)]
    
    # find minx and miny of all the most non-black color points
    non_black_points = np.argwhere(input_grid == most_non_black_color)
    min_x, min_y = np.min(non_black_points, axis=0)
    
    # copy all color A points in Input_grid to output
    color_a = most_non_black_color
    color_b = [color for color in non_black_colors if color != color_a and np.count_nonzero(input_grid == color) > 0][0] # color B is all colors in input grid exclude black and color A
    for i in range(min_x, min_x + 3):
        for j in range(min_y, min_y + 3):
            if input_grid[i][j] == color_a:
                output_grid[i - min_x][j - min_y] = input_grid[i][j]
            else:
                output_grid[i - min_x][j - min_y] = color_b # color the rest black part to color B
    
    return output_grid