import open3d as o3d
import argparse
import os
import numpy as np
import torch
import pickle


parser = argparse.ArgumentParser()
parser.add_argument("--dset", default='shake', type=str)
parser.add_argument("--traj", default=0, type=int)
parser.add_argument("--frame", default=0, type=int)
parser.add_argument("--root", default='../')
args = parser.parse_args()

def load_shape_info(dset):
    if dset == 'pour':
        info_pth = os.path.join(args.root, 'datasets/shape_info/', 'FluidPour.txt')
    elif dset == 'shake':
        info_pth = os.path.join(args.root, 'datasets/shape_info/', 'FluidShake.txt')
    f = open(info_pth)
    mesh_num = box_num = 0
    shape_info = []
    for lines in f:
        if lines.startswith('asset'):
            mesh_num += 1
            continue
        else:
            """
            [4.   0.02 4.  ] 0 [0.9 0.9 0.9]
            """
            end_pos = lines.find("]")
            visible = int(lines[end_pos + 2])

            x, y, z = lines[:end_pos].split()
            x = x[1:]

            x, y, z = float(x), float(y), float(z)

            shape_info.append([(x, y, z), visible])
    return mesh_num, shape_info


msh_num, shape_info = load_shape_info(args.dset)

if args.dset == 'pour':
    info_p_path = os.path.join('~/datasets/data_FluidPour/{}/info.p'.format(args.traj))
elif args.dset == 'shake':
    info_p_path = os.path.join('~/datasets/data_FluidShake/{}/info.p'.format(args.traj))

info = pickle.load(open(info_p_path, 'rb'))

shap_info = info['shape_states']  # 300 * shape_number * 14
particle_info = info['particles']  # particle_num * 3



