import numpy as np

(black, blue, red, green, yellow, grey, pink, orange, teal, maroon) = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    h, w = input_grid.shape
    output_grid = np.zeros((h+2, w+2), dtype=int)
    output_grid[1:h+1, 1:w+1] = input_grid
    for j in range(1, w+1):
        output_grid[0][j] = output_grid[1][j]
        output_grid[h+1][j] = output_grid[h][j]
    for i in range(1, h+1):
        output_grid[i][0] = output_grid[i][1]
        output_grid[i][w+1] = output_grid[i][w]
    return output_grid
    