import os
import pickle
import argparse
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import shutil
import matplotlib.gridspec as gridspec
from matplotlib import image as mpimg
import matplotlib.colors as mcolors


map_classes = ['divider', 'ped_crossing', 'boundary']

agent_classes = ['car', 'truck', 'construction_vehicle', 'bus', 'trailer',
                 'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone']

def get_camera(scene_token, frame_idx, root_path, prepared_data, cam_key):
    # Load the prepared data 
    data = prepared_data
    #print("target scene token:" + str(scene_token) + "  " + "target frame idx:" + str(frame_idx))
    # Look for the specified scene_token and frame_idx
    for info in data['infos']:
        #print(info['scene_token'] )
        if info['scene_token'] == scene_token and info['frame_idx'] == frame_idx:
            cam_path = info['cams'][cam_key]['data_path']
            full_path  = os.path.join(root_path, cam_path)  # Full path to the source file
            #print(full_path)
            return full_path  # return the corresponding img path

    raise ValueError("未找到指定的 scene_token 和 frame_idx")

def visualize_l2_norm(bev_feat):
    heatmap = np.linalg.norm(bev_feat, axis=2)
    plt.imshow(heatmap, cmap='hot')
    plt.title("L2 Norm Heatmap")
    plt.colorbar()
    plt.savefig("bev_l2_norm.png", dpi=300, bbox_inches='tight')
    plt.clf()

def normalize_angle(angle):
    """Normalize angle to the range [-pi, pi)."""
    return np.arctan2(np.sin(angle), np.cos(angle))


def convert_lidar_yaw_to_mathematical_yaw(lidar_yaw):
    """Convert LiDAR yaw to standard mathematical yaw."""
    return normalize_angle(np.pi / 2 - lidar_yaw)

origin_yaw = convert_lidar_yaw_to_mathematical_yaw(0)

def draw_gt(
    map_colors,agent_colors,
    map_gt_lines, map_gt_labels, agent_gt_bboxes, agent_gt_labels,
    ego_plan_result, ego_plan_gt,
    ax
    ):
    # Plot map GT lines
    for line, label in zip(map_gt_lines, map_gt_labels):
        x, y = line[:, 0], line[:, 1]
        ax.plot(x, y, color=map_colors[label], linewidth=1.0, label=map_classes[label] if map_classes[label] not in ax.get_legend_handles_labels()[1] else "")

    # Plot agents' rotated bounding boxes
    for box, label in zip(agent_gt_bboxes, agent_gt_labels):
        x_center, y_center, _, x_size, y_size, _, lidar_yaw = box[:7]
        
        corners = np.array([
            [-x_size / 2, -y_size / 2],
            [x_size / 2, -y_size / 2],
            [x_size / 2, y_size / 2],
            [-x_size / 2, y_size / 2]
        ])
        
        rotation_matrix = np.array([
            [np.cos(lidar_yaw - np.pi/2), -np.sin(lidar_yaw - np.pi/2)],
            [np.sin(lidar_yaw - np.pi/2), np.cos(lidar_yaw - np.pi/2)]
        ])
        
        rotated_corners = np.dot(corners, rotation_matrix.T) + np.array([x_center, y_center])
        
        polygon = plt.Polygon(rotated_corners, closed=True, fill=None, edgecolor=agent_colors[label], linewidth=1.5, label=agent_classes[label] if agent_classes[label] not in ax.get_legend_handles_labels()[1] else "")
        ax.add_patch(polygon)
        
        arrow_dx = 1.5 * np.cos(lidar_yaw)
        arrow_dy = 1.5 * np.sin(lidar_yaw)
        ax.arrow(x_center, y_center, arrow_dx, arrow_dy, head_width=0.4, head_length=0.6, fc=agent_colors[label], ec=agent_colors[label])

    # Plot ego car bounding box
    ego_width, ego_height = 2.0, 4.0
    ego_box = Rectangle((-ego_width / 2, -ego_height / 2), ego_width, ego_height, linewidth=2, edgecolor='black', facecolor='none', label="Ego Car")
    ax.add_patch(ego_box)
    
    # draw ego trajectory
    ego_trajectory = np.cumsum(np.vstack(([0, 0], ego_plan_result)), axis=0)  # Include (0, 0) as start point
    ego_x, ego_y = ego_trajectory[:, 0], ego_trajectory[:, 1]
    ax.plot(ego_x, ego_y, color='black', linestyle='-', linewidth=2, label="Pred Trajectory")
    
    gt_trajectory = np.cumsum(np.vstack(([0, 0], ego_plan_gt)), axis=0)  # Include (0, 0) as start point
    ego_x, ego_y = gt_trajectory[:, 0], gt_trajectory[:, 1]
    ax.plot(ego_x, ego_y, color='red', linestyle='-', linewidth=2, label="GT Trajectory")


def bev_to_world(x_bevfeature, y_bevfeature, x_bev_feature_range, y_bev_feature_range, x_range=(-15,15), y_range=(30,-30)):
    x_world = x_bevfeature * (x_range[1] - x_range[0]) / x_bev_feature_range + x_range[0]
    y_world = y_bevfeature * (y_range[1] - y_range[0]) / y_bev_feature_range + y_range[0]
    return x_world, y_world

def visualize_with_corrected_yaw(data, 
                                 map_gt_lines, map_gt_labels, map_classes, agent_gt_bboxes, agent_gt_labels, agent_classes, 
                                 ego_plan_result, ego_plan_gt, 
                                 scene_token, frame_idx, root_path, prepared_data,
                                 bev_embed, selected_attention_map, each_query_attention_point_num,
                                 save_path, 
                                 x_range=(-15,15), y_range=(30,-30)):
    """
    Visualization of map lines, agents with rotated bounding boxes, and ego trajectory.
    Corrects the yaw from LiDAR coordinate system to standard x-y plane.
    
    在同一幅图中叠加 L2 Norm Heatmap 和 BEV query最大位置点。
    叠加 地图线条 和 agent包围盒。

    Args:
        map_gt_lines (List[np.ndarray]): [num_lines, 2] Map line instances.
        map_gt_labels (np.ndarray): Line type labels [num_lines], indices of `map_classes`.
        map_classes (list): List of line class names.
        agent_gt_bboxes (np.ndarray): Agent 3D bounding boxes tensor [num_agent, box_dim].
        agent_gt_labels (np.ndarray): Agent type labels [num_agent], indices of `agent_classes`.
        agent_classes (list): List of agent class names.
    """
    map_colors = ['blue', 'green', 'red']
    agent_colors = ['cyan', 'orange', 'purple', 'yellow', 'brown', 'pink', 'lime', 'magenta', 'gray', 'red']
    cam_keys = ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT',
            'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT']

    fig = plt.figure(figsize=(10, 3), dpi=100)
    gs = fig.add_gridspec(1, 3, width_ratios=[6, 1.5, 1.5])
    # 设置图像间间距为0
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
    
    # BEV 图（中间）
    ax_bev = fig.add_subplot(gs[0, 1])
    ax_bev.axis('off')
    ax_bev.set_xlim(x_range)
    ax_bev.set_ylim(y_range[::-1])
    ax_bev.set_aspect('equal')

    # BEV attention 图（最右）
    ax_attention_bev = fig.add_subplot(gs[0, 2])
    ax_attention_bev.axis('off')
    ax_attention_bev.set_xlim(x_range)
    ax_attention_bev.set_ylim(y_range[::-1])
    ax_attention_bev.set_aspect('equal')
    
    ################################################################ Get Camera Img #########################################################
    # 获取 camera 图像
    images = []
    labels = []
    for cam_key in cam_keys:
        cam_path = get_camera(scene_token, frame_idx, root_path, prepared_data, cam_key)
        img = mpimg.imread(cam_path)
        images.append(img)
        labels.append(cam_key.replace('CAM_', ''))

    # 拼接为两行三列
    row_images = []
    for i in range(2):
        row_images.append(np.hstack(images[i * 3:(i + 1) * 3]))
    merged_image = np.vstack(row_images)

    # 添加标签
    h, w, _ = merged_image.shape
    img_h, img_w, _ = images[0].shape
    
    # 左侧拼接图像
    ax_imgs = fig.add_subplot(gs[0, 0])
    ax_imgs.imshow(merged_image, aspect='auto')
    ax_imgs.axis('off')

    for i, label in enumerate(labels):
        row = i // 3
        col = i % 3
        x = col * img_w + 5      # 左上角偏右一点（5 像素）
        y = row * img_h + 15     # 上移一点放进图像内（15 像素）

        ax_imgs.text(x, y, label,
                    fontsize=8,
                    ha='left',
                    va='top',
                    color='white',
                    bbox=dict(facecolor='black', alpha=0.5, pad=1, edgecolor='none')) 
    ############################################################### Get Camera Img #########################################################


    ############################################################### heatmap (L2 norm) #########################################################
    if hasattr(bev_embed, 'cpu'):
        bev_embed = bev_embed.cpu().numpy()
    bev_embed = bev_embed.reshape(100, 100, 256)
    
    heatmap = np.linalg.norm(bev_embed, axis=2)
    # 设置参考值
    A_ref = np.mean(heatmap)
    # 避免除以 0
    epsilon = 1e-8
    heatmap_db = 10 * np.log10((heatmap + epsilon) / (A_ref + epsilon))
    
    vmin = np.min(heatmap_db)
    vmax = np.max(heatmap_db)
    norm = mcolors.TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)

    extent = [x_range[0], x_range[1], y_range[1], y_range[0]]
    im = ax_attention_bev.imshow(heatmap_db, cmap='seismic', norm=norm, extent=extent, alpha=0.6, interpolation='nearest')
    #plt.colorbar(im, ax=ax_attention_bev, location='left', pad=0.15)
    ############################################################### heatmap (L2 norm) #########################################################
    
    ############################################################### BEV 16 Query Attention Positions ###############################################################
    if hasattr(selected_attention_map, 'cpu'):
        selected_attention_map = selected_attention_map.cpu().numpy()
    
    selected_attention_map = selected_attention_map.reshape(16, 100, 100)

    for query_idx in range(16):
        query_map = selected_attention_map[query_idx].flatten()
        top_indices = np.argpartition(-query_map, each_query_attention_point_num)[:each_query_attention_point_num]
        top_indices = top_indices[np.argsort(-query_map[top_indices])]  # 按值排序

        for rank, idx in enumerate(top_indices):
            y_bev, x_bev = np.unravel_index(idx, (100, 100))  # (row, col)
            x_world, y_world = bev_to_world(x_bev, y_bev, 100.0, 100.0, x_range, y_range)
            ax_attention_bev.scatter(
                x_world, y_world,
                c='white',
                s=10,
                edgecolors='black',
                alpha=0.5,  # 设置透明度（0.0 完全透明，1.0 不透明）
                label='BEV Query Max' if (query_idx == 0 and rank == 0) else ""
            )

            #ax_attention_bev.text(x_world + 0.3, y_world, f'{query_idx}', color='white', fontsize=8, weight='bold')
    ############################################################### BEV 16 Query Max Attention Positions ###############################################################
    
    ############################################################### Draw GT Map And Trajectory ###############################################################
    draw_gt(
        map_colors,agent_colors,
        map_gt_lines, map_gt_labels, agent_gt_bboxes, agent_gt_labels,
        ego_plan_result, ego_plan_gt,
        ax_bev
    )
        
    draw_gt(
        map_colors,agent_colors,
        map_gt_lines, map_gt_labels, agent_gt_bboxes, agent_gt_labels,
        ego_plan_result, ego_plan_gt,
        ax_attention_bev
    )
    ############################################################### Draw GT Map And Trajectory ###############################################################
    
    # Legend setup
    # handles, labels = ax_attention_bev.get_legend_handles_labels()
    # by_label = dict(zip(labels, handles))
    # ax_attention_bev.legend(by_label.values(), by_label.keys(),
    #         loc='center left',        # 放在图的左中位置（配合 bbox_to_anchor 实现右侧）
    #         bbox_to_anchor=(1.0, 0.5),  # x=1.0 表示图的右边界，y=0.5 表示垂直居中
    #         fontsize=9)

    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close()


def parse_args():
    # add two parser
    # 1. `data_directory`: the dir of data to visualize, default: /opt/nvme1/zhengxj/projects/world-model-RL-store/SSR/vis_data
    # 2. `save directory`: the dir to save the visualization data, default: /opt/nvme1/zhengxj/projects/world-model-RL-store/SSR/vis/figs_gt
    parser = argparse.ArgumentParser(description="Visualize and save ground truth data.")
    parser.add_argument(
        '--data_directory', 
        type=str, 
        default='/opt/nvme1/zhengxj/projects/world-model-RL-store/SSR/vis_data',
        help='The directory containing the data to visualize.'
    )
    parser.add_argument(
        '--save_directory', 
        type=str, 
        default='/opt/nvme1/zhengxj/projects/world-model-RL-store/SSR/vis/figs_gt',
        help='The directory to save the visualization data.'
    )
    return parser.parse_args()

def main():
    args = parse_args()
    data_directory = args.data_directory
    save_directory = args.save_directory

    # Ensure the save directory exists
    os.makedirs(save_directory, exist_ok=True)
    
    root_path = "/opt/nvme0/pengyh/projects/world-model-RL/SSR/"
    file_path = os.path.join(root_path, "data/nuscenes/vad_nuscenes_infos_temporal_val.pkl")
    with open(file_path, 'rb') as f:
        prepared_data = pickle.load(f)

    # Iterate over all files in the data_directory
    for file_name in tqdm(os.listdir(data_directory), desc="Processing files"):
        #file_name = "scene_0ac05652a4c44374998be876ba5cd6fd+frame_1.pkl"
        #file_name = "scene_9d1307e95c524ca4a51e03087bd57c29+frame_37.pkl"
        if file_name.endswith('.pkl'):
            # Construct full file path, nuscenes GT data
            file_path = os.path.join(data_directory, file_name)
            # Load the pickle file
            with open(file_path, 'rb') as f:
                data = pickle.load(f)
            '''
            data keys:
                current_bev_embed
                world_model_pred_bev_embed
                current_selected_attention_map
                predict_selected_attention_map
                agent_gt_bboxes
                agent_gt_labels
                map_gt_lines
                map_gt_labels
                pred_trajectory
                gt_command
                gt_trajectory
            '''
            # Extract necessary gt_data
            map_gt_lines = data['map_gt_lines']
            map_gt_labels = data['map_gt_labels']
            agent_gt_bboxes = data['agent_gt_bboxes']
            agent_gt_labels = data['agent_gt_labels']

            # Generate the save path for the image
            save_path = os.path.join(
                save_directory, 
                file_name.replace('.pkl', '.png')
            )
            
            ego_gt_cmd = data['gt_command']
            ego_plan_result = data['pred_trajectory'][ego_gt_cmd == 1][0]
            ego_plan_gt = data['gt_trajectory']
            #print("pred_trajectory shape:", ego_plan_result.shape)

            file_name_ex = file_name.split('.')[0]
            scene_token = file_name_ex.split('+')[0].replace('scene_', '')
            frame_idx = int(file_name_ex.split('+')[1].replace('frame_', ''))
                
            # selected_attention_map = data['current_selected_attention_map']
            # bev_embed = data['current_bev_embed']
            selected_attention_map = data['predict_selected_attention_map']
            bev_embed = data['world_model_pred_bev_embed']
            
            # Call the visualization function
            visualize_with_corrected_yaw(
                data,
                map_gt_lines, map_gt_labels, map_classes, 
                agent_gt_bboxes, agent_gt_labels, agent_classes, 
                ego_plan_result, ego_plan_gt,
                scene_token, frame_idx, root_path, prepared_data,
                bev_embed, selected_attention_map, 10,
                save_path
            )
    
            

if __name__ == "__main__":
    main()