import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"   
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"    # NVISII will crash when showed multiple devices with jax.
os.environ["JAX_PLATFORM_NAME"] = "cpu"
import pybullet as pb
import numpy as np
import matplotlib.pyplot as plt
from functools import wraps
from time import time
import pandas as pd
from scipy.spatial.transform import Rotation as sciR
import pickle
import copy
from pathlib import Path
import flax
import lz4.frame
import h5py
import pkgutil

# add path LLMPlannerEnv
import sys
LLMPlannerEnv_path = Path(__file__).parent.parent/'LLMPlannerEnv'
if LLMPlannerEnv_path not in sys.path:
    sys.path.append(str(LLMPlannerEnv_path))

if Path(__file__).parent.parent not in sys.path:
    sys.path.append(str(Path(__file__).parent.parent))

import util.scene_util as sutil


from LLMPlannerEnv.Simulation.pybullet_env_real_object.envs.shop_env import ShopEnv
from LLMPlannerEnv.Simulation.pybullet_env_real_object.imm.pybullet_util.bullet_client import BulletClient

def timeit(f):
    @wraps(f)
    def wrap(*args, **kw):
        ts = time()
        result = f(*args, **kw)
        te = time()
        print("func:%r took: %2.4f sec" % (f.__name__, te - ts))
        return result

    return wrap

VISUALIZE = False

# CAMERA_DISTANCE        = 17.79
CAMERA_DISTANCE        = 5.79
CAMERA_PITCH           = -88.94
CAMERA_YAW             = -270.0
# CAMERA_TARGET_POSITION = [-4.61, -0.52, -9.58]
CAMERA_TARGET_POSITION = [0,0, 0]
config = {
    "sim_params": {
        "delay": 0.0,
        "control_hz": 240,
        "gravity": -9.8,
    },
    "project_params": {
        "custom_urdf_path": "urdf",
    },
    "problem_params": {
        "scenario": '_predicate_gen'
    },
    "robot_params": {
        # TODO: add ur5 + gripper
        "pr2": {
            "path": "pr2/pr2.urdf",
            "pos": [0.0, -3, 0],
            "orn": [0.0, 0.0, 3.14159],
            "joint_index_last": 6,
            "joint_indices_arm": [1, 2, 3, 4, 5, 6],
            "link_index_endeffector_base": 6,
            "rest_pose": [
                0.0,        # Base (Fixed)
                3.14159,    # Joint 0
                -2.094,     # Joint 1
                2.07,       # Joint 2
                -1.57,      # Joint 3
                -1.57,      # Joint 4
                0.0,        # Joint 5         (Try to reset to 480 after release)
                0.0,        # Joint 6 EE Base (Fixed)
            ]
            # "rest_pose": [0.6907909329066655, -1.0354177550427595, 1.1440116001596894, -1.4401206757476759, -2.0206557827891047, 2.606198161917956, -0.7516451911325208, -2.2794601891064987]
        }
    },
    "manipulation_params": {
        "main_hand": "right",
        "inverse_kinematics": {
            "max_num_iterations": 1000,
            "residual_threshold": 1.0E-12
        },
        "rrt_trials": 10000,  # Use utils.RRT_ITERATIONS instead.
        "sample_space": {
            "center": [0.54, 0.04, 0.835],  # Sync with taskspace too.
            "half_ranges": [0.24, 0.15, 0.085],
            "yaw_range": [-3.14, 3.14]
        },
        "delay": 0.001,
        "resolutions": 0.05,
        "num_samples": 10
    },
    "navigation_params": {
        "base_limits": [[-5, -5], [5, 5]],
        "resolutions": 0.5,
        "trials": 1,
        "num_samples": 50,
        "delay": 0.01,
        "pickle": {
            "default_open": "default_open_navigation_prm_nodes.pkl",
            "default_close": "default_close_navigation_prm_nodes.pkl",
            "empty": "empty_navigation_prm_nodes.pkl"
        },
    },
    "pose_sampler_params": {
        # "num_filter_trials_pick": 500,
        "num_filter_trials_pick": 10,
        "num_filter_trials_place": 40,
        "num_filter_trials_sample": 2,
        "num_filter_trials_force_fetch": 2,
        "num_filter_trials_state_description": 12,
        "grasp_affordance_threshold": 0.67,
        "pos_offset": 0.8,
        "default_z": 0.42,
        "default_orn": [0.0, -1.57, 0.0]
    },
}

@timeit
def generate(save_base_dir:Path, data_id, max_obj_no, itr=10, debug = False):

    if debug:
        sim_id =  pb.connect(pb.GUI)
    else:
        sim_id =  pb.connect(pb.DIRECT)
        # egl = pkgutil.get_loader('eglRenderer')
        # if (egl):
        #     pluginId = pb.loadPlugin(egl.get_filename(), "_eglRendererPlugin")
        # else:
        #     pluginId = pb.loadPlugin("eglRendererPlugin")

    surpress_output = debug

    bc = BulletClient(sim_id)
    control_dt = 1. / config["sim_params"]["control_hz"]
    bc.setTimeStep(control_dt)
    bc.setGravity(0, 0, config["sim_params"]["gravity"])

    bc.resetDebugVisualizerCamera(
        cameraDistance = CAMERA_DISTANCE,
        cameraYaw = CAMERA_YAW,
        cameraPitch = CAMERA_PITCH,
        cameraTargetPosition = CAMERA_TARGET_POSITION
    )
    bc.configureDebugVisualizer(bc.COV_ENABLE_RENDERING, 1)

    with open(f'/home/dongwon/research/TriPlnShape/dataset/sdf_dirs.txt', 'r') as f:
        obj_files = f.readlines()
    obj_files = [of.strip() for of in obj_files]

    env = ShopEnv(bc, config, suppress_output=surpress_output)
    robot_pos = copy.deepcopy(env.all_obj['table_right'].position)
    robot_pos[2] = 0.0
    robot_pos[0] += 1.3
    # robot_pos[0] += 1.2

    # two objects
    cat_obj_path = {'wine':env.shop_config['kitchen_config']['wine']['path'].split('/')[-1].split('.')[0],
             'coke':env.shop_config['kitchen_config']['coke']['path'].split('/')[-1].split('.')[0]}
    
    cat_obj_id = {'wine': -1, 'coke': -1}
    for oidx, of in enumerate(obj_files):
        for ocat in cat_obj_path:
            if cat_obj_path[ocat] in of:
                cat_obj_id[ocat] = oidx
    assert np.all(cat_obj_id.values() != -1)

    data_points = []
    for inner_itr in range(itr):
        obj_no = np.random.randint(2, max_obj_no+1)
        env.reset(obj_no)

        obj_uids = []
        obj_name_list = [f"obj{i+1}" for i in range(obj_no)]
        for obj_name in obj_name_list:
            obj_uids.append(env.all_obj[obj_name].uid)
        
        table_udis = [env.all_obj['table_right'].uid]

        pb_scene = sutil.SceneConverter().construct_scene_from_pybullet(obj_uids, table_udis)

        close_view = np.random.choice([True, True, True, False, False])
        scene_data = pb_scene.generate_scene_data(max_obj_no=max_obj_no, close_view=close_view)
        
        numpy_datapoints = scene_data.replace(nvren_info=None, table_params=None, robot_params=None)
        numpy_datapoints = flax.serialization.to_state_dict(numpy_datapoints)
        numpy_datapoints = {k: v for k, v in numpy_datapoints.items() if v is not None}

        data_points.append(numpy_datapoints)

        if isinstance(save_base_dir, str):
            save_base_dir = Path(save_base_dir)
        scene_data_save_path = save_base_dir/f"{data_id}_{inner_itr:02d}_table_5_{max_obj_no}.lz4"
        with lz4.frame.open(str(scene_data_save_path), mode='wb') as fp:
            fp.write(pickle.dumps(numpy_datapoints))


    # plt.figure()
    # nrow = 2
    # nview = 2
    # for i in range(0, nrow):
    #     for j in range(0, nview):
    #         plt.subplot(nrow, nview, i*nview+j+1)
    #         plt.imshow(data_points[i]['rgbs'][j])
    #         plt.axis('off')
    # plt.show()

    if debug:
        # visualize figures
        plt.figure()
        nrow = 2
        nview = 2
        for i in range(0, nrow):
            for j in range(0, nview):
                plt.subplot(nrow, nview, i*nview+j+1)
                plt.imshow(data_points[i]['rgbs'][j])
                plt.axis('off')
        plt.show()

    # else:
    #     # Save data points to an HDF5 file
    #     if isinstance(save_base_dir, str):
    #         save_base_dir = Path(save_base_dir)
    #     scene_data_save_path = save_base_dir / f"{data_id}_table_5_{max_obj_no}.h5"

    #     with h5py.File(str(scene_data_save_path), 'w') as h5f:
    #         # For efficiency, we create datasets with unlimited dimensions and chunking
    #         num_data_points = len(data_points)
    #         # First, collect all keys to ensure consistent datasets
    #         all_keys = set()
    #         for data_point in data_points:
    #             all_keys.update(data_point.keys())

    #         # Create datasets for each key
    #         datasets = {}
    #         for key in all_keys:
    #             # Determine the shape and data type from the first data point
    #             sample_value = data_points[0].get(key)
    #             if sample_value is not None:
    #                 shape = (num_data_points,) + sample_value.shape
    #                 maxshape = (None,) + sample_value.shape  # Unlimited in the first dimension
    #                 datasets[key] = h5f.create_dataset(
    #                     key, shape=shape, maxshape=maxshape, chunks=True, dtype=sample_value.dtype,
    #                     compression='gzip', compression_opts=9
    #                 )
    #             else:
    #                 print(f"Key '{key}' is None in the first data point.")

    #         # Write data to datasets
    #         for idx, data_point in enumerate(data_points):
    #             for key, dataset in datasets.items():
    #                 value = data_point.get(key)
    #                 if value is not None:
    #                     dataset[idx] = value
    #                 else:
    #                     # Handle missing keys by writing zeros or a placeholder
    #                     dataset[idx] = np.zeros_like(dataset[0])

    bc.disconnect()

if __name__ == "__main__":
    import datetime
    import ray
    max_obj_no = 8
    save_base_dir = f'predicate_scenedata/{max_obj_no}_5'
    os.makedirs(save_base_dir, exist_ok=True)
    entire_ds_size = 0
    ray_env_no = 20
    innter_itr = 4
    ray_func = []
    entire_dataset = []
    cur_timestep = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    for scene_id_ in np.arange(5000):
        scene_id_str = cur_timestep + f"_{scene_id_:04d}"

        ray_func.append(ray.remote(generate).remote(save_base_dir, scene_id_str, max_obj_no, itr=innter_itr, debug=False))
        if scene_id_ % ray_env_no == 0 and scene_id_!=0:
            ray.get(ray_func)
            ray_func = []

        # test non-ray generation
        # dataset_= generate(save_base_dir, scene_id_str, max_obj_no, itr=innter_itr, debug=False)