import sys
import os

sys.path.insert(
    0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))
)

import torch
import torch.nn.functional as F
import numpy as np
import imageio
import util
import warnings
from data import get_split_dataset
from render import NeRFRenderer
from model import make_model
from scipy.interpolate import CubicSpline
from tqdm import tqdm
import matplotlib.pylab as plt
from dotmap import DotMap
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time
import torchvision
import pickle

def extra_args(parser):
    parser.add_argument(
        "--split",
        type=str,
        default="all",
        help="Split of data to use train | val | test",
    )
    parser.add_argument(
        "--source",
        "-P",
        type=str,
        default="0 1 2 3 4 5",
        help="Source view(s) in image, in increasing order. -1 to do random",
    )
    parser.add_argument(
        "--scale", type=float, default=1.0, help="Video scale relative to input size"
    )
    parser.add_argument(
        "--root",
        type=str,
        default="/home/htxue/data/mit/visual_dynamics/pixel-nerf/"
    )
    parser.add_argument(
        "--voxel_num",
        type=int,
        default=100
    )
    parser.add_argument(
        "--save_path",
        type=str,
        default='datasets/debug/'
    )
    parser.add_argument(
        "--traj_num",
        type=int,
        default=500
    )
    parser.add_argument(
        "--frame_num",
        type=int,
        default=300
    )
    parser.add_argument(
        "--inter_chunk",
        type=int,
        default=4000
    )
    parser.add_argument(
        "--seg_method",
        type=str,
        default='rgb',
        help='method that is used to seg out the water'
    )
    parser.add_argument(
        "--need_midfeature",
        type=bool,
        default=True,
        help='the output of the net is (not) the mid feature'
    )
    parser.add_argument("--fps", type=int, default=30, help="FPS of video")
    return parser



args, conf = util.args.parse_args(extra_args, default_conf="conf/default_mv.conf")
args.resume = True
args.save_path = args.root + args.save_path + args.name + "/"


print(args)
device = util.get_cuda(args.gpu_id[0])
dset = get_split_dataset(
    args.dataset_format, args.datadir, want_split=args.split, training=False, img_format='jpg'
)

def capture_scene_image(feature, output_pth='debug.png', angle=100, color=None):

    x, y, z = feature[:, 0], feature[:, 1], feature[:, 2]
    fig = plt.figure()
    ax3D = fig.add_subplot(111, projection='3d')
    ax3D.scatter(x, z, y, s=5, marker='o', c=color)
    ax3D.view_init(10, angle)
    ax3D.set_xlabel('x')
    ax3D.set_ylabel('y')
    ax3D.set_zlabel('z')
    plt.savefig(output_pth)
    plt.close()

def rescale(pcl_mat, min, max, N):
    min, max = np.array(min), np.array(max)
    scale = (max - min) / N
    assert scale[0] == scale[1] == scale[2]
    pcl_mat = pcl_mat * scale + min
    return pcl_mat

def mat2pcl(N, voxel_mat, valid):
    """
    from voxel-mat to point cloud
    :param voxel_mat: N * N * N feature
    :param valid: N * N * N bool
    :return: N * N * N * [X, Y, Z, feature_dim]
    """
    if isinstance(voxel_mat, np.ndarray):
        voxel_mat = torch.from_numpy(voxel_mat)
    if isinstance(valid, np.ndarray):
        valid = torch.from_numpy(valid)
    voxel_mat = voxel_mat.view(N, N, N, -1)
    valid = valid.view(N, N, N)

    idx = valid.nonzero()  # (K * 3)

    pcl_feature = voxel_mat[valid]

    pcl_rep = torch.cat([idx, pcl_feature], dim=-1)

    return pcl_rep

def load_shape_info(dset, root, id=None):
    # shape : some env objects like container and robo-arms
    if dset == 'pour':
        info_pth = os.path.join(root, 'datasets/shape_info/', 'FluidPour.txt')
    elif dset == 'shake':
        info_pth = os.path.join(root, 'datasets/shape_info/', 'FluidShake.txt')
    elif dset == 'pour_extra':
        assert id is not None
        info_pth = os.path.join(root, 'datasets/shape_info/', f'shapes_FluidPourExtra_{id}.txt')
    elif dset == 'shake_extra':
        info_pth = os.path.join(root, 'datasets/shape_info/', f'shapes_FluidShakeExtra_{id}.txt')
    f = open(info_pth)
    mesh_num = box_num = 0
    shape_info = []
    for lines in f:
        if lines.startswith('asset'):
            mesh_num += 1
            continue
        else:
            """
            [4.   0.02 4.  ] 0 [0.9 0.9 0.9]
            """
            end_pos = lines.find("]")
            visible = int(lines[end_pos + 2])

            x, y, z = lines[:end_pos].split()
            x = x[1:]

            x, y, z = float(x), float(y), float(z)

            shape_info.append([(x, y, z), visible])
    return mesh_num, shape_info

def rotation_matrix_from_quaternion(quant):
    # params dim - 4: w, x, y, z
    if isinstance(quant, np.ndarray):
        quant = torch.from_numpy(quant)
    if len(quant.shape) == 3:
        quant = quant.unsqueeze(1)
    one = torch.ones(1, 1)
    zero = torch.zeros(1, 1)

    # multiply the rotation matrix from the right-hand side
    # the matrix should be the transpose of the conventional one

    # Reference
    # http://www.euclideanspace.com/maths/geometry/rotations/conversions/quaternionToMatrix/index.htm

    quant = quant / torch.norm(quant)
    x, y, z, w = quant[0].view(1, 1), quant[1].view(1, 1), quant[2].view(1, 1), quant[3].view(1, 1)

    rot = torch.cat((
        torch.cat((one - y * y * 2 - z * z * 2, x * y * 2 + z * w * 2, x * z * 2 - y * w * 2), 1),
        torch.cat((x * y * 2 - z * w * 2, one - x * x * 2 - z * z * 2, y * z * 2 + x * w * 2), 1),
        torch.cat((x * z * 2 + y * w * 2, y * z * 2 - x * w * 2, one - x * x * 2 - y * y * 2), 1)), 0)

    return rot
def refine_by_container(name, pcl, box_list):
    """

    :param name:
    :param pcl: torch.Tensor
    :param box_list:
    :return:
    """
    n = pcl.shape[0]

    valid = torch.zeros(n).bool()
    for box in box_list:
        extent, center, rot = box
        rot = rotation_matrix_from_quaternion(rot).t()
        pcl_boxcoord = pcl[:, :3] - center  # N * 3
        rot_inv = torch.inverse(rot)
        rot_inv = rot_inv.expand(n, 3, 3)
        pcl_boxcoord = torch.bmm(rot_inv, pcl_boxcoord[:, :, None])[:, :, 0]  # (N * 3 * 3) * (N * 3 * 1)

        pcl_x_extent = pcl_boxcoord[:, 0]
        pcl_y_extent = pcl_boxcoord[:, 1]
        pcl_z_extent = pcl_boxcoord[:, 2]

        x_t, y_t, z_t = extent

        valid_x = (pcl_x_extent > -x_t/2) * (pcl_x_extent < x_t/2)
        valid_y = (pcl_y_extent > -y_t / 2) * (pcl_y_extent < y_t / 2)
        valid_z = (pcl_z_extent > -z_t / 2) * (pcl_z_extent < z_t / 2)

        valid = (valid | valid_x * valid_y * valid_z)


    return pcl[valid]


class Scene2Particle:
    def __init__(self, dset, args, conf):
        self.dset = dset
        self.args = args
        self.conf = conf

        self.net = make_model(conf["model"], using_intermediate_feature=args.need_midfeature).to(device=device)
        self.net.load_weights(args)

        self.z_near, self.z_far = self.dset.z_near, self.dset.z_far

        self.N = args.voxel_num

        if 'real' not in args.name:
            if "pour" in args.name or "Pour" in args.name:
                ty = np.linspace(1, 3.5, self.N + 1)
                tx = np.linspace(-1.25, 1.25, self.N + 1)
                tz = np.linspace(-1.25, 1.25, self.N + 1)
            if "shake" in args.name or "Shake" in args.name:
                ty = np.linspace(-0.2, 3, self.N + 1)
                tx = np.linspace(-1.6, 1.6, self.N + 1)
                tz = np.linspace(-1.6, 1.6, self.N + 1)
        if 'real' in args.name and 'pour' in args.name:
            ty = np.linspace(-10, 10, self.N + 1)
            tx = np.linspace(-10, 10, self.N + 1)
            tz = np.linspace(-10, 10, self.N + 1)
        if 'granular' in args.name:
            ty = np.linspace(0, 6, self.N + 1)
            tx = np.linspace(-3, 3, self.N + 1)
            tz = np.linspace(-3, 3, self.N + 1)
        self.xyz_range = [tx, ty, tz]

        self.save_path = args.save_path

        self.traj_num, self.frame_num = args.traj_num, args.frame_num

        if "extra" not in args.name:
            if "pour" in args.name or "Pour" in args.name:
                info_p_path = os.path.join('/home/htxue/datasets/data_FluidPour/')
                self.mesh_num, self.shape_info = load_shape_info('pour', '../')
            if "shake" in args.name or "Shake" in args.name:
                info_p_path = os.path.join('/home/htxue/datasets/data_FluidShake/')
                self.mesh_num, self.shape_info = load_shape_info('shake', '../')
            if args.name == 'granular_push':
                info_p_path = os.path.join('/home/htxue/datasets/data_GranularPushExtra/')
                self.mesh_num, self.shape_info = load_shape_info('shake', '../')
            center_list = []
            quant_list = []
            import pickle
            for i in range(args.traj_num):
                info = pickle.load(open(info_p_path + str(i) + "/info.p", 'rb'))
                shap_info = info['shape_states']  # frame * shape_number * 14
                mesh_n = self.mesh_num
                center, quaternion = shap_info[:, mesh_n:, 3:6], shap_info[:, mesh_n:, 10:14]
                center_list.append(center)
                quant_list.append(quaternion)
            self.center = np.stack(center_list)
            self.quant = np.stack(quant_list)
        else:
            if "pour" in args.name or "Pour" in args.name:
                info_p_path = os.path.join('/home/htxue/datasets/data_FluidPourExtra/')
                self.mesh_num = []
                self.shape_info = []
                for i in range(5):
                    mesh_n, shape_i = load_shape_info('pour_extra', '../', i)
                    self.mesh_num.append(mesh_n)
                    self.shape_info.append(shape_i)
            if "shake" in args.name or "Shake" in args.name:
                info_p_path = os.path.join('/home/htxue/datasets/data_FluidShakeExtra_new/')
                self.mesh_num = []
                self.shape_info = []
                for i in range(5):
                    mesh_n, shape_i = load_shape_info('shake_extra', '../', i)
                    self.mesh_num.append(mesh_n)
                    self.shape_info.append(shape_i)

            center_list = []
            quant_list = []
            import pickle
            for i in range(args.traj_num):
                idx = i // 100
                info = pickle.load(open(info_p_path + str(i) + "/info.p", 'rb'))
                shap_info = info['shape_states']  # frame * shape_number * 14
                mesh_n = self.mesh_num[idx]
                center, quaternion = shap_info[:,mesh_n:,  3:6], shap_info[:,mesh_n:, 10:14]
                center_list.append(center)
                quant_list.append(quaternion)
            self.center = np.stack(center_list)
            self.quant = np.stack(quant_list)

        # self.traj_id = None
        # self.traj_info = None


    def saving_pkg(self, pkg, pth):
        for k in pkg.keys():
            torch.save(pkg[k], pth+"/{}.bin".format(k))


    def get_particle(self, subset_id, device, draw, sigma_min=5):


        st_time = time.time()
        args = self.args
        data = self.dset[subset_id]
        device = util.get_cuda(args.gpu_id[0])


        images = data["images"]  # (NV, 3, H, W)
        poses = data["poses"]  # (NV, 4, 4)
        focal = data["focal"]
        if isinstance(focal, float):
            # Dataset implementations are not consistent about
            # returning float or scalar tensor in case of fx=fy
            focal = torch.tensor(focal, dtype=torch.float32)
        c = data.get("c")
        if c is not None:
            c = c.to(device=device).unsqueeze(0)
        NV, _, H, W = images.shape
        focal = focal.to(device=device)
        source = torch.tensor(list(map(int, args.source.split())), dtype=torch.long)
        NS = len(source)
        random_source = NS == 1 and source[0] == -1
        assert not (source >= NV).any()
        # print("H, W:", H, W)

        if random_source:
            src_view = torch.randint(0, NV, (1,))
        else:
            src_view = source

        if args.scale != 1.0:
            Ht = int(H * args.scale)
            Wt = int(W * args.scale)
            if abs(Ht / args.scale - H) > 1e-10 or abs(Wt / args.scale - W) > 1e-10:
                warnings.warn(
                    "Inexact scaling, please check {} times ({}, {}) is integral".format(
                        args.scale, H, W
                    )
                )
            H, W = Ht, Wt

        # load model
        # print('src views', src_view)
        self.net.encode(
            images[src_view].unsqueeze(0),
            poses[src_view].unsqueeze(0).to(device=device),
            focal,
            c=c,
        )

        traj_id = subset_id // self.traj_num
        # define a scene
        z_near = self.z_near
        z_far = self.z_far
        N = args.voxel_num
        tx, ty, tz = self.xyz_range  # xyz coordinate range in the world space

        query_pts = np.stack(np.meshgrid(tx, ty, tz), -1).astype(np.float32)  # query points

        # print(query_pts.shape)
        sh = query_pts.shape
        flat = query_pts.reshape([-1, 3])
        flat = torch.from_numpy(flat).to(args.gpu_id[0])

        fn = lambda i0, i1: self.net(flat[None, i0:i1, :], viewdirs=torch.zeros(flat[i0:i1].shape).to(args.gpu_id[0]))
        chunk = args.inter_chunk

        if args.seg_method == 'rgb':
            # feature = np.concatenate(
            #     [fn(i, i + chunk)[0][0].detach().cpu().numpy() for i in tqdm(range(0, flat.shape[0], chunk))], 0)
            raw = torch.cat(
                [fn(i, i + chunk)[1][0].detach().cpu() for i in range(0, flat.shape[0], chunk)], 0)
            raw = raw.numpy()
            sigma = np.reshape(raw, list(sh[:-1]) + [-1])  # N * N * N * 1
            sigma = np.maximum(sigma[..., -1], 0.)  # N * N * N * 1

            color = np.maximum(raw[:, :3], 0.)  # (N ** 3) * 1

            feature = color


            if 'real' in args.name and 'pour' in args.name:
                sigma_valid = sigma.flatten() > 10
                print(sigma_valid.sum())
                pcl_all = mat2pcl(N + 1, feature, sigma_valid)

            if 'granular' in args.name:
                traj_id = subset_id // self.args.frame_num
                frame_id = subset_id % self.args.frame_num

                info = pickle.load(open(f'/home/htxue/datasets/data_GranularPushExtra/{traj_id}/info.p', 'rb'))
                info = info['particles'][frame_id]
                x_min_sand, x_max_sand = info[:, 0].min(), info[:, 0].max()
                y_min_sand, y_max_sand = info[:, 1].min(), info[:, 1].max()
                z_min_sand, z_max_sand = info[:, 2].min(), info[:, 2].max()


                sigma_valid = sigma.flatten() > 1
                # torch.save(pcl_all, 'pcd_all.bin')
                feature_valid = (color[:, 0] < 0.8) * (color[:, 1] < 0.8) * (color[:, 2] < 0.8) * \
                                (color[:, 0] > 0.2) * (color[:, 1] > 0.2) * (color[:, 2] > 0.2) * \
                                (color[:, 1] - color[:, 2]) > 0.05


                valid = sigma_valid * feature_valid


                pcl_granular = mat2pcl(N + 1, feature, valid)[:, [1, 0, 2]]  # num_valid_water * 3
                pcl_all = mat2pcl(N + 1, feature, sigma_valid)[:, [1, 0, 2]]  # num_valid_all * 3

                x_min, y_min, z_min = self.xyz_range[0][0], self.xyz_range[1][0], self.xyz_range[2][0]
                x_max, y_max, z_max = self.xyz_range[0][-1], self.xyz_range[1][-1], self.xyz_range[2][-1]
                N = self.xyz_range[0].shape[0]



                pcl_granular[:, :3] = rescale(pcl_granular[:, :3], [x_min, y_min, z_min], [x_max, y_max, z_max], N)
                pcl_all[:, :3] = rescale(pcl_all[:, :3], [x_min, y_min, z_min], [x_max, y_max, z_max], N)


                valid_pos = (pcl_granular[:, 0] >= x_min_sand) * (pcl_granular[:, 0] <= x_max_sand) * \
                            (pcl_granular[:, 1] >= y_min_sand) * (pcl_granular[:, 1] <= y_max_sand) * \
                            (pcl_granular[:, 2] >= z_min_sand) * (pcl_granular[:, 2] <= z_max_sand)
                pcl_granular = pcl_granular[valid_pos]

                pkg = {
                    "xyz_range": self.xyz_range,
                    "N": self.N,
                    "args": self.args,
                    "granular_pcd": pcl_granular[pcl_granular[:, 1] > 0.99], # above the desk
                    "all_pcd": pcl_all
                }


                print(pcl_granular.shape)




                saving_path = os.path.join(self.save_path, str(traj_id), str(frame_id),
                                           "voxel_{}".format(args.voxel_num))
                if not os.path.exists(saving_path):
                    os.makedirs(saving_path)

                if draw:
                    capture_scene_image(pcl_all, saving_path + "/all.png")
                    capture_scene_image(pcl_granular, saving_path + "/sand.png")



            if 'real' not in args.name and ('pour' in args.name or 'Pour' in args.name):
                feature_valid = (color[:, 2] > color[:, 1]) * (color[:, 2] > (color[:, 0] + 0.2))  # (N ** 3) * 1
                sigma_valid = sigma.flatten() > sigma_min

                valid = feature_valid * sigma_valid
                pcl_water = mat2pcl(N + 1, feature, valid)[:, [1, 0, 2]]  # num_valid_water * 6
                pcl_all = mat2pcl(N + 1, feature, sigma_valid)[:, [1, 0, 2]]  # num_valid_all * 6

                x_min, y_min, z_min = self.xyz_range[0][0], self.xyz_range[1][0], self.xyz_range[2][0]
                x_max, y_max, z_max = self.xyz_range[0][-1], self.xyz_range[1][-1], self.xyz_range[2][-1]
                N = self.xyz_range[0].shape[0]

                pcl_water[:, :3] = rescale(pcl_water[:, :3], [x_min, y_min, z_min], [x_max, y_max, z_max], N)
                pcl_all[:, :3] = rescale(pcl_all[:, :3], [x_min, y_min, z_min], [x_max, y_max, z_max], N)

                traj_id = subset_id // self.args.frame_num
                frame_id = subset_id % self.args.frame_num


                ### bbx

                if 'extra' not in args.name:

                    center_box = (self.center[traj_id][frame_id][15] + self.center[traj_id][frame_id][16]) / 2
                    center_box[1] += self.shape_info[15][0][1]
                    box_1 = [
                        (self.shape_info[14][0][0], self.shape_info[15][0][1] * 3, self.shape_info[14][0][2]),
                        center_box,
                        self.quant[traj_id][frame_id][14]
                    ]
                    box_2 = [
                        (self.shape_info[1][0][0], self.shape_info[2][0][1], self.shape_info[1][0][2]),
                        (self.center[traj_id][frame_id][2] + self.center[traj_id][frame_id][3]) / 2,
                        self.quant[traj_id][frame_id][2]
                    ]
                else:
                    extra_id = traj_id // 100
                    center_box = (self.center[traj_id][frame_id][15] + self.center[traj_id][frame_id][16]) / 2
                    center_box[1] += self.shape_info[extra_id][15][0][1]
                    box_1 = [
                        (self.shape_info[extra_id][14][0][0], self.shape_info[extra_id][15][0][1] * 3, self.shape_info[extra_id][14][0][2]),
                        center_box,
                        self.quant[traj_id][frame_id][14]
                    ]
                    box_2 = [
                        (self.shape_info[extra_id][1][0][0], self.shape_info[extra_id][2][0][1], self.shape_info[extra_id][1][0][2]),
                        (self.center[traj_id][frame_id][2] + self.center[traj_id][frame_id][3]) / 2,
                        self.quant[traj_id][frame_id][2]
                    ]



                pcl_water = refine_by_container('pour', pcl_water, [box_1, box_2])
                pkg = {
                    "xyz_range": self.xyz_range,
                    "N": self.N,
                    "args": self.args,
                    "water_pcd": pcl_water,
                    "all_pcd": pcl_all
                }


                saving_path = os.path.join(self.save_path, str(traj_id), str(frame_id),
                                           "voxel_{}".format(args.voxel_num))
                if not os.path.exists(saving_path):
                    os.makedirs(saving_path)

                if draw:
                    capture_scene_image(pcl_water, saving_path+ "/vis_water.png")

            elif 'shake' in args.name or 'Shake' in args.name:

                if 'extra' not in args.name:
                    feature_valid_box = (color[:, 2] < 0.1) * (np.abs(color[:, 0] - color[:, 1]) < 0.1) * (color[:, 0] > 0.)
                    feature_valid_water = (color[:, 2] > color[:, 1]) * (color[:, 2] > (color[:, 0] + 0.2))
                    sigma_valid = sigma.flatten() > sigma_min

                    valid_water = feature_valid_water * sigma_valid
                    valid_box = feature_valid_box * sigma_valid


                    pcl_water = mat2pcl(N + 1, feature, valid_water)[:, [1, 0, 2]]
                    pcl_box = mat2pcl(N + 1, feature, valid_box)[:, [1, 0, 2]]

                    x_min, y_min, z_min = self.xyz_range[0][0], self.xyz_range[1][0], self.xyz_range[2][0]
                    x_max, y_max, z_max = self.xyz_range[0][-1], self.xyz_range[1][-1], self.xyz_range[2][-1]
                    N = self.xyz_range[0].shape[0]

                    pcl_water[:, :3] = rescale(pcl_water[:, :3], [x_min, y_min, z_min], [x_max, y_max, z_max], N)
                    pcl_box[:, :3] = rescale(pcl_box[:, :3], [x_min, y_min, z_min], [x_max, y_max, z_max], N)

                    traj_id = subset_id // self.args.frame_num
                    frame_id = subset_id % self.args.frame_num

                    box_center = (self.center[traj_id][frame_id][5] + self.center[traj_id][frame_id][6]) / 2
                    box_center[1] /= 2

                    box_x = self.center[traj_id][frame_id][6][0] - self.center[traj_id][frame_id][5][0]
                    box_z = self.center[traj_id][frame_id][8][2] - self.center[traj_id][frame_id][7][2]
                    box_y = (box_center[1] - self.center[traj_id][frame_id][3][1]) * 2

                    box_1 = [
                        (box_x, box_y, box_z),
                        box_center,
                        self.quant[traj_id][frame_id][5]
                    ]
                    pcl_water = refine_by_container('shake', pcl_water, [box_1])

                    pcl_box = refine_by_container('shake', pcl_box, [box_1])

                    pkg = {
                        "xyz_range": self.xyz_range,
                        "N": self.N,
                        "args": self.args,
                        "water_pcd": pcl_water,
                        "box_pcd": pcl_box
                    }

                    saving_path = os.path.join(self.save_path, str(traj_id), str(frame_id),
                                               "voxel_{}".format(args.voxel_num))
                    if not os.path.exists(saving_path):
                        os.makedirs(saving_path)

                    if draw:
                        capture_scene_image(pcl_water, output_pth=saving_path + "/water.png")
                        capture_scene_image(pcl_box, output_pth=saving_path + "/box.png")
                        self.capture_scene_image(sigma_valid, feature, output_pth=saving_path + "/all.png", angle=180)

                else:
                    traj_id = subset_id // self.args.frame_num
                    frame_id = subset_id % self.args.frame_num

                    x_min, y_min, z_min = self.xyz_range[0][0], self.xyz_range[1][0], self.xyz_range[2][0]
                    x_max, y_max, z_max = self.xyz_range[0][-1], self.xyz_range[1][-1], self.xyz_range[2][-1]
                    N = self.xyz_range[0].shape[0]

                    # refine box
                    extra_id = traj_id // 100
                    box_center = (self.center[traj_id][frame_id][5] + self.center[traj_id][frame_id][6]) / 2
                    box_center[1] /= 2
                    box_x = self.center[traj_id][frame_id][6][0] - self.center[traj_id][frame_id][5][0]
                    box_z = self.center[traj_id][frame_id][8][2] - self.center[traj_id][frame_id][7][2]
                    box_y = (box_center[1] - self.center[traj_id][frame_id][3][1]) * 2
                    box_1 = [
                        (box_x, box_y, box_z),
                        box_center,
                        self.quant[traj_id][frame_id][5]
                    ]



                    feature_valid_box_yellow = (color[:, 2] < 0.1) * (np.abs(color[:, 0] - color[:, 1]) < 0.1) * (color[:, 0] > 0.0)

                    feature_valid_box_red = ((color[:, 0] - color[:, 1]) > 0.3) * ((color[:, 0] - color[:, 2]) > 0.3)
                    feature_valid_box_green = ((color[:, 1] - color[:, 2]) > 0.3) * ((color[:, 1] - color[:, 0]) > 0.3)
                    sigma_valid = sigma.flatten() > sigma_min

                    feature_valid_water = (color[:, 2] > color[:, 1]) * (color[:, 2] > (color[:, 0] + 0.2))

                    feature_valid_box_red = feature_valid_box_red * sigma_valid
                    feature_valid_box_green = feature_valid_box_green * sigma_valid
                    feature_valid_box_yellow = feature_valid_box_yellow * sigma_valid
                    feature_valid_water = feature_valid_water * sigma_valid

                    pcl_box_red = pcl_box_green = pcl_box_yellow = None

                    pcd_all = mat2pcl(N, feature, sigma_valid)

                    scene_info = pickle.load(open('/home/htxue/datasets/data_FluidShakeExtra_new/{}/info.p'.format(traj_id), 'rb'))['scene_params']

                    color_dir = {
                        '010': 'green',
                        '100': 'red',
                        '110': 'yellow'
                    }

                    v_red = v_green = v_yellow = 0  # initialize

                    v_list = [v_red, v_green, v_yellow]

                    def assign_color(name, v_list):
                        if name == 'red':
                            v_list[0] = 1
                        elif name == 'green':
                            v_list[1] = 1
                        elif name == 'yellow':
                            v_list[2] = 1


                    if len(scene_info) ==  41:
                        v_list[0] = v_list[1] = v_list[2] = 1

                    elif len(scene_info) == 31:
                        color1 = str(int(scene_info[-5]))+str(int(scene_info[-4]))+str(int(scene_info[-3]))
                        color2 = str(int(scene_info[-15]))+str(int(scene_info[-14]))+str(int(scene_info[-13]))
                        assign_color(color_dir[color1], v_list )
                        assign_color(color_dir[color2], v_list )

                    elif len(scene_info) == 21:
                        color1 = str(int(scene_info[-5]))+str(int(scene_info[-4]))+str(int(scene_info[-3]))
                        assign_color(color_dir[color1], v_list )

                    [v_red, v_green, v_yellow] = v_list


                    if v_red:
                        pcl_box_red = mat2pcl(N, feature, feature_valid_box_red)[:, [1, 0, 2]]
                        pcl_box_red = rescale(pcl_box_red, [x_min, y_min, z_min], [x_max, y_max, z_max], N)
                        pcl_box_red = refine_by_container('shake',
                                                          pcl_box_red,
                                                          [box_1])
                    if v_green:
                        pcl_box_green = mat2pcl(N, feature, feature_valid_box_green)[:, [1, 0, 2]]
                        pcl_box_green = rescale(pcl_box_green, [x_min, y_min, z_min], [x_max, y_max, z_max], N)
                        pcl_box_green = refine_by_container('shake',
                                                          pcl_box_green,
                                                          [box_1])
                    if v_yellow:
                        pcl_box_yellow = mat2pcl(N, feature, feature_valid_box_yellow)[:, [1, 0, 2]]
                        pcl_box_yellow = rescale(pcl_box_yellow, [x_min, y_min, z_min], [x_max, y_max, z_max], N)
                        pcl_box_yellow = refine_by_container('shake',
                                                          pcl_box_yellow,
                                                          [box_1])


                    pcl_water = mat2pcl(N , feature, feature_valid_water)[:, [1, 0, 2]]
                    pcl_water[:, :3] = rescale(pcl_water[:, :3], [x_min, y_min, z_min], [x_max, y_max, z_max], N)

                    pcl_water = refine_by_container('shake', pcl_water,[box_1])

                    pkg = {
                        "xyz_range": self.xyz_range,
                        "N": self.N,
                        "args": self.args,
                        "water_pcd": pcl_water,
                        "red_box_pcd": pcl_box_red,
                        "green_box_pcd": pcl_box_green,
                        "yellow_box_pcd": pcl_box_yellow,
                        "pcd_all": pcd_all
                    }

                    saving_path = os.path.join(self.save_path, str(traj_id), str(frame_id),
                                               "voxel_{}".format(args.voxel_num))
                    if not os.path.exists(saving_path):
                        os.makedirs(saving_path)

                    if draw:
                        capture_scene_image(pcl_water, output_pth=saving_path + "/water.png")
                        # print(traj_id, frame_id, v_red, v_green, v_yellow)
                        # print(n_r, n_g, n_y)
                        if v_red:
                            capture_scene_image(pcl_box_red, output_pth=saving_path + "/box_red.png")
                        if v_yellow:
                            capture_scene_image(pcl_box_yellow, output_pth=saving_path + "/box_yellow.png")
                        if v_green:
                            capture_scene_image(pcl_box_green, output_pth=saving_path + "/box_green.png")

        self.saving_pkg(pkg, saving_path)

        red_shape = pcl_box_red.shape[0] if pcl_box_red  is not None else 0
        green_shape = pcl_box_green.shape[0] if pcl_box_green  is not None else 0
        yellow_shape = pcl_box_yellow.shape[0] if pcl_box_yellow  is not None else 0

        f.write(f"{traj_id} {frame_id} {pcl_water.shape[0]} {red_shape} {green_shape} {yellow_shape}\n")

        print(f"{traj_id} {frame_id} {pcl_water.shape[0]} {red_shape} {green_shape} {yellow_shape}\n")




        return pkg

    def capture_scene_image(self, valid, feature, elevator=10, azim=0, output_pth='./reconstr_scene.png', angle=100, h=10, order=[0, 1, 2]):

        t = [i for i in range(self.N + 1)]
        x, y, z = np.meshgrid(t, t, t)
        x, y, z = x.flatten(), y.flatten(), z.flatten()
        fig = plt.figure()
        ax3D = fig.add_subplot(111, projection='3d')
        if 'pour' in self.args.name:
            ax3D.scatter(x[valid], z[valid], y[valid], s=1, c=feature[valid], marker='o')
            ax3D.view_init(h, angle)

            ax3D.set_xlim3d(-10, 50)
            ax3D.set_ylim3d(-10, 50)
            ax3D.set_zlim3d(0, 60)


        if 'shake' in self.args.name:
            ax3D.scatter(x[valid], z[valid], y[valid], s=1, c=feature[valid], marker='o')
            ax3D.view_init(h, angle)
            # ax3D.set_xlim3d(-10, 50)
            # ax3D.set_ylim3d(-10, 50)
            # ax3D.set_zlim3d(0, 60)

        if 'push' in self.args.name:
            ax3D.scatter(x[valid], z[valid], y[valid], s=1, c=feature[valid], marker='o')
            ax3D.view_init(10, angle)

            # ax3D.set_xlim3d(-10, 50)
            # ax3D.set_ylim3d(-10, 50)
            # ax3D.set_zlim3d(0, 60)

        if 'granular' in self.args.name:
            ax3D.view_init(elevator, angle)
            ax3D.set_xlim3d(0, 30)
            ax3D.set_ylim3d(0, 30)
            ax3D.set_zlim3d(0, 30)



        ax3D.set_xlabel('x')
        ax3D.set_ylabel('y')
        ax3D.set_zlabel('z')
        plt.savefig(output_pth)
        plt.close()


    def gen_video(self, image_list, output_pth):
        pass

pipline = Scene2Particle(dset, args, conf)

f = open(pipline.save_path + "box_info.txt", 'w')


# for i in tqdm(range(dset.__len__())):
for i in tqdm([316*300+291, 316*300+292, 316*300+293, 16*300+299]):
    if i % 50 == 0:
        vis = 1
    else:
        vis = 0
    pipline.get_particle(subset_id=i, device=args.gpu_id, sigma_min=0.5, draw=vis)


f.close()
