import numpy as np
black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)
def main(input_grid: np.ndarray) -> np.ndarray:
    # Find all columns with red squares
    red_columns = []
    for j in range(input_grid.shape[1]):
        if red in input_grid[:, j]:
            red_columns.append(j)
    # Calculate distance between first black square and first red square in each red column
    distances = {}
    for j in red_columns:
        black_row = np.where(input_grid[:, j] == black)[0][0]
        red_row = np.where(input_grid[:, j] == red)[0][0]
        distance = red_row - black_row
        distances[j] = distance
    # Move all red squares in red columns up by corresponding distance
    for j, distance in distances.items():
        red_rows = np.where(input_grid[:, j] == red)[0]
        for i in red_rows:
            input_grid[i-distance, j] = red
            input_grid[i, j] = black
    return input_grid
