import numpy as np

# Define color codes
black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid):
    # Create a copy of the input grid
    output_grid = np.copy(input_grid)
    
    # Find all non-black elements in the input grid
    non_black_elements = np.where(input_grid != black)
    
    # Loop through all non-black elements
    for i in range(len(non_black_elements[0])):
        row = non_black_elements[0][i]
        col = non_black_elements[1][i]
        color = input_grid[row][col]
        
        # Check if the color is red, green, or teal
        if color == red:
            # Set the black elements around the red element to blue
            output_grid[max(row-1,0):min(row+2,output_grid.shape[0]), max(col-1,0):min(col+2,output_grid.shape[1])][output_grid[max(row-1,0):min(row+2,output_grid.shape[0]), max(col-1,0):min(col+2,output_grid.shape[1])] == black] = blue
        elif color == green:
            # Set the black elements around the green element to pink
            output_grid[max(row-1,0):min(row+2,output_grid.shape[0]), max(col-1,0):min(col+2,output_grid.shape[1])][output_grid[max(row-1,0):min(row+2,output_grid.shape[0]), max(col-1,0):min(col+2,output_grid.shape[1])] == black] = pink
        elif color == teal:
            # Set the black elements around the teal element to yellow
            output_grid[max(row-1,0):min(row+2,output_grid.shape[0]), max(col-1,0):min(col+2,output_grid.shape[1])][output_grid[max(row-1,0):min(row+2,output_grid.shape[0]), max(col-1,0):min(col+2,output_grid.shape[1])] == black] = yellow
    
    return output_grid
