import numpy as np
# import open3d as o3d
# import time
# from scipy.spatial import cKDTree
import sys
sys.path.append('stable_fast_3d')

from . import (
    # layerDiffusion_text2img,
    sd21_text2img,
    SF3D_img2mesh
)
# text2img = layerDiffusion_text2img.layerDiffusion_text2img
text2img = sd21_text2img.sd21_text2img
img2mesh = SF3D_img2mesh.SF3D_img2mesh

# def voxelize_mesh(mesh_vertices, voxel_size, max_distance):
#     # Determine the bounding box of the mesh
#     min_bounds = np.min(mesh_vertices, axis=0)
#     max_bounds = np.max(mesh_vertices, axis=0)

#     # Create a 3D grid of voxel centers
#     x_range = np.arange(min_bounds[0], max_bounds[0], voxel_size)
#     y_range = np.arange(min_bounds[1], max_bounds[1], voxel_size)
#     z_range = np.arange(min_bounds[2], max_bounds[2], voxel_size)

#     xv, yv, zv = np.meshgrid(x_range, y_range, z_range, indexing='ij')
#     voxel_centers = np.vstack((xv.flatten(), yv.flatten(), zv.flatten())).T

#     # Construct a KD-tree from the mesh vertices
#     tree = cKDTree(mesh_vertices)

#     # Find the distance from each voxel center to the nearest mesh vertex
#     distances, _ = tree.query(voxel_centers)

#     # Filter out voxel centers that are too far from the mesh vertices
#     mask = distances <= max_distance
#     voxel_centers = voxel_centers[mask]

#     return voxel_centers, min_bounds, max_bounds

# def is_point_inside_mesh_voxelized(points, voxel_centers, mesh_vertices, mesh_colors, voxel_size):
#     # Use KDTree to find nearest voxel centers for each point
#     tree = cKDTree(voxel_centers)
#     distances, _ = tree.query(points)

#     # Points are considered inside if they are within half the voxel size of a voxel center
#     inside_mask = distances < (voxel_size / 2)

#     # Find the nearest mesh vertices for the points inside the mesh
#     mesh_tree = cKDTree(mesh_vertices)
#     _, mesh_indices = mesh_tree.query(points[inside_mask])

#     # Assign colors to the points inside the mesh
#     point_colors = np.zeros_like(points)
#     point_colors[inside_mask] = mesh_colors[mesh_indices]

#     return inside_mask, point_colors

def SF3D(prompt, save_dir, load_lora1):
    # prompt = str(self.cfg.prompt_processor.prompt)
    # prompt = "a fox"
    ref_image = text2img(prompt, save_dir, load_lora1)
    mesh = img2mesh(ref_image, save_dir)

    skip = 1
    
    # get vertices coordinates
    coords = mesh.vertices
    vertex_colors = mesh.visual.vertex_colors
    # print(f"vertex_colors shape: {vertex_colors.shape}")
    # print(f"vertex_colors sample: {vertex_colors[:5]}")

    # change color from 0-255 to 0-1
    vertex_colors = vertex_colors[:, :3] / 255.0
    rgb = np.concatenate([vertex_colors[:, None, 0], vertex_colors[:, None, 1], vertex_colors[:, None, 2]], axis=1)
    # rgb = np.concatenate([pc.vertex_channels['R'][:,None],pc.vertex_channels['G'][:,None],pc.vertex_channels['B'][:,None]],axis=1) 

    coords = coords[::skip]
    rgb = rgb[::skip]
    angle_x = np.radians(90)  
    rotation_matrix = np.array([
        [1, 0, 0],
        [0, np.cos(angle_x), -np.sin(angle_x)],
        [0, np.sin(angle_x), np.cos(angle_x)]
    ])
    coords = coords @ rotation_matrix.T

    # # Voxelization parameters
    # voxel_size = 0.01  # Adjust voxel size as needed
    # max_distance = 0.1  # Maximum distance from mesh vertices to keep voxel

    # # Voxelize the mesh
    # start_time = time.time()
    # voxel_centers, min_bounds, max_bounds = voxelize_mesh(coords, voxel_size, max_distance)
    # end_time = time.time()
    # print(f"Voxelization Time: {end_time - start_time} seconds")
    # print(f"voxel_centers len: {len(voxel_centers)}")

    # # Check if points are inside the mesh and assign colors
    # num_points = 50000  # Number of random points for testing
    # points = np.random.uniform(low=min_bounds, high=max_bounds, size=(num_points, 3))
    
    # start_time = time.time()
    # inside_mask, point_colors = is_point_inside_mesh_voxelized(points, voxel_centers, coords, rgb, voxel_size)
    # end_time = time.time()
    # print(f"Point Inside Check and Coloring Time: {end_time - start_time} seconds")

    # print(f"Number of points inside the mesh: {np.sum(inside_mask)}")
    # print("Colors of points inside the mesh:", point_colors[inside_mask])

    # # Combine the original coords and rgb with the new points and point_colors
    # combined_coords = np.vstack((coords, points[inside_mask]))
    # combined_rgb = np.vstack((rgb, point_colors[inside_mask]))

    # self.num_pts = coords.shape[0]
    # point_cloud = o3d.geometry.PointCloud()
    # point_cloud.points = o3d.utility.Vector3dVector(coords)
    # point_cloud.colors = o3d.utility.Vector3dVector(rgb)
    # self.point_cloud = point_cloud
    return coords,rgb,0.8
    # return combined_coords, combined_rgb, 0.8
    