import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Count the number of grey points for each column
    grey_counts = np.sum(input_grid == grey, axis=0)

    # Sort the columns by the number of grey points
    sorted_cols = np.argsort(grey_counts)[::-1]

    # Color the grey points in each column based on their rank
    for i, col in enumerate(sorted_cols):
        if i == 0:
            input_grid[input_grid[:, col] == grey, col] = blue
        elif i == 1:
            input_grid[input_grid[:, col] == grey, col] = red
        elif i == 2:
            input_grid[input_grid[:, col] == grey, col] = green
        elif i == 3:
            input_grid[input_grid[:, col] == grey, col] = yellow

    return input_grid


