import numpy as np
import torch
from scipy.spatial.transform import Rotation as R, Slerp
from typing import List
from gaussian_splatting.scene.cameras import Camera, LightCam

def apply_trajectory_control_to_lightcam(
    camera_list: List[Camera], smooth_radius: int = 2
) -> List[LightCam]:
    """
    对 Camera 列表进行滑动窗口平滑（R 平均 + T 平均），输出 LightCam 列表。
    """
    n = len(camera_list)
    new_cams: List[LightCam] = []

    all_R = [cam.R.detach().cpu().numpy() if isinstance(cam.R, torch.Tensor) else cam.R for cam in camera_list]
    all_T = [cam.T.detach().cpu().numpy() if isinstance(cam.T, torch.Tensor) else cam.T for cam in camera_list]

    ref = camera_list[0]
    width, height = ref.image_width, ref.image_height
    FoVx, FoVy = ref.FoVx, ref.FoVy
    znear, zfar = ref.znear, ref.zfar
    trans = ref.trans if isinstance(ref.trans, torch.Tensor) else torch.tensor(ref.trans)
    scale = ref.scale
    device = ref.data_device if hasattr(ref, "data_device") else "cuda"

    for i in range(n):
        start = max(0, i - smooth_radius)
        end = min(n - 1, i + smooth_radius)
        indices = list(range(start, end + 1))

        # 平移均值
        Ts = np.stack([all_T[j] for j in indices], axis=0)
        T_avg = Ts.mean(axis=0)

        # 旋转 slerp
        Rs = R.from_matrix([all_R[j] for j in indices])
        times = np.linspace(0, 1, len(indices))
        slerp = Slerp(times, Rs)
        R_avg = slerp([0.5])[0].as_matrix()

        # 构建 LightCam
        cam = LightCam(
            image_width=width,
            image_height=height,
            FoVx=FoVx,
            FoVy=FoVy,
            znear=znear,
            zfar=zfar,
            R=torch.tensor(R_avg, dtype=torch.float32),
            T=torch.tensor(T_avg, dtype=torch.float32),
            trans=trans,
            scale=scale,
            device=device,
            uid=ref.uid,
            image_name=ref.image_name,
        )
        new_cams.append(cam)

    return new_cams
