import numpy as np

# Define color codes
black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find a 6x6 grid that contains all non-black points
    non_black_points = np.where(input_grid != black)
    min_row, max_row = np.min(non_black_points[0]), np.max(non_black_points[0])
    min_col, max_col = np.min(non_black_points[1]), np.max(non_black_points[1])
    row_range = max_row - min_row + 1
    col_range = max_col - min_col + 1
    if row_range < 6:
        min_row = max(0, min_row - (6 - row_range))
    if col_range < 6:
        min_col = max(0, min_col - (6 - col_range))
    max_row = min(max_row, min_row + 5)
    max_col = min(max_col, min_col + 5)
    six_by_six = input_grid[min_row:max_row+1, min_col:max_col+1]

    # Cut a 3x3 grid from top-left of the 6x6 grid
    output_grid = six_by_six[:3, :3]

    return output_grid


