import trimesh
import numpy as np
import json
import argparse


def sample_mesh_points(mesh_file: str, n_points: int = 1500):
    mesh = trimesh.load(mesh_file, force='mesh')
    pc, _ = trimesh.sample.sample_surface(mesh, n_points)
    pc_color = np.tile(np.array([[150, 150, 150]], dtype=np.uint8), (n_points, 1))
    return pc, pc_color


def sample_plane_points(size=(0.5, 0.5), resolution=200, y_level=-0.005):
    x = np.linspace(-size[0]/2, size[0]/2, resolution)
    z = np.linspace(-size[1]/2, size[1]/2, resolution)
    xx, zz = np.meshgrid(x, z)
    yy = np.full_like(xx, y_level)
    pc = np.stack([xx.flatten(), yy.flatten(), zz.flatten()], axis=-1)
    pc_color = np.tile(np.array([[200, 200, 200]], dtype=np.uint8), (pc.shape[0], 1))
    return pc, pc_color


def pointcloud_to_flat_color(pc: np.ndarray, pc_color: np.ndarray):
    return [pc.reshape(-1, 3).tolist()], [pc_color.reshape(-1, 3).tolist()]


def main(mesh_file: str, out_file: str, n_points: int = 5000):
    handle_pc, handle_color = sample_mesh_points(mesh_file, n_points)

    plane_pc, plane_color = sample_plane_points()

    #scene_pc = np.vstack([handle_pc, plane_pc])
    #scene_color = np.vstack([handle_color, plane_color])
    scene_pc = plane_pc
    scene_color = plane_color

    # 4. object_info
    object_info = {
        "pc": handle_pc.tolist(),
        "pc_color": handle_color.tolist()
    }

    full_pc_list, img_color_list = pointcloud_to_flat_color(scene_pc, scene_color)
    scene_info = {
        "full_pc": full_pc_list,  
        "img_color": img_color_list
    }


    grasp_info = {
        "grasp_poses": [],
        "grasp_conf": []
    }


    data = {
        "object_info": object_info,
        "scene_info": scene_info,
        "grasp_info": grasp_info
    }

    with open(out_file, "w") as f:
        json.dump(data, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mesh_file", type=str, required=True)
    parser.add_argument("--out_file", type=str)
    parser.add_argument("--n_points", type=int, default=1500)
    args = parser.parse_args()

    main(args.mesh_file, args.out_file, args.n_points)

