import json
import pickle
import cv2
import numpy as np
import matplotlib.pyplot as plt
import re
from pathlib import Path
import os
# ---------- 你给出的工具函数 ---------- #
def project_vehicle_to_image(vehicle_pose, calibration, points):
    """项目坐标到图像平面，返回 (u, v, ok)。"""
    import tensorflow as tf
    from waymo_open_dataset.utils import camera_model as py_camera_model_ops

    pose_matrix = np.array(vehicle_pose).reshape(4, 4)
    world_points = np.zeros_like(points)
    for i, point in enumerate(points):
        cx, cy, cz, _ = np.matmul(pose_matrix, [*point, 1])
        world_points[i] = (cx, cy, cz)

    extrinsic = tf.reshape(
        tf.constant(list(calibration.extrinsic.transform), dtype=tf.float32),
        [4, 4])
    intrinsic = tf.constant(list(calibration.intrinsic), dtype=tf.float32)
    metadata = tf.constant(
        [calibration.width, calibration.height,
         calibration.GLOBAL_SHUTTER], dtype=tf.int32)
    camera_image_metadata = list(vehicle_pose) + [0.0] * 10

    return py_camera_model_ops.world_to_image(
        extrinsic, intrinsic, metadata,
        camera_image_metadata, world_points).numpy()

def draw_points_on_image(image, points, size=12, color=(255, 0, 0)):
    for u, v in points:
        cv2.circle(image, (int(u), int(v)), size, color, -1)
    return image

# ---------- 主函数 ---------- #
def visualize(
    json_path: str,
    calib_pkl: str = "/home/fenglan/DiffusionDrive/navsim/planning/script/calibration.pkl",
    save_path = None,
):
    # 1) 读取 JSON
    with open(json_path, "r") as f:
        samples = json.load(f)
    for i in range(0,len(samples),1000):
        sample = samples[i]           
        image_paths = sample["images"]          # [front, fr, fl]
        messages    = sample["messages"]
        content     = messages[-1]["content"]   # assistant 最后一条回复

        # 2) 加载并按 fl → front → fr 排序
        order = [2, 0, 1]
        imgs  = [cv2.cvtColor(cv2.imread(image_paths[i]), cv2.COLOR_BGR2RGB)
                for i in order]

        # 3) 加载相机标定
        calib_dict = pickle.load(open(calib_pkl, "rb"))
        cam_calibs = [calib_dict[k] for k in ["fl", "front", "fr"]]

        # 4) **从 content 中提取未来轨迹坐标**
        #    回答格式为 …}\n[x1, y1], [x2, y2], ...
        pairs = re.findall(r'\[\s*([-0-9.]+)\s*,\s*([-0-9.]+)\s*\]', content)
        if not pairs:
            raise ValueError("未能在 content 中找到坐标点")
        future_xy = np.array([[float(x), float(y)] for x, y in pairs], dtype=np.float32)
        future_xyz = np.concatenate([future_xy, np.zeros((future_xy.shape[0], 1))], axis=1)

        # 5) 单位车辆姿态
        vehicle_pose = np.eye(4, dtype=np.float32).flatten()

        # 6) 投影并绘制
        drawn_imgs = []
        for img, calib in zip(imgs, cam_calibs):
            proj = project_vehicle_to_image(vehicle_pose, calib, future_xyz)
            ok   = proj[:, 2] > 0
            drawn_imgs.append(draw_points_on_image(img.copy(), proj[ok][:, :2]))

        # 7) 拼接并显示 + 文本
        concat = np.concatenate(drawn_imgs, axis=1)
        h, w, _ = concat.shape
        fig_h   = h + 100
        canvas  = np.zeros((fig_h, w, 3), dtype=np.uint8)
        canvas[:h] = concat

        plt.figure(figsize=(w / 100, fig_h / 100))
        plt.imshow(canvas)
        plt.axis("off")
        plt.text(
            w / 2, h + 10, content,
            ha="center", va="top", fontsize=8, wrap=True)
        if save_path:
            img_save_pth = os.path.join(save_path,str(i)+'.png')
            plt.savefig(img_save_pth, dpi=240, bbox_inches="tight")

# ---------- CLI ---------- #
if __name__ == "__main__":

    visualize("/home/fenglan/DiffusionDrive/navsim/planning/script/waymo_train_annotated.json","/home/fenglan/DiffusionDrive/navsim/planning/script/calibration.pkl", "./visualization")