import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Constants
ROWS, COLS = 4, 3
TARGET_POSITION = (1, 2)
INITIAL_POSITION = (2, 0)
DEPOT_POSITION = (2, 0)
MAX_HEIGHT = 3

# Grid setup
grid = np.zeros((ROWS, COLS), dtype=int)

# Define a function to plot the blocks in 3D
def plot_3d_grid(grid, robot_position, target_position):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Define the surface positions
    x, y, z = [], [], []

    # Loop through each cell in the grid
    for row in range(ROWS):
        for col in range(COLS):
            height = grid[row, col]
            
            # If there are blocks, generate the top corners of the blocks
            for h in range(height):
                x.append(col)
                y.append(row)
                z.append(h)
            
            # Draw a grid line if there are blocks
            if height > 0:
                ax.bar3d(col, row, 0, width=0.8, depth=0.8, height=height, shade=True)

    # Highlight robot position
    robot_x, robot_y = robot_position[1], robot_position[0]
    ax.scatter(robot_x, robot_y, 0, color='blue', s=100, label='Robot Position')
    
    # Highlight target position
    target_x, target_y = target_position[1], target_position[0]
    ax.scatter(target_x, target_y, 0, color='green', s=100, label='Target Position')
    
    # Set labels and limits
    ax.set_xticks(np.arange(COLS))
    ax.set_yticks(np.arange(ROWS))
    ax.set_xlabel('Columns')
    ax.set_ylabel('Rows')
    ax.set_zlabel('Height')
    ax.set_title('3D Block Grid Representation')
    ax.set_xlim(0, COLS)
    ax.set_ylim(0, ROWS)
    ax.set_zlim(0, MAX_HEIGHT)

    ax.legend()
    plt.show()

# Example - Initial state: no blocks placed
plot_3d_grid(grid, INITIAL_POSITION, TARGET_POSITION)
# store in image in the local directory
plt.savefig('3d_block_grid_representation.png')