import os
import torch
from torchvision import transforms
import imageio
import numpy as np
import matplotlib.pylab as plt
from tqdm import tqdm
import pickle

dset = 'shake'
base = 9
assert  base < 100
traj_id_list = [base + 100 * i for i in range(5)]

interval = 5
only_box = False

for traj_id in traj_id_list:

    print(traj_id)


    gt_particle = pickle.load(open(f'/home/htxue/datasets/data_FluidShakeExtra_new/{traj_id}/info.p', 'rb'))
    gt_particle = gt_particle['particles']


    extra_id = traj_id // 100
    box = np.load('/home/htxue/data/mit/visual_dynamics/VGPL-Dynamics-Prior/dataset/box_sampling/shake_extra/sampling_8_interval_0.2/{}.npy'.format(extra_id))
    box = torch.from_numpy(box)





    def capture_scene_image(feature, output_pth='debug.png',  angle=100, color=None, h=30, close_sticks=False):
        x, y, z = feature[:, 0], feature[:, 1], feature[:, 2]
        fig = plt.figure()
        ax3D = fig.add_subplot(111, projection='3d')
        ax3D.scatter(x, z, y, s=15, marker='o', color=color)
        ax3D.view_init(h, angle)
        ax3D.set_xlabel('x')
        ax3D.set_ylabel('y')
        ax3D.set_zlabel('z')
        ax3D.view_init(h, angle)

        if close_sticks:
            ax3D.set_xticks([])
            ax3D.set_yticks([])
            ax3D.set_zticks([])
            ax3D.grid('on')
        if 'pour' in dset:
            ax3D.set_xlim3d(-1, 1)
            ax3D.set_ylim3d(-1, 1)
            ax3D.set_zlim3d(0.5, 2.5)
        if dset == 'shake':
            ax3D.set_xlim3d(-1, 1)
            ax3D.set_ylim3d(-1, 1)
            ax3D.set_zlim3d(0, 2)
        if 'granular' in dset:
            ax3D.set_xlim3d(-2, 2)
            ax3D.set_ylim3d(-4, 0)
            ax3D.set_zlim3d(0, 4)

        plt.savefig(output_pth)
        plt.close()



    def get_color_info(traj_id):
        scene_info = pickle.load(open('/home/htxue/datasets/data_FluidShakeExtra_new/{}/info.p'.format(traj_id), 'rb'))[
            'scene_params']

        color_dir = {
            '010': 'green',
            '100': 'red',
            '110': 'yellow'
        }

        v_red = v_green = v_yellow = 0  # initialize

        v_list = [v_red, v_green, v_yellow]

        def assign_color(name, v_list):
            if name == 'red':
                v_list[0] = 1
            elif name == 'green':
                v_list[1] = 1
            elif name == 'yellow':
                v_list[2] = 1

        if len(scene_info) == 41:
            v_list[0] = v_list[1] = v_list[2] = 1

        elif len(scene_info) == 31:
            color1 = str(int(scene_info[-5])) + str(int(scene_info[-4])) + str(int(scene_info[-3]))
            color2 = str(int(scene_info[-15])) + str(int(scene_info[-14])) + str(int(scene_info[-13]))
            assign_color(color_dir[color1], v_list)
            assign_color(color_dir[color2], v_list)

        elif len(scene_info) == 21:
            color1 = str(int(scene_info[-5])) + str(int(scene_info[-4])) + str(int(scene_info[-3]))
            assign_color(color_dir[color1], v_list)

        [v_red, v_green, v_yellow] = v_list

        return v_red, v_green, v_yellow







    path = "/home/htxue/data/mit/visual_dynamics/pixel-nerf/datasets/dynamics_new/shake_extra_new/{}/".format(traj_id)
    os.system('rm -rf ./tmp/')
    os.makedirs('./tmp/')


    def get_image_to_tensor_balanced(image_size=0):
        ops = []
        if image_size > 0:
            ops.append(transforms.Resize(image_size))
        ops.extend(
            [transforms.ToTensor()]
        )
        return transforms.Compose(ops)


    image2tensor = get_image_to_tensor_balanced()

    frame = []

    v_red, v_green, v_yellow = get_color_info(traj_id)
    box_num = v_red + v_green + v_yellow

    for i in tqdm(range(0, 300, interval)):

        box_gt = torch.from_numpy(gt_particle[i][-box_num*125:])

        water = torch.load(path+f"{i}/voxel_50/water_pcd.bin")
        red_box = torch.load(path+f"{i}/voxel_50/red_box_pcd.bin")
        green_box = torch.load(path+f"{i}/voxel_50/green_box_pcd.bin")
        yellow_box = torch.load(path+f"{i}/voxel_50/yellow_box_pcd.bin")

        if not only_box:

            particles = [box_gt, box[traj_id % 100][i], water]

            colors = ['orange'] * box_gt.shape[0] + ['grey'] * box[-1][i].shape[0] + ['blue'] * water.shape[0]
        else:

            particles = [box_gt, box[traj_id % 100][i]]

            colors = ['orange'] * box_gt.shape[0] + ['grey'] * box[-1][i].shape[0]





        if v_red:
            particles.append(red_box)
            colors = colors + ['red'] * red_box.shape[0]
        if v_green:
            particles.append(green_box)
            colors = colors + ['green'] * green_box.shape[0]
        if v_yellow:
            particles.append(yellow_box)
            colors = colors + ['yellow'] * yellow_box.shape[0]


        capture_scene_image(torch.cat(particles), color=colors,
                            output_pth='./tmp/{}.png'.format(i), angle=100 + 20 * (i)/300,
                            close_sticks=True, h = 5 if only_box else 20)

    for i in range(0, 300, interval):
        print(i)
        image = imageio.imread("./tmp/" + str(i) + ".png")
        image[(image == 0).all(-1), :] = 255
        image = image2tensor(image)
        C, H, W = image.shape
        image = image.permute(1, 2, 0)
        frame.append(image.cpu().numpy())

    frames = np.stack(frame)
    print(frames.shape)
    save_name = 'nerf_shake/shake_{}.mp4'.format(traj_id) if not only_box else 'nerf_shake/shake_{}_box.mp4'.format(traj_id)
    imageio.mimwrite(
        save_name, (frames * 255).astype(np.uint8), quality=8
    )


    os.system('rm -rf ./tmp/')



