from typing import List, Callable, Union
import numpy as np
import torch
from scipy.spatial.transform import Rotation as R, Slerp
from scipy.interpolate import make_interp_spline
from gaussian_splatting.scene.cameras import Camera


def compute_pose_delta(cam1: Camera, cam2: Camera, alpha=1.0, beta=1.0) -> float:
    # Extract rotation and translation matrices from camera objects
    R1 = cam1.R.detach().cpu().numpy() if isinstance(cam1.R, torch.Tensor) else cam1.R
    R2 = cam2.R.detach().cpu().numpy() if isinstance(cam2.R, torch.Tensor) else cam2.R
    T1 = cam1.T.detach().cpu().numpy() if isinstance(cam1.T, torch.Tensor) else cam1.T
    T2 = cam2.T.detach().cpu().numpy() if isinstance(cam2.T, torch.Tensor) else cam2.T

    # Compute relative rotation matrix and corresponding angle
    R_rel = R1.T @ R2
    trace = np.trace(R_rel)
    cos_theta = np.clip((trace - 1) / 2, -1.0, 1.0)
    rot_dist = np.arccos(cos_theta)
    trans_dist = np.linalg.norm(T2 - T1)

    # Return weighted combination of rotation and translation distances
    return alpha * rot_dist + beta * trans_dist


def generate_variable_speed_interpolated_camera_list(
    origin_viewpoints: List[Camera], 
    interp_multiplier: int, 
    speed_profile: Union[List[float], Callable[[float], float]], 
    alpha: float = 1.0, 
    beta: float = 1.0
) -> List[Camera]:
    # Ensure there are at least two camera poses for interpolation
    if len(origin_viewpoints) < 2:
        raise ValueError("At least two camera poses are required for interpolation.")

    n = len(origin_viewpoints)

    # Step 1: Compute motion deltas between consecutive poses
    pose_deltas = [compute_pose_delta(origin_viewpoints[i], origin_viewpoints[i+1], alpha, beta) for i in range(n-1)]
    s_raw = np.concatenate([[0], np.cumsum(pose_deltas)])  # cumulative motion (pseudo-distance)

    # Step 2: Sample speed profile at uniformly spaced time steps
    target_frame_count = n * interp_multiplier
    t_uniform = np.linspace(0, 1, target_frame_count)
    if isinstance(speed_profile, list):
        # Convert list into a smooth speed function using B-spline
        speed_profile = make_interp_spline(np.linspace(0, 1, len(speed_profile)), speed_profile, k=3)
    speeds = np.array([speed_profile(t) for t in t_uniform])

    # Step 3: Convert speeds into cumulative motion to get new sampling locations
    ds = speeds / np.sum(speeds) * s_raw[-1]  # normalize to match total motion of original
    s_target = np.cumsum(ds)
    s_target = np.insert(s_target, 0, 0.0)  # start at 0
    s_target = s_target[:target_frame_count]  # ensure correct frame count

    # Step 4: Interpolate translation using B-spline for each axis
    T_all = np.array([cam.T.detach().cpu().numpy().reshape(-1) for cam in origin_viewpoints])
    T_interp = []
    for d in range(3):
        spline = make_interp_spline(s_raw, T_all[:, d], k=3)
        T_interp.append(spline(s_target))
    T_interp = np.stack(T_interp, axis=-1)  # shape: (target_frame_count, 3)

    # Step 5: Interpolate rotation using B-spline on rotation vectors (so(3))
    R_all = [R.from_matrix(cam.R.detach().cpu().numpy()) for cam in origin_viewpoints]
    rotvec_all = np.array([r.as_rotvec() for r in R_all])
    rotvec_interp = []
    for d in range(3):
        spline = make_interp_spline(s_raw, rotvec_all[:, d], k=3)
        rotvec_interp.append(spline(s_target))
    rotvec_interp = np.stack(rotvec_interp, axis=-1)
    R_interp = [R.from_rotvec(v).as_matrix() for v in rotvec_interp]  # convert back to rotation matrices

    # Step 6: Create new Camera objects from interpolated R and T
    new_camera_list = []
    ref_cam = origin_viewpoints[0]  # use first frame as reference for other parameters
    for i in range(len(R_interp)):
        cam = Camera(
            resolution=(ref_cam.image_width, ref_cam.image_height),
            colmap_id=-1,
            R=torch.tensor(R_interp[i], dtype=torch.float32),
            T=torch.tensor(T_interp[i], dtype=torch.float32),
            FoVx=ref_cam.FoVx,
            FoVy=ref_cam.FoVy,
            depth_params=None,
            image=ref_cam.image_pil,
            invdepthmap=None,
            image_name=f"interp_{i:04d}",
            uid=i,
            trans=ref_cam.trans,
            scale=ref_cam.scale,
            data_device=str(ref_cam.data_device),
            train_test_exp=False,
            is_test_dataset=False,
            is_test_view=False
        )
        cam.is_interp = True  # flag for interpolated frames

        # Compute focal lengths from field of view and image dimensions
        cam.focal_x = 0.5 * ref_cam.image_width / np.tan(0.5 * ref_cam.FoVx)
        cam.focal_y = 0.5 * ref_cam.image_height / np.tan(0.5 * ref_cam.FoVy)

        # Assume principal point is at image center
        cam.principal_x = ref_cam.image_width / 2.0
        cam.principal_y = ref_cam.image_height / 2.0

        new_camera_list.append(cam)

    return new_camera_list
