import numpy as np
import torch
from scipy.spatial.transform import Rotation as R, Slerp
from copy import deepcopy
from gaussian_splatting.scene.cameras import Camera
from typing import List

def apply_trajectory_control(camera_list: List[Camera], smooth_radius: int = 2) -> List[Camera]:
    """
    对相机位姿轨迹进行平滑处理（滑动窗口平均）。
    :param camera_list: 原始相机位姿列表
    :param smooth_radius: 平滑窗口半径（默认2表示 5帧窗口）
    :return: 平滑后的相机列表
    """
    n = len(camera_list)
    new_camera_list = []

    # 提前提取所有 R/T
    all_R = [cam.R.detach().cpu().numpy() if isinstance(cam.R, torch.Tensor) else cam.R for cam in camera_list]
    all_T = [cam.T.detach().cpu().numpy() if isinstance(cam.T, torch.Tensor) else cam.T for cam in camera_list]

    for i in range(n):
        # 滑动窗口索引
        start = max(0, i - smooth_radius)
        end = min(n - 1, i + smooth_radius)
        indices = list(range(start, end + 1))

        # 平均平移
        Ts = np.array([all_T[j] for j in indices])
        T_avg = Ts.mean(axis=0)

        # 平均旋转：用 Slerp 插值居中帧
        Rs = R.from_matrix([all_R[j] for j in indices])
        times = np.linspace(0, 1, len(indices))
        slerp = Slerp(times, Rs)
        R_avg = slerp([0.5])[0].as_matrix()  # 中间插值作为平滑旋转

        # 构建新相机（deepcopy + 替换 R/T）
        cam = deepcopy(camera_list[i])
        cam.R = torch.tensor(R_avg, dtype=torch.float32, device=cam.data_device)
        cam.T = torch.tensor(T_avg, dtype=torch.float32, device=cam.data_device)
        new_camera_list.append(cam)

    return new_camera_list