import numpy as np
import scipy
from utils.camera_utils import viewmatrix, normalize, pad_poses, unpad_poses, poses_to_points, points_to_poses, interp
import scipy

def generate_interpolated_path(poses, n_interp, method="spline", 
                               spline_degree=5, smoothness=.03, rot_weight=.1, 
                               ellipse_factor=(0.7, 0.3)):
    """
     pose  method，
    。（）

    ：
      poses: (n, 3, 4) ， pose  (rotation, translation) ；
              poses[i][:3, 2] “”（）。
      n_interp:  circle / zoom_in / zoom_out ， n_interp ；
                 spline ， n_interp*(n-1) 。
      method: ，
                "circle"   —— （12）
                "zoom_in"  —— zoom in 
                "zoom_out" —— zoom out 
                 spline 
      spline_degree: B-spline （ spline ）
      smoothness:    B-spline ，0（ spline ）
      rot_weight:    spline ，，
                     circle 
                     zoom ，、
      ellipse_factor: (a_factor, b_factor) ：
        a_factor：（），。
        b_factor：（），  =  × b_factor

    ：
      (num_pose, 3, 4) 
    """
    
    p0 = poses[0][:3, -1]
    p1 = poses[1][:3, -1]
    
    
    view0 = -poses[0][:3, 2]
    view1 = -poses[1][:3, 2]
    viewdir = normalize((view0 + view1) / 2.0)

    
    orig_up = normalize((poses[0][:3, 1] + poses[1][:3, 1]) / 2.0)
    up_mod = orig_up - np.dot(orig_up, viewdir) * viewdir
    up_mod = normalize(up_mod)

    if method == "circle":
        return generate_circle_trajectory(poses, n_interp, viewdir, up_mod, rot_weight)
    elif method == "zoom_in_out":
        return generate_zoom_trajectory(poses, n_interp, rot_weight, viewdir, up_mod)
    elif method == "falling_leaf":
        return generate_falling_leaf_trajectory(poses, n_interp, rot_weight, viewdir, up_mod)
    elif method == "ellipse":
        return generate_ellipse_trajectory(poses, n_interp, rot_weight, viewdir, up_mod, ellipse_factor)
    elif method == "spiral":
        return generate_spiral_trajectory(poses, n_interp, rot_weight, viewdir, up_mod)
    elif method == "parabolic":
        return generate_parabolic_trajectory(poses, n_interp, rot_weight, viewdir, up_mod)
    else:
        return generate_spline_trajectory(poses, n_interp, rot_weight, spline_degree, smoothness)

def generate_circle_trajectory(poses, n_interp, viewdir, up_mod, rot_weight, num_loops=7):
    """
    ：
      - ，
      - ，
      - ，
      - ，
      -  viewdir。
      
     num_loops 。
    """
    if poses.shape[0] < 2:
        raise ValueError("circle trajectory requires at least two keyframe poses")
    
    
    p0 = poses[0][:3, -1]
    p1 = poses[1][:3, -1]
    view0 = -poses[0][:3, 2]
    view1 = -poses[1][:3, 2]

    
    forward_offset = rot_weight * 0.1
    radius_ratio = 0.3

    
    p0_forward = p0 + forward_offset * view0
    p1_forward = p1 + forward_offset * view1

    
    center = (p0_forward + p1_forward) / 2.0

    
    def project_to_plane(point, plane_point, plane_normal):
        return plane_point + (point - plane_point) - np.dot(point - plane_point, plane_normal) * plane_normal

    
    p0_proj = project_to_plane(p0_forward, center, viewdir)
    p1_proj = project_to_plane(p1_forward, center, viewdir)

    
    r0 = np.linalg.norm(p0_proj - center)
    r1 = np.linalg.norm(p1_proj - center)
    if r0 < 1e-6 or r1 < 1e-6:
        
        radius = np.linalg.norm(p1 - p0) / 2.0
    else:
        radius = (r0 + r1) / 2.0

    
    if np.linalg.norm(p0_proj - center) < 1e-6:
        u = normalize(np.cross(viewdir, np.array([1, 0, 0])))
        if np.linalg.norm(u) < 1e-6:
            u = normalize(np.cross(viewdir, np.array([0, 1, 0])))
    else:
        u = normalize(p0_proj - center)
    
    v = normalize(np.cross(viewdir, u))

    
    vec_p1 = p1_proj - center
    theta1 = np.arctan2(np.dot(vec_p1, v), np.dot(vec_p1, u))

    
    angles = np.linspace(0, theta1 + 2 * np.pi * num_loops, n_interp)

    poses_list = []
    for theta in angles:
        
        pos = center + radius_ratio * radius * (np.cos(theta) * u + np.sin(theta) * v)
        
        lookat = pos - viewdir
        cam_matrix = viewmatrix(lookat - pos, up_mod, pos)
        poses_list.append(cam_matrix)

    return np.stack(poses_list, axis=0)

def generate_zoom_trajectory(poses, n_interp, rot_weight, viewdir, up_mod):
    """Generate zoom in/out camera trajectory"""
    p0 = poses[0][:3, -1]
    p1 = poses[1][:3, -1]
    center = (p0 + p1) / 2.0
    target = center + 0.15*rot_weight * viewdir

    poses_list = []
    n1 = n_interp 
    n2 = n_interp - n1

    for t in np.linspace(0, 1, n1, endpoint=False):
        pos = (1 - t) * p0 + t * target
        lookat = pos - viewdir
        cam_matrix = viewmatrix(lookat - pos, up_mod, pos)
        poses_list.append(cam_matrix)

    for t in np.linspace(0, 1, n2, endpoint=True):
        pos = (1 - t) * target + t * p1
        lookat = pos - viewdir
        cam_matrix = viewmatrix(lookat - pos, up_mod, pos)
        poses_list.append(cam_matrix)
            
    return np.stack(poses_list, axis=0)

def generate_falling_leaf_trajectory(poses, n_interp, rot_weight, viewdir, up_mod):
    """Generate a falling leaf-like camera trajectory with correct forward movement"""
    if poses.shape[0] < 2:
        raise ValueError(" 'falling_leaf'  poses。")
    
    
    p0 = poses[0][:3, -1]
    p1 = poses[1][:3, -1]
    mid_point = (p0 + p1) / 2.0
    
    
    right_dir = np.cross(up_mod, viewdir)  
    right_dir = normalize(right_dir)
    
    
    keyframe_distance = np.linalg.norm(p1 - p0)
    amplitude = keyframe_distance * rot_weight * 3 * 0.07  

    
    forward_length = keyframe_distance * rot_weight * 6 * 0.06
    
    
    phases = np.linspace(0, 2*np.pi*2.5, n_interp)
    
    
    t = np.linspace(0, 1, n_interp)
    forward_displacement = t * forward_length

    
    positions = (
        mid_point 
        + forward_displacement[:, None] * viewdir  
        + amplitude * np.sin(phases)[:, None] * right_dir  
    )

    
    look_distance = keyframe_distance * 10.0  
    lookat_point = mid_point - viewdir * look_distance
    
    
    poses_list = []
    for pos in positions:
        
        cam_matrix = viewmatrix(lookat_point - pos, up_mod, pos)
        poses_list.append(cam_matrix)
    
    return np.stack(poses_list, axis=0)

def generate_ellipse_trajectory(poses, n_interp, rot_weight, viewdir, up_mod, ellipse_factor):
    """Generate elliptical camera trajectory where each keyframe moves along their view direction."""
    if poses.shape[0] < 2:
        raise ValueError(" 'ellipse'  poses。")
    
    
    p0_orig = poses[0][:3, -1]
    p1_orig = poses[1][:3, -1]
    
    
    view0 = -poses[0][:3, 2]
    view1 = -poses[1][:3, 2]
    
    
    p0_forward = p0_orig + 0.2 *rot_weight * view0
    p1_forward = p1_orig + 0.2 * rot_weight * view1
    
    
    center = (p0_forward + p1_forward) / 2.0
    
    
    axis_major = p1_forward - p0_forward
    axis_major_length = np.linalg.norm(axis_major)
    if axis_major_length < 1e-6:
        raise ValueError("，。")
    
    a = (axis_major_length / 2.0) * ellipse_factor[0]  
    axis_major_dir = normalize(axis_major)  

    
    axis_minor_dir = np.cross(viewdir, axis_major_dir)
    axis_minor_dir = normalize(axis_minor_dir)
    b = a * ellipse_factor[1]  
    
    
    fixed_lookat = center - viewdir * 1
    
    
    poses_list = []
    angles = np.linspace(0, 4* 2 * np.pi, n_interp, endpoint=False)
    for ang in angles:
        pos = center + a * np.cos(ang) * axis_major_dir + b * np.sin(ang) * axis_minor_dir
        look_dir = fixed_lookat - pos  
        cam_matrix = viewmatrix(look_dir, up_mod, pos)
        poses_list.append(cam_matrix)
    
    return np.stack(poses_list, axis=0)

def generate_spiral_trajectory(poses, n_interp, rot_weight, viewdir, up_mod):
    """，"""
    if len(poses) < 2:
        raise ValueError("")

    
    p0 = poses[0][:3, -1]
    p1 = poses[1][:3, -1]
    mid_point = (p0 + p1) / 2.0

    
    vec0_proj = (p0 - mid_point) - np.dot(p0 - mid_point, viewdir)*viewdir
    vec1_proj = (p1 - mid_point) - np.dot(p1 - mid_point, viewdir)*viewdir
    base_radius = max((np.linalg.norm(vec0_proj) + np.linalg.norm(vec1_proj))/2.0, 0.1 * rot_weight)

    
    if np.linalg.norm(vec0_proj) > 1e-3:
        x_axis = normalize(vec0_proj)
    else:  
        x_axis = normalize(np.cross(up_mod, viewdir))
    y_axis = normalize(np.cross(viewdir, x_axis))

    
    spiral_params = {
        'start_offset' : rot_weight * 0.05,  
        'total_height' : rot_weight * 0.2,   
        'num_rotations': 9,                 
        
        'radius_curve' : lambda t: base_radius * (0.12 - 0.09 * t)
    }

    
    t_vals = np.linspace(0, 0.8, n_interp)
    heights = spiral_params['start_offset'] + t_vals * spiral_params['total_height']
    angles = 2 * np.pi * spiral_params['num_rotations'] * t_vals

    
    focus_point = mid_point - viewdir * (spiral_params['total_height'] + 5.0)

    
    poses_list = []
    for i in range(n_interp):
        
        current_radius = spiral_params['radius_curve'](t_vals[i])
        current_height = heights[i]
        
        
        spiral_x = current_radius * np.cos(angles[i])
        spiral_y = current_radius * np.sin(angles[i])
        
        
        position = mid_point + viewdir*current_height + spiral_x*x_axis + spiral_y*y_axis
        
        
        cam_matrix = viewmatrix(focus_point - position, up_mod, position)
        poses_list.append(cam_matrix)
    
    return np.stack(poses_list, axis=0)

def generate_parabolic_trajectory(poses, n_interp, rot_weight, viewdir, up_mod):
    """Generate parabolic camera trajectory"""
    p0 = poses[0][:3, -1]
    p1 = poses[1][:3, -1]
    control_point = (p0 + p1) / 2.0 + 0.5 * viewdir

    poses_list = []
    for t in np.linspace(0, 1, n_interp):
        pos = (1 - t) ** 2 * p0 + 2 * (1 - t) * t * control_point + t ** 2 * p1
        lookat = pos - viewdir
        cam_matrix = viewmatrix(lookat - pos, up_mod, pos)
        poses_list.append(cam_matrix)
    return np.stack(poses_list, axis=0)

def generate_spline_trajectory(poses, n_interp, rot_weight, spline_degree, smoothness):
    """Generate spline-based camera trajectory"""
    points = poses_to_points(poses, dist=rot_weight)
    new_points = interp(points,
                        n_interp * (points.shape[0] - 1),
                        k=spline_degree,
                        s=smoothness)
    return points_to_poses(new_points)