import json
import h5py
import omnigibson as og
import torch as th
th.set_printoptions(precision=3, sci_mode=False)
import numpy as np
np.set_printoptions(precision=3, suppress=True)
from omnigibson.macros import create_module_macros
from omnigibson.action_primitives.starter_semantic_action_primitives import StarterSemanticActionPrimitives
import omnigibson.utils.transform_utils as T
from scipy.spatial.transform import Rotation as R
import omnigibson.lazy as lazy
from omnigibson import object_states


seed = 0
np.random.seed(seed)
th.manual_seed(seed)

# Load the scene from the hdf5 file
f = h5py.File("/home/arpit/test_projects/OmniGibson/teleop_collected_data/r1_dishes_away.hdf5", "r")
config = f["data"].attrs["config"]
config = json.loads(config)

# Custom changes
reset_base_pose = (config["robots"][0]["position"], config["robots"][0]["orientation"])
config["scene"]["load_room_instances"] = ["kitchen_0", "dining_room_0", "entryway_0", "living_room_0"]
config["robots"][0]["position"] = [0.0, 0.0, 0.0]
config["robots"][0]["orientation"] = [0.0, 0.0, 0.0, 1.0]
# config["init_curobo"] = True
# config["env"]["flatten_obs_space"] = True
config["robots"][0]["reset_joint_pos"] =  [
                0.0000,
                0.0000,
                0.000,
                0.000,
                0.000,
                -0.0000, # 6 virtual base joint 
                0.5,
                -1.0,
                -0.8,
                -0.0000, # 4 torso joints
                -0.000,
                0.000,
                1.8944,
                1.8945,
                -0.9848,
                -0.9849,
                1.5612,
                1.5621,
                0.9097,
                0.9096,
                -1.5544,
                -1.5545,
                0.0500,
                0.0500,
                0.0500,
                0.0500,
            ]

# Sensor update
config["robots"][0]["obs_modalities"] = ["rgb", "depth_linear", "seg_instance"]
config["robots"][0]["sensor_config"]["VisionSensor"]["sensor_kwargs"]["image_height"] = 256
config["robots"][0]["sensor_config"]["VisionSensor"]["sensor_kwargs"]["image_width"] = 256
config["robots"][0]["sensor_config"]["VisionSensor"]["sensor_kwargs"]["horizontal_aperture"] = 25.0

env = og.Environment(configs=config)

controller_config = {
    "base": {"name": "HolonomicBaseJointController", "motor_type": "position", "command_input_limits": None, "use_impedances": False},
    "trunk": {"name": "JointController", "motor_type": "position", "use_delta_commands": False, "command_input_limits": None, "use_impedances": False},
    "arm_left": {"name": "JointController", "motor_type": "position", "use_delta_commands": False, "command_input_limits": None, "use_impedances": False},
    "arm_right": {"name": "JointController", "motor_type": "position", "use_delta_commands": False, "command_input_limits": None, "use_impedances": False},
    "gripper_left": {"name": "MultiFingerGripperController", "mode": "binary", "command_input_limits": (0.0, 1.0),},
    "gripper_right": {"name": "MultiFingerGripperController", "mode": "binary", "command_input_limits": (0.0, 1.0),},
    "camera": {"name": "JointController", "motor_type": "position", "use_delta_commands": False, "command_input_limits": None, "use_impedances": False},
}

env.robots[0].reload_controllers(controller_config=controller_config)
# env.robots[0]._grasping_mode = "sticky"
robot = env.robots[0]
robot.set_position_orientation(position=th.tensor(reset_base_pose[0]), orientation=th.tensor(reset_base_pose[1]))
for _ in range(5): og.sim.step()

primitive = StarterSemanticActionPrimitives(env, env.robots[0], enable_head_tracking=False, curobo_batch_size=10)

for _ in range(100): og.sim.step()
right_eef_pose = robot.links["right_eef_link"].get_position_orientation()
# breakpoint()


# set state
import pickle
state = pickle.load(open("/home/arpit/test_projects/mimicgen/random_files/start_of_last_nav.pickle", "rb"))
og.sim.load_state(state)
for _ in range(20): og.sim.step()
current_right_eef_pose = robot.links["right_eef_link"].get_position_orientation()
# breakpoint()

eef_pose = {'left': (th.tensor([[ 7.071, -1.833,  1.599],
        [ 7.086, -1.879,  1.614],
        [ 7.095, -1.917,  1.580],
        [ 7.149, -1.971,  1.583],
        [ 7.120, -1.991,  1.589],
        [ 7.116, -1.992,  1.590],
        [ 7.118, -1.981,  1.577],
        [ 7.108, -1.845,  1.525]]), th.tensor([[ 0.769, -0.226,  0.069,  0.593],
        [ 0.794, -0.165,  0.091,  0.579],
        [ 0.825, -0.130,  0.091,  0.542],
        [ 0.834, -0.102,  0.125,  0.528],
        [ 0.824, -0.123,  0.116,  0.541],
        [ 0.821, -0.128,  0.111,  0.546],
        [ 0.846, -0.114,  0.106,  0.509],
        [ 0.822, -0.244,  0.067,  0.510]])), 'right': (th.tensor([[ 6.529, -1.769,  1.300],
        [ 6.529, -1.769,  1.300],
        [ 6.529, -1.769,  1.300],
        [ 6.529, -1.769,  1.300],
        [ 6.529, -1.769,  1.300],
        [ 6.529, -1.769,  1.300],
        [ 6.529, -1.769,  1.300],
        [ 6.529, -1.769,  1.300]]), th.tensor([[-0.194,  0.804, -0.551,  0.116],
        [-0.194,  0.804, -0.551,  0.116],
        [-0.194,  0.804, -0.551,  0.116],
        [-0.194,  0.804, -0.551,  0.116],
        [-0.194,  0.804, -0.551,  0.116],
        [-0.194,  0.804, -0.551,  0.116],
        [-0.194,  0.804, -0.551,  0.116],
        [-0.194,  0.804, -0.551,  0.116]]))}
plate_603 = env.scene.object_registry("name", "plate_603")
ref_obj = plate_603
primitive._tracking_object = ref_obj

plate_601 = env.scene.object_registry("name", "plate_601")
plate_602 = env.scene.object_registry("name", "plate_602")
attached_obj = {"right_eef_link": plate_602.root_link, "left_eef_link": plate_601.root_link}
attached_obj_scale = {"right_eef_link": 0.9, "left_eef_link": 0.9}
primitive.attached_obj_info = {"attached_obj": attached_obj, "attached_obj_scale": attached_obj_scale}

primitive._motion_generator.update_obstacles()
# breakpoint()
action_generator = primitive._navigate_to_obj(obj=ref_obj, eef_pose={"left": eef_pose["left"]}, visibility_constraint=True)

# next(iter(action_generator))
breakpoint()
    

for mp_action in action_generator:
    if mp_action is None:
        break
    
    mp_action = mp_action.cpu().numpy()
    obs, _, _, _, info = env.step(mp_action)
    
breakpoint()