import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Count the number of gray pixels in each column
    gray_counts = np.sum(input_grid == grey, axis=0)
    
    # Find the column with the most gray pixels and replace them with blue
    max_gray_col = np.argmax(gray_counts)
    input_grid[:, max_gray_col][input_grid[:, max_gray_col] == grey] = blue
    
    # Find the column with the least gray pixels among columns with at least one gray pixel, and replace them with red
    min_gray_col = np.argmin(np.where(gray_counts > 0, gray_counts, np.inf))
    if gray_counts[min_gray_col] > 0:
        input_grid[:, min_gray_col][input_grid[:, min_gray_col] == grey] = red
    
    # Replace all remaining gray pixels with black
    input_grid[input_grid == grey] = black
    
    return input_grid
    