import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    color_nonblack = None
    number_nonblack = 0
    
    # Find the color and number of non-black squares
    for row in input_grid:
        for square in row:
            if square != black:
                if color_nonblack is None:
                    color_nonblack = square
                number_nonblack += 1
    
    # Create the output grid
    output_grid = np.full((1, number_nonblack), color_nonblack)
    
    return output_grid