import numpy as np
def main(input_grid):
    # Count the number of non-black colors
    color_num = len(np.unique(input_grid)) - 1
    # Create the output grid
    output_grid = np.zeros((5*color_num, 5*color_num))
    # Fill the last row with the input grid
    output_grid[-1, :5] = input_grid
    # Shift the rows
    for i in range(output_grid.shape[0]-2, -1, -1):
        output_grid[i, 1:] = output_grid[i+1, :-1]
    return output_grid