from typing import List
import numpy as np
import torch
from scipy.spatial.transform import Rotation as R, Slerp
from gaussian_splatting.scene.cameras import Camera


def interpolate_between_cameras(cam_start: Camera, cam_end: Camera, num_inserted: int) -> List[Camera]:
    """
    在两个相机位姿之间进行插值
    :param cam_start: 起始 Camera
    :param cam_end:   终止 Camera
    :param num_inserted: 要插入的帧数（不包含起始和终止帧）
    :return: 总共 (num_inserted+2) 帧的 Camera 列表（包含起始和终止）
    """
    total_frames = num_inserted + 2  # 包括起始和终止
    alphas = np.linspace(0, 1, total_frames)

    # 获取旋转和平移，并确保数据格式正确
    pos_start = cam_start.T.cpu().numpy() if isinstance(cam_start.T, torch.Tensor) else np.array(cam_start.T)
    pos_end = cam_end.T.cpu().numpy() if isinstance(cam_end.T, torch.Tensor) else np.array(cam_end.T)

    rot_start = cam_start.R.cpu().numpy() if isinstance(cam_start.R, torch.Tensor) else np.array(cam_start.R)
    rot_end = cam_end.R.cpu().numpy() if isinstance(cam_end.R, torch.Tensor) else np.array(cam_end.R)

    # 对平移向量做线性插值
    pos_interp = np.linspace(pos_start, pos_end, total_frames, axis=0)

    # 对旋转做 SLERP 插值
    rot_start = R.from_matrix(rot_start)
    rot_end = R.from_matrix(rot_end)
    key_times = np.array([0, 1])
    key_rots = R.from_quat([rot_start.as_quat(), rot_end.as_quat()])
    slerp = Slerp(key_times, key_rots)
    rot_interp = slerp(alphas).as_matrix()  # shape: (total_frames, 3, 3)

    interpolated_cams = []
    for i in range(total_frames):
        # 这里我们用起始相机的所有参数来创建新的相机对象
        new_cam = Camera(
            resolution=(cam_start.image_width, cam_start.image_height),
            colmap_id=-1,  # 暂时设为 -1，后续可以重新编号
            R=torch.tensor(rot_interp[i], dtype=torch.float32, device=cam_start.data_device),
            T=torch.tensor(pos_interp[i], dtype=torch.float32, device=cam_start.data_device),
            FoVx=cam_start.FoVx,
            FoVy=cam_start.FoVy,
            depth_params=None,
            image=cam_start.image_pil,
            invdepthmap=cam_start.invdepthmap,
            image_name=f"interpolated_{i:03d}.jpg",
            uid=-1,  # 需要在后续重新编号
            trans=cam_start.trans,
            scale=cam_start.scale,
            data_device=cam_start.data_device
        )
        if i == 0 or i == total_frames - 1:
            new_cam.is_interp = False  # 起始和终止帧不是插值帧
        else:
            new_cam.is_interp = True  # 标记为插值相机
        interpolated_cams.append(new_cam)

    return interpolated_cams


def generate_linear_interpolated_camera_list(camera_list: List[Camera], num_inserted: int) -> List[Camera]:
    """
    对给定的 Camera 列表中每一对相邻相机进行插值，
    并将插值结果与原始位姿按顺序拼接成新的列表。
    """
    if len(camera_list) < 2:
        raise ValueError("至少需要两个相机位姿才能进行插值。")

    new_list = []
    new_uid = 0

    for i in range(len(camera_list) - 1):
        cam_start = camera_list[i]
        cam_end = camera_list[i + 1]

        # 对当前两个相机之间进行插值
        segment = interpolate_between_cameras(cam_start, cam_end, num_inserted)

        # 若不是第一个段落，则移除第一个元素以避免重复（因为上个段落的最后一帧与本段开始相同）
        if i > 0:
            segment = segment[1:]

        # 重新设置该段内所有帧的 uid
        for cam in segment:
            cam.uid = new_uid
            new_uid += 1

        new_list.extend(segment)

    return new_list


# 示例用法
if __name__ == "__main__":
    # 创建测试用的 Camera 数据
    cam0 = Camera(
        resolution=(720, 1280),
        colmap_id=0,
        R=torch.eye(3, dtype=torch.float32),
        T=torch.tensor([0, 0, 0], dtype=torch.float32),
        FoVx=np.array(45),
        FoVy=np.array(45),
        image=None,
        invdepthmap=None,
        image_name="cam0.jpg",
        uid=0,
        trans=np.array([0.0, 0.0, 0.0]),
        scale=1.0,
        data_device="cuda"
    )

    cam1 = Camera(
        resolution=(720, 1280),
        colmap_id=1,
        R=torch.eye(3, dtype=torch.float32),
        T=torch.tensor([1, 0, 0], dtype=torch.float32),
        FoVx=np.array(45),
        FoVy=np.array(45),
        image=None,
        invdepthmap=None,
        image_name="cam1.jpg",
        uid=1,
        trans=np.array([0.0, 0.0, 0.0]),
        scale=1.0,
        data_device="cuda"
    )

    cam2 = Camera(
        resolution=(720, 1280),
        colmap_id=2,
        R=torch.eye(3, dtype=torch.float32),
        T=torch.tensor([1, 1, 0], dtype=torch.float32),
        FoVx=np.array(45),
        FoVy=np.array(45),
        image=None,
        invdepthmap=None,
        image_name="cam2.jpg",
        uid=2,
        trans=np.array([0.0, 0.0, 0.0]),
        scale=1.0,
        data_device="cuda"
    )

    original_list = [cam0, cam1, cam2]
    # 假设我们希望在每对相邻相机之间插入 3 帧
    new_camera_list = generate_linear_interpolated_camera_list(original_list, num_inserted=3)

    # 输出新列表的 uid 和 T 以验证顺序
    for cam in new_camera_list:
        print(f"uid: {cam.uid}, T: {cam.T.cpu().numpy()}")
