import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find all non-black pixels in the first row and first column
    non_black_pixels = set()
    for i in range(input_grid.shape[0]):
        if input_grid[i, 0] != black:
            non_black_pixels.add((i, 0))
        if input_grid[0, i] != black:
            non_black_pixels.add((0, i))

    # Transform pixels in bottom-right direction
    for i, j in non_black_pixels:
        initial_color = input_grid[i, j]
        color = initial_color
        while i < input_grid.shape[0] and j < input_grid.shape[1]:
            if color == yellow:
                input_grid[i, j] = yellow
                color = initial_color
            else:
                input_grid[i, j] = color
                color = yellow
            i += 1
            j += 1

    return input_grid