import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as patches
from functions.env import *
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from scipy.spatial import ConvexHull
 

def plot_map(
        map: np.ndarray,
        #
        reward: np.ndarray = None,
        trajectory: np.ndarray = None,
        visit_distribution: np.ndarray = None,
        #
        title: str = None,
        save_fig: bool = False,
        namefig: str = '',
        enlarge: bool = False,
):
    """
    Plot the map of the road/environment, and optionally a reward function, a
    trajectory, or a visit distribution. Maximum n_objs=11.
    """
    assert np.max(map) <= 10

    # Number of lanes (cols, S) and lane length (rows, H)
    lane_length, num_lanes = map.shape

    # Create the figure and axis
    fig, ax = plt.subplots(figsize=(lane_length+5, num_lanes))

    # Add a rectangle to represent the road background
    road_background = patches.Rectangle((-0.5, -0.5), lane_length, num_lanes, linewidth=0, edgecolor=None, facecolor='gray')
    ax.add_patch(road_background)

    # Draw lane separators
    for lane in range(1, num_lanes):
        ax.plot([-0.5, lane_length - 0.5], [lane - 0.5, lane - 0.5], 
                color='white', linestyle='--', linewidth=4)
        
    # Plot the road boundary (solid white lines for the edges)
    ax.plot([-0.5, lane_length - 0.5], [-0.5, -0.5], color='black', linewidth=4)  # Top boundary
    ax.plot([-0.5, lane_length - 0.5], [num_lanes - 0.5, num_lanes - 0.5], color='black', linewidth=4)  # Bottom boundary

    # Define markers and colors for the objects
    marker_dict = {
        1: 'o',  # Ball
        2: '^',  # Triangle
        3: 's',   # Square
        4: '1',
        5: '2',
        6: '3',
        7: '4',
        8: '8',
        9: '*',
        10: 'X',
    }
    color_dict = {
        1: 'red',    # Ball
        2: 'green',  # Triangle
        3: 'blue',    # Square
        4: 'c',    
        5: 'm',  
        6: 'y',
        7: 'k',    
        8: 'y',  
        9: mcolors.TABLEAU_COLORS['tab:pink'],    
        10: mcolors.TABLEAU_COLORS['tab:brown'],    
    }

    # Plot each lane and position
    for lane in range(num_lanes):
        for pos in range(lane_length):
            obj = map[pos, lane]
            if obj != 0:  # If there's an object
                ax.scatter(pos, lane, 
                        marker=marker_dict[obj], 
                        s=400,  # Size of the marker
                        edgecolor='black',
                        facecolor=color_dict[obj])  # Colored marker

    # Set the limits and labels
    ax.set_xlim(-0.5, lane_length - 0.5)
    ax.set_ylim(-0.5, num_lanes - 0.5)
    ax.set_xticks(np.arange(lane_length))
    ax.set_yticks(np.arange(num_lanes))
    ax.grid(False)

    # Add labels to the left side of the lanes if only 3 lanes
    if num_lanes == 3:
        ax.text(-0.65, 0, f"L", va='center', ha='right', fontsize=14, color='black')
        ax.text(-0.65, 1, f"C", va='center', ha='right', fontsize=14, color='black')
        ax.text(-0.65, 2, f"R", va='center', ha='right', fontsize=14, color='black')

    # Add stage to the top
    pos = -0.8
    if enlarge:
        pos = -1.1
    for h in range(lane_length):
        ax.text(h+0.2, pos, f"h="+str(h+1), va='center', ha='right', fontsize=14, color='black')
        
    # plot trajectory (HxS)
    if trajectory is not None:
        for h in range(lane_length):
                s_idx = -1
                for s in range(num_lanes):
                    if trajectory[h,s] != 0:
                        s_idx = s
                        break
                ax.scatter(h, s_idx, marker='+', s=500, color='yellow', linewidth=3)
    
    # plot visit distribution state-only (HxS)
    if visit_distribution is not None:
        for h in range(lane_length):
                for s in range(num_lanes):
                    if visit_distribution[h,s] != 0:
                        plt.text(h, s, f"{visit_distribution[h,s]:.2f}", fontsize=12, ha='right', color='yellow')
         
    # plot reward function
    if reward is not None:
        for h in range(lane_length):
                for s in range(num_lanes):
                    plt.text(h, s, f"{reward[map[h,s]]:.2f}", fontsize=12, ha='right', color='yellow')
    
    if title is not None:
        pos = -0.75
        if enlarge:
            pos = -0.9
        ax.text(2.7, pos, title, va='center', ha='right', fontsize=14, color='black')

    plt.gca().invert_yaxis()  # Invert y-axis to have Lane 1 at the top
    plt.axis('off')  # Hide the axes

    # Display the plot
    plt.show()

    if save_fig:
        fig.savefig('images/'+namefig+".pdf", format="pdf", dpi=1200, bbox_inches='tight')
     
def plot_feasible_set(
    feasible_rewards: np.ndarray = None,
    r_star: list = None, 
    r_M: list = None, 
    r_m: list = None, 
    rr: list = None,
    save_fig: bool = False,
    namefig: str = '',
    reverse_1_2: bool = True,
):   
    """
    Plot the feasible set and, optionally, other reward functions.
    """     
    # plot 
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    if reverse_1_2:
        feasible_rewards = feasible_rewards[:, [0, 2, 1]]

    if len(feasible_rewards) >= 4: 
        # compute the convex hull of the feasible rewards found
        hull = ConvexHull(feasible_rewards)

        # plot the convex hull as a Poly3DCollection
        for simplex in hull.simplices:
            triangle = feasible_rewards[simplex]
            poly = Poly3DCollection([triangle], color='yellow', alpha=0.5, edgecolor='k')
            ax.add_collection3d(poly) 
             
    else: 
        ax.scatter( 
            feasible_rewards[:, 0], 
            feasible_rewards[:, 1],
            feasible_rewards[:, 2], 
            color='yellow',  
            s=1 
        ) 
         
    # plot r star  
    if r_star is not None:
        if reverse_1_2:
            r_star = r_star[[0, 1, 3, 2]]
        ax.scatter( 
            r_star[1], 
            r_star[2],
            r_star[3], 
            color='red', 
            label='$r^\star$',
            s=100
        )  
         
    # plot r M  
    if r_M is not None:
        if reverse_1_2:
            r_M = r_M[[0, 1, 3, 2]]
        # lab = '$r_{M,K}$'
        # if exact_comp:
        lab = '$r_{M}$'
        ax.scatter( 
            r_M[1], 
            r_M[2],
            r_M[3], 
            color='green', 
            label=lab,
            s=100
        )  
         
    # plot r m
    if r_m is not None:
        if reverse_1_2:
            r_m = r_m[[0, 1, 3, 2]]
        # lab = '$r_{m,K}$'
        # if exact_comp:
        lab = '$r_{m}$'
        ax.scatter( 
            r_m[1], 
            r_m[2],
            r_m[3], 
            color='blue', 
            label=lab,
            s=100
        )  
          
    # plot r
    if rr is not None:
        if reverse_1_2:
            rr = rr[[0, 1, 3, 2]]
        # lab = '$\widehat{r}$'
        # if exact_comp:
        lab = '$\widehat{r}_{\mathcal{F},g}$'
        ax.scatter( 
            rr[1], 
            rr[2],
            rr[3], 
            color='black', 
            label=lab,
            s=100
        ) 

    # label axes
    ax.set_xlabel('B')
    if reverse_1_2:
        ax.set_ylabel('S')
        ax.set_zlabel('T')
    else:
        ax.set_ylabel('T')
        ax.set_zlabel('S') 

    # legend 
    ax.legend(
    loc='upper left',          # Position of the legend
    bbox_to_anchor=(0.1, 0.8), # Precise position (x, y) relative to the axes
    fontsize=10,          # Font size of the legend text
    frameon=True,              # Draw a box around the legend
    edgecolor='black',         # Color of the box
    framealpha=0.8,            # Transparency of the box
    borderpad=0.5,             # Padding inside the legend box
    labelspacing=0.4,          # Spacing between legend entries
    handlelength=1.5,          # Length of the legend handles
    handletextpad=0.4,         # Space between the handles and the text
    )

    # set custom viewpoint
    ax.view_init(elev=20, azim=220)

    if save_fig:
        # Custom bounding box adjustment
        from matplotlib.transforms import Bbox

        # Define your custom bounds in figure coordinates (left, bottom, right, top)
        custom_bbox = Bbox([[0.95, 0.6], [5.2, 3.7]])  # Adjust values to crop the figure
        fig.savefig('images/'+namefig+".pdf", format="pdf", dpi=1200, bbox_inches=custom_bbox)

def plot_reward_seq(
        r_seq,
        save_fig: bool = False,
        namefig: str = '',
        reverse_1_2: bool = True,
        ): 
    """
    Plot a sequence of reward functions in the 3D space.
    """
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')  
    for i in range(0,len(r_seq),50):
        if reverse_1_2:
            x = np.array([r_seq[i][1],r_seq[i][3],r_seq[i][2]])
        else:
            x = r_seq[i][1:]
        ax.scatter( 
                x[0], 
                x[1],
                x[2], 
                color='blue',  
                s=50
            ) 
        
    ax.set_xlim([0,1])
    ax.set_ylim([0,1])
    ax.set_zlim([0,1]) 
    
    # label axes
    ax.set_xlabel('B')
    if reverse_1_2:
        ax.set_ylabel('S')
        ax.set_zlabel('T')
    else:
        ax.set_ylabel('T')
        ax.set_zlabel('S') 

    # set custom viewpoint
    ax.view_init(elev=20, azim=220)

    if save_fig:
        fig.savefig('images/'+namefig+".pdf", format="pdf", dpi=1200, bbox_inches='tight')

def plot_seq_delta_J(
    seq: list,
    title: str = '',
    save_fig: bool = False,
    namefig: str = '',
):
    """
    Plot a sequence of delta J values.
    """
    # plot
    plt.plot(seq)

    # adding title and labels
    plt.xlabel('Iteration $k$', fontsize=20)
    plt.ylabel(r'$\Delta J(\cdot)$', fontsize=20)

    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.grid(True)
    if title != '':
        plt.title(title)

    if save_fig:
        plt.savefig('images/'+namefig+".pdf", format="pdf", dpi=1200, bbox_inches='tight')

    # show the plot
    plt.show() 