import numpy as np

# Define color codes
black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find all not-black colors
    not_black = set(np.unique(input_grid)) - {black}

    # Initialize variables for largest rectangle area and corresponding color
    max_area = 0
    max_color = None

    # Loop through not-black colors
    for color in not_black:
        # Find indices of all cells with current color
        indices = np.where(input_grid == color)

        # Find minimum and maximum row and column indices
        min_row, max_row = np.min(indices[0]), np.max(indices[0])
        min_col, max_col = np.min(indices[1]), np.max(indices[1])

        # Calculate area of rectangle with current color as edge
        area = (max_row - min_row + 1) * (max_col - min_col + 1)

        # Update max_area and max_color if current area is larger
        if area > max_area:
            max_area = area
            max_color = color

    # Create output grid with max_color as color
    output_grid = np.full((2, 2), max_color)

    return output_grid


