from copy import deepcopy

import gym
import numpy as np
import torch

from droid.calibration.calibration_utils import load_calibration_info
from droid.camera_utils.info import camera_type_dict
from droid.camera_utils.wrappers.multi_camera_wrapper import MultiCameraWrapper
from droid.misc.parameters import hand_camera_id, nuc_ip
from droid.misc.time import time_ms
from droid.misc.transformations import change_pose_frame
from droid.robot_ik.robot_ik_solver import RobotIKSolver

from furniture_bench.robot.panda import Panda

class TOPPRAEnv(gym.Env):
    def __init__(
        self, 
        action_space="joint_position", 
        gripper_action_space=None, 
        camera_kwargs={}, 
        do_reset=True, 
        control_hz=15, 
        toppra_last_vel=1.0,
        gain_scale=1.0,
        vel_gain_scale=1.0
    ):
        # Initialize Gym Environment
        super().__init__()

        # Define Action Space #
        assert action_space in ["cartesian_position", "joint_position", "cartesian_velocity", "joint_velocity"]
        self.action_space = action_space
        self.gripper_action_space = gripper_action_space
        self.check_action_range = "velocity" in action_space

        # Robot Configuration
        self.reset_joints = np.array([0, -1 / 5 * np.pi, 0, -4 / 5 * np.pi, 0, 3 / 5 * np.pi, 0.0])
        self.randomize_low = np.array([-0.1, -0.2, -0.1, -0.3, -0.3, -0.3])
        self.randomize_high = np.array([0.1, 0.2, 0.1, 0.3, 0.3, 0.3])
        self.DoF = 7 if ("cartesian" in action_space) else 8
        self.control_hz = control_hz
        self.toppra_last_vel = toppra_last_vel
        if "velocity" in self.action_space:
            self._ik_solver = RobotIKSolver(control_hz=self.control_hz, relative_max_joint_delta=np.array([0.2] * 7))

        robot_config = {
            "server_ip": nuc_ip,
            "reset_joints": self.reset_joints.tolist(),
            "FR3": False,
            "hz": self.control_hz,
            "position_limits": [[-1.0, -1.0, 0.0], [1.0, 1.0, 1.2]],
            "motion_stopped_counter_threshold": 100,
        }
        kwargs = {
            "binary_grasping": False,
            "gain_scale": gain_scale,
            "vel_gain_scale": vel_gain_scale,
        }
        self.robot = Panda(robot_config=robot_config, toppra=True, **kwargs)

        # Create Cameras
        self.camera_reader = MultiCameraWrapper(camera_kwargs)
        self.calibration_dict = load_calibration_info()
        self.camera_type_dict = camera_type_dict

        # Reset Robot
        if do_reset:
            self.reset()

    def step_joint(self, action):
        if "velocity" in self.action_space:
            # action is a chunk of joint velocities
            current_joints = self.robot.get_state()[0].joint_positions

            positions = []
            q_pos = current_joints
            for q_vel in action:
                # assuming last element is gripper
                joint_vel = q_vel[:-1]
                gripper_action = q_vel[-1]
                joint_delta = self._ik_solver.joint_velocity_to_delta(joint_vel)
                print("joint_vel", joint_vel)
                print("joint_delta", joint_delta)
                q_pos = q_pos + joint_delta
                positions.append(np.append(q_pos, gripper_action))
            action = np.array(positions)
        print("execute joint")
        print("action", action)
        
        if not self.robot.arm.is_running_policy():
            self.robot.init_controller(None, None, "JOINT_POSITION_TRACKING")
        self.robot.execute_joint(action, joint_pos_track=True)

    def reset(self, randomize=False):
        self.robot.open_gripper(blocking=True)

        if randomize:
            # noise = np.random.uniform(low=self.randomize_low, high=self.randomize_high)
            print("Warning: randomize=True is not implemented for Panda robot in PandaEnv.")
        else:
            noise = None
            
        # self.robot.go_home()

        self.robot.arm.move_to_joint_positions(self.reset_joints)
        
        self.robot.init_controller(None, None, "JOINT_POSITION_TRACKING")
        self.robot.init_interpolator("TOPPRA", toppra_last_vel=self.toppra_last_vel)

    def read_cameras(self):
        return self.camera_reader.read_cameras()

    def get_state(self):
        read_start = time_ms()
        robot_state, panda_error = self.robot.get_state()
        state_dict = robot_state.__dict__

        if state_dict["ee_pos"] is None:
            joint_positions = state_dict["joint_positions"]
            pos, quat = self.robot.arm.robot_model.forward_kinematics(torch.from_numpy(joint_positions).float())
            state_dict["ee_pos"] = pos.numpy()
            state_dict["ee_quat"] = quat.numpy()

        timestamp_dict = {}
        timestamp_dict["read_start"] = read_start
        timestamp_dict["read_end"] = time_ms()
        return state_dict, timestamp_dict

    def get_camera_extrinsics(self, state_dict):
        # Adjust gripper camere by current pose
        extrinsics = deepcopy(self.calibration_dict)
        for cam_id in self.calibration_dict:
            if hand_camera_id not in cam_id:
                continue
            gripper_pose = np.concatenate([state_dict["ee_pos"], state_dict["ee_quat"]])
            extrinsics[cam_id + "_gripper_offset"] = extrinsics[cam_id]
            extrinsics[cam_id] = change_pose_frame(extrinsics[cam_id], gripper_pose)
        return extrinsics

    def get_observation(self):
        obs_dict = {"timestamp": {}}

        # Robot State #
        state_dict, timestamp_dict = self.get_state()
        obs_dict["robot_state"] = state_dict
        obs_dict["timestamp"]["robot_state"] = timestamp_dict

        # Camera Readings #
        camera_read_start_time = time_ms()
        camera_obs, camera_timestamp = self.read_cameras()
        obs_dict.update(camera_obs)
        obs_dict["timestamp"]["cameras"] = camera_timestamp
        # print("camera_read_time: ", time_ms() - camera_read_start_time)
        # for k in camera_timestamp:
        #     print(k, camera_timestamp[k] - camera_timestamp[k.split("_")[0] + "_read_start"])

        # Camera Info #
        camera_info_start_time = time_ms()
        obs_dict["camera_type"] = deepcopy(self.camera_type_dict)
        extrinsics = self.get_camera_extrinsics(state_dict)
        obs_dict["camera_extrinsics"] = extrinsics

        intrinsics = {}
        for cam in self.camera_reader.camera_dict.values():
            cam_intr_info = cam.get_intrinsics()
            for (full_cam_id, info) in cam_intr_info.items():
                intrinsics[full_cam_id] = info["cameraMatrix"]
        obs_dict["camera_intrinsics"] = intrinsics
        # print("camera_info_time: ", time_ms() - camera_info_start_time)

        return obs_dict

    def get_total_completed_waypoints(self):
        return self.robot.get_total_completed_waypoints()

    def get_replan_chunk_index(self):
        return self.robot.get_replan_chunk_index()

    def get_num_remaining_waypoints(self):
        return self.robot.get_num_remaining_waypoints()
