import argparse

import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d

parser = argparse.ArgumentParser()
parser.add_argument("--npz-path", type=str, required=True)
args = parser.parse_args()

data = np.load(args.npz_path)
pos = data["positive_sdf_samples"]
print(pos.shape)
try:
    neg = data["negative_sdf_samples"]
    print(neg.shape)
    sdf = np.vstack((pos, neg))
except:
    print("no negative sdf samples")
    sdf = pos
points = sdf[:, :3]
sdf = sdf[:, 3]
indices = np.where(points[:, 0] <= 0)[0]
points = points[indices]
sdf = sdf[indices]
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)

cmap = plt.get_cmap("RdBu")
normalized_sdf = 2 * sdf + 0.5
rgba = cmap(normalized_sdf)
colors = rgba[:, :3]
pcd.colors = o3d.utility.Vector3dVector(colors)
print(points.shape, colors.shape)
print(pcd.has_colors())
o3d.visualization.draw_geometries([pcd], "sdf")
