import numpy as np

def main(input_grid):
    # Define color codes
    black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)


    # Find positions of red and blue elements
    red_positions = np.where(input_grid == red)
    blue_positions = np.where(input_grid == blue)

    # Set elements around red positions to yellow
    for i, j in zip(red_positions[0], red_positions[1]):
        if i > 0 and j > 0:
            input_grid[i-1, j-1] = yellow
        if i > 0 and j < input_grid.shape[1]-1:
            input_grid[i-1, j+1] = yellow
        if i < input_grid.shape[0]-1 and j > 0:
            input_grid[i+1, j-1] = yellow
        if i < input_grid.shape[0]-1 and j < input_grid.shape[1]-1:
            input_grid[i+1, j+1] = yellow

    # Set elements around blue positions to orange
    for i, j in zip(blue_positions[0], blue_positions[1]):
        if i > 0:
            input_grid[i-1, j] = orange
        if i < input_grid.shape[0]-1:
            input_grid[i+1, j] = orange
        if j > 0:
            input_grid[i, j-1] = orange
        if j < input_grid.shape[1]-1:
            input_grid[i, j+1] = orange

    # Output entire ndarray
    return input_grid