from typing import List, Callable, Union
import numpy as np
import torch
import scipy
from scipy.spatial.transform import Rotation as R
from scipy.interpolate import make_interp_spline
from gaussian_splatting.scene.cameras import LightCam
import random

def compute_pose_delta(cam1: LightCam, cam2: LightCam, alpha: float = 1.0, beta: float = 1.0) -> float:
    R1 = cam1.R.cpu().numpy()
    R2 = cam2.R.cpu().numpy()
    T1 = cam1.T.cpu().numpy()
    T2 = cam2.T.cpu().numpy()
    cos_theta = np.clip((np.trace(R1.T @ R2) - 1.0) / 2.0, -1.0, 1.0)
    rot_dist = np.arccos(cos_theta)
    trans_dist = np.linalg.norm(T2 - T1)
    return alpha * rot_dist + beta * trans_dist

def _interpolate_translation(s_raw: np.ndarray, T_all: np.ndarray, s_target: np.ndarray) -> np.ndarray:
    out = []
    for d in range(3):
        spline = make_interp_spline(s_raw, T_all[:, d], k=3)
        out.append(spline(s_target))
    return np.stack(out, axis=-1)

def _interpolate_rotation(s_raw: np.ndarray, R_all: List[R], s_target: np.ndarray) -> List[np.ndarray]:
    rotvec_all = np.array([r.as_rotvec() for r in R_all])
    rot_interp = []
    for d in range(3):
        spline = make_interp_spline(s_raw, rotvec_all[:, d], k=3)
        rot_interp.append(spline(s_target))
    rot_interp = np.stack(rot_interp, axis=-1)
    return [R.from_rotvec(v).as_matrix() for v in rot_interp]

def generate_variable_speed_interpolated_lightcam_list(
    origin_viewpoints: List[LightCam],
    interp_multiplier: int,
    speed_profile: Union[List[float], Callable[[float], float]],
    alpha: float = 1.0,
    beta: float = 1.0,
) -> List[LightCam]:
    if len(origin_viewpoints) < 2:
        raise ValueError("Need ≥2 camera poses for interpolation.")
    target_frame_count = len(origin_viewpoints) * interp_multiplier

    pose_deltas = [compute_pose_delta(origin_viewpoints[i], origin_viewpoints[i + 1], alpha, beta) for i in range(len(origin_viewpoints) - 1)]
    s_raw = np.concatenate([[0.0], np.cumsum(pose_deltas)])

    t_uniform = np.linspace(0.0, 1.0, target_frame_count)
    if isinstance(speed_profile, list):
        speed_profile = make_interp_spline(np.linspace(0, 1, len(speed_profile)), speed_profile, k=3)
    speeds = np.array([speed_profile(t) for t in t_uniform])
    speeds = np.maximum(speeds, 1e-6)
    ds = speeds / speeds.sum() * s_raw[-1]
    s_target = np.insert(np.cumsum(ds)[:-1], 0, 0.0)

    T_all = np.stack([cam.T.cpu().numpy() for cam in origin_viewpoints])
    T_interp = _interpolate_translation(s_raw, T_all, s_target)

    R_all = [R.from_matrix(cam.R.cpu().numpy()) for cam in origin_viewpoints]
    R_interp = _interpolate_rotation(s_raw, R_all, s_target)

    ref = origin_viewpoints[0]
    out = []
    for i, (Ri, Ti) in enumerate(zip(R_interp, T_interp)):
        cam = LightCam(
            image_width=ref.image_width,
            image_height=ref.image_height,
            FoVx=ref.FoVx,
            FoVy=ref.FoVy,
            znear=ref.znear,
            zfar=ref.zfar,
            R=torch.tensor(Ri, dtype=torch.float32),
            T=torch.tensor(Ti, dtype=torch.float32),
            trans=ref.trans,
            scale=ref.scale,
            device=ref.device,
            uid=i,
            image_name=f"interp_{i:04d}"
        )
        out.append(cam)
    return out
