'''
希望生成一组“时间均匀 / 信息丰富”的高帧率图像序列，其中：
快速变化处插帧多（避免 event 丢失）；
静态/变化小区域插帧少（避免生成虚假 event）
不只是时间均匀，而是 信息密度均匀。
'''

from typing import List
import numpy as np
import torch
from scipy.spatial.transform import Rotation as R, Slerp
from gaussian_splatting.scene.cameras import Camera
from interpolation.optimize_pose_linear import interpolate_between_cameras

def compute_pose_delta(cam1: Camera, cam2: Camera, alpha=1.0, beta=1.0) -> float:
    """计算两个Camera之间的加权位姿变化度"""
    # switch GPU tensor to numpy
    R1 = cam1.R.detach().cpu().numpy() if isinstance(cam1.R, torch.Tensor) else cam1.R
    R2 = cam2.R.detach().cpu().numpy() if isinstance(cam2.R, torch.Tensor) else cam2.R
    T1 = cam1.T.detach().cpu().numpy() if isinstance(cam1.T, torch.Tensor) else cam1.T
    T2 = cam2.T.detach().cpu().numpy() if isinstance(cam2.T, torch.Tensor) else cam2.T

    R_rel = R1.T @ R2
    trace = np.trace(R_rel)
    cos_theta = np.clip((trace - 1) / 2, -1.0, 1.0)
    rot_dist = np.arccos(cos_theta)

    trans_dist = np.linalg.norm(T2 - T1)

    return alpha * rot_dist + beta * trans_dist


def generate_adaptive_interpolated_camera_list(origin_viewpoints: List[Camera], interp_multiplier: float) -> List[Camera]:
    """
    根据相邻相机之间的位姿变化度自适应地插入中间位姿，
    插帧总数 = 原始相机数量 × interp_multiplier。
    
    :param origin_viewpoints: 原始稀疏相机列表
    :param interp_multiplier: 插帧倍数（如 3.0 表示插成原来的 3 倍）
    :return: 插值后的相机列表
    """
    if len(origin_viewpoints) < 2:
        raise ValueError("至少需要两个相机位姿进行插值。")

    n = len(origin_viewpoints)
    pose_deltas = []

    # Step 1: 计算每一对相机之间的 pose delta（加权旋转+平移）
    for i in range(n - 1):
        delta = compute_pose_delta(origin_viewpoints[i], origin_viewpoints[i + 1])
        pose_deltas.append(delta)

    pose_deltas = np.array(pose_deltas)
    total_motion = np.sum(pose_deltas)

    # Step 2: 计算总目标帧数
    target_total_frames = int(round(n * interp_multiplier))
    remaining_frames = target_total_frames - n
    if remaining_frames < 0:
        raise ValueError("插帧倍数太小，导致目标帧数小于原始帧数。")

    # Step 3: 按变化度分配插帧数
    inserted_per_pair = np.round(pose_deltas / total_motion * remaining_frames).astype(int)

    # Step 4: 插值生成相机序列
    new_camera_list = []
    new_uid = 0
    for i in range(n - 1):
        cam_start = origin_viewpoints[i]
        cam_end = origin_viewpoints[i + 1]
        num_inserted = int(inserted_per_pair[i])

        segment = interpolate_between_cameras(cam_start, cam_end, num_inserted)

        if i > 0:
            segment = segment[1:]  # 去重第一帧
        for cam in segment:
            cam.uid = new_uid
            new_uid += 1
        new_camera_list.extend(segment)

    return new_camera_list
