import rospy

import std_msgs
import tf2_ros
import tf2_geometry_msgs as tf2gm
from sensor_msgs.msg import NavSatFix
from geometry_msgs.msg import (
    TransformStamped,
    Transform,
    PoseStamped,
    Vector3,
    Pose,
    Quaternion,
)
from nav_msgs.msg import Odometry
import pygeodesy as geodesy
from .state_estimator import StateEstimator
import numpy as np
import gtsam


def _gtsam_to_ros_orientation(orientation: gtsam.Rot3):
    return Quaternion(
        w=orientation.toQuaternion().w(),
        x=orientation.toQuaternion().x(),
        y=orientation.toQuaternion().y(),
        z=orientation.toQuaternion().z(),
    )


def _ros_to_gtsam_orientation(orientation: Quaternion):
    return gtsam.Rot3.Quaternion(orientation.w, orientation.x, orientation.y, orientation.z)


def _gtsam_to_ros_pose(pose: gtsam.Pose3):
    return Pose(
        position=Point(
            x=pose.translation()[0],
            y=pose.translation()[1],
            z=pose.translation()[2],
        ),
        rotation=_gtsam_to_ros_orientation(pose.rotation()),
    )


def _ros_to_gtsam_pose(pose: Pose):
    return gtsam.Pose3(
        _ros_to_gtsam_orientation(pose.orientation),
        np.array([
            pose.position.x,
            pose.position.y,
            pose.position.z
        ])
    )


def _gtsam_to_ros_transform(pose: gtsam.Pose3):
    return Transform(
        translation=Vector3(
            x=pose.translation()[0],
            y=pose.translation()[1],
            z=pose.translation()[2],
        ),
        rotation=_gtsam_to_ros_orientation(pose.rotation()),
    )


class RosStateEstimator:
    def __init__(self, navsatfix_topic, odom_topic):
        self.state_estimator = StateEstimator(history=50, use_3d=True)
        self.tf_buffer = tf2_ros.Buffer()
        self.tf_listener = tf2_ros.TransformListener(buffer=self.tf_buffer)
        self.tf_pub = tf2_ros.TransformBroadcaster()

        self.fixed_frame = "map"
        self.robot_frame = "base_link"

        self.last_navsatfix = None
        self.odom_pose_at_last_navsatfix: gtsam.Pose3 = None

        self.last_odom_pose: gtsam.Pose3 = None
        self.utm_datum = None
        self.last_down = None
        self.last_pose = None
        self.odom_to_world: gtsam.Pose3 = None

        self.navsatfix_sub = rospy.Subscriber(
            navsatfix_topic, NavSatFix, self.navsatfix_callback
        )
        self.odom_sub = rospy.Subscriber(odom_topic, Odometry, self.odom_callback)
        self.i = 0

        self.zero_pos = np.array([564430.725, 4192022.969])

    def navsatfix_callback(self, msg: NavSatFix):
        self.i += 1
        if self.i % 3 != 0:
            return

        if self.last_down is None:
            return

        if self.last_odom_pose is None:
            return

        # First, get the NavSatFix in UTM
        if self.utm_datum is None:
            utm: geodesy.Utm = geodesy.utm.toUtm8(msg.latitude, msg.longitude)
            self.utm_datum = utm.datum
        else:
            utm = geodesy.utm.toUtm8(msg.latitude, msg.longitude, datum=self.utm_datum)
            print(f"Got UTM: {np.array([utm.easting, utm.northing]) - self.zero_pos}")

        if self.odom_pose_at_last_navsatfix is None:
            self.odom_pose_at_last_navsatfix = self.last_odom_pose
            pose_delta_gtsam = None
        else:
            # Compute the pose differential
            pose_delta_gtsam = self.odom_pose_at_last_navsatfix.inverse() * self.last_odom_pose

        self.state_estimator.add_measurements(
            gps=gtsam.Point3(utm.easting, utm.northing, msg.altitude),
            odom=pose_delta_gtsam,
            down=gtsam.Unit3(np.array([0, 0, -1])),
        )
        values = self.state_estimator.optimize()

        last_pose = values.atPose3(values.size() - 1)
        self.tf_pub.sendTransform(
            TransformStamped(
                header=std_msgs.msg.Header(
                    frame_id=self.fixed_frame, stamp=msg.header.stamp
                ),
                child_frame_id=self.robot_frame,
                transform=_gtsam_to_ros_transform(last_pose),
            )
        )

        self.odom_pose_at_last_navsatfix = self.last_odom_pose
        self.last_navsatfix = msg
        self.odom_to_world = self.odom_pose_at_last_navsatfix.inverse() * last_pose

    def odom_callback(self, msg: Odometry):
        self.last_odom_pose = _ros_to_gtsam_pose(msg.pose.pose)
        self.last_down = gtsam.Rot3.Quaternion(
            w=msg.pose.pose.orientation.w,
            x=msg.pose.pose.orientation.x,
            y=msg.pose.pose.orientation.y,
            z=msg.pose.pose.orientation.z,
        ).inverse().matrix() @ np.array([0, 0, -1])

        if self.odom_to_world:
            transform_ros = _gtsam_to_ros_transform(self.odom_to_world * self.last_odom_pose)
            self.tf_pub.sendTransform(
                TransformStamped(
                    header=std_msgs.msg.Header(
                        frame_id=self.fixed_frame, stamp=msg.header.stamp
                    ),
                    child_frame_id=self.robot_frame,
                    transform=transform_ros,
                )
            )


def main():
    # Initialize the ROS node
    rospy.init_node("state_estimation")

    navsatfix_topic = rospy.get_param("~navsatfix_topic")
    odom_topic = rospy.get_param("~odom_topic")
    estimator = RosStateEstimator(
        navsatfix_topic=navsatfix_topic, odom_topic=odom_topic
    )
    rospy.spin()


if __name__ == "__main__":
    main()
