import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Step 1: Find the first and last non-black pixels
    x1, y1, x2, y2 = find_non_black_pixels(input_grid)

    # Step 2: Extract the subgrid
    subgrid = input_grid[x1+1:x2, y1+1:y2]

    # Step 3: Change all non-black pixels in ans_grid to color of input[x1,y1]
    ans_grid = np.zeros_like(subgrid)
    ans_grid[subgrid != black] = input_grid[x1, y1]

    # Step 4: Return the ans_grid
    return ans_grid

def find_non_black_pixels(grid):
    rows, cols = grid.shape
    x1, y1, x2, y2 = None, None, None, None

    for i in range(rows):
        for j in range(cols):
            if grid[i,j] != black:
                if x1 is None:
                    x1, y1 = i, j
                x2, y2 = i, j

    return x1, y1, x2, y2
    