import os
import torch
from torchvision import transforms
import imageio
import numpy as np
import matplotlib.pylab as plt
from tqdm import tqdm
tar_path = "./tmp/"
video_path = "./video_gen_from_pcd/"
dset = "pour"
traj = 0
vx = 40
rot=True
interval = 2
frame = range(100, 150, interval)

if dset == 'pour':
    box = np.load('/home/htxue/data/mit/VGPL-Dynamics-Prior/dataset/box_sampling/pour/sampling_8_interval_0.15/pkg.npy')


def capture_motion_image(feature_list, output_pth='motion.png', angle=100, color=None):
    frame_num = len(feature_list)
    fig = plt.figure()
    ax3D = fig.add_subplot(111, projection='3d')
    for i in range(frame_num):
        print(i)
        feature = feature_list[i]
        x, y, z = feature[:, 0], feature[:, 1], feature[:, 2]
        # ax3D.scatter(x, z, y, s=1, marker='o', color='g')
        if i > 0:
            feautre_pre = feature_list[i - 1]
            for o in range(feature_list[0].shape[0]):
                x_start, x_end = feature[o][0], feautre_pre[o][0]
                y_start, y_end = feature[o][1], feautre_pre[o][1]
                z_start, z_end = feature[o][2], feautre_pre[o][2]
                ax3D.plot(xs=[x_start, x_end], ys=[z_start, z_end], zs=[y_start, y_end], c='r')
    plt.savefig(output_pth)
    plt.close()



if not os.path.exists(tar_path):
    os.makedirs(tar_path)


if not os.path.exists(video_path):
    os.makedirs(video_path)


os.system(f"rm -rf {tar_path}/*.png")


box_motion = box[traj, frame , :, :]
capture_motion_image(box_motion, output_pth='motion.png')


