import numpy as np

# Define colors
black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def find_block(arr):
    """
    Find the block containing four non-black elements in a 9x9 array
    """
    for i in range(0, 9, 3):
        for j in range(0, 9, 3):
            block = arr[i:i+3, j:j+3]
            if np.count_nonzero(block != black) == 4:
                return block
    return None

def main(input_grid):


    # Remove grey lines
    input_grid = np.delete(input_grid, [3, 7], axis=0)
    input_grid = np.delete(input_grid, [3, 7], axis=1)

    # Find block with four non-black elements
    block = find_block(input_grid)

    # Initialize output grid as black
    output_grid = np.full((9, 9), black)

    # Set i-th 3x3 block in output grid to color of i-th element in found block
    for i in range(3):
        for j in range(3):
            color = block[i, j]
            output_grid[i*3:i*3+3, j*3:j*3+3] = color

    # Add grey lines
    output_grid = np.insert(output_grid, [3, 6], grey, axis=0)
    output_grid = np.insert(output_grid, [3, 6], grey, axis=1)

    # Output entire ndarray
    return output_grid