import numpy as np
from plyfile import PlyData
import open3d as o3d
import matplotlib.pyplot as plt

def sigmoid(x):
    """Sigmoid function."""
    return 1 / (1 + np.exp(-x))

def visualize_ply(ply_path):
    # Load the PLY file
    ply_data = PlyData.read(ply_path)
    vertex_data = ply_data['vertex'].data
    all_attributes = vertex_data.dtype.names
    print("所有属性名称:", all_attributes)

    # Extract the point cloud attributes
    points = np.array([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T
    # 判断是否有score字段
    if 'score' in all_attributes:
        scores = vertex_data['score']
        # 归一化
        score_norm = (scores - scores.min()) / (scores.max() - scores.min() + 1e-8)
        # 灰度映射：score越小越黑，越大越白
        colors = np.stack([score_norm, score_norm, score_norm], axis=1)  # [N,3], 0-1
    elif 'label' in all_attributes:
        labels = vertex_data['label']
        # 颜色映射：1->红色，0->蓝色
        colors = np.zeros((points.shape[0], 3), dtype=np.float32)
        colors[labels == 1] = [1.0, 0.0, 0.0]
        colors[labels == 0] = [0.0, 0.0, 1.0]
    else:
        colors = np.array([vertex_data['red'], vertex_data['green'], vertex_data['blue']]).T / 255.0
    # 兼容有opacity的情况
    if 'opacity' in all_attributes:
        opacity = vertex_data['opacity']
        sigmoid_opacity = sigmoid(opacity)
        filtered_indices = sigmoid_opacity >= 0.1
        filtered_points = points[filtered_indices]
        filtered_colors = colors[filtered_indices]
    else:
        filtered_points = points
        filtered_colors = colors

    # Create an Open3D PointCloud object
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(filtered_points)
    pcd.colors = o3d.utility.Vector3dVector(filtered_colors)

    # Visualize the point cloud
    o3d.visualization.draw_geometries([pcd])

if __name__ == "__main__":
    # Replace with the path to your PLY file
    ply_path = ""
    visualize_ply(ply_path)