import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
from waymo_open_dataset import dataset_pb2, label_pb2
from scipy.spatial.transform import Rotation as R
from scipy.spatial.transform import Slerp
from .scheduler import NovelPoseScheduler

try:
    # 与参考代码保持一致的依赖方式
    from webdataset import WebDataset, non_empty
except Exception:
    WebDataset = None  # type: ignore
    non_empty = None   # type: ignore
from waymo_open_dataset import dataset_pb2, label_pb2
from scipy.spatial.transform import Rotation as R
from scipy.spatial.transform import Slerp
from .scheduler import NovelPoseScheduler


def _get_sample(url: Union[str, Path]) -> Dict[str, Any]:
    """
    简化版的 get_sample，实现参考 cosmos-drive-dreams-toolkits/utils/wds_utils.py。
    仅用于从 WebDataset tar 文件中读取一个 sample。
    """
    if WebDataset is None or non_empty is None:
        raise ImportError(
            "[InsertionInfoProvider] 需要安装 webdataset 才能读取 all_object_info_insertion。"
        )

    if isinstance(url, Path):
        url = url.as_posix()

    ds = WebDataset(url, nodesplitter=non_empty, workersplitter=None, shardshuffle=False).decode()
    sample = next(iter(ds))
    if not isinstance(sample, dict):
        # WebDataset 通常返回 dict，这里做一次保护
        raise RuntimeError(f"[InsertionInfoProvider] WebDataset 返回非字典类型: {type(sample)}")
    return sample


def object_tfm_to_heading(tfm: Any) -> np.ndarray:
    if isinstance(tfm, list):
        tfm = np.array(tfm)
    heading_vector = tfm[:3, 0]
    heading_norm = np.linalg.norm(heading_vector) + 1e-8
    return heading_vector / heading_norm


def _orthonormalize_rotation(R_in: np.ndarray) -> np.ndarray:
    # 使用 SVD 将 3x3 矩阵正交化，保证右手系
    U, _, Vt = np.linalg.svd(R_in)
    R_ortho = U @ Vt
    if np.linalg.det(R_ortho) < 0:
        U[:, -1] *= -1
        R_ortho = U @ Vt
    return R_ortho


def fix_static_objects(all_object_info: Dict[str, Any]) -> Dict[str, Any]:
    # 1) 统计静态物体的 lwh
    static_tracking_id_to_lwhs: Dict[str, List[List[float]]] = {}
    for frame_key, object_info_dict in all_object_info.items():
        if str(frame_key).startswith('__'):
            continue
        for tracking_id, object_info in object_info_dict.items():
            if not object_info.get('object_is_moving', True):
                static_tracking_id_to_lwhs.setdefault(tracking_id, []).append(object_info['object_lwh'])

    static_tracking_id_to_mean_lwh: Dict[str, np.ndarray] = {}
    for tracking_id, lwhs in static_tracking_id_to_lwhs.items():
        static_tracking_id_to_mean_lwh[tracking_id] = np.mean(np.array(lwhs, dtype=np.float32), axis=0)

    # 更新静态物体的 lwh
    for frame_key, object_info_dict in all_object_info.items():
        if str(frame_key).startswith('__'):
            continue
        for tracking_id, object_info in object_info_dict.items():
            if (not object_info.get('object_is_moving', True)) and (tracking_id in static_tracking_id_to_mean_lwh):
                object_info['object_lwh'] = static_tracking_id_to_mean_lwh[tracking_id].tolist()

    # 2) 统计静态物体的位姿与朝向
    static_tracking_id_to_tfms: Dict[str, List[np.ndarray]] = {}
    static_tracking_id_to_headings: Dict[str, List[np.ndarray]] = {}
    for frame_key, object_info_dict in all_object_info.items():
        if str(frame_key).startswith('__'):
            continue
        for tracking_id, object_info in object_info_dict.items():
            if not object_info.get('object_is_moving', True):
                static_tracking_id_to_tfms.setdefault(tracking_id, []).append(np.array(object_info['object_to_world'], dtype=np.float32))
                static_tracking_id_to_headings.setdefault(tracking_id, []).append(object_tfm_to_heading(object_info['object_to_world']))

    # 计算静态物体的平均朝向，用于去除离群
    static_tracking_id_to_mean_heading: Dict[str, np.ndarray] = {}
    for tracking_id, headings in static_tracking_id_to_headings.items():
        mean_heading = np.mean(np.stack(headings, axis=0), axis=0)
        norm = np.linalg.norm(mean_heading) + 1e-8
        static_tracking_id_to_mean_heading[tracking_id] = mean_heading / norm

    # 去除离群，并对旋转做正交化
    threshold = 0.7
    static_tracking_id_to_tfms_remove_outlier: Dict[str, List[np.ndarray]] = {}
    for tracking_id, tfms in static_tracking_id_to_tfms.items():
        for tfm in tfms:
            heading = object_tfm_to_heading(tfm)
            if float(np.dot(heading, static_tracking_id_to_mean_heading[tracking_id])) > threshold:
                tfm_ortho = np.array(tfm, dtype=np.float32)
                tfm_ortho[:3, :3] = _orthonormalize_rotation(tfm_ortho[:3, :3])
                static_tracking_id_to_tfms_remove_outlier.setdefault(tracking_id, []).append(tfm_ortho)

    # 计算平均位姿：平移取均值，旋转由平均的前/上方向构造
    static_tracking_id_to_mean_tfm: Dict[str, np.ndarray] = {}
    for tracking_id, tfms in static_tracking_id_to_tfms_remove_outlier.items():
        if len(tfms) == 0:
            continue
        translations = np.stack([tfm[:3, 3] for tfm in tfms], axis=0)
        translation_mean = np.mean(translations, axis=0)

        front_dirs = np.stack([tfm[:3, 0] for tfm in tfms], axis=0)
        up_dirs = np.stack([tfm[:3, 2] for tfm in tfms], axis=0)
        front_dir_mean = np.mean(front_dirs, axis=0)
        up_dir_mean = np.mean(up_dirs, axis=0)
        front_dir_mean /= (np.linalg.norm(front_dir_mean) + 1e-8)
        up_dir_mean /= (np.linalg.norm(up_dir_mean) + 1e-8)
        left_dir_mean = np.cross(up_dir_mean, front_dir_mean)
        left_dir_mean /= (np.linalg.norm(left_dir_mean) + 1e-8)

        rot_mean = np.stack([front_dir_mean, left_dir_mean, up_dir_mean], axis=1)
        rot_mean = _orthonormalize_rotation(rot_mean)

        mean_tfm = np.eye(4, dtype=np.float32)
        mean_tfm[:3, :3] = rot_mean
        mean_tfm[:3, 3] = translation_mean
        static_tracking_id_to_mean_tfm[tracking_id] = mean_tfm

    # 更新静态物体位姿
    for frame_key, object_info_dict in all_object_info.items():
        if str(frame_key).startswith('__'):
            continue
        for tracking_id, object_info in object_info_dict.items():
            if (not object_info.get('object_is_moving', True)) and (tracking_id in static_tracking_id_to_mean_tfm):
                object_info['object_to_world'] = static_tracking_id_to_mean_tfm[tracking_id].tolist()

    return all_object_info


def interpolate_pose(prev_pose: Any, next_pose: Any, t: float) -> Any:
    input_is_list = isinstance(prev_pose, list)
    prev_pose = np.array(prev_pose, dtype=np.float32)
    next_pose = np.array(next_pose, dtype=np.float32)

    prev_translation = prev_pose[:3, 3]
    next_translation = next_pose[:3, 3]
    translation = (1.0 - t) * prev_translation + t * next_translation

    prev_rotation = R.from_matrix(prev_pose[:3, :3])
    next_rotation = R.from_matrix(next_pose[:3, :3])
    times = [0, 1]
    rotations = R.from_quat([prev_rotation.as_quat(), next_rotation.as_quat()])
    rotation = Slerp(times, rotations)(t)

    new_pose = np.eye(4, dtype=np.float32)
    new_pose[:3, :3] = rotation.as_matrix().astype(np.float32)
    new_pose[:3, 3] = translation.astype(np.float32)

    if input_is_list:
        return new_pose.tolist()
    else:
        return new_pose


def interpolate_bbox(all_object_info: Dict[str, Any], valid_frame_ids: List[int]) -> Dict[str, Any]:
    interpolated_all_object_info: Dict[str, Any] = {}

    max_valid = max(valid_frame_ids) if len(valid_frame_ids) > 0 else -1
    existing_keys = set(all_object_info.keys())

    for frame_id in valid_frame_ids:
        frame_key = f"{frame_id:06d}.all_object_info.json"
        if frame_key in existing_keys:
            interpolated_all_object_info[frame_key] = all_object_info[frame_key]
            continue

        prev_frame_id = frame_id
        next_frame_id = frame_id
        while f"{prev_frame_id:06d}.all_object_info.json" not in existing_keys and prev_frame_id >= 0:
            prev_frame_id -= 1
        while f"{next_frame_id:06d}.all_object_info.json" not in existing_keys and next_frame_id <= max_valid:
            next_frame_id += 1

        if next_frame_id > max_valid and prev_frame_id >= 0:
            interpolated_all_object_info[frame_key] = all_object_info[f"{prev_frame_id:06d}.all_object_info.json"]
            continue

        prev_key = f"{prev_frame_id:06d}.all_object_info.json"
        next_key = f"{next_frame_id:06d}.all_object_info.json"
        prev_object_info = all_object_info.get(prev_key, {})
        next_object_info = all_object_info.get(next_key, {})

        prev_tracking_ids = set(prev_object_info.keys())
        next_tracking_ids = set(next_object_info.keys())
        common_tracking_ids = prev_tracking_ids & next_tracking_ids

        denom = max(1, (next_frame_id - prev_frame_id))
        t = float(frame_id - prev_frame_id) / float(denom)

        interpolated_object_info: Dict[str, Any] = {}
        for tracking_id in common_tracking_ids:
            prev_pose = prev_object_info[tracking_id]['object_to_world']
            next_pose = next_object_info[tracking_id]['object_to_world']
            interp_pose = interpolate_pose(prev_pose, next_pose, t)

            prev_lwh = np.array(prev_object_info[tracking_id]['object_lwh'], dtype=np.float32)
            next_lwh = np.array(next_object_info[tracking_id]['object_lwh'], dtype=np.float32)
            interp_lwh = (1.0 - t) * prev_lwh + t * next_lwh

            interpolated_object_info[tracking_id] = {
                'object_to_world': (interp_pose.tolist() if isinstance(interp_pose, np.ndarray) else interp_pose),
                'object_lwh': interp_lwh.tolist(),
                'object_is_moving': prev_object_info[tracking_id]['object_is_moving'],
                'object_type': prev_object_info[tracking_id]['object_type'],
            }

        interpolated_all_object_info[frame_key] = interpolated_object_info

    return interpolated_all_object_info


class ObjectInfoProcessor:
    """
    仅负责加载与处理 all_object_info 的工具类，风格对齐 lib_render 其它组件。

    - 复用 NovelPoseScheduler 的 num_frames 以避免重复读取位姿。
    - 仅实现骨架，外部实际读取与插值函数用占位实现。
    """

    def __init__(
        self,
        input_root: Path,
        clip_id: str,
        num_frames: int,
    ) -> None:
        self.input_root = Path(input_root)
        self.clip_id = str(clip_id)
        self._num_frames = num_frames

        # 计算 render_frame_ids：默认直接 10Hz -> 30Hz（step=3）
        step = 3

        # 目标为 30Hz 全帧（步长为 1），总帧数放大 step 倍
        render_frame_ids = list(range(0, self._num_frames * step, 1))
        self._render_frame_ids: List[int] = render_frame_ids
        self._key_frame_ids: List[int] = list(range(0, self._num_frames* step, step))

        tfrecord_path = os.path.join(str(self.input_root), f"segment-{self.clip_id}_with_camera_labels.tfrecord")

        info = self._read_objects_from_tfrecord(tfrecord_path=tfrecord_path, clip_id=self.clip_id, index_scale_ratio=step)
        info = fix_static_objects(info)
        info = interpolate_bbox(info, self._render_frame_ids)
        self._all_object_info: Dict[str, Any] = info

    def get_object_info(self) -> Dict[str, Any]:
        return self._all_object_info

    def get_render_frame_ids(self) -> List[int]:
        return self._render_frame_ids

    @staticmethod
    def _read_objects_from_tfrecord(tfrecord_path: str, clip_id: str, index_scale_ratio: int) -> Dict[str, Any]:
        """
        从 Waymo 原始 TFRecord 中解析目标框信息，产出与原 all_object_info 相似的字典结构。
        - index_scale_ratio 用于将 10Hz 索引映射到更高帧率索引空间（如 30Hz）。
        - 若依赖缺失（tf/waymo/scipy），返回仅包含键的占位结构。

        v2: 两遍读取，第一遍统计车辆位姿中心点并做与调度器一致的归一化；
            第二遍按归一化世界坐标返回 object_to_world。
        """

        if not Path(tfrecord_path).exists():
            print(f"[ObjectInfoProcessor] 错误: TFRecord 文件未找到: {tfrecord_path}")
            return {"__key__": clip_id}

        try:
            dataset = tf.data.TFRecordDataset(str(tfrecord_path), compression_type="")

            # 第一遍：缓存帧数据并统计车辆位姿，以计算中心点
            all_frames_data: List[bytes] = []
            all_vehicle_poses: List[np.ndarray] = []
            for data in dataset:
                frame_bytes = bytearray(data.numpy())
                all_frames_data.append(frame_bytes)
                fr = dataset_pb2.Frame()
                fr.ParseFromString(frame_bytes)
                vehicle_to_world_abs = np.array(fr.pose.transform, dtype=np.float32).reshape((4, 4))
                all_vehicle_poses.append(vehicle_to_world_abs)

            if len(all_vehicle_poses) == 0:
                print(f"[ObjectInfoProcessor] 错误: TFRecord 文件为空: {tfrecord_path}")
                return {"__key__": clip_id}

            all_vehicle_poses_np = np.asarray(all_vehicle_poses, dtype=np.float32)
            center_point = np.mean(all_vehicle_poses_np[:, :3, 3], axis=0)

            world_abs_to_norm_tfm = np.eye(4, dtype=np.float32)
            world_abs_to_norm_tfm[:3, 3] = -center_point

            # 第二遍：解析物体并应用归一化
            sample: Dict[str, Any] = {"__key__": clip_id}
            min_moving_speed = 0.2

            for frame_idx, frame_bytes in enumerate(all_frames_data):
                frame = dataset_pb2.Frame()
                frame.ParseFromString(frame_bytes)

                vehicle_to_world_abs = all_vehicle_poses[frame_idx]

                target_frame_key = f"{frame_idx * max(1, int(index_scale_ratio)) :06d}.all_object_info.json"
                frame_objects_info: Dict[str, Any] = {}

                for label in frame.laser_labels:
                    lb_type = getattr(label, "type", None)
                    if lb_type is None:
                        continue
                    if lb_type not in (
                        getattr(label_pb2.Label.Type, "TYPE_VEHICLE", -1),
                        getattr(label_pb2.Label.Type, "TYPE_PEDESTRIAN", -1),
                        getattr(label_pb2.Label.Type, "TYPE_CYCLIST", -1),
                    ):
                        continue

                    box = getattr(label, "camera_synced_box", None)
                    if box is None or (hasattr(box, "ByteSize") and box.ByteSize() == 0):
                        continue

                    object_id = getattr(label, "id", None)
                    if object_id is None:
                        continue

                    center_in_vehicle = np.array([
                        getattr(box, "center_x", 0.0),
                        getattr(box, "center_y", 0.0),
                        getattr(box, "center_z", 0.0),
                        1.0,
                    ], dtype=np.float32).reshape((4, 1))
                    center_in_world_abs = vehicle_to_world_abs @ center_in_vehicle

                    heading = float(getattr(box, "heading", 0.0))
                    rot_vehicle = R.from_euler("xyz", [0.0, 0.0, heading], degrees=False).as_matrix().astype(np.float32)
                    rot_world_abs = (vehicle_to_world_abs[:3, :3] @ rot_vehicle).astype(np.float32)

                    obj_to_world_abs = np.eye(4, dtype=np.float32)
                    obj_to_world_abs[:3, :3] = rot_world_abs
                    obj_to_world_abs[:3, 3] = center_in_world_abs.flatten()[:3]

                    # 绝对世界 -> 归一化世界
                    obj_to_world_norm = world_abs_to_norm_tfm @ obj_to_world_abs

                    l = float(getattr(box, "length", 0.0))
                    w = float(getattr(box, "width", 0.0))
                    h = float(getattr(box, "height", 0.0))
                    lwh = np.array([l, w, h], dtype=np.float32)

                    md = getattr(label, "metadata", None)
                    if md is not None:
                        speed = float(getattr(md, "speed_x", 0.0)) ** 2 + float(getattr(md, "speed_y", 0.0)) ** 2 + float(getattr(md, "speed_z", 0.0)) ** 2
                        speed = float(np.sqrt(speed))
                    else:
                        speed = 0.0
                    is_moving = bool(speed > min_moving_speed)

                    if hasattr(label_pb2.Label.Type, "Name"):
                        obj_type = label_pb2.Label.Type.Name(lb_type)
                    else:
                        obj_type = str(int(lb_type))

                    frame_objects_info[str(object_id)] = {
                        "object_to_world": obj_to_world_norm.tolist(),
                        "object_lwh": lwh.tolist(),
                        "object_is_moving": is_moving,
                        "object_type": obj_type,
                    }

                sample[target_frame_key] = frame_objects_info

            return sample

        except Exception as e:
            print(f"[ObjectInfoProcessor] 处理 TFRecord 失败: {tfrecord_path}")
            print(f"错误: {e}")
            return {"__key__": clip_id}


class InsertionInfoProvider:
    """
    负责读取插入目标的 all_object_info_insertion，语义与参考代码保持一致：
    - 从 WebDataset tar 文件中读取一个 sample，其本身是形如
      {frame_key -> {tracking_id -> object_info}} 的字典。
    """

    # 与参考脚本保持一致的默认路径（RDS 预处理后的根目录）
    DEFAULT_RDS_ROOT = Path("/data/dataset/waymo_rds/validation")

    def __init__(
        self,
        clip_id: str,
        rds_root: Optional[Path] = None,
        waymo_root: Optional[Path] = None,
    ) -> None:
        self.clip_id = str(clip_id)
        self._rds_root = Path(rds_root) if rds_root is not None else self.DEFAULT_RDS_ROOT
        # Waymo 原始 TFRecord 根目录（用于对齐坐标系）
        self._waymo_root = Path(waymo_root) if waymo_root is not None else None
        self._all_object_info_insertion: Dict[str, Any] = {}
        self._load()

    @staticmethod
    def _compute_center_point_from_tfrecord(tfrecord_path: Path) -> Optional[np.ndarray]:
        """
        从 Waymo TFRecord 中计算车辆位姿中心点，与 ObjectInfoProcessor 中的逻辑保持一致。
        """
        if not tfrecord_path.exists():
            print(f"[InsertionInfoProvider] 计算中心点失败，TFRecord 未找到: {tfrecord_path}")
            return None

        try:
            dataset = tf.data.TFRecordDataset(str(tfrecord_path), compression_type="")
            all_vehicle_poses: List[np.ndarray] = []
            for data in dataset:
                fr = dataset_pb2.Frame()
                fr.ParseFromString(bytearray(data.numpy()))
                vehicle_to_world_abs = np.array(fr.pose.transform, dtype=np.float32).reshape((4, 4))
                all_vehicle_poses.append(vehicle_to_world_abs)

            if len(all_vehicle_poses) == 0:
                print(f"[InsertionInfoProvider] 计算中心点失败，TFRecord 为空: {tfrecord_path}")
                return None

            all_vehicle_poses_np = np.asarray(all_vehicle_poses, dtype=np.float32)
            center_point = np.mean(all_vehicle_poses_np[:, :3, 3], axis=0)
            return center_point
        except Exception as e:
            print(f"[InsertionInfoProvider] 计算中心点失败: {tfrecord_path}")
            print(f"错误: {e}")
            return None

    def _load(self) -> None:
        tar_path = self._rds_root / "all_object_info_insertion" / f"{self.clip_id}.tar"
        if not tar_path.exists():
            print(f"[InsertionInfoProvider] 插入信息文件未找到: {tar_path}")
            self._all_object_info_insertion = {}
            return

        try:
            sample = _get_sample(tar_path)
            # 参考代码中直接对返回值 items() 进行遍历，因此这里也假定它就是 dict
            if not isinstance(sample, dict):
                print(f"[InsertionInfoProvider] 非法的 sample 类型: {type(sample)} 来自 {tar_path}")
                self._all_object_info_insertion = {}
                return
            self._all_object_info_insertion = sample

            # === 关键修正：将插入目标的 object_to_world 对齐到当前管线的归一化世界坐标系 ===
            # 当前点云 / 相机 / 普通目标都是在 "减去 center_point 之后" 的归一化坐标系中；
            # all_object_info_insertion 则来自 RDS 预处理，可能仍在原始绝对世界坐标系中。
            # 这里复用 TFRecord 计算的 center_point，将插入目标的位姿做同样的平移，
            # 但通过一个简单的尺度判断避免对已经归一化的数据重复平移。
            if self._waymo_root is not None:
                tfrecord_path = self._waymo_root / f"segment-{self.clip_id}_with_camera_labels.tfrecord"
                center_point = self._compute_center_point_from_tfrecord(tfrecord_path)
                if center_point is not None:
                    # 统计插入目标的平均位置，用于判断是否已经在归一化坐标系
                    all_centers: List[np.ndarray] = []
                    for frame_key, objs in self._all_object_info_insertion.items():
                        if str(frame_key).startswith("__"):
                            continue
                        if not isinstance(objs, dict):
                            continue
                        ins_info = objs.get("insertion_0")
                        if ins_info is None:
                            continue
                        try:
                            tfm = np.asarray(ins_info.get("object_to_world", None), dtype=np.float32)
                        except Exception:
                            continue
                        if tfm.shape != (4, 4):
                            continue
                        all_centers.append(tfm[:3, 3])

                    if len(all_centers) > 0:
                        mean_pos = np.mean(np.stack(all_centers, axis=0), axis=0)
                        norm_center = float(np.linalg.norm(center_point))
                        norm_mean = float(np.linalg.norm(mean_pos))

                        # 若插入目标的平均位置尺度远小于 center_point，说明很可能已经归一化，避免重复平移
                        do_normalize = True
                        if norm_center > 1e-4 and norm_mean < 0.1 * norm_center:
                            do_normalize = False

                        if do_normalize:
                            world_abs_to_norm_tfm = np.eye(4, dtype=np.float32)
                            world_abs_to_norm_tfm[:3, 3] = -center_point
                            for frame_key, objs in self._all_object_info_insertion.items():
                                if str(frame_key).startswith("__"):
                                    continue
                                if not isinstance(objs, dict):
                                    continue
                                ins_info = objs.get("insertion_0")
                                if ins_info is None:
                                    continue
                                try:
                                    tfm_abs = np.asarray(ins_info.get("object_to_world", None), dtype=np.float32)
                                except Exception:
                                    continue
                                if tfm_abs.shape != (4, 4):
                                    continue
                                tfm_norm = world_abs_to_norm_tfm @ tfm_abs
                                ins_info["object_to_world"] = tfm_norm.tolist()
        except Exception as e:
            print(f"[InsertionInfoProvider] 读取插入信息失败: {tar_path}")
            print(f"错误: {e}")
            self._all_object_info_insertion = {}

    def get_insertion_info(self) -> Dict[str, Any]:
        """
        返回形如 {frame_key -> { 'insertion_0': object_info, ... }} 的字典。
        """
        return self._all_object_info_insertion


