from typing import List, Callable, Union, Tuple
import numpy as np
import torch
import scipy
from scipy.spatial.transform import Rotation as R
from scipy.interpolate import make_interp_spline
from gaussian_splatting.scene.cameras import Camera
import random

##############################################################
# 1.  Basic pose‑distance and adaptive speed interpolation   #
##############################################################

def compute_pose_delta(cam1: Camera, cam2: Camera, alpha: float = 1.0, beta: float = 1.0) -> float:
    """Return a weighted pose distance (rotation + translation)."""
    R1 = cam1.R.detach().cpu().numpy() if isinstance(cam1.R, torch.Tensor) else cam1.R
    R2 = cam2.R.detach().cpu().numpy() if isinstance(cam2.R, torch.Tensor) else cam2.R
    T1 = cam1.T.detach().cpu().numpy() if isinstance(cam1.T, torch.Tensor) else cam1.T
    T2 = cam2.T.detach().cpu().numpy() if isinstance(cam2.T, torch.Tensor) else cam2.T

    # relative rotation angle
    cos_theta = np.clip((np.trace(R1.T @ R2) - 1.0) / 2.0, -1.0, 1.0)
    rot_dist = np.arccos(cos_theta)
    trans_dist = np.linalg.norm(T2 - T1)
    return alpha * rot_dist + beta * trans_dist


def _interpolate_translation(s_raw: np.ndarray, T_all: np.ndarray, s_target: np.ndarray) -> np.ndarray:
    """Cubic B‑spline interpolation for translation."""
    out = []
    for d in range(3):
        spline = make_interp_spline(s_raw, T_all[:, d], k=3)
        out.append(spline(s_target))
    return np.stack(out, axis=-1)


def _interpolate_rotation(s_raw: np.ndarray, R_all: List[R], s_target: np.ndarray) -> List[np.ndarray]:
    """Cubic B‑spline interpolation in rot‑vector space."""
    rotvec_all = np.array([r.as_rotvec() for r in R_all])
    rot_interp = []
    for d in range(3):
        spline = make_interp_spline(s_raw, rotvec_all[:, d], k=3)
        rot_interp.append(spline(s_target))
    rot_interp = np.stack(rot_interp, axis=-1)
    return [R.from_rotvec(v).as_matrix() for v in rot_interp]


def generate_variable_speed_interpolated_camera_list(
    origin_viewpoints: List[Camera],
    interp_multiplier: int,
    speed_profile: Union[List[float], Callable[[float], float]],
    alpha: float = 1.0,
    beta: float = 1.0,
) -> List[Camera]:
    """Temporally‑uniform, velocity‑aware pose interpolation."""
    if len(origin_viewpoints) < 2:
        raise ValueError("Need ≥2 camera poses for interpolation.")
    target_frame_count = len(origin_viewpoints) * interp_multiplier

    # --- 1. build cumulative motion axis --------------------------------------------------
    pose_deltas = [compute_pose_delta(origin_viewpoints[i], origin_viewpoints[i + 1], alpha, beta) for i in range(len(origin_viewpoints) - 1)]
    s_raw = np.concatenate([[0.0], np.cumsum(pose_deltas)])  # shape (N,)

    # --- 2. sample / normalise speed profile ----------------------------------------------
    t_uniform = np.linspace(0.0, 1.0, target_frame_count)
    if isinstance(speed_profile, list):
        speed_profile = make_interp_spline(np.linspace(0, 1, len(speed_profile)), speed_profile, k=3)
    speeds = np.array([speed_profile(t) for t in t_uniform])
    speeds = np.maximum(speeds, 1e-6)  # avoid zeros

    # integral to cumulative motion, scaled to match original length
    ds = speeds / speeds.sum() * s_raw[-1]
    s_target = np.insert(np.cumsum(ds)[:-1], 0, 0.0)  # length = M

    # --- 3. interpolate translation & rotation -------------------------------------------
    T_all = np.stack([cam.T.detach().cpu().numpy() for cam in origin_viewpoints])  # (N,3)
    T_interp = _interpolate_translation(s_raw, T_all, s_target)                    # (M,3)

    R_all = [R.from_matrix(cam.R.detach().cpu().numpy()) for cam in origin_viewpoints]
    R_interp = _interpolate_rotation(s_raw, R_all, s_target)                       # list(M)

    # --- 4. assemble new Camera list ------------------------------------------------------
    ref = origin_viewpoints[0]
    out = []
    for i, (Ri, Ti) in enumerate(zip(R_interp, T_interp)):
        cam = Camera(
            resolution=(ref.image_width, ref.image_height),
            colmap_id=-1,
            R=torch.tensor(Ri, dtype=torch.float32),
            T=torch.tensor(Ti, dtype=torch.float32),
            FoVx=ref.FoVx,
            FoVy=ref.FoVy,
            depth_params=None,
            image=ref.image_pil,
            invdepthmap=None,
            image_name=f"interp_{i:04d}",
            uid=i,
            trans=ref.trans,
            scale=ref.scale,
            data_device=str(ref.data_device),
            train_test_exp=False,
            is_test_dataset=False,
            is_test_view=False,
        )
        cam.is_interp = True
        cam.focal_x = 0.5 * ref.image_width / np.tan(0.5 * ref.FoVx)
        cam.focal_y = 0.5 * ref.image_height / np.tan(0.5 * ref.FoVy)
        cam.principal_x = ref.image_width / 2.0
        cam.principal_y = ref.image_height / 2.0
        out.append(cam)
    return out

##############################################################
# 2.  Novel‑View Trajectory Generation Module                #
##############################################################

def _bezier_interp(values: np.ndarray, num: int) -> np.ndarray:
    """Bernstein‑polynomial Bezier interpolation for 1‑D array values (len = K)."""
    K = len(values) - 1
    t = np.linspace(0.0, 1.0, num)
    bern = [scipy.special.comb(K, i) * t ** i * (1 - t) ** (K - i) for i in range(K + 1)]
    bern = np.stack(bern, axis=-1)  # shape (num,K+1)
    return bern @ values  # (num,)


def _interp_pose_sequence(key_cams: List[Camera], num_frames: int, method: str = "bspline", order: int = 3) -> List[Camera]:
    """Interpolate a mini‑trajectory from a few key camera poses."""
    K = len(key_cams)
    if K < 2:
        raise ValueError("Need ≥2 keyframes to interpolate.")

    # build s_raw as cumulative distances between keyframes
    pose_deltas = [compute_pose_delta(key_cams[i], key_cams[i + 1]) for i in range(K - 1)]
    s_raw = np.concatenate([[0.0], np.cumsum(pose_deltas)])

    # target positions uniformly in [0, s_last]
    s_target = np.linspace(0.0, s_raw[-1], num_frames)

    # --- translation --------------------------------------------------
    T_all = np.stack([cam.T.detach().cpu().numpy() for cam in key_cams])  # (K,3)
    if method == "bspline":
        T_interp = _interpolate_translation(s_raw, T_all, s_target)
    else:  # bezier
        T_interp = []
        for d in range(3):
            T_interp.append(_bezier_interp(T_all[:, d], num_frames))
        T_interp = np.stack(T_interp, axis=-1)

    # --- rotation -----------------------------------------------------
    R_key = [R.from_matrix(cam.R.detach().cpu().numpy()) for cam in key_cams]
    if method == "bspline":
        R_interp = _interpolate_rotation(s_raw, R_key, s_target)
    else:  # bezier in rot‑vec space
        rotvec_key = np.array([r.as_rotvec() for r in R_key])  # (K,3)
        rotvec_interp = []
        for d in range(3):
            rotvec_interp.append(_bezier_interp(rotvec_key[:, d], num_frames))
        rotvec_interp = np.stack(rotvec_interp, axis=-1)
        R_interp = [R.from_rotvec(v).as_matrix() for v in rotvec_interp]

    # --- build cameras ------------------------------------------------
    ref = key_cams[0]
    cams = []
    for i, (Ri, Ti) in enumerate(zip(R_interp, T_interp)):
        cam = Camera(
            resolution=(ref.image_width, ref.image_height),
            colmap_id=-1,
            R=torch.tensor(Ri, dtype=torch.float32),
            T=torch.tensor(Ti, dtype=torch.float32),
            FoVx=ref.FoVx,
            FoVy=ref.FoVy,
            depth_params=None,
            image=ref.image_pil,
            invdepthmap=None,
            image_name=f"novel_{i:04d}",
            uid=i,
            trans=ref.trans,
            scale=ref.scale,
            data_device=str(ref.data_device),
            train_test_exp=False,
            is_test_dataset=False,
            is_test_view=True,
        )
        cam.focal_x = 0.5 * ref.image_width / np.tan(0.5 * ref.FoVx)
        cam.focal_y = 0.5 * ref.image_height / np.tan(0.5 * ref.FoVy)
        cam.principal_x = ref.image_width / 2.0
        cam.principal_y = ref.image_height / 2.0

        cams.append(cam)
    return cams


def generate_novel_view_groups(
    source_viewpoints: List[Camera],
    num_groups: int = 5,
    max_keyframes: int = 5,
    num_frames: int = 150,
) -> List[List[Camera]]:
    """Generate multiple novel‑view pose sequences.

    Args:
        interp_viewpoints: dense pose list (after speed control).
        num_groups: how many novel‑view trajectories to output.
        max_keyframes: sample ≤ this many random keyframes for each group.
        traj_length: number of frames for each generated mini‑trajectory.
    Returns:
        List of camera‑list trajectories.
    """
    groups = []
    N = len(source_viewpoints)
    for g in range(num_groups):
        # k = random.randint(2, max_keyframes)  # number of keyframes
        k = max_keyframes
        idx = np.sort(random.sample(range(N), k)).tolist()
        key_cams = [source_viewpoints[i] for i in idx]

        method = random.choice(["bspline", "bezier"])
        order = random.randint(2, 5)
        traj = _interp_pose_sequence(key_cams, num_frames, method=method, order=order)
        groups.append(traj)
    return groups
