import copy
import open3d as o3d
import numpy as np
import jax
import jax.numpy as jnp
from scipy.spatial import distance_matrix

import warp as wp
from warp.jax_experimental.custom_call import jax_kernel

def furthest_point_sampling(points, n_samples, initial_idx=None):
    """
    Performs furthest point sampling (FPS) on a set of points.

    Args:
        points (ndarray): Array of shape (N, D), points in D-dimensional space.
        n_samples (int): Number of points to sample.
        initial_idx (int, optional): Index of the initial point to start sampling. 
                                     If None, randomly selected.

    Returns:
        sampled_idxs (list): Indices of sampled points.
    """
    N, D = points.shape

    if initial_idx is None:
        initial_idx = np.random.randint(N)

    # sampled_idxs = [initial_idx]
    sampled_idxs = []
    sampled_pnts = []
    distances = distance_matrix(points, points[[initial_idx]]).flatten()

    for _ in range(1, n_samples+1):
        idx = np.argmax(distances)
        sampled_idxs.append(idx)
        sampled_pnts.append(points[idx])
        new_distances = distance_matrix(points, points[[idx]]).flatten()
        distances = np.minimum(distances, new_distances)

    return np.stack(sampled_pnts, axis=0), np.stack(sampled_idxs, axis=0)

@wp.func
def mesh_query_point_fn(
    idx: wp.uint64,
    point: wp.vec3,
    max_distance: float,
):
    """Query point on mesh."""
    collide_result = wp.mesh_query_point(idx, point, max_distance)
    return collide_result

def get_mesh_query_point():
    @wp.kernel
    def mesh_query_point(
        points: wp.array(dtype=wp.vec3),
        mesh_id_32: wp.array(dtype=wp.vec2ui),
        max_distance: wp.array(dtype=wp.float32),
        # outputs
        is_point_inside: wp.array(dtype=wp.float32),
        signed_dist: wp.array(dtype=wp.vec3),
    ):
        tid = wp.tid()  # get the thread index

        point = points[tid]
        mesh_id = (wp.uint64(mesh_id_32[0][0]) << wp.uint64(32)) | wp.uint64(mesh_id_32[0][1])
        collide_result = wp.mesh_query_point(mesh_id, point, max_distance[0])
        sign = 1.0

        delta = wp.vec3(0.0, 0.0, 0.0)
        if collide_result.result:
            cl_pt = wp.mesh_eval_position(
                mesh_id, collide_result.face, collide_result.u, collide_result.v
            )
            delta = cl_pt - point
            dis_length = wp.length(delta)
            sign = collide_result.sign # negative if inside, positive if outside
        
        signed_dist[tid] = delta * sign

        is_point_inside[tid] = sign

    return mesh_query_point

def uint64_to_uint32_pair(value):
    # Extract the lower 32 bits
    low = jnp.uint32(value & 0xFFFFFFFF)
    # Extract the upper 32 bits
    high = jnp.uint32(value >> 32)
    # return high, low
    return jnp.stack([high, low], axis=-1).astype(jnp.uint32)


def decompose_mesh_to_spheres(mesh, voxel_interval=0.1, num_surface_samples=1000, num_of_pnts=None, visualize=False):
    """
    Decompose a watertight mesh into spheres.
    
    Parameters:
      mesh: open3d.geometry.TriangleMesh input.
      voxel_interval: spacing of the voxel grid. Also defines sphere radius = voxel_interval / 2.
      num_surface_samples: number of points to sample uniformly from the mesh surface.
    
    Returns:
      spheres: (N,4) numpy array where each row is [x, y, z, radius].
    """
    # Compute interal voxel and turn in to spheres
    if not mesh.is_watertight():
        print("Warning: mesh is not watertight. There might be errors in the sphere decomposition.")

    mesh_wp = wp.Mesh(
        points=wp.array(np.array(mesh.vertices), dtype=wp.vec3),
        velocities=None,
        indices=wp.array(np.array(mesh.triangles).reshape(-1), dtype=int),
        # support_winding_number=True,
    )

    mesh_id = uint64_to_uint32_pair(mesh_wp.id)[None]
    mesh_query_point_fn = jax.jit(jax_kernel(get_mesh_query_point()))

    vertices = jnp.array(mesh.vertices)
    min_bound = vertices.min(axis=0)
    max_bound = vertices.max(axis=0)
    xs = jnp.arange(min_bound[0], max_bound[0], voxel_interval)
    ys = jnp.arange(min_bound[1], max_bound[1], voxel_interval)
    zs = jnp.arange(min_bound[2], max_bound[2], voxel_interval)
    grid = jnp.stack(jnp.meshgrid(xs, ys, zs, indexing='ij'), axis=-1)
    grid_points = grid.reshape(-1, 3)

    args = [
        grid_points,                         # wp.array(dtype=wp.vec3)
        mesh_id,                             # wp.array(dtype=wp.vec2ui)
        jnp.array([10.0], dtype=jnp.float32), # wp.array(dtype=wp.float32)
    ]
    outputs = mesh_query_point_fn(*args)
    jax.block_until_ready(outputs)
    inside_flags = (outputs[0] < 0)
    
    closest_dir = outputs[1]


    # voxel_flags = jnp.all(jnp.stack([
    #     inside_flags[:-1, :-1, :-1], inside_flags[1:, :-1, :-1],
    #     inside_flags[:-1, 1:, :-1], inside_flags[1:, 1:, :-1],
    #     inside_flags[:-1, :-1, 1:], inside_flags[1:, :-1, 1:],
    #     inside_flags[:-1, 1:, 1:], inside_flags[1:, 1:, 1:],
    # ],axis=0), axis=0)
    # voxel_centers = grid[:-1, :-1, :-1] + voxel_interval / 2.0
    voxel_centers = grid
    voxel_flags = inside_flags.reshape(grid.shape[:-1])

    internal_voxel_centers = voxel_centers[voxel_flags]
    sphere_radius = jnp.linalg.norm(closest_dir[inside_flags], axis=-1)
    # sphere_radius = voxel_interval / 2.0

    internal_spheres = jnp.hstack([
        internal_voxel_centers, sphere_radius[...,None]
        # jnp.full((internal_voxel_centers.shape[0], 1), sphere_radius)
    ])
    spheres = [internal_spheres]
    
    # Sample points uniformly on the mesh surface
    if num_of_pnts is not None:
        num_surface_samples = num_of_pnts - len(internal_spheres)
    if num_surface_samples > 0:
        pcd = mesh.sample_points_uniformly(number_of_points=num_surface_samples * 100)
        surface_points = np.asarray(pcd.points)
        surface_points, _ = furthest_point_sampling(surface_points, num_surface_samples)
        surface_spheres = np.hstack([
            surface_points,
            np.full((surface_points.shape[0], 1), 0.001)
        ])
        spheres.append(surface_spheres)
    
    # Combine both sets of spheres
    spheres = np.vstack(spheres)

    if visualize:
        visualize_spheres(mesh, spheres)

    return spheres

# -------------------------------
# Visualization Function for Spheres
# -------------------------------
def visualize_spheres(mesh, spheres, sphere_color=[0, 1, 0], mesh_color=[0.7, 0.7, 0.7]):
    """
    Visualizes the original mesh and spheres.
    
    Parameters:
      mesh: open3d.geometry.TriangleMesh of the original mesh.
      spheres: (N, 4) numpy array where each row is [x, y, z, radius].
      sphere_color: list of 3 floats for sphere color.
      mesh_color: list of 3 floats for mesh color.
    """
    sphere_meshes = []
    # Create a sphere mesh for each sphere.
    for sphere in spheres:
        x, y, z, r = sphere
        sphere_mesh = o3d.geometry.TriangleMesh.create_sphere(radius=r, resolution=8)
        sphere_mesh.translate(np.array([x, y, z]))
        sphere_mesh.paint_uniform_color(sphere_color)
        sphere_mesh.compute_vertex_normals()
        sphere_meshes.append(sphere_mesh)
    
    # Paint the original mesh and add it to the visualization.
    mesh_copy = copy.deepcopy(mesh)
    mesh_copy.paint_uniform_color(mesh_color)
    mesh_copy.translate(np.array([0, 0, 1]))
    
    geometries = [mesh_copy] + sphere_meshes
    o3d.visualization.draw_geometries(geometries)

# -------------------------------
# Example Usage
# -------------------------------
if __name__ == "__main__":
    # '/home/rogga/research/efficient_planning/dataset/ur5/meshes/ur5/visual/base_mobile.obj', 
    # Load a sample mesh (replace with your mesh file)
    for mesh_path in [
        # "/home/rogga/research/efficient_planning/dataset/NOCS/modified/mug-46ed9dad0440c043d33646b0990bb4a.obj",
        # "/home/rogga/research/efficient_planning/dataset/GoogleScannedObjects/modified/Utana_5_Porcelain_Ramekin_Large.obj",
        # "/home/rogga/research/efficient_planning/dataset/GoogleScannedObjects/modified/Womens_Sparkle_Suede_Bahama_in_Silver_Sparkle_Suede_Grey_Patent_tYrIBLMhSTN.obj",
        "/home/dongwon/research/object_set/EGAD/modified/L18_1.obj",
        # '/home/rogga/research/efficient_planning/dataset/ur5/meshes/ur5/visual/shoulder.obj',
        # '/home/rogga/research/efficient_planning/dataset/ur5/meshes/ur5/visual/upperarm.obj',
        # '/home/rogga/research/efficient_planning/dataset/ur5/meshes/ur5/visual/forearm.obj',
        # '/home/rogga/research/efficient_planning/dataset/ur5/meshes/ur5/visual/wrist1.obj',
        # '/home/rogga/research/efficient_planning/dataset/ur5/meshes/ur5/visual/wrist2.obj',
        # '/home/rogga/research/efficient_planning/dataset/ur5/meshes/rg2_gripper/modified/ee_rg2_gripper_open10.obj'
    ]:
        mesh = o3d.io.read_triangle_mesh(mesh_path)
        print(mesh.is_watertight())
        mesh.compute_vertex_normals()

        # Set hyperparameters
        voxel_interval = 0.01       # adjust based on mesh scale
        num_surface_samples = 100
        
        # Compute sphere decomposition (Nx4 array)
        spheres = decompose_mesh_to_spheres(mesh, voxel_interval, num_surface_samples)
        print("Number of spheres:", spheres.shape[0])
        print("First few spheres:\n", spheres[:5])
        
        # Visualize the mesh along with the generated spheres
        visualize_spheres(mesh, spheres)
