import numpy as np

def main(input_grid):
    # Get the total number of columns in the input grid
    num_cols = input_grid.shape[1]
    
    # Loop through each column in the input grid
    for i in range(num_cols):
        # Calculate the total number of columns minus i
        total_minus_i = num_cols - i
        
        # Check if the result is odd
        if total_minus_i % 2 == 1:
            # Replace the grey element (int 5) in the i-th column with a green element (int 3)
            input_grid[:, i] = np.where(input_grid[:, i] == 5, 3, input_grid[:, i])
    
    # Return the transformed grid
    return input_grid