# Original code from nerfmm: https://github.com/ActiveVisionLab/nerfmm/blob/main/utils/lie_group_helper.py

import numpy as np
import torch
from scipy.spatial.transform import Rotation as RotLib

def replace_repeated_indices_with_random(selected_points_index, all_index_number):
    # Ensure the input is a 1D tensor
    selected_points_index = selected_points_index.flatten()

    # Create a mask for all possible indices
    mask = torch.ones(all_index_number, dtype=torch.bool)

    # Mark the selected indices as False (used)
    mask[selected_points_index] = False

    # Get the unused indices
    unused_indices = torch.nonzero(mask).view(-1)

    # Find the unique indices and their first occurrence
    unique_indices, inverse_indices, counts = torch.unique(selected_points_index, return_inverse=True, return_counts=True)

    # Check if there are any repeats
    if len(unique_indices) < len(selected_points_index):
        # Find the repeated indices
        repeated_indices_mask = counts > 1
        repeated_indices = unique_indices[repeated_indices_mask]

        # Shuffle the unused indices to randomly replace repeats
        unused_indices = unused_indices[torch.randperm(len(unused_indices))]

        # Replace the repeated indices with random unused indices
        for index in repeated_indices:
            # Find the positions of the repeats for this index
            repeat_positions = (selected_points_index == index).nonzero(as_tuple=False).view(-1)

            # Skip the first occurrence and replace the rest with unused indices
            for repeat_pos in repeat_positions[1:]:
                if len(unused_indices) == 0:
                    raise ValueError("Not enough unused indices to replace repeats.")
                selected_points_index[repeat_pos] = unused_indices[0]
                unused_indices = unused_indices[1:]  # Remove the used index

    return selected_points_index

def select_k_uniform_points(points, k):
    n = len(points)
    selected_indices = torch.randint(n, [1]).to(points.device)
    for i in range(1, k):
        distance_list = []
        for i in range(len(points)): 
            distances = torch.norm(points[i] - points[selected_indices], dim=1)
            min_distance = torch.min(distances)
            distance_list.append(min_distance)
        farthest_index = torch.argmax(torch.tensor(distance_list).to(points.device))
        selected_indices = torch.cat([selected_indices, torch.tensor([farthest_index]).to(points.device)])
    
    return selected_indices

def select_m_random_k_uniform_points(points, m, k):
    n = len(points)
    selected_indices = torch.randperm(n)[:m].to(points.device)

    for i in range(k):
        distance_list = []
        for i in range(len(points)): 
            distances = torch.norm(points[i] - points[selected_indices], dim=1)
            min_distance = torch.min(distances)
            distance_list.append(min_distance)
        farthest_index = torch.argmax(torch.tensor(distance_list).to(points.device))
        selected_indices = torch.cat([selected_indices, torch.tensor([farthest_index]).to(points.device)])
    
    return selected_indices

def select_k_uniform_points_old(points, k): 
    points = points / points.norm(dim=1, keepdim=True)

    def select_base_point_closest_to_altitude(points, target_altitude_degrees=55):
        # Convert the target altitude to radians
        target_altitude = torch.tensor(target_altitude_degrees * torch.pi / 180).to(points.device)

        # Calculate the altitude angle for each point
        # For normalized points, the altitude angle can be calculated as arccos(z)
        altitudes = torch.acos(points[:, 2])

        # Find the index of the point with the altitude closest to 45 degrees
        idx_closest = torch.argmin(torch.abs(altitudes - target_altitude))

        return points[idx_closest], idx_closest


    def generate_uniform_reference_points(base_point, k):
        # Assuming base_point is normalized and in [x, y, z] format
        altitude = torch.acos(base_point[2])  # z is the height
        azimuths = torch.linspace(0, 2 * torch.pi, steps=k).to(points.device)
        
        # Convert polar coordinates to Cartesian coordinates
        xs = torch.sin(altitude) * torch.cos(azimuths)
        ys = torch.sin(altitude) * torch.sin(azimuths)
        zs = torch.full_like(xs, base_point[2])  # Constant altitude
        
        reference_points = torch.stack([xs, ys, zs], dim=1)
        return reference_points

    def find_closest_points(original_points, reference_points):
        closest_points = []
        for ref_point in reference_points:
            # Compute squared distances to avoid square root for efficiency
            distances = torch.sum((original_points - ref_point) ** 2, dim=1)
            closest_idx = torch.argmin(distances)
            closest_points.append(closest_idx)
        
        # Convert list of tensors to a single tensor
        closest_points = torch.stack(closest_points)
        return closest_points

    # Select the base point
    base_point, base_point_idx = select_base_point_closest_to_altitude(points)
    # Generate k uniformly distributed reference points
    reference_points = generate_uniform_reference_points(base_point, k)
    # Find the closest original points to each reference point
    selected_points_index = find_closest_points(points, reference_points)
    selected_points_index = replace_repeated_indices_with_random(selected_points_index, all_index_number=len(points))
    return selected_points_index


def lookat_to_c2w(target, position, up_axis=[0,0,1]):
    up_axis = torch.tensor(up_axis, dtype=target.dtype).cuda()
    front = target - position
    right = torch.linalg.cross(front, up_axis)
    up = torch.linalg.cross(right, front)

    front = torch.nn.functional.normalize(front, dim=0)
    right = torch.nn.functional.normalize(right, dim=0)
    up = torch.nn.functional.normalize(up, dim=0)

    R = torch.vstack([right,up,-front]).T
    c2w = torch.cat([R, position.unsqueeze(1)], dim=1)  # (3, 4)
    return c2w

def SO3_to_quat(R):
    """
    :param R:  (N, 3, 3) or (3, 3) np
    :return:   (N, 4, ) or (4, ) np
    """
    x = RotLib.from_matrix(R)
    quat = x.as_quat()
    return quat


def quat_to_SO3(quat):
    """
    :param quat:    (N, 4, ) or (4, ) np
    :return:        (N, 3, 3) or (3, 3) np
    """
    x = RotLib.from_quat(quat)
    R = x.as_matrix()
    return R


def convert3x4_4x4(input):
    """
    :param input:  (N, 3, 4) or (3, 4) torch or np
    :return:       (N, 4, 4) or (4, 4) torch or np
    """
    if torch.is_tensor(input):
        if len(input.shape) == 3:
            output = torch.cat([input, torch.zeros_like(input[:, 0:1])], dim=1)  # (N, 4, 4)
            output[:, 3, 3] = 1.0
        else:
            output = torch.cat([input, torch.tensor([[0,0,0,1]], dtype=input.dtype, device=input.device)], dim=0)  # (4, 4)
    else:
        if len(input.shape) == 3:
            output = np.concatenate([input, np.zeros_like(input[:, 0:1])], axis=1)  # (N, 4, 4)
            output[:, 3, 3] = 1.0
        else:
            output = np.concatenate([input, np.array([[0,0,0,1]], dtype=input.dtype)], axis=0)  # (4, 4)
            output[3, 3] = 1.0
    return output


def vec2skew(v):
    """
    :param v:  (3, ) torch tensor
    :return:   (3, 3)
    """
    zero = torch.zeros(1, dtype=torch.float32, device=v.device)
    skew_v0 = torch.cat([ zero,    -v[2:3],   v[1:2]])  # (3, 1)
    skew_v1 = torch.cat([ v[2:3],   zero,    -v[0:1]])
    skew_v2 = torch.cat([-v[1:2],   v[0:1],   zero])
    skew_v = torch.stack([skew_v0, skew_v1, skew_v2], dim=0)  # (3, 3)
    return skew_v  # (3, 3)


def Exp(r):
    """so(3) vector to SO(3) matrix
    :param r: (3, ) axis-angle, torch tensor
    :return:  (3, 3)
    """
    skew_r = vec2skew(r)  # (3, 3)
    norm_r = r.norm() + 1e-15
    eye = torch.eye(3, dtype=torch.float32, device=r.device)
    R = eye + (torch.sin(norm_r) / norm_r) * skew_r + ((1 - torch.cos(norm_r)) / norm_r**2) * (skew_r @ skew_r)
    return R


def make_c2w(r, t):
    """
    :param r:  (3, ) axis-angle             torch tensor
    :param t:  (3, ) translation vector     torch tensor
    :return:   (4, 4)
    """
    R = Exp(r)  # (3, 3)
    c2w = torch.cat([R, t.unsqueeze(1)], dim=1)  # (3, 4)
    #c2w = convert3x4_4x4(c2w)  # (4, 4)
    return c2w

def c2w_rt(c2w): 
    """
    :param r:  (3, ) axis-angle             torch tensor
    :param t:  (3, ) translation vector     torch tensor
    :return:   (4, 4)
    """
    t = c2w[:,3]
    theta = torch.arccos((c2w[:,:3].trace()-1)/2)
    r = theta/(2*torch.sin(theta)) * torch.stack([c2w[2,1]-c2w[1,2], c2w[0,2]-c2w[2,0], c2w[1,0]-c2w[0,1]])
    return r, t

def ellip_to_c2w(ellip, radius, focus_depth): # radius 12, focus depth 1 for st giles
    """
    Inputs:
        radius: (3) radius of the  for each axis
        focus_depth: float, the depth that the spiral poses look at
        n_poses: (2) int, number of poses veritcally and horizontally

    Outputs:
        poses: (n_poses, 3, 4) the poses in the spiral path
    """

    # project the vector to sphere to constrain the angles
    ellip = ellip / ellip.norm()
    lat = torch.arccos(ellip[2])
    longit = torch.arcsin(ellip[1]/torch.sin(lat))
    
    lower_bound = torch.pi/3
    if lat > torch.pi - lower_bound: 
        lat = lower_bound - lat

    center = torch.stack([torch.cos(longit)*torch.sin(lat), torch.sin(longit)*torch.sin(lat), torch.cos(lat)], dim=0) * radius
    z = center - torch.tensor([0, 0, -focus_depth]).cuda()
    z = z / torch.norm(z)

    # compute other axes as in @average_poses
    y_ = torch.tensor([0.0, 0.0, 1.0]).cuda()  # (3)
    x = torch.cross(y_, z)  # (3)
    x = x / torch.norm(x)
    y = torch.cross(z, x)  # (3)

    pose = torch.stack([x, y, z, center], dim=1)  # (3, 4)

    return pose  # (n_poses, 3, 4)

def c2w_to_ellip(pose, radius): 
    # pose: [3, 4]
    center = pose[:,3] / radius
    lat = torch.arccos(center[2].clip(-1, 1))
    longit = torch.arcsin(center[1]/torch.sin(lat))

    ellip = torch.stack([torch.cos(longit)*torch.sin(lat), torch.sin(longit)*torch.sin(lat), torch.cos(lat)], dim=0)

    return [ellip, radius]

def identical(param, **kwargs):
    return param
