import numpy as np
def main(input_grid):
    # Get the four corners' 2x2 grid
    upper_left = input_grid[:2, :2]
    upper_right = input_grid[:2, -2:]
    lower_left = input_grid[-2:, :2]
    lower_right = input_grid[-2:, -2:]
    # Fill the output_grid's four corners' 2x2 grid with these 2x2 grid
    output_grid = np.zeros((3, 3))
    output_grid[:2, :2] = upper_left
    output_grid[:2, -1] = upper_right[:, -1]
    output_grid[-1, :2] = lower_left[-1, :]
    output_grid[-1, -1] = lower_right[-1, -1]
    # Find non-black pixel indices in each 2x2 corner and change corresponding positions in output grid to colored
    if np.sum(upper_left) > 0:
        row, col = np.where(upper_left)
        output_grid[row, col] = 1
    if np.sum(upper_right) > 0:
        row, col = np.where(upper_right)
        output_grid[row, col+1] = 1
    if np.sum(lower_left) > 0:
        row, col = np.where(lower_left)
        output_grid[row+1, col] = 1
    if np.sum(lower_right) > 0:
        row, col = np.where(lower_right)
        output_grid[row+1, col+1] = 1
    return output_grid