import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

takeball_obs = np.array([
    [[1, 1, 1, 1, 1, 1, 1, 1, 1], # wall map
    [1, 0, 0, 0, 0, 0, 0, 0, 1],
    [1, 0, 0, 0, 0, 0, 0, 0, 1],
    [1, 0, 0, 0, 0, 0, 0, 0, 1],
    [1, 0, 0, 0, 0, 0, 0, 0, 1],
    [1, 0, 0, 0, 0, 0, 0, 0, 1],
    [1, 0, 0, 0, 0, 0, 0, 0, 1],
    [1, 0, 0, 0, 0, 0, 0, 0, 1],
    [1, 1, 1, 1, 1, 1, 1, 1, 1]],
    [[0, 0, 0, 0, 0, 0, 0, 0, 0], # agent map
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 1, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0]],
    [[0, 0, 0, 0, 0, 0, 0, 0, 0], # goal map
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 1, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0]],
    [[0, 0, 0, 0, 0, 0, 0, 0, 0], # bal0 map
     [0, 1, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0]],
    [[0, 0, 0, 0, 0, 0, 0, 0, 0], # bal1 map
     [0, 0, 0, 1, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0]],
    [[0, 0, 0, 0, 0, 0, 0, 0, 0], # bal2 map
     [0, 0, 0, 0, 0, 1, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0]],
    [[0, 0, 0, 0, 0, 0, 0, 0, 0], # bal3 map
     [0, 0, 0, 0, 0, 0, 0, 1, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0]],
])
takeball_arrows = {
    "start_point": [(7,1), (1,1), (7,1), (1,3), (7,1), (1,5), (7,1), (1,7)],
    "relative_end_point": [(-6, 0), (6, 6), (-6, 2), (6, 4), (-6, 4), (6, 2), (-6, 6), (6, 0)],
    "color": ["red", "red", "blue", "blue", "green", "green", "purple", "purple"]
}
diagonal_obs = np.array([
    [[1, 1, 1, 1, 1, 1, 1, 1, 1], # wall map
    [1, 0, 0, 0, 0, 0, 0, 0, 1],
    [1, 0, 0, 0, 0, 0, 0, 0, 1],
    [1, 0, 0, 0, 0, 0, 0, 0, 1],
    [1, 0, 0, 0, 0, 0, 0, 0, 1],
    [1, 0, 0, 0, 0, 0, 0, 0, 1],
    [1, 0, 0, 0, 0, 0, 0, 0, 1],
    [1, 0, 0, 0, 0, 0, 0, 0, 1],
    [1, 1, 1, 1, 1, 1, 1, 1, 1]],
    [[0, 0, 0, 0, 0, 0, 0, 0, 0], # agent map
     [0, 1, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0]],
    [[0, 0, 0, 0, 0, 0, 0, 0, 0], # goal map
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 1, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0]],
    [[0, 0, 0, 0, 0, 0, 0, 0, 0], # bal0 map
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0]],
    [[0, 0, 0, 0, 0, 0, 0, 0, 0], # bal1 map
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0]],
    [[0, 0, 0, 0, 0, 0, 0, 0, 0], # bal2 map
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0]],
    [[0, 0, 0, 0, 0, 0, 0, 0, 0], # bal3 map
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 0]],
])
diagonal_arrows = {
    "start_point": [(1,1), (1,7), (1,1), (7,1), (1,1), (2, 3), (2, 4), (3, 4), (3, 5), (4, 5), (4, 6), (3, 2), (4, 2), (4, 3), (5, 3), (5, 4), (6, 4)],
    "relative_end_point": [(0, 6), (6, 0), (6, 0), (0, 6), (6, 6), (0, 1), (1, 0), (0, 1), (1, 0), (0, 1), (1, 0), (1, 0), (0, 1), (1, 0), (0, 1), (1, 0), (0, 1)],
    "color": ["red", "red", "blue", "blue", "green","purple", "purple", "purple", "purple", "purple", "purple", "yellow", "yellow", "yellow", "yellow", "yellow", "yellow"]
}

def save_first_frame_as_png(obs, filename="sketch.png", arrows=None):
    """
    Saves the observation as a PNG file.
    """
    grid_size = obs.shape[0]
    fig, ax = plt.subplots(figsize=(grid_size / 2, grid_size / 2)) 
    ax.set_xlim(-0.51, grid_size - 0.5)
    ax.set_ylim(grid_size -0.5, - 0.51)

    # Extract each layer from the observation
    wall_layer = obs[:, :, 0]
    agent_layer = obs[:, :, 1]
    goal_layer = obs[:, :, 2]

    # Create grid for text display
    grid = np.full((grid_size, grid_size), " ", dtype=str)
    grid[wall_layer == 1] = "#"
    grid[np.where(goal_layer == 1)] = "G"
    grid[np.where(agent_layer == 1)] = "S"
    
    # Add balls to the grid and highlight the balls
    for i in range(4):
        ball_layer = obs[:, :, 3 + i]
        ball_pos = np.argwhere(ball_layer == 1)
        if len(ball_pos) > 0:
            grid[ball_pos[0][0], ball_pos[0][1]] = str(i)
            ax.add_patch(
                patches.Rectangle(
                    (ball_pos[0][1] - 0.5, ball_pos[0][0] - 0.5), 1, 1, 
                    facecolor="green", edgecolor="black", alpha=0.3
                )
            )
            
    # Add arrows to the grid
    if not arrows is None:
        for i in range(len(arrows["start_point"])):
            ax.arrow(
                arrows["start_point"][i][1], arrows["start_point"][i][0], 
                arrows["relative_end_point"][i][1], arrows["relative_end_point"][i][0], 
                head_width=0.3,
                head_length=0.3, 
                fc=arrows["color"][i], 
                ec=arrows["color"][i],
                length_includes_head=True
            )

    # Render the grid with text
    ax.imshow(wall_layer, cmap="gray", alpha=0.3)
    for i in range(grid_size):
        for j in range(grid_size):
            ax.text(j, i, grid[i, j], ha="center", va="center", fontsize=12)
            
    x_ticks = [i - 0.5 for i in range(1, grid_size)]
    y_ticks = [i - 0.5 for i in range(1, grid_size)]
    ax.set_xticks(x_ticks)  
    ax.set_yticks(y_ticks)  
    ax.grid(color="black", linestyle="-", alpha=0.5)      
    ax.set_xticklabels([])
    ax.set_yticklabels([])            

    # Save the figure as a PNG file
    plt.savefig(filename, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved the PNG to {filename}")

obs = diagonal_obs.transpose(1, 2, 0)
save_first_frame_as_png(obs, "diagonal.png", arrows=diagonal_arrows)