import os
import torch
from torchvision import transforms
import imageio
import numpy as np
import matplotlib.pylab as plt
from tqdm import tqdm

tar_path = "./tmp/"
video_path = "./video_gen_from_pcd_new/"
dset = "fluid_shake"
traj_list = [0]
vx = 30
rot=True
interval = 3
angle = 0
root = '/home/htxue/data/mit/visual_dynamics/'




def capture_scene_image(feature, output_pth='debug.png', angle=100, color=None, h=20):

    x, y, z = feature[:, 0], feature[:, 1], feature[:, 2]
    fig = plt.figure()
    ax3D = fig.add_subplot(111, projection='3d')
    ax3D.scatter(x, z, y, s=15, marker='o', color=color)
    ax3D.view_init(h, angle)
    ax3D.set_xlabel('x')
    ax3D.set_ylabel('y')
    ax3D.set_zlabel('z')
    ax3D.axis('off')
    if 'pour' in dset:
        ax3D.set_xlim3d(-1, 1)
        ax3D.set_ylim3d(-1, 1)
        ax3D.set_zlim3d(0.5, 2.5)
    if dset == 'shake':
        ax3D.set_xlim3d(-1, 1)
        ax3D.set_ylim3d(-1, 1)
        ax3D.set_zlim3d(0, 2)
    if 'granular' in dset:
        ax3D.set_xlim3d(-2, 2)
        ax3D.set_ylim3d(-4, 0)
        ax3D.set_zlim3d(0, 4)




    plt.savefig(output_pth)
    plt.close()






def capture_motion_image(feature_list, output_pth='motion.png', angle=100, color=None):
    frame_num = len(feature_list)
    fig = plt.figure()
    ax3D = fig.add_subplot(111, projection='3d')
    for i in range(frame_num):
        feature = feature_list[i]
        x, y, z = feature[:, 0], feature[:, 1], feature[:, 2]
        ax3D.scatter(x, z, y, s=5, marker='o', color=color)
        if i > 0:
            feautre_pre = feature_list[i - 1]
            for o in range(feature_list[0].shape[0]):
                x_start, x_end = feature[o][0], feautre_pre[o][0]
                y_start, y_end = feature[o][0], feautre_pre[o][0]
                z_start, z_end = feature[o][0], feautre_pre[o][0]
                plt.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], c=color[o])
    plt.savefig(output_pth)
    plt.close()



if not os.path.exists(tar_path):
    os.makedirs(tar_path)


if not os.path.exists(video_path):
    os.makedirs(video_path)


os.system(f"rm -rf {tar_path}/*.png")


for traj in traj_list:
    if dset == 'pour':
        box = np.load(
            f'{root}VGPL-Dynamics-Prior/dataset/box_sampling/pour/sampling_8_interval_0.15/pkg.npy')
        par_fps = torch.load(f'/home/htxue/data/mit/VGPL-Dynamics-Prior/dataset/particle_fps/pour/fps_300/{traj}.bin')

    if dset == 'pour_gt':
        box = np.load(
            f'{root}VGPL-Dynamics-Prior/dataset/box_sampling/pour/sampling_8_interval_0.15/pkg.npy')
        dump = np.load(f'{root}VGPL-Dynamics-Prior/dataset/GT_particle_fps/pour/fps_200/{traj}.npy')
        info_path = os.path.join("/home/htxue/datasets/data_FluidPour/{}/".format(str(traj)), "info.p")
        import pickle

        info = pickle.load(open(info_path, 'rb'))['particles']
        batch_idx = np.array(range(0, dump.shape[0])).repeat(dump.shape[1])
        par_fps = info[batch_idx, dump.flatten(), :].reshape(dump.shape[0], dump.shape[1], -2)

    if dset == 'shake':
        box = np.load(
            f'{root}/VGPL-Dynamics-Prior/dataset/box_sampling/shake_extra')
        par_fps = torch.load(
            f'{root}VGPL-Dynamics-Prior/dataset/particle_fps_new/shake/fps_300/{traj}.bin')

    if dset == 'pour_extra':
        extra_id = traj // 100
        box = np.load(
            f'{root}VGPL-Dynamics-Prior/dataset/box_sampling/pour_extra/sampling_8_interval_0.15/pkg{extra_id}.npy')
        par_fps = torch.load(
            f'{root}VGPL-Dynamics-Prior/dataset/particle_fps_new/pour_extra/fps_250/{traj}.bin')


    if dset == 'granular_push_gt':
        import pickle
        box = np.load(f'{root}VGPL-Dynamics-Prior/dataset/box_sampling/granular_push/sampling_8_interval_0.4/pkg.npy')
        par_fps = pickle.load(open(f'/home/htxue/datasets/data_GranularPushExtra/{traj}/info.p', 'rb'))['particles']
    if dset == 'granular_push':
        box = np.load(f'{root}VGPL-Dynamics-Prior/dataset/box_sampling/granular_push/sampling_8_interval_0.4/pkg.npy')
        par_fps = par_fps = torch.load(
            f'{root}VGPL-Dynamics-Prior/dataset/particle_fps_new/granular_push/fps_300/{traj}.bin')

    for frames in tqdm(range(0, 300, interval)):
        ag = 100 - 50 * (frames) / 300
        ag *= -1
        if dset == 'pour' or dset == 'pour_gt':
            # ag = (300 - frames) / 300 * 180
            # water_pcd = torch.load(f'/home/htxue/data/mit/pixel-nerf/datasets/dynamics/pour_50_6_5/{traj}/{frames}/voxel_{vx}/water_pcd.bin')
            water_pcd = par_fps[frames]
            box_pcd = box[traj][frames]
            pcd_all = np.concatenate([box_pcd, water_pcd])
            c = ['b'] * box_pcd.shape[0] + ['r'] * water_pcd.shape[0]
            capture_scene_image(pcd_all, output_pth=tar_path + "/" + f"{frames}.png", angle=ag, color=c)
        if dset == 'shake':
            # water_pcd = torch.load(f'/home/htxue/data/mit/pixel-nerf/datasets/dynamics/shake_50_6_5/{traj}/{frames}/voxel_{vx}/water_pcd.bin')
            # cube_pcd = torch.load(f'/home/htxue/data/mit/pixel-nerf/datasets/dynamics/shake_50_6_5/{traj}/{frames}/voxel_{vx}/box_pcd.bin')
            water_pcd = par_fps[frames][64:, :]
            cube_pcd = par_fps[frames][:64, :]
            box_pcd = box[traj][frames]
            pcd_all = np.concatenate([box_pcd, water_pcd, cube_pcd])
            c = ['b'] * box_pcd.shape[0] + ['r'] * water_pcd.shape[0] + ['c'] * cube_pcd.shape[0]
            capture_scene_image(pcd_all , output_pth=tar_path + "/" + f"{frames}.png", angle=ag, color=c)
        if dset == 'pour_extra':
            water_pcd = par_fps[frames]
            box_pcd = box[traj % 100][frames]
            pcd_all = np.concatenate([box_pcd, water_pcd])
            c = ['b'] * box_pcd.shape[0] + ['r'] * water_pcd.shape[0]
            capture_scene_image(pcd_all, output_pth=tar_path + "/" + f"{frames}.png", angle=ag, color=c)
        if dset == 'granular_push_gt' or dset == 'granular_push':
            water_pcd = par_fps[frames]
            box_pcd = box[traj][frames]
            pcd_all = np.concatenate([box_pcd, water_pcd])
            c = ['b'] * box_pcd.shape[0] + ['r'] * water_pcd.shape[0]
            capture_scene_image(pcd_all, output_pth=tar_path + "/" + f"{frames}.png", angle=ag, color=c, h=100)



    def get_image_to_tensor_balanced(image_size=0):
        ops = []
        if image_size > 0:
            ops.append(transforms.Resize(image_size))
        ops.extend(
            [transforms.ToTensor()]
        )
        return transforms.Compose(ops)


    image2tensor = get_image_to_tensor_balanced()

    frame = []

    for i in range(0, 300, interval):
        print(i)
        image = imageio.imread(tar_path + "/" + str(i) + ".png")
        image[(image == 0).all(-1), :] = 255
        image = image2tensor(image)
        C, H, W = image.shape
        image = image.permute(1, 2, 0)
        frame.append(image.cpu().numpy())

    frames = np.stack(frame)
    print(frames.shape)
    imageio.mimwrite(
        video_path + "/" + f'{dset}_traj{traj}.mp4', (frames * 255).astype(np.uint8), quality=8
    )
    print('video writen to: ', video_path + f'{dset}_traj{traj}.mp4')
