# from typing import List
# import numpy as np
# import torch
# from scipy.interpolate import make_interp_spline
# from scipy.spatial.transform import Rotation as R
# from gaussian_splatting.scene.cameras import Camera


# def generate_bspline_interpolated_camera_list(camera_list: List[Camera], num_inserted: int, order: int = 3) -> List[Camera]:
#     """
#     对给定的 Camera 列表使用 n 阶 B 样条进行位姿插值，
#     生成新的 Camera 列表。新列表的总帧数为：
#         total_frames = (len(camera_list) - 1) * num_inserted + len(camera_list)
    
#     平移部分直接在欧几里得空间中利用 B 样条插值，
#     旋转部分先转换为旋转向量（rotvec），再进行 B 样条插值，
#     最后将结果转换回旋转矩阵。
    
#     :param camera_list: 原始的 Camera 列表
#     :param num_inserted: 每对相邻相机之间插入的帧数（不包括原始帧）
#     :param order: B 样条插值的阶数（n阶）
#     :return: 插值后的 Camera 列表
#     """
#     n = len(camera_list)
#     if n < 2:
#         raise ValueError("至少需要两个相机位姿才能进行插值。")

#     # 总帧数：原始帧和插入帧
#     total_frames = (n - 1) * num_inserted + n

#     # 定义控制点的参数值及插值时采样的参数值
#     t_control = np.linspace(0, 1, n)
#     t_new = np.linspace(0, 1, total_frames)

#     # ---------------------------
#     # 插值平移（Translation）
#     # ---------------------------
#     T_points = []
#     for cam in camera_list:
#         pos = cam.T.cpu().numpy() if isinstance(cam.T, torch.Tensor) else np.array(cam.T)
#         T_points.append(pos)
#     T_points = np.array(T_points)  # shape: (n, 3)

#     # 阶数不能超过 (n-1)
#     k = min(order, n - 1)
#     spline_T = make_interp_spline(t_control, T_points, k=k, axis=0)
#     T_interp = spline_T(t_new)  # shape: (total_frames, 3)

#     # ---------------------------
#     # 插值旋转（Rotation）
#     # ---------------------------
#     # 先将旋转矩阵转换为旋转向量表示
#     rotvec_points = []
#     for cam in camera_list:
#         R_mat = cam.R.cpu().numpy() if isinstance(cam.R, torch.Tensor) else np.array(cam.R)
#         rotvec = R.from_matrix(R_mat).as_rotvec()
#         rotvec_points.append(rotvec)
#     rotvec_points = np.array(rotvec_points)  # shape: (n, 3)

#     spline_rot = make_interp_spline(t_control, rotvec_points, k=k, axis=0)
#     rotvec_interp = spline_rot(t_new)  # shape: (total_frames, 3)
#     R_interp = R.from_rotvec(rotvec_interp).as_matrix()  # shape: (total_frames, 3, 3)

#     # ---------------------------
#     # 根据插值结果创建新的 Camera 对象
#     # ---------------------------
#     new_cameras = []
#     new_uid = 0
#     # 这里其余参数（例如分辨率、FoV等）均使用原始相机列表中第一个相机的参数
#     for i in range(total_frames):
#         new_cam = Camera(
#             resolution=(camera_list[0].image_width, camera_list[0].image_height),
#             colmap_id=-1,  # 新生成的相机暂时设置为 -1，后续可重新编号
#             R=torch.tensor(R_interp[i], dtype=torch.float32, device=camera_list[0].data_device),
#             T=torch.tensor(T_interp[i], dtype=torch.float32, device=camera_list[0].data_device),
#             FoVx=camera_list[0].FoVx,
#             FoVy=camera_list[0].FoVy,
#             depth_params=None,
#             image=camera_list[0].image_pil,
#             invdepthmap=camera_list[0].invdepthmap,
#             image_name=f"interpolated_{i:03d}.jpg",
#             uid=new_uid,
#             trans=camera_list[0].trans,
#             scale=camera_list[0].scale,
#             data_device=camera_list[0].data_device
#         )
#         new_uid += 1
#         new_cameras.append(new_cam)

#     return new_cameras


from typing import List
import numpy as np
import torch
from scipy.interpolate import make_interp_spline
from scipy.spatial.transform import Rotation as R
from gaussian_splatting.scene.cameras import Camera


def generate_bspline_interpolated_camera_list(
    camera_list: List[Camera], num_inserted: int, order: int = 3
) -> List[Camera]:
    """
    对给定的 Camera 列表使用 n 阶 B 样条进行位姿插值，
    生成新的 Camera 列表。新列表的总帧数为：
        total_frames = (len(camera_list) - 1) * num_inserted + len(camera_list)

    平移部分直接在欧几里得空间中利用 B 样条插值，
    旋转部分先转换为旋转向量（rotvec），再进行 B 样条插值，
    最后将结果转换回旋转矩阵。

    同时，对第一台相机的 FoV 及分辨率计算得到焦距 (focal_x, focal_y) 和主点 (cx, cy)，
    并赋值给所有生成的新相机对象，包括 principal_x, principal_y。
    """
    n = len(camera_list)
    if n < 2:
        raise ValueError("至少需要两个相机位姿才能进行插值。")

    # 计算原始相机的焦距（以像素为单位）
    first_cam = camera_list[0]
    if hasattr(first_cam, 'fx') and hasattr(first_cam, 'fy'):
        fx = float(first_cam.fx)
        fy = float(first_cam.fy)
        width = first_cam.image_width
        height = first_cam.image_height
    else:
        width, height = first_cam.image_width, first_cam.image_height
        fx = (width * 0.5) / np.tan(first_cam.FoVx * 0.5)
        fy = (height * 0.5) / np.tan(first_cam.FoVy * 0.5)
    # 主点假定在图像中心
    cx = width * 0.5
    cy = height * 0.5

    # 总帧数：原始帧和插入帧
    total_frames = (n - 1) * num_inserted + n

    # 定义控制点和新采样点的参数值
    t_control = np.linspace(0, 1, n)
    t_new = np.linspace(0, 1, total_frames)

    # ---------------------------
    # 插值平移（Translation）
    # ---------------------------
    T_points = [
        cam.T.cpu().numpy() if isinstance(cam.T, torch.Tensor) else np.array(cam.T)
        for cam in camera_list
    ]
    T_points = np.array(T_points)

    k = min(order, n - 1)
    spline_T = make_interp_spline(t_control, T_points, k=k, axis=0)
    T_interp = spline_T(t_new)

    # ---------------------------
    # 插值旋转（Rotation）
    # ---------------------------
    rotvec_points = [
        R.from_matrix(
            cam.R.cpu().numpy() if isinstance(cam.R, torch.Tensor) else np.array(cam.R)
        ).as_rotvec()
        for cam in camera_list
    ]
    rotvec_points = np.array(rotvec_points)

    spline_rot = make_interp_spline(t_control, rotvec_points, k=k, axis=0)
    rotvec_interp = spline_rot(t_new)
    R_interp = R.from_rotvec(rotvec_interp).as_matrix()

    # ---------------------------
    # 创建新的 Camera 对象并赋予内参信息
    # ---------------------------
    new_cameras: List[Camera] = []
    new_uid = 0
    for i in range(total_frames):
        new_cam = Camera(
            resolution=(first_cam.image_width, first_cam.image_height),
            colmap_id=-1,
            R=torch.tensor(R_interp[i], dtype=torch.float32, device=first_cam.data_device),
            T=torch.tensor(T_interp[i], dtype=torch.float32, device=first_cam.data_device),
            FoVx=first_cam.FoVx,
            FoVy=first_cam.FoVy,
            depth_params=None,
            image=first_cam.image_pil,
            invdepthmap=first_cam.invdepthmap,
            image_name=f"interpolated_{i:03d}.jpg",
            uid=new_uid,
            trans=first_cam.trans,
            scale=first_cam.scale,
            data_device=first_cam.data_device
        )
        # 设置焦距、主点及 principal 属性
        new_cam.focal_x = fx
        new_cam.focal_y = fy
        new_cam.cx = cx
        new_cam.cy = cy
        new_cam.principal_x = cx
        new_cam.principal_y = cy

        new_cameras.append(new_cam)
        new_uid += 1

    return new_cameras
