#! /usr/bin/python3
# This is the highest level interface to interact with the widowx setup.

import time
import argparse
import numpy as np

import torch
import torchvision.transforms as T
from transforms3d.euler import quat2euler
from transforms3d.quaternions import quat2mat, mat2quat
import transforms3d.quaternions as quat
import transforms3d.affines as aff

from widowx_envs.utils.exceptions import Environment_Exception
from widowx_envs.widowx_env_service import WidowXClient, WidowXStatus, show_video

# install from: https://github.com/youliangtan/edgeml
from edgeml.action import ActionClient, ActionServer, ActionConfig
from edgeml.internal.utils import mat_to_jpeg, jpeg_to_mat, compute_hash

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def pose_to_matrix(position, quaternion):
    R = quat.quat2mat(quaternion)  # Input should be (w, x, y, z) order!
    T = aff.compose(position, R, np.ones(3))  # No scaling
    return T

def matrix_to_pose(T):
    position, R, _, _ = aff.decompose(T)
    quaternion = quat.mat2quat(R)  # Output is (w, x, y, z)
    return position, quaternion

def convert_to_robot_frame(ee_pos, ee_quat, base_pos, base_quat):
    # Build matrices
    T_world_base = pose_to_matrix(base_pos, base_quat)
    T_world_ee = pose_to_matrix(ee_pos, ee_quat)

    # Compute T_base_ee
    T_base_ee = np.linalg.inv(T_world_base) @ T_world_ee

    # Extract position and quaternion
    ee_pos_in_base, ee_quat_in_base = matrix_to_pose(T_base_ee)
    print("ee_pos_in_base", ee_pos_in_base)
    print("ee_quat_in_base", ee_quat_in_base)
    return T_base_ee

def get_states_from_traj(traj):
    ee_states = []
    gripper_qpos = []
    for state in traj:
        base_pos = state['link_pose_p_world_based'][0]
        base_quat = state['link_pose_q_world_based'][0]
        ee_pos = state['ee_pos'].numpy()
        ee_quat = state['ee_quat'].numpy()
        print('ee_pos_in_world', ee_pos)
        print('ee_quat_in_world', ee_quat)
        T_base_ee = convert_to_robot_frame(ee_pos, ee_quat, base_pos, base_quat)
        ee_states.append(T_base_ee)
        left_gripper_qpos = state['agent_qpos'][-1].numpy()
        right_gripper_qpos = state['agent_qpos'][-2].numpy()
        gripper = np.clip((left_gripper_qpos + right_gripper_qpos) * 7.5, 0, 1)
        gripper_qpos.append(gripper)
        print("gripper_qpos", gripper)

    return ee_states, gripper_qpos


class WidowXConfigs:
    DefaultEnvParams = {
        "fix_zangle": 0.1,
        "move_duration": 0.2,
        "adaptive_wait": True,
        "move_to_rand_start_freq": 1,
        "override_workspace_boundaries": [
            [0.1, -0.15, -0.1, -1.57, 0],
            [0.45, 0.25, 0.25, 1.57, 0],
        ],
        # "action_clipping": "xyz",
        "catch_environment_except": False,
        "start_state": [
            0.11865137,
            -0.01696823,
            0.24405071,
            -0.03702571,
            -0.11837727,
            0.03907566,
            0.9994886,
        ],
        "skip_move_to_neutral": False,
        "return_full_image": False,
        "camera_topics": [
            # {"name": "/D435/color/image_raw"},
            {"name": "/blue/image_raw"},
            # {"name": "/yellow/image_raw"},
        ],
    }

    DefaultActionConfig = ActionConfig(
        port_number=5556,
        action_keys=["init", "move", "gripper", "reset", "step_action", "reboot_motor"],
        observation_keys=["image", "state", "full_image"],
        broadcast_port=5556 + 1,
    )

def main(args):        
    filename = "/home/liralab-widowx/estimate-intervention-real-robot/widowx_agent_states/proximity2_agent_state.pt"
    trajectories = torch.load(filename, weights_only=False)
    
    if args.client:
        widowx_client = WidowXClient(host=args.ip, port=args.port)

        # NOTE: this normally takes 10 seconds when first time init
        widowx_client.init(WidowXConfigs.DefaultEnvParams, image_size=256)

        # This ensures that the robot is ready to be controlled
        obs = None
        while obs is None:
            obs = widowx_client.get_observation()
            time.sleep(1)
            print("Waiting for robot to be ready...")

        # NOTE, use blocking to make sure the qpos is reset after the move
        # this is important so that step_action works in this initial position
        widowx_client.reset()
        obs = widowx_client.get_observation()

        while obs is None:
            obs = widowx_client.get_observation()
            time.sleep(1)
            print("Waiting for robot to be ready...")
        show_video(widowx_client, duration=2.5)
        input("Press [Enter] to start the experiment.")

        try:
            for traj in trajectories:
                widowx_client.reset()
                obs = widowx_client.get_observation()
                print("Replaying trajectory...")
                input("Press [Enter] to start")
                ee_pose, gripper_qpos = get_states_from_traj(traj)
                t = 0
                # select every 3rd frame
                # for ee, gripper in zip(ee_pose[::3], gripper_qpos[::3]):
                for ee, gripper in zip(ee_pose, gripper_qpos):
                    obs = widowx_client.get_observation()
                    while obs is None:
                        obs = widowx_client.get_observation()
                        time.sleep(1)
                        print("Waiting for robot to be ready...")
                    t+=1

                    widowx_client.move(ee)
                    widowx_client.move_gripper(gripper)

                input("Finished replaying trajectory. Press [Enter] to go to neutral pose.")
        except KeyboardInterrupt:
            print("\nExiting...")
        finally:
            widowx_client.stop()
            print("Done all")



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--client', action='store_true')
    parser.add_argument('--ip', type=str, default='localhost')
    parser.add_argument('--port', type=int, default=5556)
    args = parser.parse_args()

    main(args)