import math
import time
from interbotix_xs_modules.locobot import InterbotixLocobotXS
import time
# rospy for the subscriber
import rospy
# ROS Image message
# from sensor_msgs.msg import Image
# ROS Image message -> OpenCV2 image converter
from cv_bridge import CvBridge
# OpenCV2 for saving an image
import cv2
import numpy as np
import pyrealsense2 as rs

from PIL import Image

from  locobot_env.resources.motor import Motor

class LoCoBotInterface():
    # Constructor and move arm to resting position
    def __init__(self):
        print("THIS INTERFACE H")
        self.locobot = InterbotixLocobotXS(robot_model="locobot_wx200", arm_model="mobile_wx200")
        self.locobot.gripper.set_pressure(1)
        self.open = False
        self.current_trajectory = []
        self.it = 0
        self.moving_time = 1
        self.last_msg = None
        ctx = rs.context()
        devices = ctx.query_devices()
        for dev in devices:
            dev.hardware_reset()
        self.pipe = rs.pipeline()
        self.profile = self.pipe.start()
        self.camera_image = 0
        self.motor = Motor()

        print("MOTOR", self.motor.BOTTOM_POSITION)

        self.go_rest()
        # self.reset()

    # Define resting position as right in front of base of arm
    def go_rest(self):
        # [-0.0015339808305725455, -0.07056311517953873, 0.9464661478996277, 0.4770680367946625, 0.016873789951205254]
        # new_home_joints = [0.004601942375302315, -0.3160000443458557, 1.1612235307693481, 0.7639224529266357, -3.1047770977020264]
        weird = [-0.09050486981868744, -0.07363107800483704, 1.0505341701507568, 0.4, -3.1293208599090576]
        self.locobot.arm.set_joint_positions(weird, moving_time=self.moving_time)
        self.open_gripper()

    # Use Rospy to get an image from the camera
    # def get_image(self):
    #     image_topic = "/locobot/camera/color/image_raw"
    #     start = time.time()
    #     print("waiting for message...")
    #     msg = rospy.wait_for_message(image_topic, Image)
    #     print("Image took", time.time() - start)
    #     cv2_img = bridge.imgmsg_to_cv2(msg, "rgb8")
    #     return cv2_img

    # Use pyrealsense2 to get an image from the camera
    def get_image(self):
        frames = self.pipe.wait_for_frames()
        color_frame = frames.get_color_frame() 

        rgb = np.array(color_frame.get_data())

        # For some reason the first image usually has pretty bad quality
        im = Image.fromarray(rgb)
        im.save(f"/home/locobot/Desktop/Real_locobot_env/goalrelabel_locobot_fullgcsl/examples/all_camera_images/{self.camera_image}.jpeg")
        self.camera_image += 1

        if self.camera_image > 50000:
            self.camera_image = 0

        return rgb
    
    # Use numpy to convert cv2 image into a more usable image by
    # modifying bounds, rotating the image and converting the image to an RGB array
    def get_image_rgb(self) :
        rgb = self.get_image()[170:520,555:850] # old [160:500,550:830] # Plastic: [250:480,520:740]
        # rgb = np.array(cv2.resize(rgb, (145, 145)))
        rgb = np.flip(np.transpose(rgb, (1, 0, 2)), 1)
        return rgb
    
    # Test to see if a particular pixel is blue using its RGB values
    def is_blue(self, pixel):
        return (pixel[2] > 100 and int(pixel[2]) > 25 + int(pixel[1]) + int(pixel[0]))
    
    # Test to see if a particular pixel is lime green using its RGB values
    def is_green(self, pixel):
        return (pixel[1] > 100 and int(pixel[1]) > 50 + int(pixel[2]))
    
    # Find the average weight of all blue pixels
    def get_blue_average_pixel(self, rgb):
        rgb = np.copy(rgb).astype(np.float32)
        rgb = rgb * 184 / np.mean(rgb)

        x_sum = 0
        y_sum = 0
        L = 0

        for i in range(len(rgb)):
            for j in range(len(rgb[0])):
                if self.is_blue(rgb[i,j]):
                    x_sum += i
                    y_sum += j
                    L += 1
        if L == 0:
            print("No blue pixels detected")
            np.save('/home/locobot/Desktop/Real_locobot_env/goalrelabel_locobot_fullgcsl/examples/no_pixels', rgb)
            return None, None

        # x = np.clip(x_sum / L, 30, 125)
        # y = np.clip(y_sum / L, 30, 115)

        x = x_sum / L
        y = y_sum / L

        return x, y

    # Get the position (x, y) in the real world representing the average of blue color 
    def get_blue_average_position(self, rgb) :
        x, y = self.get_blue_average_pixel(rgb)
        if x is None or y is None:
            return [x, y]

        return np.array(self.normalize_pixel(x, y))
    
    # Find the average weight of all blue pixels
    def get_green_average_pixel(self, rgb):
        rgb = np.copy(rgb).astype(np.float32)
        rgb = rgb * 184 / np.mean(rgb)

        x_sum = 0
        y_sum = 0
        L = 0

        for i in range(len(rgb)):
            for j in range(len(rgb[0])):
                if self.is_green(rgb[i,j]):
                    x_sum += i
                    y_sum += j
                    L += 1

        if L == 0:
            print("No green pixels detected")
            np.save('/home/locobot/Desktop/Real_locobot_env/goalrelabel_locobot_fullgcsl/examples/no_pixels', rgb)
            return None, None

        # x = np.clip(x_sum / L, 30, 125)
        # y = np.clip(y_sum / L, 30, 115)

        x = x_sum / L
        y = y_sum / L

        return x, y

    # Get the position (x, y) in the real world representing the average of lime green color 
    def get_green_average_position(self, rgb) :
        x, y = self.get_green_average_pixel(rgb)
        if x is None or y is None:
            return [x, y]

        return np.array(self.normalize_pixel(x, y))
    
    # Get the screen that shows all lime green pixels
    def get_screen(self, rgb) :
        screen = np.full(rgb.shape, 0)

        for i in range(len(rgb)):
            for j in range(len(rgb[0])):
                if self.is_blue(rgb[i,j]):
                    screen[i,j] = [255] * 3
        return screen
    
    def get_point(self, x, y):
        min_x = 0.1
        max_x = 0.8
        min_y = 0.1
        max_y = 0.9
        
        # print("!!!", x, y)
        
        x = np.clip(x, min_x, max_x)
        y = np.clip(y, min_y, max_y)
        
        # print("***", x, y)

        query_x = 0.35 - (0.35 - 0.14) * (x - min_x) / (max_x - min_y)
        query_y = -0.15 + (0.15 + 0.15) * (y - min_y) / (max_y - min_y)
        
        return query_x, query_y
    
    def move_clear(self, color = "blue"):
        self.move_to_point(0.14, 0.15)

        if color == "blue":
            x, y = self.get_blue_average_position(self.get_image_rgb())
        else:
            x, y = self.get_green_average_position(self.get_image_rgb())
        return x, y


    def reset(self):
        self.motor.reset()
        self.go_rest()
        self.open_gripper()



    # def reset(self):
    #     self.go_rest()
        
    #     x, y = self.get_blue_average_position(self.get_image_rgb())
    #     if x is None or y is None:
    #         x, y = self.move_clear("blue")
    #     qx, qy = self.get_point(x, y)
        
        
    #     self.move_to_point(qx, qy)
    #     self.grab_object()
        
    #     qx, qy = self.get_point(0, 0)
        
    #     self.move_to_point(qx, qy)
    #     self.leave_object()
        
    #     x, y = self.get_green_average_position(self.get_image_rgb())
    #     if x is None or y is None:
    #         x, y = self.move_clear("green")
    #     qx, qy = self.get_point(x, y)
        
        
    #     self.move_to_point(qx, qy)
    #     self.grab_object()
        
    #     qx, qy = self.get_point(0, 1)
        
    #     self.move_to_point(qx, qy)
    #     self.leave_object()

    #     self.go_rest()


        

    # Reset the box by tilting it backwards, causing all objects to go to the back
    # def reset(self):
    #     self.go_rest()
    #     self.open_gripper()
    #     # self.open_gripper()

    #     # avoid_stuck_sock = [-0.09050486981868744, -0.07363107800483704, 1.1075341701507568, -0.9004467725753784, -3.1293208599090576]
    #     # self.locobot.arm.set_joint_positions(avoid_stuck_sock, moving_time=1)

    #     # new_home_joints = [-0.09050486981868744, -0.07363107800483704, 1.1075341701507568, 0.5737088322639465, -3.1293208599090576]
    #     # self.locobot.arm.set_joint_positions(new_home_joints, moving_time=1)

    #     self.close_gripper()

    #     left_position = [-1.6526798725128174, -0.26384469866752625, 1.0047574043273926, 0.8927768468856812, -3.110913038253784]
    #     down_position = [-1.6539225339889526, 1.179631233215332, 0.5476311445236206, -0.21629129350185394, -3.113981008529663]
    #     below_box_position = [-1.2440584897994995, 1.167359471321106, 0.6105243563652039, -0.21629129350185394, -3.1185829639434814]
    #     push_box_mid_postion = [-0.8851069211959839, -0.19788353145122528, 1.1765632629394531, 0.25003886222839355, -3.1262528896331787]
    #     push_box_postion = [-0.5230874419212341, -0.13038836419582367, 0.9756118059158325, 0.23316508531570435, -3.1247189044952393]


    #     self.locobot.arm.set_joint_positions(left_position, moving_time=self.moving_time)
    #     self.locobot.arm.set_joint_positions(down_position, moving_time=self.moving_time)
    #     self.locobot.arm.set_joint_positions(below_box_position, moving_time=self.moving_time)
    #     self.locobot.arm.set_joint_positions(push_box_mid_postion, moving_time=self.moving_time)
    #     self.locobot.arm.set_joint_positions(push_box_postion, moving_time=self.moving_time)
    #     self.locobot.arm.set_joint_positions(push_box_mid_postion, moving_time=self.moving_time)
    #     self.locobot.arm.set_joint_positions(below_box_position, moving_time=self.moving_time)
    #     self.locobot.arm.set_joint_positions(down_position, moving_time=self.moving_time)
    #     self.locobot.arm.set_joint_positions(left_position, moving_time=self.moving_time)

    #     self.go_rest()


    # Normalize pixel positions to [0, 1] range
    def normalize_pixel(self, i, j):
        return i / 280, j / 340

    # Convert pixels from our image to real world x, y coordinates
    # We bound x and y to ensure we do not reach beyond the box and scale values appropriately
    def pixel_to_point(self, i, j):
        max_x = 0.33
        min_x = 0.11
        max_y = 0.12
        min_y = -0.12
        image_size = 145
        x = max_x - i * (max_x - min_x) / (image_size - 1)
        y = min_y + j * (max_y - min_y) / (image_size - 1)
        
        return x, y

    # Convert our real world x, y coordinates to polar coordinates, 
    # calculate displacement from current position and then 
    # move to the desired x, y coordinates
    # Note that we always want to rotate with smallest r, 
    # so if our r is increasing, go to our desired theta first, then increase r
    # likewise, if our r is decreasing, decrease r first then rotate to the desired theta
    def move_to_point(self, x, y):
        """
        Assume x > 0. Here we should bound it and clip x, y
        """


        x0 = self.locobot.arm.T_sb[0, 3]
        y0 = self.locobot.arm.T_sb[1, 3]
        r0 = math.sqrt(x0 ** 2 + y0 ** 2)
        theta0 = math.atan(y0 / x0)

        # print("initial: ", x0, y0, r0, theta0)

        r = math.sqrt(x ** 2 + y ** 2)
        r *= (1 + 0.2 * ((abs(y)) / 0.15) * (0.35 - x)/(0.35 - 0.14))
        theta = math.atan(y / x)

        # print("new: ", x, y, r, theta)

        dr = r - r0
        dtheta = theta - theta0

        # if (abs(y) > 0.1 and r < 0.25):
        #     dr *= 1.25

        # print("deltas: ", dr, dtheta)

        if dr < 0:
            self.locobot.arm.set_ee_cartesian_trajectory(x = dr, moving_time = self.moving_time)
            self.locobot.arm.set_single_joint_position("waist", -theta, moving_time = self.moving_time)
        else:
            self.locobot.arm.set_single_joint_position("waist", -theta, moving_time = self.moving_time)
            self.locobot.arm.set_ee_cartesian_trajectory(x = dr, moving_time = self.moving_time)

    # Once we are above an object, open the gripper, move down to the object, 
    # close the gripper to grab the object, then move back up
    def grab_object(self):
        y0 = self.locobot.arm.T_sb[1, 3]
        x0 = self.locobot.arm.T_sb[0, 3]
        eps = 0

        if x0 < 0.22:
            if abs(y0) < 0.08:
                eps = 0.02
            else:
                eps = 0.01

        d = 0.1 + eps

        self.open_gripper()
        self.locobot.arm.set_ee_cartesian_trajectory(z = -d, moving_time = self.moving_time)
        self.close_gripper()
        # self.locobot.gripper.
        self.locobot.arm.set_ee_cartesian_trajectory(z = d, moving_time = self.moving_time)

    def leave_object(self):
        d = 0.09

        self.locobot.arm.set_ee_cartesian_trajectory(z = -d, moving_time = self.moving_time)
        self.open_gripper()
        self.locobot.arm.set_ee_cartesian_trajectory(z = d, moving_time = self.moving_time)
    
    # Move the end effector up by some small amount
    def move_up(self):
        self.locobot.arm.set_ee_cartesian_trajectory(z = 0.005, moving_time = self.moving_time)

    # Move the end effector down by some small amount
    def move_down(self):
        self.locobot.arm.set_ee_cartesian_trajectory(z = -0.005, moving_time = self.moving_time)

    # Move the end effector away from the base by some small amount
    def move_backwards(self):
        self.locobot.arm.set_ee_cartesian_trajectory(x = -0.005, moving_time = self.moving_time)

    # Move the end effector towards the base by some small amount
    def move_forward(self):
        self.locobot.arm.set_ee_cartesian_trajectory(x = 0.005, moving_time = self.moving_time)

    # Rotate the arm clockwise by some small amount
    def rotate_cw(self):
        # moves 0.26 rad = 15 deg
        current_angle = self.locobot.arm.get_single_joint_command("waist")
        self.locobot.arm.set_single_joint_position("waist", current_angle -1 * math.pi/60.0, moving_time = self.moving_time)

    # Rotate the arm counterclockwise by some small amount
    def rotate_ccw(self):
        # moves 0.26 rad = 15 deg
        current_angle = self.locobot.arm.get_single_joint_command("waist")
        self.locobot.arm.set_single_joint_position("waist", current_angle +1 * math.pi/60.0, moving_time = self.moving_time)

    def open_gripper(self):
        if not self.open:
            self.open = True
            self.locobot.gripper.open()

    def close_gripper(self):
        if self.open:
            self.open = False
            self.locobot.gripper.close()
            time.sleep(1)

        # gripper_pos = self.locobot.gripper.core.joint_states.position[self.locobot.gripper.left_finger_index]
        # print(gripper_pos, self.locobot.gripper.left_finger_lower_limit)

    # def shake_arm(self) :
    #     self.move_up()
    #     self.locobot.gripper.close()
    #     self.move_forward()
    #     self.move_backwards()
    #     self.locobot.gripper.close()
    #     self.rotate_ccw()
    #     self.rotate_cw()
    #     self.locobot.gripper.close()
