import os, sys
import numpy as np
import torch
import open3d as o3d

from . import pcd_utils


class Colors():
    red   = [0.8, 0.2, 0]
    green = [0, 0.7, 0.2]
    blue  = [0, 0, 1]
    gold  = [1, 0.706, 0]
    greenish  = [0, 0.8, 0.506]


def visualize_point_tensor(
        points_list, R, t,
        colors_list=None, 
        compute_bbox_list=None, 
        additional_pcds=[],
        exit_after=False,
        convert_to_opengl_coords=True
    ):

    assert len(points_list) == len(colors_list) == len(compute_bbox_list)

    # World frame
    referece_frame = create_frame(size=1.0)
    additional_pcds.append(referece_frame)

    # camera frame
    camera_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
            size=1.0, origin=[0, 0, 0]
    )
    camera_frame.rotate(R, pcd_utils.origin)
    camera_frame.translate(t, relative=True)
    additional_pcds.append(camera_frame)
    
    # Unit bbox
    unit_bbox = create_unit_bbox()
    additional_pcds.append(unit_bbox)

    # Go over list of numpy arrays and convert them to o3d.geometry.PointClouds 
    # (maybe also create bboxes around them)
    pcds = []
    bboxes = []
    for i, points in enumerate(points_list):
        if torch.is_tensor(points):
            points_np = points.cpu().numpy()
        elif isinstance(points, type(np.empty(0))):
            points_np = points
        
        if len(points_np.shape) == 3:
            # we then assume the first dimension is the batch_size
            points_np = points_np.squeeze(axis=0)
        
        if points_np.shape[1] > points_np.shape[0] and points_np.shape[0] == 3:
            points_np = np.moveaxis(points_np, 0, -1) # [N, 3]

        # transform to opengl coordinates
        if convert_to_opengl_coords:
            points_np = pcd_utils.transform_pointcloud_to_opengl_coords(points_np)
    
        pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points_np))
        if colors_list is not None:
            if colors_list[i] is not None:
                color_np = colors_list[i] * np.ones_like(points_np)
                pcd.colors = o3d.utility.Vector3dVector(color_np)
        pcds.append(pcd)

        if compute_bbox_list is not None:
            if compute_bbox_list[i]:
                bbox = pcd_utils.BBox(points_np)
                bboxes.append(bbox.get_bbox_as_line_set())

    # sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.05)
    # sphere = sphere.translate(np.array([0, -1, 0]), relative=True)
    # sphere.paint_uniform_color([1.0, 0.0, 0.0])
    # additional_pcds.append(sphere)

    # sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.05)
    # sphere = sphere.translate(np.array([0, 0, 1]), relative=True)
    # sphere.paint_uniform_color([1.0, 0.0, 0.0])
    # additional_pcds.append(sphere)

    # transform also additional_pcds if necessary
    if convert_to_opengl_coords:
        for additional_pcd in additional_pcds:
            additional_pcd.transform(pcd_utils.T_opengl_cv_homogeneous)

    o3d.visualization.draw_geometries([*additional_pcds, *pcds, *bboxes])

    if exit_after:
        exit()


def create_unit_bbox():
    # unit bbox
    unit_bbox = pcd_utils.BBox.compute_bbox_from_min_point_and_max_point(
        np.array([-1, -1, -1]), np.array([1, 1, 1])
    )
    return unit_bbox


def create_frame(size=1.0, origin=[0, 0, 0]):
    frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
        size=size, origin=origin
    )
    return frame


def create_lines_from_start_and_end_points(start_points, end_points, color=[201/255, 177/255, 14/255]):

    if start_points.shape[1] > start_points.shape[0] and start_points.shape[0] == 3:
        start_points = start_points.transpose()
        end_points = end_points.transpose()
        
    num_pairs = start_points.shape[0]
    all_points = np.concatenate((start_points, end_points), axis=0)

    lines       = [[i, i + num_pairs] for i in range(0, num_pairs, 1)]
    line_colors = [color for i in range(num_pairs)]
    line_set   = o3d.geometry.LineSet(
        points=o3d.utility.Vector3dVector(all_points),
        lines=o3d.utility.Vector2iVector(lines),
    )
    line_set.colors = o3d.utility.Vector3dVector(line_colors)
    
    return line_set


def create_lines_from_view_vectors(
        view_vectors_original, 
        offsets_original, 
        dist_original, 
        R, t,
        return_geoms=False,
        convert_to_opengl_coords=False
    ):
    view_vectors = np.copy(view_vectors_original)
    offsets      = np.copy(offsets_original)
    dist         = np.copy(dist_original)

    # Move coordinates to the last axis
    view_vectors = np.moveaxis(view_vectors, 0, -1) # [N, 3]
    offsets      = np.moveaxis(offsets, 0, -1)      # [N, 3]

    len_dist_shape = len(dist.shape)
    if len_dist_shape == 1:
        dist = dist[:, np.newaxis]
    else:
        dist = np.moveaxis(dist, 0, -1)             # [N, 1]

    N = offsets.shape[0] # number of points (and lines)

    # Advance along the view_vectors by a distance of "dist"
    end_points = offsets + view_vectors * dist

    # Concatenate offsets and end_points into one array
    points = np.concatenate((offsets, end_points), axis=0)

    # Compute list of edges between offsets and end_points
    lines       = [[i, i + N] for i in range(0, N, 1)]
    line_colors = [[201/255, 177/255, 14/255] for i in range(N)]
    line_set   = o3d.geometry.LineSet(
        points=o3d.utility.Vector3dVector(points),
        lines=o3d.utility.Vector2iVector(lines),
    )
    line_set.colors = o3d.utility.Vector3dVector(line_colors)

    # Offsets PointCloud
    offsets_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(offsets))
    offsets_pcd.paint_uniform_color(Colors.red)
    
    # End points PointCloud
    end_points_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(end_points))
    end_points_pcd.paint_uniform_color(Colors.green)
    
    # Concatenate PointClouds
    pcds = [offsets_pcd, end_points_pcd]

    # Convert to opengl coordinates if necessary
    if not return_geoms or convert_to_opengl_coords:
        offsets_pcd.transform(pcd_utils.T_opengl_cv_homogeneous)
        end_points_pcd.transform(pcd_utils.T_opengl_cv_homogeneous)
        line_set.transform(pcd_utils.T_opengl_cv_homogeneous)

    if return_geoms:
        return line_set, pcds 
    else:
        # camera frame
        camera_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
                size=1.0, origin=[0, 0, 0]
        )
        camera_frame.rotate(R, pcd_utils.origin)
        camera_frame.translate(t, relative=True)
        camera_frame.rotate(pcd_utils.T_opengl_cv, pcd_utils.origin) # convert to opengl coordinates for visualization
        
        o3d.visualization.draw_geometries([camera_frame, *pcds, line_set])
        exit()
      

def viz_and_exit(pcd_list):
    o3d.visualization.draw_geometries(pcd_list)
    exit()


def visualize_mesh(mesh_path):
    # world frame
    world_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
            size=1.0, origin=[0, 0, 0]
    )

    mesh = o3d.io.read_triangle_mesh(mesh_path)
    o3d.visualization.draw_geometries([world_frame, mesh])


def visualize_grid(points_list, colors=None, exit_after=True):
    # world frame
    world_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
            size=1.5, origin=[0, 0, 0]
    )
    world_frame = pcd_utils.rotate_around_axis(world_frame, axis_name="x", angle=-np.pi) 
        
    pcds = []
    for i, points in enumerate(points_list):
        pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(np.moveaxis(points, 0, -1)))
        pcd = pcd_utils.rotate_around_axis(pcd, "x", np.pi)
        if colors:
            pcd.paint_uniform_color(colors[i])
        pcds.append(pcd)
    o3d.visualization.draw_geometries([world_frame, *pcds])
    if exit_after: exit()


def visualize_sphere():
    import marching_cubes as mcubes
    from utils.sdf_utils import sphere_tsdf

    # Extract sphere with Marching cubes.
    dim = 20

    # Extract the 0-isosurface.
    X, Y, Z = np.meshgrid(np.arange(-1, 1, 2.0 / dim), np.arange(-1, 1, 2.0 / dim), np.arange(-1, 1, 2.0 / dim))
    sdf = sphere_tsdf(X, Y, Z)

    vertices, triangles = mcubes.marching_cubes(sdf, 0)

    # Convert extracted surface to o3d mesh.
    mesh_sphere = o3d.geometry.TriangleMesh(o3d.utility.Vector3dVector(vertices), o3d.utility.Vector3iVector(triangles))
    mesh_sphere.compute_vertex_normals()

    o3d.visualization.draw_geometries([mesh_sphere])


def merge_line_sets(line_sets):
  # Compute total number of vertices and faces.
  num_points = 0
  num_lines = 0
  num_line_colors = 0
  for i in range(len(line_sets)):
    num_points += np.asarray(line_sets[i].points).shape[0]
    num_lines += np.asarray(line_sets[i].lines).shape[0]
    num_line_colors += np.asarray(line_sets[i].colors).shape[0]

  # Merge points and faces.
  points = np.zeros((num_points, 3), dtype=np.float64)
  lines = np.zeros((num_lines, 2), dtype=np.int32)
  line_colors = np.zeros((num_line_colors, 3), dtype=np.float64)

  vertex_offset = 0
  line_offset = 0
  vertex_color_offset = 0
  for i in range(len(line_sets)):
    current_points = np.asarray(line_sets[i].points)
    current_lines = np.asarray(line_sets[i].lines)
    current_line_colors = np.asarray(line_sets[i].colors)

    points[vertex_offset:vertex_offset + current_points.shape[0]] = current_points
    lines[line_offset:line_offset + current_lines.shape[0]] = current_lines + vertex_offset
    line_colors[vertex_color_offset:vertex_color_offset + current_line_colors.shape[0]] = current_line_colors

    vertex_offset += current_points.shape[0]
    line_offset += current_lines.shape[0]
    vertex_color_offset += current_line_colors.shape[0]

  # Create a merged line set object.
  line_set = o3d.geometry.LineSet(o3d.utility.Vector3dVector(points), o3d.utility.Vector2iVector(lines))
  line_set.colors = o3d.utility.Vector3dVector(line_colors)
  return line_set


def merge_meshes(meshes):
  # Compute total number of vertices and faces.
  num_vertices = 0
  num_triangles = 0
  num_vertex_colors = 0
  for i in range(len(meshes)):
    num_vertices += np.asarray(meshes[i].vertices).shape[0]
    num_triangles += np.asarray(meshes[i].triangles).shape[0]
    num_vertex_colors += np.asarray(meshes[i].vertex_colors).shape[0]

  # Merge vertices and faces.
  vertices = np.zeros((num_vertices, 3), dtype=np.float64)
  triangles = np.zeros((num_triangles, 3), dtype=np.int32)
  vertex_colors = np.zeros((num_vertex_colors, 3), dtype=np.float64)

  vertex_offset = 0
  triangle_offset = 0
  vertex_color_offset = 0
  for i in range(len(meshes)):
    current_vertices = np.asarray(meshes[i].vertices)
    current_triangles = np.asarray(meshes[i].triangles)
    current_vertex_colors = np.asarray(meshes[i].vertex_colors)

    vertices[vertex_offset:vertex_offset + current_vertices.shape[0]] = current_vertices
    triangles[triangle_offset:triangle_offset + current_triangles.shape[0]] = current_triangles + vertex_offset
    vertex_colors[vertex_color_offset:vertex_color_offset + current_vertex_colors.shape[0]] = current_vertex_colors

    vertex_offset += current_vertices.shape[0]
    triangle_offset += current_triangles.shape[0]
    vertex_color_offset += current_vertex_colors.shape[0]

  # Create a merged mesh object.
  mesh = o3d.geometry.TriangleMesh(o3d.utility.Vector3dVector(vertices), o3d.utility.Vector3iVector(triangles))
  mesh.paint_uniform_color([1, 0, 0])
  mesh.vertex_colors = o3d.utility.Vector3dVector(vertex_colors)
  return mesh

##############################
###############################
##############################

from scipy.spatial.transform import Rotation as R
import trimesh

def get_color(v, vmin=0.0, vmax=1.0):
    c = np.array([1.0, 1.0, 1.0], dtype='float32')
    
    if v < vmin:
        v = vmin
    if v > vmax:
        v = vmax
    dv = vmax - vmin

    if v < (vmin + 0.25 * dv):
        c[0] = 0
        c[1] = 4 * (v - vmin) / dv
    elif v < vmin + 0.5 * dv:
        c[0] = 0
        c[2] = 1 + 4 * (vmin + 0.25 * dv - v) / dv
    elif v < vmin + 0.75 * dv:
        c[0] = 4 * (v - vmin - 0.5 * dv) / dv
        c[2] = 0
    else:
        c[1] = 1 + 4 * (vmin + 0.75 * dv - v) / dv
        c[2] = 0
    return c

def vis_error_map(vertices, faces, error_npy):
    error_min  = 0
    error_max  = 0.2
    error_dist = error_max - error_min
    num_points  = error_npy.shape[0]
    error_map  = np.ones((num_points, 3), dtype='float32')

    mask = error_npy < error_min + 0.25 * error_dist
    error_map[mask, 0] = np.zeros((mask.sum()), dtype='float32')
    error_map[mask, 1] = 4 * (error_npy[mask] - error_min) / error_dist
    
    mask = (error_npy >= error_min + 0.25 * error_dist) & (error_npy < error_min + 0.5 * error_dist)
    error_map[mask, 0] = np.zeros((mask.sum()), dtype='float32')
    error_map[mask, 2] = 1 + 4 * (error_min + 0.25 * error_dist - error_npy[mask]) / error_dist
    
    mask = (error_npy >= error_min + 0.5 * error_dist) & (error_npy < error_min + 0.75 * error_dist)
    error_map[mask, 0] = 4 * (error_npy[mask] - error_min - 0.5 * error_dist) / error_dist
    error_map[mask, 2] = np.zeros((mask.sum()), dtype='float32')
    
    mask = error_npy >= error_min + 0.75 * error_dist
    error_map[mask, 1] = 1 + 4 * (error_min + 0.75 * error_dist - error_npy[mask]) / error_dist
    error_map[mask, 2] = np.zeros((mask.sum()), dtype='float32')
    
    # mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=error_map, process=False)
    # mesh.visual.vertex_colors = error_map
    
    # open3d
    mesh = o3d.geometry.TriangleMesh(o3d.utility.Vector3dVector(vertices), o3d.utility.Vector3iVector(faces))
    mesh.vertex_colors = o3d.utility.Vector3dVector(error_map)
    
    return mesh
    
def vector_magnitude(vec):
    """
    Calculates a vector's magnitude.
    Args:
        - vec (): 
    """
    magnitude = np.sqrt(np.sum(vec**2))
    return(magnitude)


def calculate_zy_rotation_for_arrow(vec):
    """
    Calculates the rotations required to go from the vector vec to the 
    z axis vector of the original FOR. The first rotation that is 
    calculated is over the z axis. This will leave the vector vec on the
    XZ plane. Then, the rotation over the y axis. 

    Returns the angles of rotation over axis z and y required to
    get the vector vec into the same orientation as axis z
    of the original FOR

    Args:
        - vec (): 
    """
    # Rotation over z axis of the FOR
    gamma = np.arctan(vec[1]/vec[0])
    Rz = np.array([[np.cos(gamma),-np.sin(gamma),0],
                [np.sin(gamma),np.cos(gamma),0],
                [0,0,1]])
    # Rotate vec to calculate next rotation
    vec = Rz.T@vec.reshape(-1,1)
    vec = vec.reshape(-1)
    # Rotation over y axis of the FOR
    beta = np.arctan(vec[0]/vec[2])
    Ry = np.array([[np.cos(beta),0,np.sin(beta)],
                [0,1,0],
                [-np.sin(beta),0,np.cos(beta)]])
    return(Rz, Ry)

def create_arrow(scale=10):
    """
    Create an arrow in for Open3D
    """
    cone_height = scale*0.2
    cylinder_height = scale*0.8
    cone_radius = scale/10
    cylinder_radius = scale/20
    mesh_frame = o3d.geometry.TriangleMesh.create_arrow(
        cone_radius=cone_radius,
        cylinder_radius=cylinder_radius,
        cone_height=cone_height, cylinder_height=cylinder_height)
    #cylinder_radius=0.00175, cone_radius=0.0035,  cylinder_height=0.02, cone_height=0.01, resolution=10)
    # mesh_arrow = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.00175, cone_radius=0.0035, cylinder_height=0.02, cone_height=0.01, resolution=10)
    return(mesh_frame)

def get_arrow(origin=[0, 0, 0], end=None, vec=None):
    """
    Creates an arrow from an origin point to an end point,
    or create an arrow from a vector vec starting from origin.
    Args:
        - end (): End point. [x,y,z]
        - vec (): Vector. [i,j,k]
    """
    scale = 10
    Ry = Rz = np.eye(3)
    T = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
    T[:3, -1] = origin
    if end is not None:
        vec = np.array(end) - np.array(origin)
    elif vec is not None:
        vec = np.array(vec)
    if end is not None or vec is not None:
        scale = vector_magnitude(vec)
        Rz, Ry = calculate_zy_rotation_for_arrow(vec)
    mesh = create_arrow(scale)
    # Create the arrow
    mesh.rotate(Ry, center=np.array([0, 0, 0]))
    mesh.rotate(Rz, center=np.array([0, 0, 0]))
    mesh.translate(origin)
    return(mesh)



def vis_flow_volume_arrow(flow_volume, flow_mask):            
    dim = 32
    bbox_size = 1.5
    
    # flow_color = normalize_flow_to_color(flow_volume)
    # print(flow_volume.min(axis=0), flow_volume.max(axis=0))
    
    flow_lenght = np.sqrt(np.sum(flow_volume**2, axis=1))
    min_len, max_len = flow_lenght[flow_mask==1].min(), flow_lenght[flow_mask==1].max()
    
    arrow_triangles = []
    for idx in range(flow_volume.shape[0]):
        magnitude = vector_magnitude(flow_volume[idx, :]+1e-6)
        Rz, Ry = calculate_zy_rotation_for_arrow(flow_volume[idx, :]+1e-6)
        #print(scale)
        if flow_mask[idx]: # and scale>0:
        
            z, y, x = idx // (dim * dim), (idx // dim) % dim, idx % dim
            z_coord, y_coord, x_coord = ((z + 0.5)/dim-0.5)*bbox_size, ((y+ 0.5)/dim-0.5)*bbox_size, ((x+ 0.5)/dim-0.5)*bbox_size
            center = np.array([x_coord, y_coord, z_coord], dtype='f4')
            scale = np.array([1, 1, 1], dtype='f4')

            if True:
                # Create a sphere mesh.
                # mesh_arrow = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.035, cone_radius=0.07, cylinder_height=0.8, cone_height=0.2)
                mesh_arrow = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.007, cone_radius=0.014,  cylinder_height=0.08, cone_height=0.04, resolution=10)

                scale = scale + 1e-8

                T_t = np.eye(4)
                T_t[0:3, 3] = center

                T_s = np.eye(4)
                T_s[0, 0] = scale[0]
                T_s[1, 1] = scale[1]
                T_s[2, 2] = scale[2]

                T_R = np.eye(4)
                # r = R.from_dcm(np.array([ [flow_volume[idx, 0], 0, 0],
                #                         [0, flow_volume[idx, 1], 0],
                #                         [0, 0, flow_volume[idx, 2]] ]))
                # r = r.as_dcm() 
                #T_R[0:3, 0:3] = r
                
                T = np.matmul(T_t, np.matmul(T_R, T_s))
                #mesh_arrow.transform(T)
                
                mesh_arrow.transform(T_s)
                mesh_arrow.rotate(Ry, center=np.array([0, 0, 0]))
                mesh_arrow.rotate(Rz, center=np.array([0, 0, 0]))
                mesh_arrow.transform(T_t)
            else:
                mesh_arrow = get_arrow(origin=center, vec=flow_volume[idx, :])

            # We view spheres as wireframe.
            mesh_arrow.paint_uniform_color(get_color(flow_lenght[idx], min_len, max_len))
            arrow_triangles.append(mesh_arrow)

    # Merge sphere meshes.
    merged_arrow_triangles = merge_meshes(arrow_triangles)
    ############
    return merged_arrow_triangles

def vis_flow_surface_arrow(geometry, flow, mask):
    flow_lenght = np.sqrt(np.sum(flow**2, axis=1))
    min_len, max_len = flow_lenght[mask==1].min(), flow_lenght[mask==1].max()
    
    arrow_triangles = []
    for idx in range(flow.shape[0]):
        magnitude = vector_magnitude(flow[idx, :]+1e-6)
        Rz, Ry = calculate_zy_rotation_for_arrow(flow[idx, :]+1e-6)

        center = geometry[idx, :]
        scale = np.array([1, 1, 1], dtype='f4')

        if mask[idx]: 
            if True:
                # Create a sphere mesh.
                # mesh_arrow = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.035, cone_radius=0.07, cylinder_height=0.8, cone_height=0.2)
                mesh_arrow = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.007, cone_radius=0.014,  cylinder_height=0.08, cone_height=0.04, resolution=10)

                scale = scale + 1e-8

                T_t = np.eye(4)
                T_t[0:3, 3] = center

                T_s = np.eye(4)
                T_s[0, 0] = scale[0]
                T_s[1, 1] = scale[1]
                T_s[2, 2] = scale[2]

                T_R = np.eye(4)
                # r = R.from_dcm(np.array([ [flow_volume[idx, 0], 0, 0],
                #                         [0, flow_volume[idx, 1], 0],
                #                         [0, 0, flow_volume[idx, 2]] ]))
                # r = r.as_dcm() 
                #T_R[0:3, 0:3] = r
                
                T = np.matmul(T_t, np.matmul(T_R, T_s))
                #mesh_arrow.transform(T)
                
                mesh_arrow.transform(T_s)
                mesh_arrow.rotate(Ry, center=np.array([0, 0, 0]))
                mesh_arrow.rotate(Rz, center=np.array([0, 0, 0]))
                mesh_arrow.transform(T_t)
            else:
                mesh_arrow = get_arrow(origin=center, vec=flow[idx, :])

            # We view spheres as wireframe.
            mesh_arrow.paint_uniform_color(get_color(flow_lenght[idx], min_len, max_len))
            arrow_triangles.append(mesh_arrow)

    # Merge sphere meshes.
    merged_arrow_triangles = merge_meshes(arrow_triangles)
    ############
    return merged_arrow_triangles


#Farthest point sampling
import numpy

def l2_norm(x, y):
    """Calculate l2 norm (distance) of `x` and `y`.
    Args:
        x (numpy.ndarray or cupy): (batch_size, num_point, coord_dim)
        y (numpy.ndarray): (batch_size, num_point, coord_dim)
    Returns (numpy.ndarray): (batch_size, num_point,)
    """
    return ((x - y) ** 2).sum(axis=2)


def farthest_point_sampling(pts, k, initial_idx=None, metrics=l2_norm,
                            skip_initial=False, indices_dtype=numpy.int32,
                            distances_dtype=numpy.float32):
    """Batch operation of farthest point sampling
    Code referenced from below link by @Graipher
    https://codereview.stackexchange.com/questions/179561/farthest-point-algorithm-in-python
    Args:
        pts (numpy.ndarray or cupy.ndarray): 2-dim array (num_point, coord_dim)
            or 3-dim array (batch_size, num_point, coord_dim)
            When input is 2-dim array, it is treated as 3-dim array with
            `batch_size=1`.
        k (int): number of points to sample
        initial_idx (int): initial index to start farthest point sampling.
            `None` indicates to sample from random index,
            in this case the returned value is not deterministic.
        metrics (callable): metrics function, indicates how to calc distance.
        skip_initial (bool): If True, initial point is skipped to store as
            farthest point. It stabilizes the function output.
        xp (numpy or cupy):
        indices_dtype (): dtype of output `indices`
        distances_dtype (): dtype of output `distances`
    Returns (tuple): `indices` and `distances`.
        indices (numpy.ndarray or cupy.ndarray): 2-dim array (batch_size, k, )
            indices of sampled farthest points.
            `pts[indices[i, j]]` represents `i-th` batch element of `j-th`
            farthest point.
        distances (numpy.ndarray or cupy.ndarray): 3-dim array
            (batch_size, k, num_point)
    """
    if pts.ndim == 2:
        # insert batch_size axis
        pts = pts[None, ...]
    assert pts.ndim == 3
    batch_size, num_point, coord_dim = pts.shape
    indices = np.zeros((batch_size, k, ), dtype=indices_dtype)

    # distances[bs, i, j] is distance between i-th farthest point `pts[bs, i]`
    # and j-th input point `pts[bs, j]`.
    distances = np.zeros((batch_size, k, num_point), dtype=distances_dtype)
    if initial_idx is None:
        indices[:, 0] = np.random.randint(len(pts))
    else:
        indices[:, 0] = initial_idx

    batch_indices = np.arange(batch_size)
    farthest_point = pts[batch_indices, indices[:, 0]]
    # minimum distances to the sampled farthest point
    
    min_distances = metrics(farthest_point[:, None, :], pts)

    if skip_initial:
        # Override 0-th `indices` by the farthest point of `initial_idx`
        indices[:, 0] = np.argmax(min_distances, axis=1)
        farthest_point = pts[batch_indices, indices[:, 0]]
        min_distances = metrics(farthest_point[:, None, :], pts)

    distances[:, 0, :] = min_distances
    for i in range(1, k):
        indices[:, i] = np.argmax(min_distances, axis=1)
        farthest_point = pts[batch_indices, indices[:, i]]
        dist = metrics(farthest_point[:, None, :], pts)
        distances[:, i, :] = dist
        min_distances = np.minimum(min_distances, dist)
    return indices #, 