import jax
import jax.numpy as jnp
import numpy as np
import os
import pickle
import glob
from tqdm import tqdm
from functools import partial
import time
import open3d as o3d

import sys
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if BASE_DIR not in sys.path:
    sys.path.insert(0, BASE_DIR)

import util.latent_obj_util as loutil
import util.model_util as mutil
import util.transform_util as tutil
import modules.shakey_module as shakey_module
from util.reconstruction_util import create_fps_fcd_from_oriCORNs, create_scene_mesh_from_oriCORNs
from modules.traj_opt_module import evaluate_full_trajectory
import util.broad_phase as broad_phase

models = mutil.Models().load_pretrained_models()

# urdf_dirs = "assets/ur5/urdf/ur5.urdf"
urdf_dirs = "assets/ur5/urdf/shakey_open_rg6.urdf"
shakey = shakey_module.load_urdf_kinematics(urdf_dirs=urdf_dirs, models=models)

broad_phase_cls = broad_phase.BroadPhaseWarp()

# hyperparameters
interpolation_num = 14
reduce_k = 20
jkey = jax.random.PRNGKey(0)
broadphase_type = 'timeoptbf_traj'


# visualize swept volume function
# inputs
q_control_points = jnp.stack([jnp.zeros(6), -2*jnp.ones(6)], axis=-2)
# q_control_points = jnp.stack([jnp.ones(6), -2*jnp.ones(6)], axis=-2)
# q_control_points = jnp.stack([jnp.zeros(6), -np.pi/3*jnp.ones(6)], axis=-2)
# q_control_points = jnp.stack([-np.pi/2*jnp.ones(6), -np.pi/3*jnp.ones(6)], axis=-2)

interpolated_trajectory, vel, acc, jerk = partial(evaluate_full_trajectory, samples_per_segment=interpolation_num)(q_control_points)
moving_obj_pqs = shakey.FK(interpolated_trajectory, oriCORN_out=False)
moving_obj = shakey.link_canonical_oriCORN

# @jax.jit
def dec_func(dummpy, qpoints, jkey):
    original_outer_shape = qpoints.shape[:-2]
    qpoints = qpoints.reshape((-1,)+(3,))

    # option 1 - entire
    nq = qpoints.shape[-2]
    point_oriCORN = loutil.LatentObjects(rel_fps=jnp.zeros((1,3)).astype(jnp.float32), z=jnp.zeros((1,)+models.latent_shape[1:]).astype(jnp.float32)).init_pos_zero()
    point_oriCORN = point_oriCORN.extend_and_repeat_outer_shape(nq, 0)
    point_oriCORN = point_oriCORN.translate(qpoints)

    collision_loss_pair = jax.vmap(partial(models.apply, 'col_decoder', latent_obj_B=moving_obj, pq_transform_B=moving_obj_pqs,
                                                                        broadphase_type=broadphase_type,
                                                                        reduce_k=reduce_k, path_check=True, jkey=jkey))(point_oriCORN)

    # option 2 - body segmentation
    # moving_obj_pqs = jnp.stack([moving_obj_pqs[...,1:,:,:], moving_obj_pqs[...,:-1,:,:]], axis=-3)
    # moving_obj_pqs = jnp.moveaxis(moving_obj_pqs, -4, 0)
    # collision_loss_pair = \
    #     jax.vmap(partial(models.apply, 'col_decoder', fixed_obj.merge()[None], moving_obj, reduce_k=self.reduce_k, path_check=True, jkey=jkey))(pq_transform_B=moving_obj_pqs)
    # collision_loss_pair = jnp.moveaxis(collision_loss_pair.squeeze(-1), 0, -1)

    res = jnp.max(collision_loss_pair, axis=(-1,-2,-3))
    res = res.reshape(original_outer_shape + res.shape[-1:])

    return res

# qpoints = np.random.uniform(-0.3, 0.3, (1, 100, 3)).astype(np.float32)
# res = dec_func(None, qpoints, jkey)

moved_likes = moving_obj.apply_pq_z(moving_obj_pqs, models.rot_configs)

pcd_fps = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(moved_likes.fps_tf.reshape(-1,3)))
moved_likes.fps_tf.reshape(-1,3)

mesh = create_scene_mesh_from_oriCORNs(None, dec=jax.jit(dec_func), visualize=False, density=100, ndiv=800, qp_bound=1.2)
# create_scene_mesh_from_oriCORNs(None, dec=dec_func, visualize=True, density=100, ndiv=800, qp_bound=0.2)
o3d.visualization.draw_geometries([mesh, pcd_fps])

print(1)
