import numpy as np
from typing import Dict, List, Optional
import tensorflow as tf
from waymo_open_dataset import dataset_pb2

image_heights = [1280, 1280, 1280, 886, 886]
image_widths = [1920, 1920, 1920, 1920, 1920]

_camera2label = {
    'FRONT': 0,
    'FRONT_LEFT': 1,
    'FRONT_RIGHT': 2,
    'SIDE_LEFT': 3,
    'SIDE_RIGHT': 4,
}

_label2camera = {
    0: 'FRONT',
    1: 'FRONT_LEFT',
    2: 'FRONT_RIGHT',
    3: 'SIDE_LEFT',
    4: 'SIDE_RIGHT',
}

def load_poses_calibration(datadir):
    """
    从 Waymo 官方 TFRecord 读取：
      - 车辆位姿序列 `ego_frame_poses`  (下采样间隔=3)
      - 各相机位姿序列 `ego_cam_poses` (shape: [5, N, 4, 4]，下采样间隔=3)
      - 相机外参与内参 `extrinsics`, `intrinsics`（按 FRONT, FRONT_LEFT, FRONT_RIGHT, SIDE_LEFT, SIDE_RIGHT 顺序）

    Args:
        datadir: TFRecord 文件路径（单个文件）

    Returns:
        ego_frame_poses, ego_cam_poses, extrinsics, intrinsics
    """

    def _compute_camera_to_world_opencv(vehicle_to_world: np.ndarray, camera_to_vehicle: np.ndarray) -> np.ndarray:
        # 与 demo_waymo.py 中一致：先车辆系(FLU)→世界，再做 OpenCV 轴变换
        camera_to_world_flu = vehicle_to_world @ camera_to_vehicle
        camera_to_world_opencv = np.concatenate(
            [
                -camera_to_world_flu[:, 1:2],  # -Y
                -camera_to_world_flu[:, 2:3],  # -Z
                camera_to_world_flu[:, 0:1],   #  X
                camera_to_world_flu[:, 3:4],   #  T
            ],
            axis=1,
        )
        return camera_to_world_opencv

    camera_order = ['FRONT', 'FRONT_LEFT', 'FRONT_RIGHT', 'SIDE_LEFT', 'SIDE_RIGHT']

    # 读取标定（取第一帧）
    tfrecord_path = datadir
    dataset = tf.data.TFRecordDataset(str(tfrecord_path), compression_type='')

    intrinsics: List[Optional[Dict[str, np.ndarray]]] = [None] * 5  # type: ignore
    extrinsics: List[np.ndarray] = [None] * 5  # type: ignore
    calib_extrinsics_c2v: Dict[str, np.ndarray] = {}

    first = True
    ego_frame_poses_list: List[np.ndarray] = []
    ego_cam_poses_lists: List[List[np.ndarray]] = [[] for _ in range(5)]

    for frame_idx, data in enumerate(dataset):
        frame = dataset_pb2.Frame()
        frame.ParseFromString(bytearray(data.numpy()))

        if first:
            # 相机标定（固定不变），构造 5 个相机的内外参
            for calib in frame.context.camera_calibrations:
                name = dataset_pb2.CameraName.Name.Name(calib.name)
                if name not in camera_order:
                    continue
                idx = camera_order.index(name)
                # 内参 [fx, fy, cx, cy, k1, k2, p1, p2, (k3)]
                vals = np.array(calib.intrinsic, dtype=np.float32)
                fx, fy, cx, cy = vals[:4]
                K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32)
                dist_coeffs = np.zeros(5, dtype=np.float32)
                if len(vals) >= 8:
                    dist_coeffs[:4] = vals[4:8]
                if len(vals) >= 9:
                    dist_coeffs[4] = vals[8]
                intrinsics[idx] = {'K': K, 'dist_coeffs': dist_coeffs}

                cam_to_vehicle = np.array(calib.extrinsic.transform, dtype=np.float32).reshape(4, 4)
                calib_extrinsics_c2v[name] = cam_to_vehicle

                # 提供的 extrinsics 需转换到 OpenCV 右手坐标（保持与原工具函数一致）
                extrinsics[idx] = get_extrinsic(calib)

            # 填补任何缺失相机
            for i in range(5):
                if intrinsics[i] is None:
                    intrinsics[i] = {
                        'K': np.eye(3, dtype=np.float32),
                        'dist_coeffs': np.zeros(5, dtype=np.float32),
                    }
                if extrinsics[i] is None:
                    extrinsics[i] = np.eye(4, dtype=np.float32)
            first = False

        # 车辆位姿（世界←车辆）
        vehicle_to_world = np.array(frame.pose.transform, dtype=np.float32).reshape(4, 4)
        ego_frame_poses_list.append(vehicle_to_world)

        # 各相机的位姿（世界←相机，OpenCV 坐标）
        for cam_name in camera_order:
            if cam_name not in calib_extrinsics_c2v:
                continue
            cam_to_vehicle = calib_extrinsics_c2v[cam_name]
            cam_to_world = _compute_camera_to_world_opencv(vehicle_to_world, cam_to_vehicle)
            ego_cam_poses_lists[camera_order.index(cam_name)].append(cam_to_world)

    # 转数组并以中心平移归一
    if len(ego_frame_poses_list) == 0:
        raise ValueError(f"未能从 {tfrecord_path} 读取到任何帧数据")

    ego_frame_poses = np.asarray(ego_frame_poses_list, dtype=np.float32)
    center_point = np.mean(ego_frame_poses[:, :3, 3], axis=0)
    ego_frame_poses[:, :3, 3] -= center_point

    ego_cam_poses = [np.asarray(lst, dtype=np.float32) for lst in ego_cam_poses_lists]
    ego_cam_poses = np.asarray(ego_cam_poses, dtype=np.float32)  # [5, N, 4, 4]
    ego_cam_poses[:, :, :3, 3] -= center_point  # 与旧实现保持一致

    return ego_frame_poses, ego_cam_poses, extrinsics, intrinsics

opencv2camera = np.array([[0., 0., 1., 0.],
                          [-1., 0., 0., 0.],
                          [0., -1., 0., 0.],
                          [0., 0., 0., 1.]])


def get_extrinsic(camera_calibration):
    camera_extrinsic = np.array(camera_calibration.extrinsic.transform).reshape(4, 4)  # camera to vehicle
    extrinsic = np.matmul(camera_extrinsic, opencv2camera)  # [forward, left, up] to [right, down, forward]
    return extrinsic


def get_lane_shift_direction(ego_frame_poses, frame):
    assert frame >= 0 and frame < len(ego_frame_poses)
    R_world_from_vehicle = ego_frame_poses[frame][:3, :3].astype(np.float32)
    lateral_vehicle = np.array([0.0, 1.0, 0.0], dtype=np.float32)  # 车辆左向 +Y
    lateral_world = (R_world_from_vehicle @ lateral_vehicle).astype(np.float32)
    norm = float(np.linalg.norm(lateral_world))
    if not np.isfinite(norm) or norm < 1e-6:
        lateral_world = np.array([0.0, 1.0, 0.0], dtype=np.float32)
        norm = 1.0
    return (lateral_world / norm).astype(np.float32)
