import numpy as np
black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)
def main(input_grid: np.ndarray) -> np.ndarray:
    num = np.count_nonzero(input_grid != black) # count number of non-black pixels
    output = np.zeros_like(input_grid) # create output grid with same shape as input grid
    if num < 3: # if num is less than 3, fill left num pixels in first row with red
        output[0, :num] = red
    elif num == 4: # if num is 4, fill first row and output[1, 1] with red
        output[0] = red
        output[1, 1] = red
    return output