import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import trimesh
import numpy as np

# Assuming 'points' is your array of shape (50000, 3) with corresponding colors
# 'points' should be a list of lists, where each sublist represents a point's coordinates [x, y, z]
# 'colors' should be a list of RGB tuples (e.g., [(R1, G1, B1), (R2, G2, B2), ...])
def values_to_colors(values):
    # Define the color stops and corresponding RGB values
    min_val, max_val = np.min(values) - 0.01, np.max(values)
    # color_stops = np.array([-10, -6, -4, -2, 0])
    color_stops = np.array([min_val, min_val + (max_val - min_val) * 0.4, min_val + (max_val - min_val) * 0.6,
                            min_val + (max_val - min_val) * 0.8, max_val])
    colors = np.array([[0, 0, 255], [0, 0, 255], [0, 255, 0], [255, 165, 0], [255, 0, 0]])

    # Ensure values are within the specified range
    # values = np.clip(values, -10, 0)

    # Find the appropriate color ranges for interpolation
    lower_idx = np.digitize(values, color_stops) - 1
    upper_idx = np.minimum(lower_idx + 1, len(color_stops) - 1)

    # Perform linear interpolation
    t = (values - color_stops[lower_idx]) / (color_stops[upper_idx] - color_stops[lower_idx])
    interpolated_colors = (1 - t[:, np.newaxis]) * colors[lower_idx] + t[:, np.newaxis] * colors[upper_idx]

    return interpolated_colors.astype(int)


# Create a trimesh PointCloud object
def vis_plt(point_xyz, point_att, human_mesh):
    point_colors = values_to_colors(point_att)

    # Create a figure
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')

    # Plot the point cloud with colors
    ax.scatter(point_xyz.vertices[:, 0], point_xyz.vertices[:, 1], point_xyz.vertices[:, 2], c=point_colors / 255.0, s=1)

    # Plot the human mesh
    ax.add_collection3d(trimesh.rendering.face_groups(human_mesh))

    # Set axis labels (optional)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')

    # Set axis limits (optional)
    # ax.set_xlim([-10, 10])
    # ax.set_ylim([-10, 10])
    # ax.set_zlim([0, 20])

    # Save the figure as a PNG file
    plt.savefig('point_cloud_with_mesh.png', dpi=300)
