import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find the row with the highest number of non-black pixels
    max_row = np.argmax(np.sum(input_grid != black, axis=1))
    row = input_grid[max_row]
    
    # Count the number of non-black pixels in the row
    n = np.sum(row != black)
    
    # Replace black pixels on the left of non-black pixels in the row
    for i in range(len(row)):
        if row[i] != black:
            for j in range(i-1, -1, -1):
                if row[j] == black:
                    row[j] = row[min(j+n, len(row)-1)]
                else:
                    break
    
    # Replace black pixels on the right of non-black pixels in the row
    for i in range(len(row)-1, -1, -1):
        if row[i] != black:
            for j in range(i+1, len(row)):
                if row[j] == black:
                    row[j] = row[max(j-n, 0)]
                else:
                    break
    
    # Find the column with the highest number of non-black pixels
    max_col = np.argmax(np.sum(input_grid != black, axis=0))
    col = input_grid[:, max_col]
    
    # Count the number of non-black pixels in the column
    n = np.sum(col != black)
    
    # Replace black pixels above non-black pixels in the column
    for i in range(len(col)):
        if col[i] != black:
            for j in range(i-1, -1, -1):
                if col[j] == black:
                    col[j] = col[min(j+n, len(col)-1)]
                else:
                    break
    
    # Replace black pixels below non-black pixels in the column
    for i in range(len(col)-1, -1, -1):
        if col[i] != black:
            for j in range(i+1, len(col)):
                if col[j] == black:
                    col[j] = col[max(j-n, 0)]
                else:
                    break
    
    # Update the input grid with the modified row and column
    output_grid = input_grid.copy()
    output_grid[max_row] = row
    output_grid[:, max_col] = col
    
    return output_grid
    