import numpy as np

(black, blue, red, green, yellow, grey, pink, orange, teal, maroon) = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Step 1
    red_x1, red_x2 = find_first_last_row(input_grid, red)
    blue_x1, blue_x2 = find_first_last_row(input_grid, blue)
    yellow_x1, yellow_x2 = find_first_last_row(input_grid, yellow)

    # Step 2
    red_y1, red_y2 = find_first_last_col(input_grid, red)
    blue_y1, blue_y2 = find_first_last_col(input_grid, blue)
    yellow_y1, yellow_y2 = find_first_last_col(input_grid, yellow)

    # Step 3
    red_grid = input_grid[red_x1:red_x2+1, red_y1:red_y2+1]
    blue_grid = input_grid[blue_x1:blue_x2+1, blue_y1:blue_y2+1]
    yellow_grid = input_grid[yellow_x1:yellow_x2+1, yellow_y1:yellow_y2+1]

    # Step 4
    output_grid = np.zeros_like(input_grid)

    # Step 5
    overlay_grid(output_grid, blue_grid, blue_x1, blue_y1)
    overlay_grid(output_grid, red_grid, blue_x1, red_y1)
    overlay_grid(output_grid, yellow_grid, blue_x1, yellow_y1)

    # Step 6
    return output_grid

def find_first_last_row(grid, color):
    rows, cols = np.where(grid == color)
    return np.min(rows), np.max(rows)

def find_first_last_col(grid, color):
    rows, cols = np.where(grid == color)
    return np.min(cols), np.max(cols)

def overlay_grid(output_grid, grid, x, y):
    h, w = grid.shape
    output_grid[x:x+h, y:y+w] = grid
    