import os
import pickle
import argparse
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import shutil
import matplotlib.gridspec as gridspec
from matplotlib import image as mpimg
import torch
from matplotlib.gridspec import GridSpec
import matplotlib.colors as mcolors

map_classes = ['divider', 'ped_crossing', 'boundary']

agent_classes = ['car', 'truck', 'construction_vehicle', 'bus', 'trailer',
                 'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone']

# def map_agent_collection(data_directory, scene_token, frame_idx, cur_frame_data, n):
    
#     cur_frame_map_gt_lines = cur_frame_data['map_gt_lines']
#     cur_frame_map_gt_labels = cur_frame_data['map_gt_labels']
#     cur_frame_agent_gt_bboxes = cur_frame_data['agent_gt_bboxes']
#     cur_frame_agent_gt_labels = cur_frame_data['agent_gt_labels']
    
#     # 先将当前帧所有线的点统一成集合，方便查找
#     cur_points_set = set()
#     for line in cur_frame_map_gt_lines:
#         points = line.numpy() if isinstance(line, torch.Tensor) else line
#         for pt in points:
#             cur_points_set.add(tuple(np.round(pt, decimals=2)))  # 用tuple，方便集合查找
    
#     cur_frame_gt_traj = cur_frame_data['gt_trajectory']
#     cur_frame_gt_traj.append([0, 0])
#     # scene_0ac05652a4c44374998be876ba5cd6fd+frame_4.pkl file name example
#     for i in range(n):
#         target_frame = i + 1
#         file_name = "scene_" + scene_token + "+" +"frame_" + str(frame_idx + target_frame) + ".pkl"
#         file_path = os.path.join(data_directory, file_name)
#         if os.path.exists(file_path):
#             # Load the pickle file
#             with open(file_path, 'rb') as f:
#                 next_frame_data = pickle.load(f)
                
#             next_frame_map_gt_lines = next_frame_data['map_gt_lines']
#             next_frame_map_gt_labels = next_frame_data['map_gt_labels']
#             next_frame_agent_gt_bboxes = next_frame_data['agent_gt_bboxes']
#             next_frame_agent_gt_labels = next_frame_data['agent_gt_labels']
            
#             coord_offset = cur_frame_gt_traj[:i].cumsum(dim=-1)  # i帧相对于上一帧的位移offset
            
#             # 遍历下一帧的线段，处理坐标并检查是否新点
#             for line, label in zip(next_frame_map_gt_lines, next_frame_map_gt_labels):
#                 adjusted_line = line - coord_offset  # 坐标对齐当前帧

#                 new_line_points = []
#                 for pt in adjusted_line:
#                     pt_tuple = tuple(np.round(pt.numpy(), decimals=2)) if isinstance(pt, torch.Tensor) else tuple(np.round(pt, 2))
#                     if pt_tuple not in cur_points_set:
#                         new_line_points.append(pt_tuple)
#                         cur_points_set.add(pt_tuple)

#                 if new_line_points:
#                     # 如果这个线段有新点，则加入当前帧
#                     new_line_tensor = torch.tensor(new_line_points, dtype=torch.float32)
#                     cur_frame_map_gt_lines.append(new_line_tensor)
#                     cur_frame_map_gt_labels.append(label)
            
            
            
#         else:
#             break

def bev_to_world(x_bevfeature, y_bevfeature, x_bev_feature_range, y_bev_feature_range, x_range=(-15,15), y_range=(30,-30)):
    x_world = x_bevfeature * (x_range[1] - x_range[0]) / x_bev_feature_range + x_range[0]
    y_world = y_bevfeature * (y_range[1] - y_range[0]) / y_bev_feature_range + y_range[0]
    return x_world, y_world
            
def draw_gt(
    map_colors,agent_colors,
    map_gt_lines, map_gt_labels, agent_gt_bboxes, agent_gt_labels,
    ego_plan_result, ego_plan_gt,
    ax
    ):
    # Plot map GT lines
    for line, label in zip(map_gt_lines, map_gt_labels):
        x, y = line[:, 0], line[:, 1]
        ax.plot(x, y, color=map_colors[label], linewidth=1.0, label=map_classes[label] if map_classes[label] not in ax.get_legend_handles_labels()[1] else "")

    # Plot agents' rotated bounding boxes
    for box, label in zip(agent_gt_bboxes, agent_gt_labels):
        x_center, y_center, _, x_size, y_size, _, lidar_yaw = box[:7]
        
        corners = np.array([
            [-x_size / 2, -y_size / 2],
            [x_size / 2, -y_size / 2],
            [x_size / 2, y_size / 2],
            [-x_size / 2, y_size / 2]
        ])
        
        rotation_matrix = np.array([
            [np.cos(lidar_yaw - np.pi/2), -np.sin(lidar_yaw - np.pi/2)],
            [np.sin(lidar_yaw - np.pi/2), np.cos(lidar_yaw - np.pi/2)]
        ])
        
        rotated_corners = np.dot(corners, rotation_matrix.T) + np.array([x_center, y_center])
        
        polygon = plt.Polygon(rotated_corners, closed=True, fill=None, edgecolor=agent_colors[label], linewidth=1.5, label=agent_classes[label] if agent_classes[label] not in ax.get_legend_handles_labels()[1] else "")
        ax.add_patch(polygon)
        
        arrow_dx = 1.5 * np.cos(lidar_yaw)
        arrow_dy = 1.5 * np.sin(lidar_yaw)
        ax.arrow(x_center, y_center, arrow_dx, arrow_dy, head_width=0.4, head_length=0.6, fc=agent_colors[label], ec=agent_colors[label])

    # Plot ego car bounding box
    ego_width, ego_height = 2.0, 4.0
    ego_box = Rectangle((-ego_width / 2, -ego_height / 2), ego_width, ego_height, linewidth=2, edgecolor='black', facecolor='none', label="Ego Car")
    ax.add_patch(ego_box)
    
    # draw ego trajectory
    ego_trajectory = np.cumsum(np.vstack(([0, 0], ego_plan_result)), axis=0)  # Include (0, 0) as start point
    ego_x, ego_y = ego_trajectory[:, 0], ego_trajectory[:, 1]
    ax.plot(ego_x, ego_y, color='black', linestyle='-', linewidth=2, label="Pred Trajectory")
    
    gt_trajectory = np.cumsum(np.vstack(([0, 0], ego_plan_gt)), axis=0)  # Include (0, 0) as start point
    ego_x, ego_y = gt_trajectory[:, 0], gt_trajectory[:, 1]
    ax.plot(ego_x, ego_y, color='red', linestyle='-', linewidth=2, label="GT Trajectory")

def vis_world_model_close_loop(data_dir, save_path, scene_token, frame_idx, each_query_attention_point_num, n=3, x_range=(-15,15), y_range=(30,-30)):
    map_colors = ['blue', 'green', 'red']
    agent_colors = ['cyan', 'orange', 'purple', 'yellow', 'brown', 'pink', 'lime', 'magenta', 'gray', 'red']

    file_name =  f"scene_{scene_token}+frame_{frame_idx}.pkl"
    file_path = os.path.join(data_dir, file_name)
    if os.path.exists(file_path):
        # Load the pickle file
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
    else: 
        print("file_path not found:" + file_path) 
        return
    
    ego_gt_traj = data['gt_trajectory']
    explore_trajs = data['explore_trajs'][0]
    ego_gt_cmd = data['gt_command']
    ego_plan_result = data['pred_trajectory'][ego_gt_cmd == 1][0]
    ego_plan_gt = data['gt_trajectory']
    explore_predicted_bev_embeds = data['predict_explore_bev_embed'][0]
    explore_predicted_attention_map = data['predict_explore_selected_attention_map'][0]
    
    explore_num = len(explore_trajs)
    explore_cols = 4  # explore_traj每行4列
    total_cols = explore_cols + 1  # 5列：1列给ax_gt，4列给explore
    total_rows = (explore_num + explore_cols - 1) // explore_cols  # 自动行数
    fig = plt.figure(figsize=(3 * total_cols, 4 * total_rows))
    gs = GridSpec(nrows=total_rows, ncols=total_cols, width_ratios=[3, 1, 1, 1, 1], figure=fig)
    # 顶行中间放ax_gt（第1列位置），左右空一列
    ax_gt = fig.add_subplot(gs[:, 0])  # 第0行第1列（从0开始）
    ax_gt.axis('off')
    ax_gt.set_title("Ground Truth", fontsize=16)
    
    frame_number = int(frame_idx) + n
    world_model_pred_file_name =  f"scene_{scene_token}+frame_{frame_number}.pkl"
    world_model_pred_file_path = os.path.join(data_dir, world_model_pred_file_name)
    if os.path.exists(world_model_pred_file_path):
        # Load the pickle file
        with open(world_model_pred_file_path, 'rb') as f:
            world_model_frame_data = pickle.load(f)
    else:
        print("world_model_pred_file_path not found" + world_model_pred_file_path) 
        return
    
    world_model_frame_map_gt_lines = world_model_frame_data['map_gt_lines']
    world_model_frame_map_gt_labels = world_model_frame_data['map_gt_labels']
    world_model_frame_agent_gt_bboxes = world_model_frame_data['agent_gt_bboxes']
    world_model_frame_agent_gt_labels = world_model_frame_data['agent_gt_labels']
    world_model_frame_plan_gt_cmd = world_model_frame_data['gt_command']
    world_model_frame_plan_result = world_model_frame_data['pred_trajectory'][world_model_frame_plan_gt_cmd == 1][0]
    world_model_frame_plan_gt = world_model_frame_data['gt_trajectory']
    
    draw_gt(
        map_colors,agent_colors,
        world_model_frame_map_gt_lines, world_model_frame_map_gt_labels, world_model_frame_agent_gt_bboxes, world_model_frame_agent_gt_labels,
        world_model_frame_plan_result, world_model_frame_plan_gt,
        ax_gt
    )
    
    # 小图区域：从第2行开始，分布4列
    axes = []
    for i in range(explore_num):
        row = i // explore_cols
        col = (i % explore_cols) + 1  # explore_traj从第2列开始（索引1）
        ax = fig.add_subplot(gs[row, col])
        ax.axis('off')
        ax.set_title(f"Explore Traj {i}", fontsize=10)
        axes.append(ax)

    
    for i, (ax, explore_traj, explore_bev_embed, explore_attention_map) in enumerate(zip(axes, explore_trajs, explore_predicted_bev_embeds, explore_predicted_attention_map)):
        offset = explore_traj[n-1] - ego_gt_traj[n-1]   # 第n帧是world model预测的那帧
        
        # Plot map GT lines
        for line, label in zip(world_model_frame_map_gt_lines, world_model_frame_map_gt_labels):
            x, y = line[:, 0], line[:, 1]
            x = [xi - offset[0] for xi in x]
            y = [yi - offset[1] for yi in y]
            ax.plot(x, y, color=map_colors[label], linewidth=1.0, label=map_classes[label] if map_classes[label] not in ax.get_legend_handles_labels()[1] else "")
            
        # Plot agents' rotated bounding boxes
        for box, label in zip(world_model_frame_agent_gt_bboxes, world_model_frame_agent_gt_labels):
            x_center, y_center, _, x_size, y_size, _, lidar_yaw = box[:7]
            x_center -= offset[0]
            y_center -= offset[1]
            
            corners = np.array([
                [-x_size / 2, -y_size / 2],
                [x_size / 2, -y_size / 2],
                [x_size / 2, y_size / 2],
                [-x_size / 2, y_size / 2]
            ])
            
            rotation_matrix = np.array([
                [np.cos(lidar_yaw - np.pi/2), -np.sin(lidar_yaw - np.pi/2)],
                [np.sin(lidar_yaw - np.pi/2), np.cos(lidar_yaw - np.pi/2)]
            ])
            
            rotated_corners = np.dot(corners, rotation_matrix.T) + np.array([x_center, y_center])
            
            polygon = plt.Polygon(rotated_corners, closed=True, fill=None, edgecolor=agent_colors[label], linewidth=1.5, label=agent_classes[label] if agent_classes[label] not in ax.get_legend_handles_labels()[1] else "")
            ax.add_patch(polygon)
            
            arrow_dx = 1.5 * np.cos(lidar_yaw)
            arrow_dy = 1.5 * np.sin(lidar_yaw)
            ax.arrow(x_center, y_center, arrow_dx, arrow_dy, head_width=0.4, head_length=0.6, fc=agent_colors[label], ec=agent_colors[label])
            
        # plot gt car position
        ego_width, ego_height = 2.0, 4.0
        ego_box = Rectangle(
            (-offset[0] - ego_width / 2, -offset[1] - ego_height / 2),  # 新的位置
            ego_width,
            ego_height,
            linewidth=2,
            edgecolor='red',
            facecolor='none',
            label="Ego Car GT"
        )
        ax.add_patch(ego_box)
        
        # Plot ego car bounding box
        ego_width, ego_height = 2.0, 4.0
        ego_box = Rectangle((-ego_width / 2, -ego_height / 2), ego_width, ego_height, linewidth=2, edgecolor='black', facecolor='none', label="Ego Car World Model Explore")
        ax.add_patch(ego_box)
        
        # ############################################################### heatmap (L2 norm) #########################################################
        # bev_embed = explore_bev_embed
        # if hasattr(bev_embed, 'cpu'):
        #     bev_embed = bev_embed.cpu().numpy()
        # bev_embed = bev_embed.reshape(100, 100, 256)
        
        # heatmap = np.linalg.norm(bev_embed, axis=2)
        # # 设置参考值
        # A_ref = np.mean(heatmap)
        # # 避免除以 0
        # epsilon = 1e-8
        # heatmap_db = 10 * np.log10((heatmap + epsilon) / (A_ref + epsilon))
        
        # vmin = np.min(heatmap_db)
        # vmax = np.max(heatmap_db)
        # norm = mcolors.TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)

        # extent = [x_range[0], x_range[1], y_range[1], y_range[0]]
        # im = ax.imshow(heatmap_db, cmap='seismic', norm=norm, extent=extent, alpha=0.6, interpolation='nearest')
        # #plt.colorbar(im, ax=ax_attention_bev, location='left', pad=0.15)
        # ############################################################### heatmap (L2 norm) #########################################################
        
        # ############################################################### BEV 16 Query Attention Positions ###############################################################
        # selected_attention_map = explore_attention_map
        # if hasattr(selected_attention_map, 'cpu'):
        #     selected_attention_map = selected_attention_map.cpu().numpy()
        
        # selected_attention_map = selected_attention_map.reshape(16, 100, 100)

        # for query_idx in range(16):
        #     query_map = selected_attention_map[query_idx].flatten()
        #     top_indices = np.argpartition(-query_map, each_query_attention_point_num)[:each_query_attention_point_num]
        #     top_indices = top_indices[np.argsort(-query_map[top_indices])]  # 按值排序

        #     for rank, idx in enumerate(top_indices):
        #         y_bev, x_bev = np.unravel_index(idx, (100, 100))  # (row, col)
        #         x_world, y_world = bev_to_world(x_bev, y_bev, 100.0, 100.0, x_range, y_range)
        #         ax.scatter(
        #             x_world, y_world,
        #             c='white',
        #             s=10,
        #             edgecolors='black',
        #             alpha=0.5,  # 设置透明度（0.0 完全透明，1.0 不透明）
        #             label='BEV Query Max' if (query_idx == 0 and rank == 0) else ""
        #         )

        #         #ax_attention_bev.text(x_world + 0.3, y_world, f'{query_idx}', color='white', fontsize=8, weight='bold')
        # ############################################################### BEV 16 Query Max Attention Positions ###############################################################
        
        
        
    plt.tight_layout()
    print("全部 explore_traj 可视化完毕")
    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close()

def parse_args():
    # add two parser
    # 1. `data_directory`: the dir of data to visualize, default: /opt/nvme1/zhengxj/projects/world-model-RL-store/SSR/vis_data
    # 2. `save directory`: the dir to save the visualization data, default: /opt/nvme1/zhengxj/projects/world-model-RL-store/SSR/vis/figs_gt
    parser = argparse.ArgumentParser(description="Visualize and save ground truth data.")
    parser.add_argument(
        '--data_directory', 
        type=str, 
        default='/opt/nvme1/zhengxj/projects/world-model-RL-store/SSR/vis_data',
        help='The directory containing the data to visualize.'
    )
    parser.add_argument(
        '--save_directory', 
        type=str, 
        default='/opt/nvme1/zhengxj/projects/world-model-RL-store/SSR/vis/figs_gt',
        help='The directory to save the visualization data.'
    )
    parser.add_argument(
        '--scene_token', 
        type=str, 
        default='/opt/nvme1/zhengxj/projects/world-model-RL-store/SSR/vis/figs_gt',
        help='The directory to save the visualization data.'
    )
    parser.add_argument(
        '--frame_idx', 
        type=str, 
        default='/opt/nvme1/zhengxj/projects/world-model-RL-store/SSR/vis/figs_gt',
        help='The directory to save the visualization data.'
    )
    return parser.parse_args()

def main():
    args = parse_args()
    data_directory = args.data_directory
    save_directory = args.save_directory
    scene_token = args.scene_token
    frame_idx = args.frame_idx
    
    os.makedirs(save_directory, exist_ok=True)
        
    save_file_name = "scene_" + scene_token + "+" +"frame_" + str(frame_idx) + "_world_model_closed_loop" + ".png"
    save_path = os.path.join(
        save_directory, 
        save_file_name
    )

    vis_world_model_close_loop(data_directory, save_path, scene_token, frame_idx, 10)
    
    
if __name__ == "__main__":
    main()