import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find the indices of all grey and red pixels
    grey_indices = np.argwhere(input_grid == grey)
    red_indices = np.argwhere(input_grid == red)
    
    # For each grey pixel, find the nearest red pixel and move it there
    for grey_index in grey_indices:
        nearest_red_index = find_nearest_red(grey_index, red_indices)
        input_grid = move_pixel(grey_index, nearest_red_index, input_grid)
    
    # Remove the grey pixels that are not next to the red pixels
    for grey_index in np.argwhere(input_grid == grey):
        if not is_next_to_red(grey_index, input_grid):
            input_grid[grey_index[0], grey_index[1]] = black
    
    return input_grid

def find_nearest_red(grey_index, red_indices):
    # Find the distance between the grey pixel and each red pixel
    distances = np.linalg.norm(red_indices - grey_index, axis=1)
    # Find the index of the nearest red pixel
    nearest_red_index = red_indices[np.argmin(distances)]
    return nearest_red_index

def move_pixel(start_index, end_index, grid):
    # Move the pixel from start_index to end_index
    while start_index[0] != end_index[0] or start_index[1] != end_index[1]:
        # Move horizontally
        if start_index[0] == end_index[0]:
            if start_index[1] < end_index[1]:
                start_index[1] += 1
            else:
                start_index[1] -= 1
        # Move vertically
        elif start_index[1] == end_index[1]:
            if start_index[0] < end_index[0]:
                start_index[0] += 1
            else:
                start_index[0] -= 1
        # Move diagonally
        else:
            if start_index[0] < end_index[0]:
                if start_index[1] < end_index[1]:
                    start_index[0] += 1
                    start_index[1] += 1
                else:
                    start_index[0] += 1
                    start_index[1] -= 1
            else:
                if start_index[1] < end_index[1]:
                    start_index[0] -= 1
                    start_index[1] += 1
                else:
                    start_index[0] -= 1
                    start_index[1] -= 1
        # Check if the pixel has reached the next red pixel
        if grid[start_index[0], start_index[1]] == red:
            break
        # Update the grid
        grid[start_index[0], start_index[1]] = grey
    return grid

def is_next_to_red(grey_index, grid):
    # Check if the grey pixel is next to a red pixel
    for i in range(grey_index[0]-1, grey_index[0]+2):
        for j in range(grey_index[1]-1, grey_index[1]+2):
            if i >= 0 and i < grid.shape[0] and j >= 0 and j < grid.shape[1]:
                if grid[i, j] == red:
                    return True
    return False
