"""
Open3d visualization tool box
Written by Jihan YANG
All rights preserved from 2021 - present.
"""

# import open3d
import torch
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from tqdm import tqdm
import open3d

box_colormap = [
    [1, 1, 1],
    [0, 1, 0],
    [0, 1, 1],
    [1, 1, 0],
]


def get_coor_colors(obj_labels):
    """
    Args:
        obj_labels: 1 is ground, labels > 1 indicates different instance cluster

    Returns:
        rgb: [N, 3]. color for each point.
    """
    colors = matplotlib.colors.XKCD_COLORS.values()
    max_color_num = obj_labels.max()

    color_list = list(colors)[: max_color_num + 1]
    colors_rgba = [matplotlib.colors.to_rgba_array(color) for color in color_list]
    label_rgba = np.array(colors_rgba)[obj_labels]
    label_rgba = label_rgba.squeeze()[:, :3]

    return label_rgba


def draw_scenes(
    points,
    gt_boxes=None,
    ref_boxes=None,
    ref_labels=None,
    ref_scores=None,
    point_colors=None,
    draw_origin=True,
):
    if isinstance(points, torch.Tensor):
        points = points.cpu().numpy()
    if isinstance(gt_boxes, torch.Tensor):
        gt_boxes = gt_boxes.cpu().numpy()
    if isinstance(ref_boxes, torch.Tensor):
        ref_boxes = ref_boxes.cpu().numpy()

    vis = open3d.visualization.Visualizer()
    vis.create_window()

    vis.get_render_option().point_size = 1.0
    vis.get_render_option().background_color = np.zeros(3)

    # draw origin
    if draw_origin:
        axis_pcd = open3d.geometry.TriangleMesh.create_coordinate_frame(
            size=1.0, origin=[0, 0, 0]
        )
        vis.add_geometry(axis_pcd)

    pts = open3d.geometry.PointCloud()
    pts.points = open3d.utility.Vector3dVector(points[:, :3])

    vis.add_geometry(pts)
    if point_colors is None:
        pts.colors = open3d.utility.Vector3dVector(np.ones((points.shape[0], 3)))
    else:
        pts.colors = open3d.utility.Vector3dVector(point_colors)

    if gt_boxes is not None:
        vis = draw_box(vis, gt_boxes, (0, 0, 1))

    if ref_boxes is not None:
        vis = draw_box(vis, ref_boxes, (0, 1, 0), ref_labels, ref_scores)

    vis.run()
    vis.destroy_window()


def translate_boxes_to_open3d_instance(gt_boxes):
    """
       4-------- 6
     /|         /|
    5 -------- 3 .
    | |        | |
    . 7 -------- 1
    |/         |/
    2 -------- 0
    """
    center = gt_boxes[0:3]
    lwh = gt_boxes[3:6]
    axis_angles = np.array([0, 0, gt_boxes[6] + 1e-10])
    rot = open3d.geometry.get_rotation_matrix_from_axis_angle(axis_angles)
    box3d = open3d.geometry.OrientedBoundingBox(center, rot, lwh)

    line_set = open3d.geometry.LineSet.create_from_oriented_bounding_box(box3d)

    # import ipdb; ipdb.set_trace(context=20)
    lines = np.asarray(line_set.lines)
    lines = np.concatenate([lines, np.array([[1, 4], [7, 6]])], axis=0)

    line_set.lines = open3d.utility.Vector2iVector(lines)

    return line_set, box3d


def draw_box(vis, gt_boxes, color=(0, 1, 0), ref_labels=None, score=None):
    for i in range(gt_boxes.shape[0]):
        line_set, box3d = translate_boxes_to_open3d_instance(gt_boxes[i])
        if ref_labels is None:
            line_set.paint_uniform_color(color)
        else:
            line_set.paint_uniform_color(box_colormap[ref_labels[i]])

        vis.add_geometry(line_set)

        # if score is not None:
        #     corners = box3d.get_box_points()
        #     vis.add_3d_label(corners[5], '%.2f' % score[i])
    return vis


def vis_points_in_img(img, points, projection_matrix):
    # # 取出点云数据
    X, Y, Z = points[:, 0], points[:, 1], points[:, 2]

    points = np.vstack((X, Y, Z, np.ones_like(Z)))

    # # 执行投影转换
    projected_points = projection_matrix @ points  # 矩阵相乘

    projected_points[:2] /= projected_points[2:3]  # 归一化

    # # 获取投影后的像素坐标 (u, v)
    u, v, Z = projected_points[0], projected_points[1], projected_points[2]

    # 过滤掉超出图像范围的点
    image_height, image_width, _ = img.shape
    valid_mask = (u >= 0) & (u < image_width) & (v >= 0) & (v < image_height) & (Z > 0)
    u, v, Z = u[valid_mask], v[valid_mask], Z[valid_mask]

    # 颜色编码 (根据深度 Z 值调整颜色)
    depth_normalized = (Z - Z.min()) / (Z.max() - Z.min())  # 归一化深度
    colors = plt.cm.jet(1 - depth_normalized)[:, :3]  # 伪彩色映射 (jet colormap)

    # 画出投影点
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.imshow(img.astype("uint8"))  # 显示原始图像
    ax.scatter(u, v, c=colors, s=2, marker="o", alpha=0.8)  # 叠加点云
    ax.set_title("Projected LiDAR Points onto Image")
    ax.axis("off")  # 关闭坐标轴
    plt.show()

    # plt.imshow(img.astype("uint8"))
    # plt.axis("off")
    # plt.title("使用OpenCV读取的图像")
    # plt.show()


def vis_gaussian_in_img(
    img,
    gaussian_xy,
    gaussian_scale,
    gaussian_rot,
    gaussian_label,
    pc_range=[-32.0, -32.0, 32.0, 32.0],
):

    label_color = np.array(["black", "gray", "white", "yellow", "red", "blue", "blue"])
    img_size = img.shape
    # 创建一个新的图像
    fig, ax = plt.subplots()

    # 显示图片
    cax = ax.imshow(img, cmap="gray", origin="lower")  # 使用灰度颜色映射
    plt.colorbar(cax)  # 添加颜色条

    # 定义多个椭圆的参数
    ellipses = []
    for i in range(gaussian_xy.shape[0]):
        xy = gaussian_xy[i]
        scale = gaussian_scale[i]
        rot = gaussian_rot[i]
        label = gaussian_label[i]
        center_x = (xy[0] - pc_range[0]) / (pc_range[2] - pc_range[0]) * img_size[1]
        center_y = (xy[1] - pc_range[1]) / (pc_range[3] - pc_range[1]) * img_size[0]
        scale_x = (scale[0]) / (pc_range[2] - pc_range[0]) * img_size[1]
        scale_y = (scale[1]) / (pc_range[2] - pc_range[0]) * img_size[1]
        theta = np.arctan2(rot[0], rot[1])
        color = label_color[label]

        # 创建椭圆对象
        ellipse_patch = Ellipse(
            xy=(center_y, center_x),
            width=scale_y,
            height=scale_x,
            angle=theta,
            edgecolor=color,
            facecolor="none",
            linewidth=2,
            label=label,
        )

        # 添加椭圆到图像
        ax.add_patch(ellipse_patch)

    # 设置图像属性
    ax.set_xlim(0, img_size[1])
    ax.set_ylim(0, img_size[0])
    ax.set_aspect("equal")  # 保持比例
    ax.set_title("Ellipses on 2D Plane")
    ax.set_xlabel("X-axis")
    ax.set_ylabel("Y-axis")
    ax.legend()

    # 显示图像
    plt.grid(True)
    plt.show()


def vis_bev_map(pred_bev, gt_bev):
    # 显示图像
    plt.figure(figsize=(6, 3))
    plt.subplot(1, 2, 1)
    plt.imshow(gt_bev, cmap="viridis")
    plt.title("Image 1")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(pred_bev, cmap="viridis")
    plt.title("Image 2")
    plt.axis("off")

    plt.show()


def visualize_large_voxels(voxels, voxel_size=0.5):
    coords = voxels[:, :3]
    labels = voxels[:, 3].astype(int)
    num_classes = labels.max() + 1

    # 生成颜色映射
    cmap = plt.get_cmap("tab20")
    color_map = np.array([cmap(i % 20)[:3] for i in range(num_classes)])

    voxel_meshes = []
    cube_template = open3d.geometry.TriangleMesh.create_box(
        voxel_size, voxel_size, voxel_size
    )

    for coord, label in tqdm(
        zip(coords, labels), total=len(coords), desc="Building voxel mesh"
    ):
        cube = open3d.geometry.TriangleMesh(cube_template)  # 复制模板
        cube.translate(coord.astype(float))  # float 防止 int 溢出
        cube.paint_uniform_color(color_map[label])
        voxel_meshes.append(cube)

    # 合并所有 mesh
    print("Combining all voxel meshes...")
    full_mesh = voxel_meshes[0]
    for mesh in voxel_meshes[1:]:
        full_mesh += mesh

    full_mesh.compute_vertex_normals()
    open3d.visualization.draw_geometries([full_mesh])


if __name__ == "__main__":
    vis_gaussian_in_img(None, None)
