import os
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"   
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"    # NVISII will crash when showed multiple devices with jax.
# os.environ["JAX_PLATFORM_NAME"] = "cpu"
import pybullet as pb
import numpy as np
import matplotlib.pyplot as plt
from functools import wraps
from time import time
import pandas as pd
from scipy.spatial.transform import Rotation as sciR
import pickle
from pathlib import Path
import flax
import lz4.frame
import pybullet_data

import sys
if Path(__file__).parent.parent not in sys.path:
    sys.path.insert(0,str(Path(__file__).parent.parent))

import util.scene_util as sutil
from util.dotenv_util import OBJECT_PATH


def set_one_table(table_pos, table_quat, obj_no):

    kitchen = {}

    reference_object = "table_right"
    target_objects = [
        "coke", 
        "wine"
                      ]
    target_objects_mesh_dir = {
        "coke": "NOCS/modified/can-d3e24e7712e1e82dece466fd8a3f2b40.obj",
        "wine": "NOCS/modified/bottle-5470f2c11fd647a7c5af224a021be7c5.obj",
    }
    objects_scales = {
        'coke': 0.16,
        'wine': 0.25, 
    }
    objects_orns = {
        'coke': sciR.from_euler('x', np.pi/2).as_quat(),
        'wine': sciR.from_euler('x', np.pi/2).as_quat(),
    }
    max_target_object = obj_no
    target_objects = np.random.choice(target_objects, max_target_object, replace=True).tolist()
    count = 0

    object_names = [reference_object] + target_objects

    class objcls:
        pass

    urdf_base_dir = OBJECT_PATH

    for name in object_names:
        if "table" in name or "cabinet" in name:

            table_mesh_filename = os.path.join(urdf_base_dir, 'kitchen/table/big_table_1.obj')
            table_mesh_vis = pb.createVisualShape(shapeType=pb.GEOM_MESH, fileName=table_mesh_filename)
            table_mesh_col = pb.createCollisionShape(shapeType=pb.GEOM_MESH, fileName=table_mesh_filename)
            table_uid = pb.createMultiBody(baseMass=0, baseCollisionShapeIndex=table_mesh_col, baseVisualShapeIndex=table_mesh_vis)
            pb.resetBasePositionAndOrientation(table_uid, [0,0,0.4], sciR.from_euler('z', np.pi/2).as_quat())

            obj = objcls()
            obj.uid = table_uid


        if name in target_objects:
            # random sample position of object on reference object(table)
            aabb = pb.getAABB(kitchen[reference_object].uid)
            valid = False
            while not valid:
                lower_bound = np.array(aabb[0]) + np.array([0.5*np.abs(aabb[1][0]-aabb[0][0])+0.4, 0.2, 0])
                upper_bound = np.array(aabb[1]) + np.array([-0.06, -0.2, 0])
                lower_bound[0] = 0.15
                upper_bound[0] = 0.35
                lower_bound[1] = -0.2
                upper_bound[1] = 0.2

                position = np.random.uniform(0, 1, (3,)) * (upper_bound - lower_bound) + lower_bound
                position[2] = 1.0
            
                count += 1

                vis_shape_path = os.path.join(urdf_base_dir, target_objects_mesh_dir[name])
                col_shape_path = str(Path(vis_shape_path).parent.parent / 'cvx/32_64_1_v4' / os.path.basename(vis_shape_path))

                col_id = pb.createCollisionShape(shapeType=pb.GEOM_MESH, fileName=col_shape_path, meshScale=[objects_scales[name],objects_scales[name],objects_scales[name]])
                vis_id = pb.createVisualShape(shapeType=pb.GEOM_MESH, fileName=vis_shape_path, meshScale=[objects_scales[name],objects_scales[name],objects_scales[name]])
                obj_uid = pb.createMultiBody(baseMass=0.1,
                                        baseCollisionShapeIndex=col_id,
                                        baseVisualShapeIndex=vis_id,
                                        basePosition=position,
                                        baseOrientation=objects_orns[name])
            
                obj = objcls()
                obj.uid = obj_uid

                # collision check between new objects and others
                pb.performCollisionDetection()
                con_res = pb.getContactPoints(obj.uid)
                if len(con_res) == 0:
                    valid = True
                else:
                    pb.removeBody(obj.uid)
                    count -= 1
                
            name = f'obj{count}'

        kitchen[name] = obj

    # transform kitchen envs (table and objects) with scene pos and quat
    for name, obj in kitchen.items():
        pos, orn = pb.getBasePositionAndOrientation(obj.uid)
        newpos, neworn = pb.multiplyTransforms(table_pos, table_quat, pos, orn)
        pb.resetBasePositionAndOrientation(obj.uid, newpos, neworn)

    return kitchen


def create_scene(obj_no, table_no=2, create_eval_scene=False, previous_kitchen=None):

    # check pb connection and connect if not connected
    if pb.getConnectionInfo()['isConnected'] == 0:
        pb.connect(pb.GUI)
    if previous_kitchen is not None:
        for k, v in previous_kitchen.items():
            pb.removeBody(v.uid)

    if table_no == 2:
        obj_no_kitchen2 = np.random.randint(1, obj_no)
        if create_eval_scene:
            scene_pos = [3.0,0,0]
            scene_quat = sciR.from_euler('z', np.pi).as_quat()
        else:
            scene_pos = [np.random.uniform(1.5, 4.5),0,0]
            scene_quat = np.random.uniform(np.pi-np.pi/10, np.pi+np.pi/10)
            scene_quat = sciR.from_euler('z', scene_quat).as_quat()
        kitchen2 = set_one_table(scene_pos, scene_quat, obj_no=obj_no_kitchen2)
        kitchen2 = {k+'2':v for k,v in kitchen2.items()}
    else:
        obj_no_kitchen2 = 0
        kitchen2 = {}

    obj_no_kitchen = obj_no - obj_no_kitchen2
    if create_eval_scene:
        scene_pos = [0.0,0,0]
        scene_quat = sciR.from_euler('z', 0).as_quat()
    else:
        scene_pos = [np.random.uniform(-0.1, -0.5),0,0]
        scene_quat = np.random.uniform(-np.pi/10, np.pi/10)
        scene_quat = sciR.from_euler('z', scene_quat).as_quat()
    kitchen = set_one_table(scene_pos, scene_quat, obj_no=obj_no_kitchen)

    for i in range(200):
        pb.stepSimulation()

    return {**kitchen, **kitchen2}


def generate(save_base_dir:Path, data_id, max_obj_no, table_no=2, create_eval_scene=False, itr=10, debug = False):

    if debug:
        sim_id =  pb.connect(pb.GUI)
    else:
        sim_id =  pb.connect(pb.DIRECT)
    
    pb.setGravity(0, 0, -9.81)

    CAMERA_DISTANCE        = 5.79
    CAMERA_PITCH           = -88.94
    CAMERA_YAW             = -270.0
    CAMERA_TARGET_POSITION = [0,0, 0]

    pb.resetDebugVisualizerCamera(
        cameraDistance = CAMERA_DISTANCE,
        cameraYaw = CAMERA_YAW,
        cameraPitch = CAMERA_PITCH,
        cameraTargetPosition = CAMERA_TARGET_POSITION
    )
    # pb.configureDebugVisualizer(pb.COV_ENABLE_RENDERING, 1)

    # create floor
    plane_uid = pb.loadURDF(
            fileName        = os.path.join(pybullet_data.getDataPath(), "plane.urdf"), 
            basePosition    = (0.0, 0.0, 0.0), 
            baseOrientation = pb.getQuaternionFromEuler((0.0, 0.0, 0.0)),
            useFixedBase    = True)
    pb.changeVisualShape(plane_uid, -1, textureUniqueId=-1, rgbaColor=[0.5, 0.5, 0.5, 1])
        
    # robot_id = pb.loadURDF('assets/ur5/urdf/ur5_rg2_merged_open.urdf')
    robot_id = pb.loadURDF('assets/RobotBimanualV4/urdf/RobotBimanualV4.urdf')
    base_robot_pos_offset, base_robot_quat_offset = pb.getBasePositionAndOrientation(robot_id)

    # get joint lower and upper limits
    joint_lower_limits = []
    joint_upper_limits = []
    for i in range(pb.getNumJoints(robot_id)):
        joint_info = pb.getJointInfo(robot_id, i)
        joint_lower_limit = joint_info[8]
        joint_upper_limit = joint_info[9]
        joint_lower_limits.append(joint_lower_limit)
        joint_upper_limits.append(joint_upper_limit)

    random_height = np.random.uniform(0.1, 0.5)
    pb.resetBasePositionAndOrientation(robot_id, np.array([0,0,random_height]) + np.array(base_robot_pos_offset), [0,0,0,1])

    data_points = []
    all_objs = None
    for inner_itr in range(itr):
        obj_no = np.random.randint(2, max_obj_no+1)
        pb.resetBasePositionAndOrientation(robot_id, np.array([10.,0, 0.52]) + np.array(base_robot_pos_offset), [0,0,0,1])
        # env.reset(obj_no)
        all_objs = create_scene(obj_no, table_no, create_eval_scene, previous_kitchen=all_objs)

        # place robot in random position
        # lower_limits = [-np.pi, -np.pi, -np.pi, -np.pi, -np.pi, -np.pi]
        # upper_limits = [np.pi, 0, np.pi, np.pi, np.pi, np.pi]
        for _ in range(50):
            robot_random_q = np.random.uniform(joint_lower_limits, joint_upper_limits)
            for i in range(1, 7):
                pb.resetJointState(robot_id, i, robot_random_q[i-1])
            robot_random_pos = np.random.uniform([-1.5, -1.5, 0.48], [3.5, 1.5, 0.55])
            robot_random_quat = sciR.from_euler('xyz', [0, 0, np.random.uniform(-np.pi, np.pi)]).as_quat()
            pb.resetBasePositionAndOrientation(robot_id, robot_random_pos + np.array(base_robot_pos_offset), robot_random_quat)

            pb.performCollisionDetection()
            if len(pb.getContactPoints(robot_id))==0:
                break

        obj_uids = [all_objs[k].uid for k in all_objs.keys() if 'obj' in k]
        table_udis = [all_objs[k].uid for k in all_objs.keys() if 'table' in k]
        pb_scene = sutil.SceneConverter().construct_scene_from_pybullet(obj_uids, table_udis, robot_id)

        if create_eval_scene:
            # table color to brown
            for uid in table_udis:
                pb.changeVisualShape(uid, -1, rgbaColor=[185/255., 156/255., 107/255., 1])
            
            # object color
            for uid in obj_uids:
                pb.changeVisualShape(uid, -1, rgbaColor=[119/255., 0, 200/255., 1])

            # pb.disconnect()
            return pb_scene

        # pos_pool = [pb.getBasePositionAndOrientation(uid)[0] for uid in obj_uids + table_udis]
        pos_pool = [pb.getBasePositionAndOrientation(uid)[0] for uid in obj_uids]
        pos_pool2 = [np.array(pb.getBasePositionAndOrientation(uid)[0]) + np.array([0,0,0.4]) for uid in table_udis]

        close_view = np.random.choice([True, True, False, False, False])
        if close_view:
            view_base_target = np.zeros((1,3)) + np.array([0,0,1.0])
        else:
            view_base_target = np.zeros((5,3)) + np.array([0,0,1.0])
        view_base_target = np.array(view_base_target.tolist()+[*pos_pool, *pos_pool2])
        scene_data = pb_scene.generate_scene_data(max_obj_no=max_obj_no, close_view=close_view, view_base_target=view_base_target, robot_pbuid=robot_id)
        
        numpy_datapoints = scene_data.replace(nvren_info=None, table_params=None, robot_params=None)
        numpy_datapoints = flax.serialization.to_state_dict(numpy_datapoints)
        numpy_datapoints = {k: v for k, v in numpy_datapoints.items() if v is not None}

        data_points.append(numpy_datapoints)

        if not debug:
            if isinstance(save_base_dir, str):
                save_base_dir = Path(save_base_dir)
            scene_data_save_path = save_base_dir/f"{data_id}_{inner_itr:02d}_table_5_{max_obj_no}.lz4"
            with lz4.frame.open(str(scene_data_save_path), mode='wb') as fp:
                fp.write(pickle.dumps(numpy_datapoints))

    if debug:
        # visualize figures
        plt.figure()
        nrow = 2
        nview = 2
        for i in range(0, nrow):
            for j in range(0, nview):
                plt.subplot(nrow, nview, i*nview+j+1)
                plt.imshow(data_points[i]['rgbs'][j])
                plt.axis('off')
        plt.show()


    pb.disconnect()

if __name__ == "__main__":
    import datetime
    import ray
    max_obj_no = 8
    save_base_dir = f'one_table_scenedata/{max_obj_no}_5'
    os.makedirs(save_base_dir, exist_ok=True)
    table_no = 1
    entire_ds_size = 0
    ray_env_no = 20
    innter_itr = 4
    ray_func = []
    entire_dataset = []
    cur_timestep = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    for scene_id_ in np.arange(5000):
        scene_id_str = cur_timestep + f"_{scene_id_:04d}"

        ray_func.append(ray.remote(generate).remote(save_base_dir, scene_id_str, max_obj_no, table_no=table_no, itr=innter_itr, debug=False))
        if scene_id_ % ray_env_no == 0 and scene_id_!=0:
            ray.get(ray_func)
            ray_func = []

        # test non-ray generation
        # dataset_= generate(save_base_dir, scene_id_str, max_obj_no, table_no=table_no, itr=innter_itr, debug=True)