import math
import numpy as np
import torch
import torch.nn.functional as F
from typing import Tuple
from utils.stepfun import sample_np, sample
import scipy


def quad2rotation(q):
    """
    Convert quaternion to rotation in batch. Since all operation in pytorch, support gradient passing.

    Args:
        quad (tensor, batch_size*4): quaternion.

    Returns:
        rot_mat (tensor, batch_size*3*3): rotation.
    """
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    if not isinstance(q, torch.Tensor):
        q = torch.tensor(q).cuda()

    norm = torch.sqrt(
        q[:, 0] * q[:, 0] + q[:, 1] * q[:, 1] + q[:, 2] * q[:, 2] + q[:, 3] * q[:, 3]
    )
    q = q / norm[:, None]
    rot = torch.zeros((q.size(0), 3, 3)).to(q)
    r = q[:, 0]
    x = q[:, 1]
    y = q[:, 2]
    z = q[:, 3]
    rot[:, 0, 0] = 1 - 2 * (y * y + z * z)
    rot[:, 0, 1] = 2 * (x * y - r * z)
    rot[:, 0, 2] = 2 * (x * z + r * y)
    rot[:, 1, 0] = 2 * (x * y + r * z)
    rot[:, 1, 1] = 1 - 2 * (x * x + z * z)
    rot[:, 1, 2] = 2 * (y * z - r * x)
    rot[:, 2, 0] = 2 * (x * z - r * y)
    rot[:, 2, 1] = 2 * (y * z + r * x)
    rot[:, 2, 2] = 1 - 2 * (x * x + y * y)
    return rot

def get_camera_from_tensor(inputs):
    """
    Convert quaternion and translation to transformation matrix.

    """
    if not isinstance(inputs, torch.Tensor):
        inputs = torch.tensor(inputs).cuda()

    N = len(inputs.shape)
    if N == 1:
        inputs = inputs.unsqueeze(0)
    
    
    
    
    
    
    
    
    
    
    

    quad, T = inputs[:, :4], inputs[:, 4:]
    w2c = torch.eye(4).to(inputs).float()
    w2c[:3, :3] = quad2rotation(quad)
    w2c[:3, 3] = T
    return w2c

def quadmultiply(q1, q2):
    """
    Multiply two quaternions together using quaternion arithmetic
    """
    
    w1, x1, y1, z1 = q1.unbind(dim=-1)
    w2, x2, y2, z2 = q2.unbind(dim=-1)
    
    result_quaternion = torch.stack(
        [
            w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
            w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
            w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
            w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
        ],
        dim=-1,
    )

    return result_quaternion

def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
    """
    Returns torch.sqrt(torch.max(0, x))
    but with a zero subgradient where x is 0.
    Source: https:
    """
    ret = torch.zeros_like(x)
    positive_mask = x > 0
    ret[positive_mask] = torch.sqrt(x[positive_mask])
    return ret

def rotation2quad(matrix: torch.Tensor) -> torch.Tensor:
    """
    Convert rotations given as rotation matrices to quaternions.

    Args:
        matrix: Rotation matrices as tensor of shape (..., 3, 3).

    Returns:
        quaternions with real part first, as tensor of shape (..., 4).
    Source: https:
    """
    if matrix.size(-1) != 3 or matrix.size(-2) != 3:
        raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")

    if not isinstance(matrix, torch.Tensor):
        matrix = torch.tensor(matrix).cuda()

    batch_dim = matrix.shape[:-2]
    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
        matrix.reshape(batch_dim + (9,)), dim=-1
    )

    q_abs = _sqrt_positive_part(
        torch.stack(
            [
                1.0 + m00 + m11 + m22,
                1.0 + m00 - m11 - m22,
                1.0 - m00 + m11 - m22,
                1.0 - m00 - m11 + m22,
            ],
            dim=-1,
        )
    )

    
    quat_by_rijk = torch.stack(
        [
            
            
            torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
            
            
            torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
            
            
            torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
            
            
            torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
        ],
        dim=-2,
    )

    
    
    flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))

    
    

    return quat_candidates[
        F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
    ].reshape(batch_dim + (4,))


def get_tensor_from_camera(RT, Tquad=False):
    """
    Convert transformation matrix to quaternion and translation.

    """
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    

    if not isinstance(RT, torch.Tensor):
        RT = torch.tensor(RT).cuda()

    rot = RT[:3, :3].unsqueeze(0).detach()
    quat = rotation2quad(rot).squeeze()
    tran = RT[:3, 3].detach()

    return torch.cat([quat, tran])

def normalize(x):
    return x / np.linalg.norm(x)


def viewmatrix(lookdir, up, position, subtract_position=False):
  """Construct lookat view matrix."""
  vec2 = normalize((lookdir - position) if subtract_position else lookdir)
  vec0 = normalize(np.cross(up, vec2))
  vec1 = normalize(np.cross(vec2, vec0))
  m = np.stack([vec0, vec1, vec2, position], axis=1)
  return m


def poses_avg(poses):
  """New pose using average position, z-axis, and up vector of input poses."""
  position = poses[:, :3, 3].mean(0)
  z_axis = poses[:, :3, 2].mean(0)
  up = poses[:, :3, 1].mean(0)
  cam2world = viewmatrix(z_axis, up, position)
  return cam2world


def focus_point_fn(poses):
    """Calculate nearest point to all focal axes in poses."""
    directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
    m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
    mt_m = np.transpose(m, [0, 2, 1]) @ m
    focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
    return focus_pt


def pad_poses(p):
    """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
    bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
    return np.concatenate([p[..., :3, :4], bottom], axis=-2)

def unpad_poses(p):
    """Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
    return p[..., :3, :4]

def transform_poses_pca(poses):
    """Transforms poses so principal components lie on XYZ axes.

  Args:
    poses: a (N, 3, 4) array containing the cameras' camera to world transforms.

  Returns:
    A tuple (poses, transform), with the transformed poses and the applied
    camera_to_world transforms.
  """
    t = poses[:, :3, 3]
    t_mean = t.mean(axis=0)
    t = t - t_mean

    eigval, eigvec = np.linalg.eig(t.T @ t)
    
    inds = np.argsort(eigval)[::-1]
    eigvec = eigvec[:, inds]
    rot = eigvec.T
    if np.linalg.det(rot) < 0:
        rot = np.diag(np.array([1, 1, -1])) @ rot

    transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
    poses_recentered = unpad_poses(transform @ pad_poses(poses))
    transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)

    
    if poses_recentered.mean(axis=0)[2, 1] < 0:
        poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
        transform = np.diag(np.array([1, -1, -1, 1])) @ transform

    
    scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3]))
    poses_recentered[:, :3, 3] *= scale_factor
    transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
    return poses_recentered, transform


def recenter_poses(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  """Recenter poses around the origin."""
  cam2world = poses_avg(poses)
  transform = np.linalg.inv(pad_poses(cam2world))
  poses = transform @ pad_poses(poses)
  return unpad_poses(poses), transform

def generate_ellipse_path(views, n_frames=600, const_speed=True, z_variation=0., z_phase=0.):
    poses = []
    for view in views:
        tmp_view = np.eye(4)
        tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1)
        tmp_view = np.linalg.inv(tmp_view)
        tmp_view[:, 1:3] *= -1
        poses.append(tmp_view)
    poses = np.stack(poses, 0)
    poses, transform = transform_poses_pca(poses)


    
    center = focus_point_fn(poses)
    
    offset = np.array([center[0] , center[1],  0 ])
    
    sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)

    
    low = -sc + offset
    high = sc + offset
    
    z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
    z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)


    def get_positions(theta):
        
        
        return np.stack([
            (low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5)),
            (low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5)),
            z_variation * (z_low[2] + (z_high - z_low)[2] *
                           (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)),
        ], -1)

    theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True)
    positions = get_positions(theta)

    if const_speed:
        
        lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
        theta = sample_np(None, theta, np.log(lengths), n_frames + 1)
        positions = get_positions(theta)

    
    positions = positions[:-1]

    
    avg_up = poses[:, :3, 1].mean(0)
    avg_up = avg_up / np.linalg.norm(avg_up)
    ind_up = np.argmax(np.abs(avg_up))
    up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
    

    render_poses = []
    for p in positions:
        render_pose = np.eye(4)
        render_pose[:3] = viewmatrix(p - center, up, p)
        render_pose = np.linalg.inv(transform) @ render_pose
        render_pose[:3, 1:3] *= -1
        render_poses.append(np.linalg.inv(render_pose))
    return render_poses



def generate_spiral_path(poses_arr,
                         n_frames: int = 180,
                         n_rots: int = 2,
                         zrate: float = .5) -> np.ndarray:
  """Calculates a forward facing spiral path for rendering."""
  poses = poses_arr[:, :-2].reshape([-1, 3, 5])
  bounds = poses_arr[:, -2:]
  fix_rotation = np.array([
      [0, -1, 0, 0],
      [1, 0, 0, 0],
      [0, 0, 1, 0],
      [0, 0, 0, 1],
  ], dtype=np.float32)
  poses = poses[:, :3, :4] @ fix_rotation

  scale = 1. / (bounds.min() * .75)
  poses[:, :3, 3] *= scale
  bounds *= scale
  poses, transform = recenter_poses(poses)

  close_depth, inf_depth = bounds.min() * .9, bounds.max() * 5.
  dt = .75
  focal = 1 / (((1 - dt) / close_depth + dt / inf_depth))

  
  positions = poses[:, :3, 3]
  radii = np.percentile(np.abs(positions), 90, 0)
  radii = np.concatenate([radii, [1.]])

  
  render_poses = []
  cam2world = poses_avg(poses)
  up = poses[:, :3, 1].mean(0)
  for theta in np.linspace(0., 2. * np.pi * n_rots, n_frames, endpoint=False):
    t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]
    position = cam2world @ t
    lookat = cam2world @ [0, 0, -focal, 1.]
    z_axis = position - lookat
    render_pose = np.eye(4)
    render_pose[:3] = viewmatrix(z_axis, up, position)
    render_pose = np.linalg.inv(transform) @ render_pose
    render_pose[:3, 1:3] *= -1
    render_pose[:3, 3] /= scale
    render_poses.append(np.linalg.inv(render_pose))
  render_poses = np.stack(render_poses, axis=0)
  return render_poses



def generate_interpolated_path(
    views,
    n_interp,
    spline_degree = 5,
    smoothness = 0.03,
    rot_weight = 0.1,
    lock_up = False,
    fixed_up_vector = None,
    lookahead_i = None,
    frames_per_colmap = None,
    const_speed = False,
    n_buffer = None,
    periodic = False,
    n_interp_as_total = False,
):
  """Creates a smooth spline path between input keyframe camera poses.

  Spline is calculated with poses in format (position, lookat-point, up-point).
  Args:
    poses: (n, 3, 4) array of input pose keyframes.
    n_interp: returned path will have n_interp * (n - 1) total poses.
    spline_degree: polynomial degree of B-spline.
    smoothness: parameter for spline smoothing, 0 forces exact interpolation.
    rot_weight: relative weighting of rotation/translation in spline solve.
    lock_up: if True, forced to use given Up and allow Lookat to vary.
    fixed_up_vector: replace the interpolated `up` with a fixed vector.
    lookahead_i: force the look direction to look at the pose `i` frames ahead.
    frames_per_colmap: conversion factor for the desired average velocity.
    const_speed: renormalize spline to have constant delta between each pose.
    n_buffer: Number of buffer frames to insert at the start and end of the
      path. Helps keep the ends of a spline path straight.
    periodic: make the spline path periodic (perfect loop).
    n_interp_as_total: use n_interp as total number of poses in path rather than
      the number of poses to interpolate between each input.

  Returns:
    Array of new camera poses with shape (n_interp * (n - 1), 3, 4), or
    (n_interp, 3, 4) if n_interp_as_total is set.
  """
  poses = []
  for view in views:
    tmp_view = np.eye(4)
    tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1)
    tmp_view = np.linalg.inv(tmp_view)
    tmp_view[:, 1:3] *= -1
    poses.append(tmp_view)
  poses = np.stack(poses, 0)

  def poses_to_points(poses, dist):
    """Converts from pose matrices to (position, lookat, up) format."""
    pos = poses[:, :3, -1]
    lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
    up = poses[:, :3, -1] + dist * poses[:, :3, 1]
    return np.stack([pos, lookat, up], 1)

  def points_to_poses(points):
    """Converts from (position, lookat, up) format to pose matrices."""
    poses = []
    for i in range(len(points)):
      pos, lookat_point, up_point = points[i]
      if lookahead_i is not None:
        if i + lookahead_i < len(points):
          lookat = pos - points[i + lookahead_i][0]
      else:
        lookat = pos - lookat_point
      up = (up_point - pos) if fixed_up_vector is None else fixed_up_vector
      poses.append(viewmatrix(lookat, up, pos))
    return np.array(poses)

  def insert_buffer_poses(poses, n_buffer):
    """Insert extra poses at the start and end of the path."""

    def average_distance(points):
      distances = np.linalg.norm(points[1:] - points[0:-1], axis=-1)
      return np.mean(distances)

    def shift(pose, dz):
      result = np.copy(pose)
      z = result[:3, 2]
      z /= np.linalg.norm(z)
      
      result[:3, 3] += z * dz
      return result

    dz = average_distance(poses[:, :3, 3])
    prefix = np.stack([shift(poses[0], (i + 1) * dz) for i in range(n_buffer)])
    prefix = prefix[::-1]  
    suffix = np.stack(
        [shift(poses[-1], -(i + 1) * dz) for i in range(n_buffer)]
    )
    result = np.concatenate([prefix, poses, suffix])
    return result

  def remove_buffer_poses(poses, u, n_frames, u_keyframes, n_buffer):
    u_keyframes = u_keyframes[n_buffer:-n_buffer]
    mask = (u >= u_keyframes[0]) & (u <= u_keyframes[-1])
    poses = poses[mask]
    u = u[mask]
    n_frames = len(poses)
    return poses, u, n_frames, u_keyframes

  def interp(points, u, k, s):
    """Runs multidimensional B-spline interpolation on the input points."""
    sh = points.shape
    pts = np.reshape(points, (sh[0], -1))
    k = min(k, sh[0] - 1)
    tck, u_keyframes = scipy.interpolate.splprep(pts.T, k=k, s=s, per=periodic)
    new_points = np.array(scipy.interpolate.splev(u, tck))
    new_points = np.reshape(new_points.T, (len(u), sh[1], sh[2]))
    return new_points, u_keyframes

  
  if n_buffer is not None:
    poses = insert_buffer_poses(poses, n_buffer)
  points = poses_to_points(poses, dist=rot_weight)
  if n_interp_as_total:
    n_frames = n_interp + 1  
  else:
    n_frames = n_interp * (points.shape[0] - 1)
  u = np.linspace(0, 1, n_frames, endpoint=True)
  new_points, u_keyframes = interp(points, u=u, k=spline_degree, s=smoothness)
  poses = points_to_poses(new_points)
  if n_buffer is not None:
    poses, u, n_frames, u_keyframes = remove_buffer_poses(
        poses, u, n_frames, u_keyframes, n_buffer
    )
    
  if frames_per_colmap is not None:
    
    positions = poses[:, :3, -1]
    lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
    total_length_colmap = lengths.sum()
    print('old n_frames:', n_frames)
    print('total_length_colmap:', total_length_colmap)
    n_frames = int(total_length_colmap * frames_per_colmap)
    print('new n_frames:', n_frames)
    u = np.linspace(
        np.min(u_keyframes), np.max(u_keyframes), n_frames, endpoint=True
    )
    new_points, _ = interp(points, u=u, k=spline_degree, s=smoothness)
    poses = points_to_poses(new_points)

  if const_speed:
    
    positions = poses[:, :3, -1]
    lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
    u = sample(None, u, np.log(lengths), n_frames + 1)
    new_points, _ = interp(points, u=u, k=spline_degree, s=smoothness)
    poses = points_to_poses(new_points)


  return poses[:-1]
