import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find the non-black area in the grid
    non_black_areas = []
    visited = set()
    for i in range(len(input_grid)):
        for j in range(len(input_grid[0])):
            if input_grid[i][j] != black and (i, j) not in visited:
                area = set([(i, j)])
                queue = [(i, j)]
                while queue:
                    x, y = queue.pop(0)
                    visited.add((x, y))
                    if x > 0 and input_grid[x-1][y] == input_grid[i][j] and (x-1, y) not in area:
                        area.add((x-1, y))
                        queue.append((x-1, y))
                    if x < len(input_grid)-1 and input_grid[x+1][y] == input_grid[i][j] and (x+1, y) not in area:
                        area.add((x+1, y))
                        queue.append((x+1, y))
                    if y > 0 and input_grid[x][y-1] == input_grid[i][j] and (x, y-1) not in area:
                        area.add((x, y-1))
                        queue.append((x, y-1))
                    if y < len(input_grid[0])-1 and input_grid[x][y+1] == input_grid[i][j] and (x, y+1) not in area:
                        area.add((x, y+1))
                        queue.append((x, y+1))
                non_black_areas.append(area)
    
    # Exchange the two colors the area has
    color1, color2 = set([input_grid[i][j] for area in non_black_areas for i, j in area])
    for area in non_black_areas:
        for i, j in area:
            if input_grid[i][j] == color1:
                input_grid[i][j] = color2
            else:
                input_grid[i][j] = color1
    
    # Output only the non-black area after exchanging
    min_x, max_x, min_y, max_y = float('inf'), float('-inf'), float('inf'), float('-inf')
    for area in non_black_areas:
        for i, j in area:
            if i < min_x:
                min_x = i
            if i > max_x:
                max_x = i
            if j < min_y:
                min_y = j
            if j > max_y:
                max_y = j
    output_grid = np.zeros((max_x-min_x+1, max_y-min_y+1), dtype=int)
    for area in non_black_areas:
        for i, j in area:
            output_grid[i-min_x][j-min_y] = input_grid[i][j]
    return output_grid