import numpy as np
from typing import *

(black, blue, red, green, yellow, grey, pink, orange, teal, maroon) = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Count the pixels of each non-black color
    color_counts = np.bincount(input_grid[input_grid > 0].flatten())

    # Sort the colors by count
    sorted_colors = np.argsort(color_counts)[::-1]

    # Create the output grid
    output_grid = np.zeros_like(input_grid)

    # Fill the output grid with the sorted colors
    for i, color in enumerate(sorted_colors):
        if color_counts[color] == 0:
            break
        pixels = np.where(input_grid == color)
        row = output_grid.shape[0] - 1 - i
        output_grid[row, -color_counts[color]:] = color

    return output_grid