import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid):
    max_row, max_col = np.where(input_grid == grey)
    max_row, max_col = max_row.max(), max_col.max()
    min_row, min_col = np.where(input_grid == grey)
    min_row, min_col = min_row.min(), min_col.min()
    output_grid = input_grid[min_row-1:max_row+2, min_col:max_col+1]
    return output_grid