#!/usr/bin/env python
from __future__ import print_function
###
# install Kinova_gen3 software by following https://github.com/Kinovarobotics/ros_kortex
# type 'source devel/setup.bash' in Kinova_gen3 software workspace
# launch the driver: 'roslaunch kortex_driver kortex_driver.launch ip_address:=<ip.address> gripper:=robotiq_2f_85'
# source your python2 virtualenv: source your_venv/bin/activate
# change the directory to this package, 'chmod +x RL.py
# type 'python RL.py'
###

import sys
import os
from ..RSI2.RL_env_RSI2 import RLEnvRSI2, Recorder
import rospy
import select, tty, termios
import time
import numpy as np
import pyrealsense2 as rs
import cv2
import gym
import copy
from kortex_driver.srv import *
from kortex_driver.msg import *
import sounddevice as sd
from PIL import Image, ImageEnhance


class ROSRLEnv(RLEnvRSI2):

    def __init__(self):
        RLEnvRSI2.__init__(self)

        # # observation space
        d = {
            'image': self.observation_space['image'],
            'goal_sound': self.observation_space['goal_sound'],
            'current_sound':self.observation_space['goal_sound'], # this only acts as a placeholder
            'robot_pose': self.observation_space['robot_pose'],
            'ground_truth': self.observation_space['ground_truth']
        }

        self.observation_space = gym.spaces.Dict(d)
        self.maxSteps = self.config.RLEnvMaxSteps
        self.desiredEndEffectorPos = np.array([0.0, 0.0, 0.0])
        self.odom = None

        # initialize realsense camera
        if self.config.render:
            # Show images
            cv2.namedWindow('RealSense', cv2.WINDOW_AUTOSIZE)
        # Configure color streams
        self.realSensePipeline = rs.pipeline()
        self.realSenseConfig = rs.config()


        # Get device product line for setting a supporting resolution
        pipeline_wrapper = rs.pipeline_wrapper(self.realSensePipeline)
        pipeline_profile = self.realSenseConfig.resolve(pipeline_wrapper)
        device = pipeline_profile.get_device()
        device_product_line = str(device.get_info(rs.camera_info.product_line))
        self.realSenseConfig.enable_stream(rs.stream.color, 640, 480, rs.format.rgb8, 30)

        # Start streaming
        self.realSensePipeline.start(self.realSenseConfig)
        self.realSenseSensor = self.realSensePipeline.get_active_profile().get_device().query_sensors()[1]
        self.realSenseOption_range = self.realSenseSensor.get_option_range(rs.option.brightness)
        self.realSenseSensor.set_option(rs.option.brightness, 50)

        # ROS and physical robot initialization
        try:
            rospy.init_node('ROSEnv')

            self.action_topic_sub = None
            self.all_notifs_succeeded = True

            # Get node params
            self.robot_name = rospy.get_param('~robot_name', "my_gen3")

            rospy.loginfo("Using robot_name " + self.robot_name)

            # Init the action topic subscriber
            self.action_topic_sub = rospy.Subscriber("/" + self.robot_name + "/action_topic", ActionNotification,
                                                     self.cb_action_topic)
            self.last_action_notif_type = None

            # Init the cartesian velocity publisher
            self.cartesian_vel_pub = rospy.Publisher("/" + self.robot_name + "/in/cartesian_velocity", TwistCommand,
                                                     queue_size=10)

            # Init the services
            clear_faults_full_name = '/' + self.robot_name + '/base/clear_faults'
            rospy.wait_for_service(clear_faults_full_name)
            self.clear_faults = rospy.ServiceProxy(clear_faults_full_name, Base_ClearFaults)

            send_gripper_command_full_name = '/' + self.robot_name + '/base/send_gripper_command'
            rospy.wait_for_service(send_gripper_command_full_name)
            self.send_gripper_command_service = rospy.ServiceProxy(send_gripper_command_full_name, SendGripperCommand)

            stop_full_name = '/' + self.robot_name + '/base/stop'
            rospy.wait_for_service(stop_full_name)
            self.robot_stop=rospy.ServiceProxy(stop_full_name, Stop)

            play_cartesian_trajectory_full_name = '/' + self.robot_name + '/base/play_cartesian_trajectory'
            rospy.wait_for_service(play_cartesian_trajectory_full_name)
            self.play_cartesian_trajectory = rospy.ServiceProxy(play_cartesian_trajectory_full_name,
                                                                PlayCartesianTrajectory)

            play_joint_trajectory_full_name = '/' + self.robot_name + '/base/play_joint_trajectory'
            rospy.wait_for_service(play_joint_trajectory_full_name)
            self.play_joint_trajectory = rospy.ServiceProxy(play_joint_trajectory_full_name, PlayJointTrajectory)

            set_cartesian_reference_frame_full_name = '/' + self.robot_name + '/control_config/set_cartesian_reference_frame'
            rospy.wait_for_service(set_cartesian_reference_frame_full_name)
            self.set_cartesian_reference_frame = rospy.ServiceProxy(set_cartesian_reference_frame_full_name,
                                                                    SetCartesianReferenceFrame)

            activate_publishing_of_action_notification_full_name = '/' + self.robot_name + '/base/activate_publishing_of_action_topic'
            rospy.wait_for_service(activate_publishing_of_action_notification_full_name)
            self.activate_publishing_of_action_notification = rospy.ServiceProxy(
                activate_publishing_of_action_notification_full_name, OnNotificationActionTopic)

        except:
            self.is_init_success = False
            print("ROS subscriber or services initialization failed")
        else:
            self.is_init_success = True

        if self.is_init_success:
            # Make sure to clear the robot's faults else it won't move if it's already in fault
            self.clear_faults()
            # Activate the action notifications
            self.subscribe_to_a_robot_notification()
            self.send_gripper_command(0.0)  # open gripper
            # Set the reference frame to "CARTESIAN_REFERENCE_FRAME_BASE"
            self.set_cartesian_reference_frame()

        rospy.on_shutdown(self.close_env)

        if self.config.record:
            self.episodeRecorder = Recorder()  # used for recording data
            self.episodeRecorder.saveTo = self.config.recordSaveDir
            self.episodeRecorder.episodeInitNum=self.config.recorderInitNum
            if self.config.loadAction:
                self.episodeRecorder.loadFrom = self.config.loadActionFile
                self.episodeRecorder.loadActions()

    def close_env(self):
        self.stop_robot()
        self.realSenseSensor.set_option(rs.option.brightness, self.realSenseOption_range.default)
        self.realSensePipeline.stop()

    def getKey(self):
        fd = sys.stdin.fileno()
        old_settings = termios.tcgetattr(fd)
        try:
            tty.setraw(sys.stdin.fileno())
            ch = sys.stdin.read(1)
        finally:
            termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
        return ch

    def get_image(self):
        color_frame = None
        color_image = None
        while not rospy.is_shutdown():
            # Wait for a coherent pair of frames: depth and color
            frames = self.realSensePipeline.wait_for_frames()
            color_frame = frames.get_color_frame()
            # if not depth_frame or not color_frame:
            if color_frame:
                # Convert images to numpy arrays
                color_image = np.asanyarray(color_frame.get_data())
                break

        # process the image, the image will be of size (height=480,width=640,3)
        resized_img=cv2.resize(color_image,(self.config.externalCamRenderSize[1], self.config.externalCamRenderSize[0]))
        resized_img = resized_img[:, 12:87, :]
        img = cv2.resize(resized_img, (self.config.img_dim[2], self.config.img_dim[1]))
        # image sharpen
        enhancer = ImageEnhance.Sharpness(Image.fromarray(img))
        img = np.array(enhancer.enhance(2.5))
        if self.config.render:
            cv2.imshow('RealSense', cv2.cvtColor(cv2.resize(img, (224,224)), cv2.COLOR_RGB2BGR))
            cv2.waitKey(1)

        return img

    def gen_obs(self):

        image = self.get_image()
        self.saveEpisodeImage(image)

        # update odom
        self.odom = self.get_odom()
        s = {'eeState': self.odom}

        if self.envStepCounter == 0:
            if self.config.hideObj['mode'] == 'random':
                prob = np.ones((self.config.taskNum,)) / (self.config.taskNum - self.config.hideObj['hideNum'])
                prob[self.hideObjIdx] = 0.
                self.goalObjIdx = self.np_random.choice(self.config.taskNum, replace=False, p=prob)

            elif self.config.hideObj['mode'] == 'fix':
                prob = np.ones((self.config.taskNum,)) / (self.config.taskNum - len(self.config.hideObj['hideIdx']))
                prob[self.config.hideObj['hideIdx']] = 0.
                self.goalObjIdx = self.np_random.choice(self.config.taskNum, replace=False, p=prob)

            # all 4 objects are present
            elif self.config.hideObj['mode'] == 'none':
                # randomly select an object
                if self.config.RLTrain or self.config.render:
                    self.goalObjIdx = 1#self.np_random.randint(0, self.config.taskNum)

                else:
                    idx = np.where(self.size_per_class_cumsum <= self.episodeCounter)[0]
                    self.goalObjIdx = 0 if len(idx) == 0 else int(idx.max() + 1)
            else:
                raise NotImplementedError

            self.goal_sound, self.goal_audio = self.audio.genSoundFeat(objIndx=self.goalObjIdx, featType='MFCC',
                                                                       rand_fn=self.np_random.randint)

            self.ground_truth = np.int32(self.goalObjIdx)
            print('Goal object is------------------------------------', self.goalObjIdx)
            if self.config.render or self.config.RLTrain == False:
                if self.goal_audio is not None:
                    sd.play(self.goal_audio, self.audio.fs)

        obs = {
            'image': np.transpose(image, (2, 0, 1)),  # for PyTorch convolution,
            'goal_sound': self.goal_sound,
            'current_sound': np.zeros(self.config.sound_dim), # no use
            'robot_pose': np.array([s['eeState'][0], s['eeState'][1]]),
            'ground_truth':self.ground_truth,
        }

        if self.config.record:
            self.episodeRecorder.eePositionList.append(s['eeState'])
            self.episodeRecorder.eePositionList_desired.append(copy.deepcopy(self.robot.desiredEndEffectorPos))

        return obs, s

    def stop_robot(self):
        req = StopRequest()
        try:
            self.robot_stop(req)
        except rospy.ServiceException:
            rospy.logerr("Failed to call Stop")

    def envReset(self):
        print('episode', self.episodeCounter)
        self.stop_robot()
        self.randomization()
        ret = self.gen_obs()
        return ret[0]

    def randomization(self):

        if not self.config.RLTrain or self.episodeCounter%10==0:
            print("Manually randomize the blocks")
            while not rospy.is_shutdown():
                print ("Press 'y' for the next episode and any other key for exiting\t")
                a = raw_input()
                if a == 'y':
                    break
                else:
                    exit()

        if self.config.record and self.config.loadAction:
            self.desiredEndEffectorPos[0]=self.episodeRecorder.loadedAction[0, 0]
            self.desiredEndEffectorPos[1]=self.episodeRecorder.loadedAction[0, 1]
        else:
            self.desiredEndEffectorPos[0] = np.random.uniform(self.config.xMin + 0.05, self.config.xMax - 0.05)
            self.desiredEndEffectorPos[1] = np.random.uniform(self.config.yMin + 0.05, self.config.yMax - 0.05)
        # Kinova driver calculate inverse kinematics using gripper tip while we use end-effector frame,
        # the transform between gripper and the end-effector frame is 0.12m along z-axis
        self.desiredEndEffectorPos[2] = self.config.endEffectorHeight-0.12
        self.send_cartesian_pose()
        self.wait_for_action_end_or_abort()

    def keyboardControl(self):
        dx, dy, dz = 0.0, 0.0, 0.0
        k = self.getKey()
        if k == 'w':
            dx = -0.01
        if k == 's':
            dx = 0.01
        if k == 'a':
            dy = -0.01
        if k == 'd':
            dy = 0.01

        return dx,dy,dz,k

    def step(self, action):
        if self.config.RLManualControl:  # use keyboard to control the real robot
           dx, dy, dz,k=self.keyboardControl()

        else:
            dv = 0.01 #was 0.02
            dx = float(np.clip(action[0], -1, +1)) * dv  # move in x direction
            dy = float(np.clip(action[1], -1, +1)) * dv  # move in y direction
            dz = 0.

        if self.config.record and self.config.loadAction:  # replace network output with other actions if needed
            self.desiredEndEffectorPos[0] = self.episodeRecorder.loadedAction[self.envStepCounter, 0]
            self.desiredEndEffectorPos[1] = self.episodeRecorder.loadedAction[self.envStepCounter, 1]
        else:

            self.desiredEndEffectorPos[0] = np.clip(self.desiredEndEffectorPos[0] + dx, a_min=self.config.xMin, a_max=self.config.xMax)
            self.desiredEndEffectorPos[1] = np.clip(self.desiredEndEffectorPos[1] + dy, a_min=self.config.yMin, a_max=self.config.yMax)
            self.desiredEndEffectorPos[2] = self.desiredEndEffectorPos[2] + dz

        try:
            # apply action
            if self.config.RLManualControl:
                self.send_cartesian_pose()
                self.wait_for_action_end_or_abort()
            else:
                vel_vector = self.desiredEndEffectorPos - self.odom[:3]
                self.send_cartesian_vel(vel_vector)
                rospy.sleep(self.config.ROSStepInterval)  # act as frame skip

            self.envStepCounter = self.envStepCounter + 1

            # get observations
            obs, s = self.gen_obs()

        except rospy.ROSInterruptException:
            exit()

        if not self.config.RLTrain:
            print ('Step', self.envStepCounter)

        r = [self.rewards()]  # calculate reward
        self.reward = sum(r)
        self.episodeReward = self.episodeReward + self.reward
        self.done = self.termination()

        infoDict = {}
        if self.done:

            if self.config.record:
                self.episodeRecorder.saveEpisode(self.episodeCounter)
            self.stop_robot()

            if not self.config.RLTrain:
                infoDict['goal_area_count'] = 0
                # lower the gripper and do the grasp
                while not rospy.is_shutdown():
                    print("Press 'y' to grasp and any other key to not grasp\t")

                    a = raw_input()
                    if a == 'y':
                        # lower the gripper
                        self.desiredEndEffectorPos[2]=self.desiredEndEffectorPos[2] - 0.08
                        self.send_cartesian_pose()
                        self.wait_for_action_end_or_abort()
                        # closed the gripper
                        self.send_gripper_command(0.6)

                        # lift the gripper
                        self.desiredEndEffectorPos[2] = self.desiredEndEffectorPos[2] + 0.11

                        self.send_cartesian_pose()
                        self.wait_for_action_end_or_abort()

                        # open the gripper (optional)
                        print("Press 'y' to NOT open the gripper and any other key to open\t")
                        b = raw_input()
                        if b != 'y':
                            self.send_gripper_command(0.0)

                        # lower the gripper
                        self.desiredEndEffectorPos[2] = self.desiredEndEffectorPos[2] - 0.11
                        self.send_cartesian_pose()
                        self.wait_for_action_end_or_abort()

                        # open the gripper
                        self.send_gripper_command(0.0)
                        self.wait_for_action_end_or_abort()

                        # lift the gripper
                        self.desiredEndEffectorPos[2] = self.desiredEndEffectorPos[2] + 0.08
                        self.send_cartesian_pose()
                        self.wait_for_action_end_or_abort()
                        break
                    else:
                        break

        return obs, self.reward, self.done, infoDict  # reset will be called if done

    def termination(self):
        if self.envStepCounter >= self.maxSteps:
            return True
        return False

    def rewards(self):
        return 0

    def cb_action_topic(self, notif):
        self.last_action_notif_type = notif.action_event

    def wait_for_action_end_or_abort(self):
        while not rospy.is_shutdown():
            if self.last_action_notif_type == ActionEvent.ACTION_END:
                rospy.loginfo("Received ACTION_END notification")
                return True
            elif self.last_action_notif_type == ActionEvent.ACTION_ABORT:
                rospy.loginfo("Received ACTION_ABORT notification")
                self.all_notifs_succeeded = False
                return False
            else:
                time.sleep(0.01)

    def clear_faults(self):
        try:
            self.clear_faults()
        except rospy.ServiceException:
            rospy.logerr("Failed to call ClearFaults")
            return False
        else:
            rospy.loginfo("Cleared the faults successfully")
            rospy.sleep(2.5)
            return True

    def send_joint_angles(self, desired_joint_angle): 
        self.last_action_notif_type = None
        # Create the list of angles
        req = PlayJointTrajectoryRequest()
        # Here the arm is vertical (all zeros)
        for i in range(len(desired_joint_angle)):
            temp_angle = JointAngle()
            temp_angle.joint_identifier = i
            temp_angle.value = np.rad2deg(desired_joint_angle[i])
            req.input.joint_angles.joint_angles.append(temp_angle)

        # Send the angles
        rospy.loginfo("Sending joint angles...")
        try:
            self.play_joint_trajectory(req)
        except rospy.ServiceException:
            rospy.logerr("Failed to call PlayJointTrajectory")
            return False
        else:
            return self.wait_for_action_end_or_abort()

    def set_cartesian_reference_frame(self):
        self.last_action_notif_type = None
        # Prepare the request with the frame we want to set
        req = SetCartesianReferenceFrameRequest()
        req.input.reference_frame = CartesianReferenceFrame.CARTESIAN_REFERENCE_FRAME_BASE

        # Call the service
        try:
            self.set_cartesian_reference_frame()
        except rospy.ServiceException:
            rospy.logerr("Failed to call SetCartesianReferenceFrame")
            return False
        else:
            rospy.loginfo("Set the cartesian reference frame successfully")
            return True

    def subscribe_to_a_robot_notification(self):
        # Activate the publishing of the ActionNotification
        req = OnNotificationActionTopicRequest()
        rospy.loginfo("Activating the action notifications...")
        try:
            self.activate_publishing_of_action_notification(req)
        except rospy.ServiceException:
            rospy.logerr("Failed to call OnNotificationActionTopic")
            return False
        else:
            rospy.loginfo("Successfully activated the Action Notifications!")

        rospy.sleep(1.0)

        return True

    def send_cartesian_vel(self, vel_vector):
        self.last_action_notif_type = None
        tc = TwistCommand()
        tc.reference_frame = 0
        tc.twist.linear_x = vel_vector[0]
        tc.twist.linear_y = vel_vector[1]
        tc.twist.linear_z = 0.0
        tc.twist.angular_x = 0.0
        tc.twist.angular_y = 0.0
        tc.twist.angular_z = 0.0
        tc.duration = 0

        self.cartesian_vel_pub.publish(tc)

    def send_cartesian_pose(self):
        self.last_action_notif_type = None
        req = PlayCartesianTrajectoryRequest()
        req.input.target_pose.x = self.desiredEndEffectorPos[0]
        req.input.target_pose.y = self.desiredEndEffectorPos[1]
        req.input.target_pose.z = self.desiredEndEffectorPos[2]
        req.input.target_pose.theta_x = 0
        req.input.target_pose.theta_y = np.rad2deg(np.pi)
        req.input.target_pose.theta_z = np.rad2deg(np.pi / 2)

        pose_speed = CartesianSpeed()
        pose_speed.translation = 0.06
        pose_speed.orientation = 15

        # The constraint is a one_of in Protobuf. The one_of concept does not exist in ROS
        # To specify a one_of, create it and put it in the appropriate list of the oneof_type member of the ROS object :
        req.input.constraint.oneof_type.speed.append(pose_speed)

        # Call the service
        rospy.loginfo("Sending the robot to the cartesian pose...")
        try:
            self.play_cartesian_trajectory(req)
        except rospy.ServiceException:
            rospy.logerr("Failed to call PlayCartesianTrajectory")
            return False
        else:
            return True

    def send_gripper_command(self, value):

        # Initialize the request
        req = SendGripperCommandRequest()
        finger = Finger()
        finger.finger_identifier = 0
        finger.value = value
        req.input.gripper.finger.append(finger)
        req.input.mode = GripperMode.GRIPPER_POSITION

        rospy.loginfo("Sending the gripper command...")

        # Call the service
        try:
            self.send_gripper_command_service(req)
        except rospy.ServiceException:
            rospy.logerr("Failed to call SendGripperCommand")
            return False
        else:
            rospy.sleep(2.0)
            return True

    def get_odom(self):
        feedback = rospy.wait_for_message("/" + self.robot_name + "/base_feedback", BaseCyclic_Feedback)
        x = feedback.base.tool_pose_x
        y = feedback.base.tool_pose_y
        z = feedback.base.tool_pose_z
        theta_x = feedback.base.tool_pose_theta_x
        theta_y = feedback.base.tool_pose_theta_y
        theta_z = feedback.base.tool_pose_theta_z
        return np.array([x, y, z, theta_x, theta_y, theta_z])
