from dataclasses import dataclass
from io import BytesIO
from typing import Optional
import numpy as np
from agentlace.action import ActionServer
from multinav.deploy.common.trainer_bridge_common import make_action_config
import time
from threading import Thread
import os

# rospy
import rclpy
from rclpy.action import ActionClient as RosActionClient
import sensor_msgs.msg as sm
import geometry_msgs.msg as gm
import nav_msgs.msg as nm
import gazebo_msgs.srv as gzs
from rclpy.node import Node
from cv_bridge import CvBridge
from rclpy.qos import QoSProfile, QoSHistoryPolicy, QoSReliabilityPolicy, QoSDurabilityPolicy

import cv2
from PIL import Image

X_TRACK = 5 # if still for 5 time steps, then let's reset! 

@dataclass
class RobotConfig:
    image_topic: str
    imu_topic: str
    pose_rel_topic: str
    pose_abs_topic: str
    action_topic: str
    teleop_action_topic: str

    is_simulation: bool = False
    model_name: Optional[str] = None

class RobotInterface(Node):
    def __init__(self, config_yaml: str, img_format = "np"):
        super().__init__("robot_interface")
        action_server_config = make_action_config()

        self.image_bridge = CvBridge()
        self.image_format = img_format
        self.num_steps = 0
        self.recent_x_avg = [1] * X_TRACK # oldest, most recent, 
        self.reset_needed = False
        self.reset_handled = True
        self.image_size = (64, 64)
        self._latest_obs = {
            "image": np.zeros((*self.image_size, 3), dtype=float),
            "imu": np.zeros((6,), dtype=float),
            "position_abs": np.zeros((3,), dtype=float),
            "position_rel": np.zeros((3,), dtype=float),
            "orientation_abs": np.zeros((4,), dtype=float),
            "orientation_rel": np.zeros((4,), dtype=float),
            "linear_velocity": np.zeros((3,), dtype=float),
            "angular_velocity": np.zeros((3,), dtype=float),
            "crash": np.zeros((), dtype=bool),
            "stuck": np.zeros((), dtype=bool),
            "latest_action": np.zeros((2,), dtype = float), # just latest action sent / received 
        }

        # Agentlace setup
        self.action_server = ActionServer(
            config=action_server_config,
            obs_callback=self.latest_obs,
            act_callback=self.act,
        )

        with open(config_yaml, "r") as f:
            import yaml

            config_yaml = yaml.safe_load(f)
        self.config = RobotConfig(**config_yaml)

        # Ros setup: SHARED (all robots need these minimums: image, imu, odom, act)
        self.qos_profile = QoSProfile(
            reliability=QoSReliabilityPolicy.BEST_EFFORT,
            history=QoSHistoryPolicy.KEEP_LAST,
            durability=QoSDurabilityPolicy.VOLATILE,
            depth=10,
        )

        self.image_sub = self.create_subscription(
            sm.Image,
            self.config.image_topic,
            self._image_callback,
            10,
        )
        self.imu_sub = self.create_subscription(
            sm.Imu,
            self.config.imu_topic,
            self._imu_callback,
            # 10,
            self.qos_profile,
        )
        
        self.act_pub = self.create_publisher(
            gm.Twist,
            self.config.action_topic,
            10,
        )
        self.act_sub = self.create_subscription(
            gm.Twist,
            self.config.action_topic,
            self._act_callback,
            10,
        )

        # Ros setup: Simulation Specific
        if self.config.is_simulation:
            self.gazebo_reset = self.create_client(
                gzs.SetEntityState, "/set_entity_state"
            )
            while not self.gazebo_reset.wait_for_service(timeout_sec = 1.0):
                print("set entity state not yet available, waiting again...")
        
            self.odom_sub = self.create_subscription(
                nm.Odometry,
                self.config.pose_topic,
                self._pose_callback_odom,
                # 10,
                self.qos_profile,
            )
        # Ros setup: iRobot Create Specific
        if self.config.model_name == "create":
            self._setup_create_vars()
            
        # Start
        self.action_server.start(threaded=True)


    # Setup Helper Functions (vary by robot type)
    def _setup_create_vars(self):
        from sensor_msgs.msg import BatteryState, Joy, Imu
        from std_msgs.msg import String, Bool
        from action_msgs.msg import GoalStatus
        from irobot_create_msgs.msg import HazardDetectionVector, HazardDetection, Dock
        from irobot_create_msgs.action import Undock, DockServo
        from nav2_msgs.action import NavigateToPose

        self.HazardDetection = HazardDetection
        self.Undock = Undock
        self.DockServo = DockServo

        # odom callback has a different type now
        self.rel_pos_sub = self.create_subscription(
            gm.PoseWithCovarianceStamped,
            self.config.pose_abs_topic,
            self._pose_callback_pose,
            # 10,
            self.qos_profile,
        )

        self.odom_sub = self.create_subscription(
            nm.Odometry,
            self.config.pose_rel_topic,
            self._pose_callback_odom,
            # 10,
            self.qos_profile,
        )

        # Subscribe to battery state
        self.battery_msg = BatteryState()
        self.battery_sub = self.create_subscription(
            BatteryState,
            "/battery_state", 
            self.battery_callback, 
            self.qos_profile)
        # Subscribe to hazard detections
        self.hazard_msg = HazardDetectionVector()
        self.hazard_sub = self.create_subscription(
            HazardDetectionVector, 
            "/hazard_detection", 
            self.hazard_callback, 
            self.qos_profile)
        # Subscribe to dock msgs
        self.dock_msg = Dock()
        self.dock_sub = self.create_subscription(
            Dock, 
            "/dock",
            self.dock_callback,
            self.qos_profile)
        
        ## PUBLISHERS 
        self.dock_lock_msg = Bool()
        self.dock_lock_msg.data = True
        self.dock_lock_pub = self.create_publisher(
            Bool, 
            "/dock_lock", 
            1
        )

        ## ACTION CLIENTS
        self.undock_action_client = RosActionClient(self, Undock, 'undock')
        self.dock_action_client = RosActionClient(self, DockServo, 'dock')
        self.navigate_to_pose_client = RosActionClient(self, NavigateToPose, 'navigate_to_pose')

    def latest_obs(self, keys):
        # self._check_for_crash()
        return {key: self._latest_obs[key] for key in keys}

    # Key robot action server capabilities 
    def act(self, key, payload):
        if key == "action_vw":
            assert isinstance(payload, np.ndarray) and payload.shape == (
                2,
            ), f"Invalid action_vw: {payload}"
            twist = gm.Twist()
            twist.linear.x = float(payload[0])
            twist.angular.z = float(payload[1])
            self.act_pub.publish(twist)
            self.num_steps += 1
        
        elif key == "reset":
            print("Resetting")

            if self.config.is_simulation: 
                
                self.num_steps = 0
                
                # slow down bud 
                twist = gm.Twist()
                twist.linear.x = 0.0
                twist.angular.z = 0.0
                self.act_pub.publish(twist)

                self.reset_needed = True # I was using reset_needed as a check slightly differently before for sim 
                return self.gazebo_reset_environment(
                    payload["position"], payload["orientation"]
                )
            
            elif self.config.model_name == "create": # try to back up a couple steps! 
                self.act("action_vw", np.array([-0.1, 0])) # move directly backwards 
                self.act("action_vw", np.array([-0.1, 0])) # move directly backwards 
                
                time.sleep(0.5) # WAIT FOR IT TO BE DONE
                self._latest_obs["crash"].fill(False) # recovered now
                self.reset_needed = False
         
            

        elif key == "move_marker": # move marker in gazebo simulation
            assert self.config.is_simulation, "Goal marker is only supported in simulation"
            if "type" in payload.keys():
                cmd = f"gz marker -m 'action: ADD_MODIFY, type: {payload['type']}, id: 9," 
            else:
                cmd = f"gz marker -m 'action: ADD_MODIFY, type: SPHERE, id: 9," 
            cmd += f"scale: {{x: 0.2, y: 0.2, z: 0.2}}," 
            if "orientation" in payload.keys():
                cmd += f"pose: {{position: {{x: {payload['position'][0]:.2f}, y: {payload['position'][1]:.2f}, z: {payload['position'][2]:.2f}}}"
                cmd += f", orientation: {{x: {payload['orientation'][0]:.2f}, y: {payload['orientation'][1]:.2f},  z: {payload['orientation'][2]:.2f}, w: {payload['orientation'][3]:.2f}  }} }}'"
            else:
                cmd += f"pose: {{position: {{x: {payload['position'][0]:.2f}, y: {payload['position'][1]:.2f}, z: {payload['position'][2]:.2f}}} }}'"
            os.system(cmd)

        elif key == "dock":
            assert self.config.model_name == "create", "Docking only supported for iRobot Create"
            # os.system(f"ros2 action send_goal /dock irobot_create_msgs/action/DockServo \" {{}}\"")
            goal_msg = self.DockServo.Goal()
            self.dock_action_client.wait_for_server()
            return self.dock_action_client.send_goal_async(goal_msg)

        elif key == "undock":
            assert self.config.model_name == "create", "Docking only supported for irobot"
            # os.system(f"ros2 action send_goal /undock irobot_create_msgs/action/Undock \"{{}}\"")
            goal_msg = self.Undock.Goal()
            self.undock_action_client.wait_for_server()
            return self.undock_action_client.send_goal_async(goal_msg)

        else:
            raise NotImplementedError()

    def gazebo_reset_environment(self, position, quaternion):
        assert self.config.is_simulation, "Reset is only supported in simulation"
        srv = gzs.SetEntityState.Request()
        srv._state.name = self.config.model_name
        srv._state.pose.position.x = float(position[0])
        srv._state.pose.position.y = float(position[1])
        srv._state.pose.position.z = float(position[2]) 
        srv._state.pose.orientation.x = float(quaternion[0])
        srv._state.pose.orientation.y = float(quaternion[1])
        srv._state.pose.orientation.z = float(quaternion[2])
        srv._state.pose.orientation.w = float(quaternion[3])
        srv._state.reference_frame = "world"
        
        result = self.gazebo_reset.call(srv)
        self._latest_obs["crash"].fill(False)
        self._latest_obs["stuck"].fill(False)
        print("reset")
        # time.sleep(2)
        self.reset_needed = False
        return {"success": result.success}
    
    # ROS CALLBACK FUNCTIONS 
    def _image_callback(self, msg: sm.Image):
        cv_image = self.image_bridge.imgmsg_to_cv2(msg)
        cv_image = cv2.resize(cv_image, self.image_size)
        cv_image = np.asarray(cv_image)
        self._latest_obs["image"] = cv_image

    def _imu_callback(self, msg: sm.Imu):
        # Load the imu
        self._latest_obs["imu"] = np.array(
            [
                msg.linear_acceleration.x,
                msg.linear_acceleration.y,
                msg.linear_acceleration.z,
                msg.angular_velocity.x,
                msg.angular_velocity.y,
                msg.angular_velocity.z,
            ],
            dtype=float,
        )

        if self.config.model_name == "create":
            imu_x = msg.linear_acceleration.x
            imu_y = msg.linear_acceleration.y
            imu_z = msg.linear_acceleration.z

            if abs(imu_x) > 4.0 or abs(imu_y) > 4.0 or abs(imu_z) > 4.0 and not self.reset_needed:
                print("Robot has been bumped: ", imu_x, imu_y, imu_z)
                self._latest_obs["crash"].fill(True)
                self.reset_needed = True

        elif self.config.model_name == "waffle_pi":
            if ( # not moving forward
                self.num_steps > 1
                and np.sqrt(msg.linear_acceleration.y**2 + msg.linear_acceleration.z**2) > 20 # fast crash
                and not self.reset_needed
                and not np.any(self._latest_obs["crash"]) # already marked as a crash, take a chill pill. don't bother. 
                ):
                print("looking stuck")
                self._latest_obs["crash"].fill(True)
        
    def _pose_callback_odom(self, msg: nm.Odometry):
        # Load the pose
        self._latest_obs["position_rel"] = np.array(
            [
                msg.pose.pose.position.x,
                msg.pose.pose.position.y,
                msg.pose.pose.position.z,
            ],
            dtype=np.float64,
        )
        self._latest_obs["orientation_rel"] = np.array(
            [
                msg.pose.pose.orientation.x,
                msg.pose.pose.orientation.y,
                msg.pose.pose.orientation.z,
                msg.pose.pose.orientation.w,
            ],
            dtype=np.float64,
        )
        self._latest_obs["linear_velocity"] = np.array(
            [
                msg.twist.twist.linear.x,
                msg.twist.twist.linear.y,
                msg.twist.twist.linear.z,
            ],
            dtype=np.float64,
        )
        self._latest_obs["angular_velocity"] = np.array(
            [
                msg.twist.twist.angular.x,
                msg.twist.twist.angular.y,
                msg.twist.twist.angular.z,
            ],
            dtype=np.float64,
        )

        self.recent_x_avg = self.recent_x_avg[1:]+[msg.twist.twist.linear.x]
        if (sum(self.recent_x_avg) / X_TRACK) < 0.000000001 and self.num_steps > 1: # stuck in place
            self._latest_obs["stuck"].fill(True)
    
    def _pose_callback_pose(self, msg: gm.PoseWithCovarianceStamped):
        # Load the pose... i think it should be the SAME as odom, just no twist.
        self._latest_obs["position_abs"] = np.array(
            [
                msg.pose.pose.position.x,
                msg.pose.pose.position.y,
                msg.pose.pose.position.z,
            ],
            dtype=np.float64,
        )
        self._latest_obs["orientation_abs"] = np.array(
            [
                msg.pose.pose.orientation.x,
                msg.pose.pose.orientation.y,
                msg.pose.pose.orientation.z,
                msg.pose.pose.orientation.w,
            ],
            dtype=np.float64,
        )

    def _act_callback(self, msg):
        self._latest_obs["latest_action"] = np.array([msg.linear.x, msg.angular.z])

    def battery_callback(self, battery_msg): 
        self.battery_msg = battery_msg
        print("Battery: ", round(self.battery_msg.percentage*100, 2))

        if self.battery_msg.percentage < 0.15: 
            self._latest_obs["stuck"].fill(True)
            
    def hazard_callback(self, hazard_msg):
        self.hazard_msg = hazard_msg
        for detection in self.hazard_msg.detections: 
            if detection.type in [self.HazardDetection.BUMP, self.HazardDetection.CLIFF, self.HazardDetection.STALL]:
                if not self.reset_needed: # we already KNOW man
                    print("Hazard bumped")
                    self._latest_obs["crash"].fill(True)
                    self.reset_needed = True

    def dock_callback(self, dock_msg):

        self.dock_msg = dock_msg
    
    def dock_lock_callback(self):
        self.dock_lock_pub.publish(self.dock_lock_msg)

if __name__ == "__main__":
    import os
    rclpy.init(args=None)

    # robot_interface = RobotInterface(
    #     os.path.join(os.path.dirname(__file__), "sim/locobot.yaml")
    # )

    robot_interface = RobotInterface(
        os.path.join(os.path.dirname(__file__), "real/create.yaml")
    )

    spin_thread = Thread(target=rclpy.spin, args=(robot_interface,))
    spin_thread.start()