from typing import List
import numpy as np
import torch
from scipy.spatial.transform import Rotation as R
from scipy.interpolate import make_interp_spline
from gaussian_splatting.scene.cameras import Camera


def compute_pose_delta(cam1: Camera, cam2: Camera, alpha=1.0, beta=1.0) -> float:
    """
    Compute the weighted pose delta between two Camera objects.
    It combines rotational and translational distances.
    """
    R1 = cam1.R.detach().cpu().numpy() if isinstance(cam1.R, torch.Tensor) else cam1.R
    R2 = cam2.R.detach().cpu().numpy() if isinstance(cam2.R, torch.Tensor) else cam2.R
    T1 = cam1.T.detach().cpu().numpy() if isinstance(cam1.T, torch.Tensor) else cam1.T
    T2 = cam2.T.detach().cpu().numpy() if isinstance(cam2.T, torch.Tensor) else cam2.T

    R_rel = R1.T @ R2
    trace = np.trace(R_rel)
    cos_theta = np.clip((trace - 1) / 2, -1.0, 1.0)
    rot_dist = np.arccos(cos_theta)
    trans_dist = np.linalg.norm(T2 - T1)

    return alpha * rot_dist + beta * trans_dist


def generate_adaptive_bspline_interpolated_camera_list(origin_viewpoints: List[Camera], interp_multiplier: float) -> List[Camera]:
    """
    Adaptively interpolate new camera poses between the original sparse viewpoints using B-Spline.
    The number of inserted cameras is proportional to the motion (pose delta) between each pair.
    """
    if len(origin_viewpoints) < 2:
        raise ValueError("At least two camera poses are required for interpolation.")

    n = len(origin_viewpoints)
    pose_deltas = []

    # Step 1: Compute pose delta between adjacent cameras
    for i in range(n - 1):
        delta = compute_pose_delta(origin_viewpoints[i], origin_viewpoints[i + 1])
        pose_deltas.append(delta)

    pose_deltas = np.array(pose_deltas)
    total_motion = np.sum(pose_deltas)
    target_total_frames = int(round(n * interp_multiplier))
    remaining_frames = target_total_frames - n
    if remaining_frames < 0:
        raise ValueError("Interpolation multiplier too small, resulting in fewer frames than input.")

    # Step 2: Allocate inserted frames for each pair
    inserted_per_pair = np.round(pose_deltas / total_motion * remaining_frames).astype(int)

    # Step 3: Prepare interpolation indices
    T_all = np.array([cam.T.detach().cpu().numpy().reshape(-1) for cam in origin_viewpoints])
    frame_indices = np.arange(len(T_all))

    # Construct non-uniform target indices (dense in fast regions)
    target_indices = [0]
    for i in range(len(inserted_per_pair)):
        target_indices.extend(np.linspace(i + 1e-4, i + 1, inserted_per_pair[i] + 1, endpoint=True)[1:])
    target_indices = np.array(target_indices)

    # Step 4: B-Spline interpolate translation vectors
    T_interp = []
    for d in range(3):
        spline = make_interp_spline(frame_indices, T_all[:, d], k=3)
        T_interp.append(spline(target_indices))
    T_interp = np.stack(T_interp, axis=-1)

    # Step 5: B-Spline interpolate rotation vectors (converted from rotation matrices)
    R_all = [R.from_matrix(cam.R.detach().cpu().numpy()) for cam in origin_viewpoints]
    rotvec_all = np.array([r.as_rotvec() for r in R_all])
    rotvec_interp = []
    for d in range(3):
        spline = make_interp_spline(frame_indices, rotvec_all[:, d], k=3)
        rotvec_interp.append(spline(target_indices))
    rotvec_interp = np.stack(rotvec_interp, axis=-1)
    R_interp = [R.from_rotvec(v).as_matrix() for v in rotvec_interp]

    # Step 6: Construct new Camera list
    new_camera_list = []
    ref_cam = origin_viewpoints[0]

    for i in range(len(R_interp)):
        cam = Camera(
            resolution=(ref_cam.image_width, ref_cam.image_height),
            colmap_id=-1,
            R=torch.tensor(R_interp[i], dtype=torch.float32),
            T=torch.tensor(T_interp[i], dtype=torch.float32),
            FoVx=ref_cam.FoVx,
            FoVy=ref_cam.FoVy,
            depth_params=None,
            image=ref_cam.image_pil,
            invdepthmap=None,
            image_name=f"interp_{i:04d}",
            uid=i,
            trans=ref_cam.trans,
            scale=ref_cam.scale,
            data_device=str(ref_cam.data_device),
            train_test_exp=False,
            is_test_dataset=False,
            is_test_view=False
        )
        cam.is_interp = True
        # Add focal length
        cam.focal_x = 0.5 * ref_cam.image_width / np.tan(0.5 * ref_cam.FoVx)
        cam.focal_y = 0.5 * ref_cam.image_height / np.tan(0.5 * ref_cam.FoVy)

        # Add principal point (assumed to be image center)
        cam.principal_x = ref_cam.image_width / 2.0
        cam.principal_y = ref_cam.image_height / 2.0
        new_camera_list.append(cam)

    return new_camera_list
