import os
import datetime
import argparse
import pickle
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap
from agentdriver.utils.geometry import location_to_pixel_coordinate, pixel_coordinate_to_location
import math

def rotate_bbox(x, y, dx, dy, theta):
    # Step 1: Calculate the center of the box
    cx, cy = x, y

    # Step 2: Calculate the coordinates of the corners relative to the center
    corners = [(dx / 2, dy / 2), (dx / 2, -dy / 2), (-dx / 2, -dy / 2), (-dx / 2, dy / 2)]

    # Step 3: Rotate each corner
    rotated_corners = []
    for px, py in corners:
        x_prime = px * math.cos(theta) + py * math.sin(theta)
        y_prime = - px * math.sin(theta) + py * math.cos(theta)
        rotated_corners.append((x_prime, y_prime))            

    # Step 4 Translate the rotated corners back
    final_corners = [(cx + x_prime, cy + y_prime) for x_prime, y_prime in rotated_corners]

    # cos_rot = np.cos(theta)
    # sin_rot = np.sin(theta)
    # half_dx = dx / 2
    # half_dy = dy / 2
    # corners = [
    #     [x + half_dx * cos_rot + half_dy * sin_rot,
    #         y - half_dx * sin_rot + half_dy * cos_rot],
    #     [x + half_dx * cos_rot - half_dy * sin_rot,
    #         y - half_dx * sin_rot - half_dy * cos_rot],
    #     [x - half_dx * cos_rot - half_dy * sin_rot,
    #         y + half_dx * sin_rot - half_dy * cos_rot],
    #     [x - half_dx * cos_rot + half_dy * sin_rot,
    #         y + half_dx * sin_rot + half_dy * cos_rot],
    # ]

    return final_corners

def plot_track_traj(sample, root_path='', dpi=300, highlight_index=None, mode="location", check_function=True, save=True, show=False):
    fig, ax = plt.subplots(figsize=(8, 8), dpi=dpi)

    # Create secondary axes for overlaying location-based or pixel-based coordinates
    ax2 = ax.twinx()
    ax3 = ax.twiny()

    # assert 'objects' key in sample
    assert 'objects' in sample.keys(), "The input sample does not have 'objects' key."

    # assert 'name' keys in each object
    assert all(['name' in obj.keys() for obj in sample['objects']]
               ), "The input sample does not have 'name' key in each object."

    # assert 'bbox' keys in each object
    assert all(['bbox' in obj.keys() for obj in sample['objects']]
               ), "The input sample does not have 'bbox' key in each object."

    # assert 'traj' keys in each object
    assert all(['traj' in obj.keys() for obj in sample['objects']]
               ), "The input sample does not have 'traj' key in each object."

    for i, detection in enumerate(sample['objects']):
        name = detection['name']
        box = detection['bbox']
        traj = detection['traj']

        # Get the box parameters
        x, y, z, dx, dy, dz, rotation_z, rotation_y, rotation_x = box

        # # Get the box corners
        # cos_rot = np.cos(rotation_z)
        # sin_rot = np.sin(rotation_z)
        # half_dx = dx / 2
        # half_dy = dy / 2
        # corners = [
        #     [x + half_dx * cos_rot + half_dy * sin_rot,
        #         y - half_dx * sin_rot + half_dy * cos_rot],
        #     [x + half_dx * cos_rot - half_dy * sin_rot,
        #         y - half_dx * sin_rot - half_dy * cos_rot],
        #     [x - half_dx * cos_rot - half_dy * sin_rot,
        #         y + half_dx * sin_rot - half_dy * cos_rot],
        #     [x - half_dx * cos_rot + half_dy * sin_rot,
        #         y + half_dx * sin_rot + half_dy * cos_rot],
        # ]

        corners = rotate_bbox(x, y, dx, dy, rotation_z)

        # Convert coordinates to pixel if mode is "pixel"
        if mode == "pixel":
            x, y, _ = location_to_pixel_coordinate(x, y)
            traj = [location_to_pixel_coordinate(point[0], point[1])[
                :2] for point in traj]
            corners_pixel = []
            for corner in corners:
                x_pixel, y_pixel, valid = location_to_pixel_coordinate(
                    corner[0], corner[1])
                if not valid:
                    break
                corners_pixel.append([x_pixel, y_pixel])

            if len(corners_pixel) == 4:  # All 4 corners are valid
                corners_pixel = np.array(corners_pixel)
                ax.plot(corners_pixel[:, 0], corners_pixel[:, 1],
                        color="green" if i == highlight_index else "red")
                rect = plt.Polygon(
                    corners_pixel, fill=False, color="green" if i == highlight_index else "red")
                ax.add_patch(rect)

        elif mode == "location":
            # double check if the pixel_coordinate_to_location function and location_to_pixel_coordinate function are correct
            if check_function:
                x, y, _ = location_to_pixel_coordinate(x, y)
                x, y, _ = pixel_coordinate_to_location(x, y)

            # Directly use the corners for plotting
            if len(corners) == 4:
                corners = np.array(corners)
                ax.plot(corners[:, 0], corners[:, 1],
                        color="green" if i == highlight_index else "red")
                rect = plt.Polygon(
                    corners, fill=False, color="green" if i == highlight_index else "red")
                ax.add_patch(rect)

        # Plot trajectory
        traj_points = np.array(traj)
        ax.plot(traj_points[:, 0], traj_points[:, 1],
                "-ob", markersize=4, linewidth=2)

        # Label the detection
        ax.text(x, y, name, color="green")

    # Set limits for primary and secondary axes based on the mode
    if mode == "location":
        ax.set_xlim([-50, 50])
        ax.set_ylim([-50, 50])
        ax3.set_xlim([0, 200])
        ax2.set_ylim([0, 200])
        ax.set_xlabel('X Coordinate')
        ax.set_ylabel('Y Coordinate')
        ax3.set_xlabel("X in pixel (second '200' entry in matrix)")
        ax2.set_ylabel("Y in pixel (first '200' entry in matrix)")
    elif mode == "pixel":  # for pixel mode
        ax.set_xlim([0, 200])
        ax.set_ylim([0, 200])
        ax3.set_xlim([-50, 50])
        ax2.set_ylim([-50, 50])
        ax.set_xlabel("X in pixel (second '200' entry in matrix)")
        ax.set_ylabel("Y in pixel (first '200' entry in matrix)")
        ax3.set_xlabel('X Coordinate')
        ax2.set_ylabel('Y Coordinate')

    ax.grid(True)
    ax.set_aspect('auto')
    ax2.set_aspect('auto')
    # plt.title('BEV visualization of tracking and trajectory predictions')
    plt.tight_layout()
    if show:
        plt.show()

    if save and root_path != '':
        save_path = os.path.join(root_path, 'BEV_bbox_traj.png')
        plt.savefig(save_path, dpi=dpi)
        print("The BEV visualization of tracking and trajectory predictions is saved at {}".format(
            save_path))


def plot_track_traj_in_pixel(sample, ax, highlight_index=None, plot_traj=True, show_name=True, markersize=4, linewidth=2, bbox_color="red", traj_color="blue", text_color="green"):
    for i, detection in enumerate(sample['objects']):
        name = detection['name']
        box = detection['bbox']
        traj = detection['traj']

        # Get the box parameters
        x, y, z, dx, dy, dz, rotation_z, rotation_y, rotation_x = box

        # # Get the box corners
        # cos_rot = np.cos(rotation_z)
        # sin_rot = np.sin(rotation_z)
        # half_dx = dx / 2
        # half_dy = dy / 2
        # corners = [
        #     [x + half_dx * cos_rot + half_dy * sin_rot,
        #         y - half_dx * sin_rot + half_dy * cos_rot],
        #     [x + half_dx * cos_rot - half_dy * sin_rot,
        #         y - half_dx * sin_rot - half_dy * cos_rot],
        #     [x - half_dx * cos_rot - half_dy * sin_rot,
        #         y + half_dx * sin_rot - half_dy * cos_rot],
        #     [x - half_dx * cos_rot + half_dy * sin_rot,
        #         y + half_dx * sin_rot + half_dy * cos_rot],
        # ]

        corners = rotate_bbox(x, y, dx, dy, rotation_z)

        # Convert coordinates to pixel if mode is "pixel"
        x, y, _ = location_to_pixel_coordinate(x, y)
        traj = [location_to_pixel_coordinate(point[0], point[1])[
            :2] for point in traj]
        corners_pixel = []
        for corner in corners:
            x_pixel, y_pixel, valid = location_to_pixel_coordinate(
                corner[0], corner[1])
            if not valid:
                break
            corners_pixel.append([x_pixel, y_pixel])

        if len(corners_pixel) == 4:  # All 4 corners are valid
            corners_pixel = np.array(corners_pixel)
            ax.plot(corners_pixel[:, 0], corners_pixel[:, 1],
                    color="green" if i == highlight_index else bbox_color)
            rect = plt.Polygon(corners_pixel, fill=False,
                               color="green" if i == highlight_index else bbox_color)
            ax.add_patch(rect)

        # Plot trajectory
        if plot_traj:
            traj_points = np.array(traj)
            ax.plot(traj_points[:, 0], traj_points[:, 1], "-o",
                    markersize=markersize, linewidth=linewidth, color=traj_color)

        # Label the detection
        if show_name:
            ax.text(x, y, name, color=text_color)
            
def plot_occ(sample, root_path='', dpi=300, highlight_index=None, mode="pixel", show_track_bbox=True, save=True, show=False):

    # assert 'occupancy' key in sample
    assert 'occupancy' in sample.keys(), "The input sample does not have 'occupancy' key."

    # assert sample['occupancy'] has shape (5, 200, 200)
    assert sample['occupancy'].shape == (
        5, 200, 200), "The input sample does not have 'occupancy' key with shape (5, 200, 200)."

    # Plot Occupancy
    occ_data = sample['occupancy']
    num_timesteps = occ_data.shape[0]
    timesteps = ["Current"] + \
        [f"+{0.5*(i+1)}s" for i in range(num_timesteps - 1)]
    fig, axs = plt.subplots(1, num_timesteps, figsize=(20, 4), dpi=dpi)

    transformed_occ_data = np.zeros_like(occ_data)

    if mode == "location":
        for y in range(occ_data.shape[1]):
            for x in range(occ_data.shape[2]):
                loc_x, loc_y, _ = pixel_coordinate_to_location(x, y)
                transformed_occ_data[:, int(loc_y), int(
                    loc_x)] = occ_data[:, y, x]

    for i, timestep in enumerate(timesteps):
        if i == 0 and show_track_bbox:
            plot_track_traj_in_pixel(
                sample, axs[i], highlight_index, plot_traj=False, show_name=False, bbox_color='white', traj_color='white')
        if mode == "location":
            im = axs[i].imshow(transformed_occ_data[0, i],
                               cmap="plasma", vmin=0, vmax=1)
            axs[i].set_xlim(-50, 50)
            axs[i].set_ylim(-50, 50)
            axs[i].set_xticks(np.arange(-50, 50, 10))
            axs[i].set_yticks(np.arange(-50, 50, 10))
            axs[i].set_xlabel("X in location (meters)")
            axs[i].set_ylabel("Y in location (meters)")
        elif mode == "pixel":
            im = axs[i].imshow(occ_data[i], cmap="plasma", vmin=0, vmax=1)
            axs[i].set_xlim(0, 200)
            axs[i].set_ylim(0, 200)
            axs[i].set_xticks(np.arange(0, 201, 50))
            axs[i].set_yticks(np.arange(0, 201, 50))
            axs[i].set_xlabel("X in pixel (second '200' entry in matrix)") if i == 0 else axs[i].set_xlabel("X in pixel")
            axs[i].set_ylabel("Y in pixel (first '200' entry in matrix)") if i == 0 else axs[i].set_ylabel("Y in pixel")

        axs[i].set_title(timestep)
        axs[i].grid(which="both", linestyle="--", linewidth=0.5, color="white")

    cbar_ax = fig.add_axes([0.95, 0.15, 0.01, 0.7])
    cbar = fig.colorbar(im, cax=cbar_ax, orientation='vertical')
    cbar.set_label('Probability', rotation=270, labelpad=15)
    # plt.suptitle("Occupancy Flow in BEV Space", y=1.1)
    plt.subplots_adjust(left=0.05, right=0.93, wspace=0.3)
    if show:
        plt.show()

    if save and root_path != '':
        save_path = os.path.join(root_path, 'BEV_occ.png')
        plt.savefig(save_path, dpi=dpi)
        print("The BEV visualization of occupancy flow is saved at {}".format(save_path))

def plot_map_seg(sample, root_path='', dpi=300, highlight_index=None, mode="pixel", save=True, show=False):
    # assert 'map' key in sample
    assert 'map' in sample.keys(), "The input sample does not have 'map' key."

    # assert 'lane' key in sample['map']
    assert 'lane' in sample['map'].keys(
    ), "The input sample does not have 'lane' key in 'map'."

    # assert sample['map']['lane'] has shape (3, 200, 200)
    assert sample['map']['lane'].shape == (
        3, 200, 200), "The input sample does not have 'lane' key with shape (3, 200, 200)."

    fig, ax = plt.subplots(figsize=(8, 8), dpi=dpi)
    # Plot Lanes Combined
    lane_data = sample['map']['lane']
    lane_categories = ["Divider", "Pedestrian Crossing", "Boundary"]
    colors = plt.cm.Dark2(np.linspace(0, 1, len(lane_categories)))
    combined_image = np.ones((200, 200, 3))
    for i, category in enumerate(lane_categories):
        mask = lane_data[i]
        for c in range(3):
            combined_image[..., c] = np.where(
                mask, colors[i][c], combined_image[..., c])
    ax.imshow(combined_image)

    plot_track_traj_in_pixel(sample, ax, highlight_index, text_color='brown')

    # ax.set_title("Combined Lane Visualization with Objects in BEV")
    ax.set_xlim([0, 200])
    ax.set_ylim([0, 200])
    ax.set_xticks(np.arange(0, 201, 50))
    ax.set_yticks(np.arange(0, 201, 50))
    ax.set_xlabel("X in pixel (second '200' entry in matrix)")
    ax.set_ylabel("Y in pixel (first '200' entry in matrix)")
    ax.grid(which="both", linestyle="--", linewidth=0.5)
    patches_list = [plt.Line2D([0], [0], marker='o', color='w',
                               markerfacecolor=color, markersize=10) for color in colors]
    ax.legend(handles=patches_list, labels=lane_categories, loc='upper right')
    plt.tight_layout()
    if show:
        plt.show()

    if save and root_path != '':
        save_path = os.path.join(root_path, 'BEV_lane.png')
        plt.savefig(save_path, dpi=dpi)
        print("The BEV visualization of lane is saved at {}".format(save_path))
        
def plot_drivable_area(sample, root_path='', dpi=300, highlight_index=None, mode="pixel", save=True, show=False):
    # assert 'map' key in sample
    assert 'map' in sample.keys(), "The input sample does not have 'map' key."

    # assert 'drivable' key in sample['map']
    assert 'drivable' in sample['map'].keys(
    ), "The input sample does not have 'drivable' key in 'map'."

    # assert sample['map']['drivable'] has shape (200, 200)
    assert sample['map']['drivable'].shape == (
        200, 200), "The input sample does not have 'drivable' key with shape (200, 200)."

    # Plot Drivable Map
    drivable_data = sample['map']['drivable']
    fig, ax = plt.subplots(figsize=(8, 8), dpi=dpi)
    cmap = ListedColormap(['white', 'orange'])
    ax.imshow(drivable_data, cmap=cmap)

    # Visualize bounding boxes and trajectories on top of drivable map
    plot_track_traj_in_pixel(sample, ax, highlight_index)

    ax.set_xlim(0, 200)
    ax.set_ylim(0, 200)
    ax.set_xticks(np.arange(0, 201, 50))
    ax.set_yticks(np.arange(0, 201, 50))
    ax.set_xlabel("X in pixel (column in matrix)")
    ax.set_ylabel("Y in pixel (row in matrix)")
    # ax.set_title("Drivable Area with Objects in BEV")
    ax.grid(which="both", linestyle="--", linewidth=0.5)
    plt.tight_layout()
    non_drivable_patch = mpatches.Patch(color='white', label='Non-Drivable')
    drivable_patch = mpatches.Patch(color='orange', label='Drivable')
    
    # Add the legend
    ax.legend(handles=[non_drivable_patch, drivable_patch], loc='upper right')
    if show:
        plt.show()

    if save and root_path != '':
        save_path = os.path.join(root_path, 'BEV_drivable.png')
        plt.savefig(save_path, dpi=dpi)
        print("The BEV visualization of drivable area is saved at {}".format(save_path))
        
def plot_all(sample, root_path='', save=True, show=False, dpi=300, highlight_index=None):
    # get the root path
    if '__file__' in globals():
        # This will be used in a .py script
        root_path = os.path.dirname(os.path.realpath(__file__))
    else:
        # This will be used in a Jupyter notebook
        root_path = os.path.abspath('')

    # create the visualization folder with datetime
    if save:
        root_path = os.path.join(
            root_path, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
        os.makedirs(root_path, exist_ok=True)
        print("The visualization folder is created at {}".format(root_path))

    plot_track_traj(sample, dpi=dpi, highlight_index=highlight_index,
                    mode="location", show=show, save=save, root_path=root_path)
    plot_occ(sample, dpi=dpi, highlight_index=highlight_index, mode="pixel",
             show_track_bbox=True, show=show, save=save,  root_path=root_path)
    plot_map_seg(sample, dpi=dpi, highlight_index=highlight_index,
                 show=show, save=save,  root_path=root_path)
    plot_drivable_area(sample, dpi=dpi, highlight_index=highlight_index,
                       show=show, save=save,  root_path=root_path)
    
if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--sample_path", help="Path to the sample file to be visualized.", default="data/train/0a0d1f7700da446580874d7d1e9fce51.pkl")
    parser.add_argument("--save", action="store_true", default=True,
                        help="Whether to save the visualization.")
    parser.add_argument("--show", action="store_true", default=False, help="Whether to show the visualization.")
    parser.add_argument("--dpi", type=int, default=300,
                        help="DPI of the visualization.")
    parser.add_argument("--highlight_index", type=int, default=None, help="The index of the object to be highlighted.")
    args = parser.parse_args()
    
    assert args.sample_path is not None, "Please provide the path to the sample file to be visualized."
    
    #assert sample_path is a pickle file
    assert args.sample_path.endswith('.pkl'), "Please provide a pickle file."
    
    with open(args.sample_path, 'rb') as f:
        sample = pickle.load(f)
        
    plot_all(sample, save=args.save, show=args.show, dpi=args.dpi, highlight_index=args.highlight_index)