'''
    ModelNet dataset. Support ModelNet40, ModelNet10, XYZ and normal channels. Up to 10000 points.
'''

import os
# import os
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
import os.path
import json
import numpy as np
import math
import sys
import torch
import vgtk.so3conv.functional as L
# import vgtk.pc as pctk
from scipy.spatial.transform import Rotation as sciR
from SPConvNets.datasets.part_transform import revoluteTransform
from SPConvNets.models.model_util import *
import random

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
# import provider
from torch.utils import data
from SPConvNets.models.common_utils import *
from SPConvNets.datasets.data_utils import *
import scipy.io as sio
import copy
import trimesh
import pyrender
# from model.utils import farthest_point_sampling

# padding 1
def padding_1(pos):
    pad = np.array([1.], dtype=np.float).reshape(1, 1)
    # print(pos.shape, pad.shape)
    return np.concatenate([pos, pad], axis=1)

# normalize point-cloud
def pc_normalize(pc):
    l = pc.shape[0]
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
    pc = pc / m
    return pc

# decod rotation info
def decode_rotation_info(rotate_info_encoding):
    if rotate_info_encoding == 0:
        return []
    rotate_vec = []
    if rotate_info_encoding <= 3:
        temp_angle = np.reshape(np.array(np.random.rand(3)) * np.pi, (3, 1))
        if rotate_info_encoding == 1:
            line_vec = np.concatenate([
                np.cos(temp_angle), np.zeros_like(temp_angle), np.sin(temp_angle),
            ], axis=-1)
        elif rotate_info_encoding == 2:
            line_vec = np.concatenate([
                np.cos(temp_angle), np.sin(temp_angle), np.zeros_like(temp_angle)
            ], axis=-1)
        else:
            line_vec = np.concatenate([
                np.zeros_like(temp_angle), np.cos(temp_angle), np.sin(temp_angle)
            ], axis=-1)
        return [line_vec[0], line_vec[1], line_vec[2]]
    elif rotate_info_encoding <= 6:
        base_rotate_vec = [np.array([1.0, 0.0, 0.0], dtype=np.float),
                           np.array([0.0, 1.0, 0.0], dtype=np.float),
                           np.array([0.0, 0.0, 1.0], dtype=np.float)]
        if rotate_info_encoding == 4:
            return [base_rotate_vec[0], base_rotate_vec[2]]
        elif rotate_info_encoding == 5:
            return [base_rotate_vec[0], base_rotate_vec[1]]
        else:
            return [base_rotate_vec[1], base_rotate_vec[2]]
    else:
        return []

def rotate_by_vec_pts(un_w, p_x, bf_rotate_pos):

    def get_zero_distance(p, xyz):
        k1 = np.sum(p * xyz).item()
        k2 = np.sum(xyz ** 2).item()
        t = -k1 / (k2 + 1e-10)
        p1 = p + xyz * t
        # dis = np.sum(p1 ** 2).item()
        return np.reshape(p1, (1, 3))

    w = un_w / np.sqrt(np.sum(un_w ** 2, axis=0))
    # w = np.array([0, 0, 1.0])
    w_matrix = np.array(
        [[0, -float(w[2]), float(w[1])], [float(w[2]), 0, -float(w[0])], [-float(w[1]), float(w[0]), 0]]
    )

    rng = 0.25
    offset = 0.1

    effi = np.random.uniform(-rng, rng, (1,)).item()
    # effi = effis[eff_id].item()
    if effi < 0:
        effi -= offset
    else:
        effi += offset
    theta = effi * np.pi
    # rotation_matrix = np.exp(w_matrix * theta)

    sin_theta = np.sin(theta)
    cos_theta = np.cos(theta)

    # rotation_matrix = np.eye(3) + w_matrix * sin_theta + (w_matrix ** 2) * (1. - cos_theta)
    rotation_matrix = np.eye(3) + w_matrix * sin_theta + (w_matrix.dot(w_matrix)) * (1. - cos_theta)

    # bf_rotate_pos = pcd_points[sem_label_to_idxes[rotate_idx][rotate_idx_inst]]

    trans = get_zero_distance(p_x, un_w)

    af_rotate_pos = np.transpose(np.matmul(rotation_matrix, np.transpose(bf_rotate_pos - trans, [1, 0])), [1, 0]) + trans

    # af_rotate_pos = rotation_matrix.dot((bf_rotate_pos - trans).T).T + trans
    return af_rotate_pos, rotation_matrix, np.reshape(trans, (3, 1))

def inv_triangles_to_seg_idx(triangles_to_seg_idx):
    # triangles_to_seg_idx: N_tri
    seg_idx_to_triangles_idxes = {}
    for i_t in range(triangles_to_seg_idx.shape[0]):
        cur_tri_seg_idx = int(triangles_to_seg_idx[i_t].item())
        if cur_tri_seg_idx not in seg_idx_to_triangles_idxes:
            seg_idx_to_triangles_idxes[cur_tri_seg_idx] = [i_t]
        else:
            seg_idx_to_triangles_idxes[cur_tri_seg_idx].append(i_t)
    for seg_idx in seg_idx_to_triangles_idxes:
        seg_idx_to_triangles_idxes[seg_idx] = np.array(seg_idx_to_triangles_idxes[seg_idx], dtype=np.long) # N_seg_tri
    return seg_idx_to_triangles_idxes

def reindex_triangles_vertices(vertices, triangles):
    vert_idx_to_new_idx = {}
    new_idx = 0
    for i_tri in range(triangles.shape[0]):
        cur_tri = triangles[i_tri]
        v1, v2, v3 = int(cur_tri[0].item()), int(cur_tri[1].item()), int(cur_tri[2].item())
        for v in [v1, v2, v3]:
            if v not in vert_idx_to_new_idx:
                vert_idx_to_new_idx[v] = new_idx
                new_idx = new_idx + 1
    new_vertices = np.zeros((new_idx, 3), dtype=np.float)
    for vert_idx in vert_idx_to_new_idx:
        new_vert_idx = vert_idx_to_new_idx[vert_idx]
        new_vertices[new_vert_idx] = vertices[vert_idx]
    new_triangles = np.zeros_like(triangles)
    for i_tri in range(triangles.shape[0]):
        cur_tri = triangles[i_tri]
        v1, v2, v3 = int(cur_tri[0].item()), int(cur_tri[1].item()), int(cur_tri[2].item())
        new_v1, new_v2, new_v3 = vert_idx_to_new_idx[v1], vert_idx_to_new_idx[v2], vert_idx_to_new_idx[v3]
        cur_new_tri = np.array([new_v1, new_v2, new_v3], dtype=np.long)
        new_triangles[i_tri] = cur_new_tri
    # new_triangles: tot_n_tri x 3
    # new_vertices: tot_n_vert x 3
    return new_vertices, new_triangles

def ndc_depth_to_buffer(z, near, far):  # z in [-1, 1]
    return 2 * near * far / (near + far - z * (far - near))


def buffer_depth_to_ndc(d, near, far):  # d in (0, +
    return ((near + far) - 2 * near * far / np.clip(d, a_min=1e-6, a_max=1e6)) / (far - near)


class MotionDataset(data.Dataset):
    def __init__(
            self, root="./data/MDV02", npoints=512, split='train', nmask=10, shape_type="laptop", args=None, global_rot=0
    ):
        super(MotionDataset, self).__init__()

        self.root = root
        self.npoints = npoints
        self.shape_type = shape_type
        self.shape_root = os.path.join(self.root, shape_type)
        self.args = args

        self.mesh_fn = "summary.obj"
        self.surface_to_seg_fn = "sfs_idx_to_dof_name_idx.npy"
        self.attribute_fn = "motion_attributes.json"
        self.partial_data_fn = "rendered.npy"

        self.global_rot = global_rot
        self.split = split
        self.rot_factor = self.args.equi_settings.rot_factor
        self.no_articulation = self.args.equi_settings.no_articulation
        self.pre_compute_delta = self.args.equi_settings.pre_compute_delta
        self.use_multi_sample = self.args.equi_settings.use_multi_sample
        self.n_samples = self.args.equi_settings.n_samples if self.use_multi_sample == 1 else 1
        self.partial = self.args.equi_settings.partial
        self.est_normals = self.args.equi_settings.est_normals



        if self.pre_compute_delta == 1 and self.split == 'train':
            # self.use_multi_sample = False
            self.use_multi_sample = 0
            self.n_samples = 1

        if self.partial:
            self.n_samples = self.n_samples * 10

        print(f"no_articulation: {self.no_articulation}, equi_settings: {self.args.equi_settings.no_articulation}")


        self.shape_folders = []

        self.shape_root = os.path.join(self.shape_root, split)

        shape_idxes = os.listdir(self.shape_root)
        shape_idxes = sorted(shape_idxes)
        shape_idxes = [tmpp for tmpp in shape_idxes if tmpp[0] != "."]

        self.shape_idxes = shape_idxes
        print(f"split: {self.split}, shape_idxes: {self.shape_idxes}")

        self.anchors = L.get_anchors(args.model.kanchor)
        equi_anchors = L.get_anchors(60)

        # self.anchors = torch.from_numpy(L.get_anchors(args.model.kanchor)).cuda()
        self.kanchor = args.model.kanchor


    def get_trans_encoding_to_trans_dir(self):
        trans_dir_to_trans_mode = {
            (0, 1, 2): 1,
            (0, 2): 2,
            (0, 1): 3,
            (1, 2): 4,
            (0,): 5,
            (1,): 6,
            (2,): 7,
            (): 0
        }
        self.trans_mode_to_trans_dir = {trans_dir_to_trans_mode[k]: k for k in trans_dir_to_trans_mode}
        self.base_transition_vec = [
            np.array([1.0, 0.0, 0.0], dtype=np.float),
            np.array([0.0, 1.0, 0.0], dtype=np.float),
            np.array([0.0, 0.0, 1.0], dtype=np.float),
        ]

    def get_test_data(self):
        return self.test_data

    def reindex_data_index(self):
        old_idx_to_new_idx = {}
        ii = 0
        for old_idx in self.data:
            old_idx_to_new_idx[old_idx] = ii
            ii += 1
        self.new_idx_to_old_idx = {old_idx_to_new_idx[k]: k for k in old_idx_to_new_idx}

    def reindex_shape_seg(self, shape_seg):
        old_seg_to_new_seg = {}
        ii = 0
        for i in range(shape_seg.shape[0]):
            old_seg = int(shape_seg[i].item())
            if old_seg not in old_seg_to_new_seg:
                old_seg_to_new_seg[old_seg] = ii
                ii += 1
            new_seg = old_seg_to_new_seg[old_seg]
            shape_seg[i] = new_seg
            # old_seg_to_new_seg[old_seg] = ii
            # ii += 1
        return shape_seg

    def transit_pos_by_transit_vec(self, trans_pos):
        tdir = np.random.uniform(-1.0, 1.0, (3,))
        tdir = tdir / (np.sqrt(np.sum(tdir ** 2)).item() + 1e-9)
        trans_scale = np.random.uniform(1.0, 2.0, (1,)).item()
        # for flow test...
        trans_pos_af_pos = trans_pos + tdir * 0.1 * trans_scale * 2
        return trans_pos_af_pos, tdir * 0.1 * trans_scale * 2

    def transit_pos_by_transit_vec_dir(self, trans_pos, tdir):
        # tdir = np.zeros((3,), dtype=np.float)
        # axis_dir = np.random.choice(3, 1).item()
        # tdir[int(axis_dir)] = 1.
        trans_scale = np.random.uniform(0.0, 1.0, (1,)).item()
        trans_pos_af_pos = trans_pos + tdir * 0.1 * trans_scale
        return trans_pos_af_pos, tdir * 0.1 * trans_scale

    def get_random_transition_dir_scale(self):
        tdir = np.random.uniform(-1.0, 1.0, (3,))
        tdir = tdir / (np.sqrt(np.sum(tdir ** 2)).item() + 1e-9)
        trans_scale = np.random.uniform(1.0, 2.0, (1,)).item()
        return tdir, trans_scale

    def decode_trans_dir(self, trans_encoding):
        trans_dir = self.trans_mode_to_trans_dir[trans_encoding]
        return [self.base_transition_vec[d] for ii, d in enumerate(trans_dir)]

    def get_rotation_from_anchor(self):
        ii = np.random.randint(0, self.kanchor, (1,)).item()
        ii = int(ii)
        R = self.anchors[ii]
        return R

    def get_whole_shape_by_idx(self, index):
        shp_idx = self.shape_idxes[index + 1]
        cur_folder = os.path.join(self.shape_root, shp_idx)

        cur_mesh_fn = os.path.join(cur_folder, self.mesh_fn)
        cur_surface_to_seg_fn = os.path.join(cur_folder, self.surface_to_seg_fn)
        cur_motion_attributes_fn = os.path.join(cur_folder, self.attribute_fn)

        cur_vertices, cur_triangles = load_vertices_triangles(cur_mesh_fn)
        cur_triangles_to_seg_idx, seg_idx_to_triangle_idxes = load_triangles_to_seg_idx(cur_surface_to_seg_fn)
        cur_motion_attributes = load_motion_attributes(cur_motion_attributes_fn)

        sampled_pcts, pts_to_seg_idx, seg_idx_to_sampled_pts = sample_pts_from_mesh(cur_vertices, cur_triangles,
                                                                                    cur_triangles_to_seg_idx,
                                                                                    npoints=self.npoints)
        sampled_pcts = torch.from_numpy(sampled_pcts).float()
        return sampled_pcts

    def get_shape_by_idx(self, index):
        shp_idx = self.shape_idxes[index + 1]
        cur_folder = os.path.join(self.shape_root, shp_idx)

        cur_mesh_fn = os.path.join(cur_folder, self.mesh_fn)
        cur_surface_to_seg_fn = os.path.join(cur_folder, self.surface_to_seg_fn)
        cur_motion_attributes_fn = os.path.join(cur_folder, self.attribute_fn)

        cur_vertices, cur_triangles = load_vertices_triangles(cur_mesh_fn)
        cur_triangles_to_seg_idx, seg_idx_to_triangle_idxes = load_triangles_to_seg_idx(cur_surface_to_seg_fn)
        cur_motion_attributes = load_motion_attributes(cur_motion_attributes_fn)

        sampled_pcts, pts_to_seg_idx, seg_idx_to_sampled_pts = sample_pts_from_mesh(cur_vertices, cur_triangles,
                                                                                    cur_triangles_to_seg_idx,
                                                                                    npoints=self.npoints)

        # get points for each segmentation/part
        tot_transformed_pts = []
        pts_nns = []
        for i_seg in range(len(cur_motion_attributes)):
            cur_seg_motion_info = cur_motion_attributes[i_seg]
            cur_seg_pts_idxes = np.array(seg_idx_to_sampled_pts[i_seg], dtype=np.long)
            cur_seg_pts = sampled_pcts[cur_seg_pts_idxes]
            pts_nns.append(cur_seg_pts.shape[0])

            tot_transformed_pts.append(cur_seg_pts)
        maxx_nn_pt = max(pts_nns)
        res_pts = []
        for i, trans_pts in enumerate(tot_transformed_pts):
            cur_seg_nn_pt = trans_pts.shape[0]
            cur_seg_center_pt = np.mean(trans_pts, axis=0, keepdims=True)
            if cur_seg_nn_pt < maxx_nn_pt:
                cur_seg_trans_pts = np.concatenate(
                    [trans_pts] + [cur_seg_center_pt for _ in range(maxx_nn_pt - cur_seg_nn_pt)], axis=0
                )
                res_pts.append(np.reshape(cur_seg_trans_pts, (1, maxx_nn_pt, 3)))
            else:
                res_pts.append(np.reshape(trans_pts, (1, maxx_nn_pt, 3)))

        res_pts = np.concatenate(res_pts, axis=0)
        res_pts = torch.from_numpy(res_pts).float()
        return res_pts

    def refine_triangle_idxes_by_seg_idx(self, seg_idx_to_triangle_idxes, cur_triangles):
        res_triangles = []
        cur_triangles_to_seg_idx = []
        for seg_idx in seg_idx_to_triangle_idxes:
            # if seg_idx == 0:
            #     continue
            cur_triangle_idxes = np.array(seg_idx_to_triangle_idxes[seg_idx], dtype=np.long)
            cur_seg_triangles = cur_triangles[cur_triangle_idxes]
            res_triangles.append(cur_seg_triangles)
            cur_triangles_to_seg_idx += [seg_idx for _ in range(cur_triangle_idxes.shape[0])]
        res_triangles = np.concatenate(res_triangles, axis=0)
        cur_triangles_to_seg_idx = np.array(cur_triangles_to_seg_idx, dtype=np.long)
        return res_triangles, cur_triangles_to_seg_idx

    def __getitem__(self, index):

        # n_samples_per_instance = 100

        nparts = None

        shape_index, sample_index = index // self.n_samples, index % self.n_samples
        # print(index, shape_index, sample_index)
        # shp_idx = self.shape_idxes[index]
        # print(index, shape_index, sample_index)
        shp_idx = self.shape_idxes[shape_index]
        shp_idx_int, sample_idx_int = int(shp_idx), sample_index

        cur_folder = os.path.join(self.shape_root, shp_idx)

        cur_mesh_fn = os.path.join(cur_folder, self.mesh_fn)
        cur_surface_to_seg_fn = os.path.join(cur_folder, self.surface_to_seg_fn)
        # cur_motion_attributes_fn = os.path.join(cur_folder, self.attribute_fn)
        cur_partial_data_fn = os.path.join(cur_folder, self.partial_data_fn)

        # vertices & triangles --- any problem...
        cur_vertices, cur_triangles = load_vertices_triangles(cur_mesh_fn)
        #
        cur_triangles_to_seg_idx, seg_idx_to_triangle_idxes = load_triangles_to_seg_idx(cur_surface_to_seg_fn, nparts=nparts)
        cur_triangles, cur_triangles_to_seg_idx = self.refine_triangle_idxes_by_seg_idx(seg_idx_to_triangle_idxes, cur_triangles)
        # cur_motion_attributes = load_motion_attributes(cur_motion_attributes_fn, ex_none=True)
        cur_rendered_data = np.load(cur_partial_data_fn, allow_pickle=True).item()

        seg_label_to_pts, seg_label_to_new_pose, glb_pose, part_axis, part_pv_point, part_angles = cur_rendered_data['seg_label_to_pts'],  cur_rendered_data['seg_label_to_new_pose'],  cur_rendered_data['glb_pose'], cur_rendered_data['part_axis'], cur_rendered_data['part_pv_point'], cur_rendered_data['part_angles']
        canon_seg_label_to_pts, canon_seg_label_to_new_pose, canon_glb_pose = cur_rendered_data['canon_seg_label_to_pts'],  cur_rendered_data['canon_seg_label_to_new_pose'],  cur_rendered_data['canon_glb_pose']


        seg_idx_to_triangle_idxes = inv_triangles_to_seg_idx(cur_triangles_to_seg_idx)


        # cur_motion_attributes = load_motion_attributes(cur_motion_attributes_fn)

        sampled_pcts, pts_to_seg_idx, seg_idx_to_sampled_pts = sample_pts_from_mesh(cur_vertices, cur_triangles, cur_triangles_to_seg_idx, npoints=self.npoints)

        # transformed_vertices: n_vertices x 3
        # transformed_vertices = np.zeros((cur_vertices.shape[0], 3), dtype=np.float)

        ''' Centralize points in the real canonical state ''' # sampeld pcts...
        boundary_pts = [np.min(sampled_pcts, axis=0), np.max(sampled_pcts, axis=0)]
        center_pt = (boundary_pts[0] + boundary_pts[1]) / 2
        length_bb = np.linalg.norm(boundary_pts[0] - boundary_pts[1])

        # all normalize into 0
        sampled_pcts = (sampled_pcts - center_pt.reshape(1, 3)) / length_bb

        boundary_pts = [np.min(cur_vertices, axis=0), np.max(cur_vertices, axis=0)]
        center_pt = (boundary_pts[0] + boundary_pts[1]) / 2
        length_bb = np.linalg.norm(boundary_pts[0] - boundary_pts[1]) # leng
        # todo: change the normalization strategy to pre-normalization
        cur_vertices = (cur_vertices - center_pt.reshape(1, 3)) / length_bb # center pt and length bb
        ''' Centralize points in the real canonical state '''


        tot_transformed_full_vertices = []
        tot_transformed_pts = []
        pts_to_seg_idx = []
        tot_transformation_mtx = []
        # seg label to
        canon_transformed_pts = []
        glb_rotation, glb_trans = glb_pose['rotation'], glb_pose['trans']
        part_axis = np.matmul(glb_rotation,  np.transpose(part_axis, (1, 0)))
        part_axis = np.transpose(part_axis, (1, 0))
        part_pv_point = np.matmul(glb_rotation, np.transpose(part_pv_point, (1, 0)))
        # get transformed part pv point
        part_pv_point = np.transpose(part_pv_point, (1, 0)) + np.reshape(glb_trans, (1, 3))
        part_pv_offset = part_pv_point - np.sum(part_pv_point * part_axis, axis=-1, keepdims=True) * part_axis
        part_pv_offset = np.sqrt(np.sum(part_pv_offset ** 2, axis=-1))

        cur_n_seg = len(seg_label_to_pts)
        tot_transformation_mtx_segs = np.zeros((cur_n_seg, 4, 4), dtype=np.float)
        part_state_rots = np.zeros((cur_n_seg, 3, 3), dtype=np.float)
        part_state_trans_bbox = np.zeros((cur_n_seg, 3), dtype=np.float)
        part_ref_rots = np.zeros((cur_n_seg, 3, 3), dtype=np.float)
        part_ref_trans = np.zeros((cur_n_seg, 3), dtype=np.float)
        part_ref_trans_bbox = np.zeros((cur_n_seg, 3), dtype=np.float)

        # print(f"seg_label_to_pts: {len(seg_label_to_pts)}, seg_label_to_new_pose: {len(seg_label_to_new_pose)}")

        for seg_label in seg_label_to_pts:
            cur_seg_trans_pts = seg_label_to_pts[seg_label]
            cur_seg_pose = seg_label_to_new_pose[seg_label]
            tot_transformed_pts.append(cur_seg_trans_pts) # partial transformed points
            pts_to_seg_idx += [seg_label for _ in range(cur_seg_trans_pts.shape[0])]

            canon_transformed_pts.append(canon_seg_label_to_pts[seg_label])

            ''' Get triangle idxes for this segmentation '''
            cur_seg_tri_idxes = seg_idx_to_triangle_idxes[seg_label]
            cur_seg_tri = cur_triangles[cur_seg_tri_idxes]
            cur_seg_tri_v1, cur_seg_tri_v2, cur_seg_tri_v3 = cur_vertices[cur_seg_tri[:, 0]], cur_vertices[
                cur_seg_tri[:, 1]], cur_vertices[cur_seg_tri[:, 2]]
            cur_seg_tri_vertices = np.concatenate([cur_seg_tri_v1, cur_seg_tri_v2, cur_seg_tri_v3], axis=0)
            cur_seg_tri_vertices_idxes = np.concatenate([cur_seg_tri[:, 0], cur_seg_tri[:, 1], cur_seg_tri[:, 2]],
                                                        axis=0)

            cur_seg_rot, cur_seg_trans = cur_seg_pose[:3, :3], cur_seg_pose[:3, 3]
            rot_cur_seg_tri_vertices = np.matmul(cur_seg_rot, np.transpose(cur_seg_tri_vertices, (1, 0))) + np.reshape(cur_seg_trans, (3, 1))
            rot_cur_seg_tri_vertices = np.transpose(rot_cur_seg_tri_vertices, (1, 0))

            tot_transformed_full_vertices.append(rot_cur_seg_tri_vertices)

            # cur_seg_pose: 3 x 3; 3

            cur_seg_pts_minn = np.min(rot_cur_seg_tri_vertices, axis=0)
            cur_seg_pts_maxx = np.max(rot_cur_seg_tri_vertices, axis=0)
            cur_seg_pts_bbox_center = (cur_seg_pts_minn + cur_seg_pts_maxx) / 2.
            cur_seg_trans_bbox = cur_seg_trans - cur_seg_pts_bbox_center

            part_state_trans_bbox[seg_label] = cur_seg_trans_bbox
            part_state_rots[seg_label] = cur_seg_rot
            tot_transformation_mtx_segs[seg_label] = cur_seg_pose
            # tot_transformation_mtx.append([cur_seg_pose for _ in range(cur_seg_trans_pts.shape[0])])
            tot_transformation_mtx += [np.reshape(cur_seg_pose, (1, 4, 4)) for _ in range(cur_seg_trans_pts.shape[0])]

        tot_transformed_pts = np.concatenate(tot_transformed_pts, axis=0)
        tot_transformed_full_vertices = np.concatenate(tot_transformed_full_vertices, axis=0)
        # center_transformed_full_vertices = np.mean(tot_transformed_full_vertices, axis=0, keepdims=True)
        tot_transformation_mtx = np.concatenate(tot_transformation_mtx, axis=0)
        canon_transformed_pts = np.concatenate(canon_transformed_pts, axis=0)
        pts_to_seg_idx = np.array(pts_to_seg_idx, dtype=np.long)

        gt_pose = tot_transformation_mtx


        if self.global_rot >= 0:
            af_glb_center_pt = np.mean(tot_transformed_full_vertices, axis=0)
            tot_transformed_pts = (tot_transformed_pts - af_glb_center_pt.reshape(1, 3))
            ''' Point normalization via centralization '''
            gt_pose[:, :3, 3] = gt_pose[:, :3, 3] - af_glb_center_pt # transformation matrix...
            tot_transformation_mtx_segs[:, :3, 3] = tot_transformation_mtx_segs[:, :3, 3] - af_glb_center_pt
            part_pv_point = part_pv_point - np.reshape(af_glb_center_pt, (1, 3))
            part_pv_offset = part_pv_point - np.sum(part_pv_point * part_axis, axis=-1, keepdims=True) * part_axis
            part_pv_offset = np.sqrt(np.sum(part_pv_offset ** 2, axis=-1))


        cur_pc = torch.from_numpy(tot_transformed_pts.astype(np.float32)).float()
        tot_transformed_pts = torch.from_numpy(tot_transformed_pts.astype(np.float32)).float()
        cur_label = torch.from_numpy(pts_to_seg_idx).long()
        tot_label = torch.from_numpy(pts_to_seg_idx).long()
        cur_pose = torch.from_numpy(gt_pose.astype(np.float32))
        cur_pose_segs = torch.from_numpy(tot_transformation_mtx_segs.astype(np.float32))
        # cur_ori_pc = torch.from_numpy(sampled_pcts.astype(np.float32)).float()
        cur_canon_transformed_pts = torch.from_numpy(canon_transformed_pts.astype(np.float32)).float()
        cur_part_state_rots = torch.from_numpy(part_state_rots.astype(np.float32)).float()
        cur_part_ref_rots = torch.from_numpy(part_ref_rots.astype(np.float32)).float()
        cur_part_ref_trans = torch.from_numpy(part_ref_trans.astype(np.float32)).float()
        cur_part_axis = torch.from_numpy(part_axis.astype(np.float32)).float()
        cur_part_pv_offset = torch.from_numpy(part_pv_offset.astype(np.float32)).float()
        cur_part_pv_point = torch.from_numpy(part_pv_point.astype(np.float32)).float()
        part_angles = np.array(part_angles, dtype=np.float32) # .float()
        part_angles = torch.from_numpy(part_angles.astype(np.float32)).float()

        # if self.est_normals == 1:
        #     cur_normals = estimate_normals(cur_pc)
        #     cur_canon_normals = estimate_normals(cur_canon_transformed_pts)

        # part_state_trans_bbox part_ref_trans_bbox
        cur_part_state_trans_bbox = torch.from_numpy(part_state_trans_bbox.astype(np.float32)).float()
        cur_part_ref_trans_bbox = torch.from_numpy(part_ref_trans_bbox.astype(np.float32)).float()


        full_tot_transformed_pts = tot_transformed_pts.clone()
        fps_idx = farthest_point_sampling(cur_pc.unsqueeze(0), n_sampling=self.npoints)
        fps_idx = fps_idx[:self.npoints]
        fps_idx_oorr = farthest_point_sampling(cur_pc.unsqueeze(0), n_sampling=4096)

        tot_transformed_pts = tot_transformed_pts[fps_idx_oorr]
        tot_label = tot_label[fps_idx_oorr]
        cur_pc = cur_pc[fps_idx]
        cur_label = cur_label[fps_idx]
        cur_pose = cur_pose[fps_idx]

        # cur_pose = cur_pose[fps_idx]
        ##### Get original transformed point cloud #####
        # cur_ori_pc = cur_ori_pc[fps_idx]
        ##### Get original transformed point cloud #####

        canon_fps_idx = farthest_point_sampling(cur_canon_transformed_pts.unsqueeze(0), n_sampling=self.npoints)
        canon_fps_idx = canon_fps_idx[:self.npoints]
        canon_fps_idx_oorr = farthest_point_sampling(cur_canon_transformed_pts.unsqueeze(0), n_sampling=4096)
        cur_oorr_canon_transformed_pts = cur_canon_transformed_pts[canon_fps_idx_oorr]
        cur_canon_transformed_pts = cur_canon_transformed_pts[canon_fps_idx]

        # if self.est_normals == 1:
        #     cur_normals = cur_normals[fps_idx]
        #     cur_canon_normals = cur_canon_normals[fps_idx]

        idx_arr = np.array([index], dtype=np.long)
        idx_arr = torch.from_numpy(idx_arr).long()
        shp_idx_arr = torch.tensor([shp_idx_int], dtype=torch.long).long()
        sampled_idx_arr = torch.tensor([sample_idx_int], dtype=torch.long).long()

        rt_dict = {
            'pc': cur_pc.contiguous().transpose(0, 1).contiguous(),
            'af_pc': cur_pc.contiguous().transpose(0, 1).contiguous(),
            'ori_pc': full_tot_transformed_pts.contiguous().transpose(0, 1).contiguous(),
            'canon_pc': cur_canon_transformed_pts, #.contiguous().transpose(0, 1).contiguous(),
            'oorr_pc': tot_transformed_pts.contiguous().transpose(0, 1).contiguous(), # 3 x oorr_N
            'oorr_canon_pc': cur_oorr_canon_transformed_pts.contiguous(),
            'label': cur_label,
            'oorr_label': tot_label, # oorr_N
            'pose': cur_pose,
            'pose_segs': cur_pose_segs,
            'part_state_rots': cur_part_state_rots,
            'part_ref_rots': cur_part_ref_rots,
            'part_ref_trans': cur_part_ref_trans,
            'part_axis': cur_part_axis,  # get ground-truth axis
            'idx': idx_arr,
            'shp_idx': shp_idx_arr,
            'sampled_idx': sampled_idx_arr,
            'part_state_trans_bbox': cur_part_state_trans_bbox,
            'part_ref_trans_bbox': cur_part_ref_trans_bbox,
            'part_pv_offset': cur_part_pv_offset,
            'part_pv_point': cur_part_pv_point,
            'part_angles': part_angles,
        }

        return rt_dict

    def __len__(self):
        return len(self.shape_idxes) * self.n_samples

    def get_num_moving_parts_to_cnt(self):
        return self.num_mov_parts_to_cnt

    def reset_num_moving_parts_to_cnt(self):
        self.num_mov_parts_to_cnt = {}


if __name__ == '__main__':
    d = ModelNetDataset(root='../data/modelnet40_normal_resampled', split='test')
    print(d.shuffle)
    print(len(d))
    import time

    tic = time.time()
    for i in range(10):
        ps, cls = d[i]
    print(time.time() - tic)
    print(ps.shape, type(ps), cls)

    print(d.has_next_batch())
    ps_batch, cls_batch = d.next_batch(True)
    print(ps_batch.shape)
    print(cls_batch.shape)
