import gtsam
import numpy as np
import matplotlib.pyplot as plt
import time


GPS_NOISE_AMOUNT = 2.5
ODOM_TRANSLATION_NOISE_AMOUNT = 0.03
ODOM_ROTATION_NOISE_AMOUNT = 0.01
DOWN_NOISE_AMOUNT = 0.1


def make_graph(odoms, gps):
    graph = gtsam.NonlinearFactorGraph()

    for i, gps_meas in enumerate(gps):
        graph.add(
            gtsam.PriorFactorPose2(
                i,
                gtsam.Pose2(0, gps_meas),
                gtsam.noiseModel.Diagonal.Sigmas([2.0, 2.0, np.inf]),
            )
        )

    for i, odom in enumerate(odoms):
        graph.add(
            gtsam.BetweenFactorPose2(
                i, i + 1, odom, gtsam.noiseModel.Diagonal.Sigmas([0.1, 0.1, 0.01])
            )
        )

    return graph


def make_graph_3d(odoms, gps, down):
    graph = gtsam.NonlinearFactorGraph()

    # GPS factors constrain 3D position
    for i, gps_meas in enumerate(gps):
        graph.add(
            gtsam.PoseTranslationPrior3D(
                i,
                gtsam.Pose3(gtsam.Rot3(), gps_meas),
                gtsam.noiseModel.Diagonal.Sigmas([GPS_NOISE_AMOUNT] * 3),
            )
        )

    # Odometry factors constrain relative 3D pose
    for i, odom in enumerate(odoms):
        graph.add(
            gtsam.BetweenFactorPose3(
                i,
                i + 1,
                odom,
                gtsam.noiseModel.Diagonal.Sigmas(
                    [ODOM_ROTATION_NOISE_AMOUNT] * 3
                    + [ODOM_TRANSLATION_NOISE_AMOUNT] * 3
                ),
            )
        )

    # Down estimator factors constrain 3D orientation
    for i, down_meas in enumerate(down):
        graph.add(
            gtsam.Pose3AttitudeFactor(
                i,
                down_meas,
                gtsam.noiseModel.Diagonal.Sigmas([DOWN_NOISE_AMOUNT] * 2),
                gtsam.Unit3(np.array([0, 0, -1])),
            )
        )

    return graph


def make_initial_values(initial_pose):
    initial_values = gtsam.Values()
    initial_values.insert(0, initial_pose)
    return initial_values


def roll_values(values, history, use_3d):
    def get_pose(i):
        if use_3d:
            return values.atPose3(i)
        else:
            return values.atPose2(i)

    if values.size() <= history:
        values.insert(values.size(), get_pose(values.size() - 1))
        return values
    else:
        new_values = gtsam.Values()
        for i in range(values.size() - 1):
            new_values.insert(i, get_pose(i + 1))
        new_values.insert(values.size() - 1, get_pose(values.size() - 1))
        return new_values


class StateEstimator:
    def __init__(self, history, use_3d):
        self.gps_readings = []
        self.use_3d = use_3d

        if self.use_3d:
            self.down_readings = []

        self.values = gtsam.Values()
        self.odom_readings = []
        self.history = history

    def optimize(self):
        if self.use_3d:
            graph = make_graph_3d(
                self.odom_readings, self.gps_readings, self.down_readings
            )
        else:
            graph = make_graph(self.odom_readings, self.gps_readings)
        self.values = gtsam.LevenbergMarquardtOptimizer(graph, self.values).optimize()
        return self.values

    def add_measurements(self, gps, odom, down=None):
        assert self.use_3d == (down is not None), "Down estimate should only be provided in 3D"

        if odom:
            self.odom_readings.append(odom)
            assert len(self.gps_readings) != 0, "Odom should be unspecified for the first GPS reading"
        else:
            assert len(self.gps_readings) == 0, "Odom should only be unspecified for the first GPS reading"

        self.gps_readings.append(gps)

        if self.use_3d:
            self.down_readings.append(down)

        if len(self.gps_readings) > self.history + 1:
            self.gps_readings.pop(0)
            self.odom_readings.pop(0)
            if self.use_3d:
                self.down_readings.pop(0)

        if self.values.size() == 0:
            if self.use_3d:
                self.values.insert(0, gtsam.Pose3(gtsam.Rot3(), gps))
            else:
                self.values.insert(0, gtsam.Pose2(gtsam.Rot2(), gps))
        else:
            self.values = roll_values(self.values, self.history, self.use_3d)


def test_state_estimator():
    ts = np.linspace(0, 25, 251)

    true_poses = [gtsam.Pose2()] * 200 + [
        gtsam.Pose2(t * 5, 3 * np.cos(2 * t), np.arctan2(-6 * np.sin(2 * t), 5))
        for t in ts
    ]

    gps = [
        gtsam.Point2(*(pose.translation() + np.random.normal(size=2) * 2.0))
        for pose in true_poses
    ]

    odoms = [
        gtsam.Pose2(
            p1.rotation().theta() - p0.rotation().theta() + np.random.normal() * 0.01,
            p0.transformTo(p1.translation()) + np.random.normal(size=2) * 0.1,
        )
        for p0, p1 in zip(true_poses[:-1], true_poses[1:])
    ]

    state_estimator = StateEstimator(
        history=50, use_3d=False
    )

    xs = []
    ys = []
    thetas = []

    state_estimator.add_measurements(gps=gps[0], odom=None)

    for _gps, _odom in zip(gps[1:], odoms):
        t0 = time.time()
        state_estimator.add_measurements(_gps, _odom)
        estimate = state_estimator.optimize()
        current_pose = estimate.atPose2(estimate.size() - 1)
        xs.append(current_pose.translation()[0])
        ys.append(current_pose.translation()[1])
        thetas.append(current_pose.rotation().theta())
        print("Time taken: ", time.time() - t0)

    # t0 = time.time()
    # result = gtsam.LevenbergMarquardtOptimizer(self.graph, roll_values(result)).optimize()
    # print("Time taken: ", time.time() - t0)

    # plt.plot([pt[0] for pt in gps], [pt[1] for pt in gps])
    plt.plot(
        [pt.translation()[0] for pt in true_poses],
        [pt.translation()[1] for pt in true_poses],
    )
    plt.plot(xs, ys)
    plt.plot(
        [xs[::10], [xs[i] + np.cos(thetas[i]) for i in range(0, len(xs), 10)]],
        [ys[::10], [ys[i] + np.sin(thetas[i]) for i in range(0, len(xs), 10)]],
        color="red",
    )
    plt.axis("equal")
    plt.show()


def test_state_estimator_3d():
    ts = np.linspace(0, 25, 251)

    true_poses = [gtsam.Pose3()] * 200 + [
        gtsam.Pose3(
            gtsam.Rot3.Yaw(np.arctan2(-6 * np.sin(2 * t), 5)),
            np.array([t * 5, 3 * np.cos(2 * t), 0]),
        )
        for t in ts
    ]

    gps = [
        gtsam.Point3(
            *(pose.translation() + np.random.normal(size=3, scale=GPS_NOISE_AMOUNT))
        )
        for pose in true_poses
    ]

    odoms = [
        gtsam.Pose3(
            gtsam.Rot3.Expmap(
                np.random.normal(size=3, scale=ODOM_ROTATION_NOISE_AMOUNT)
            )
            * p0.rotation().inverse()
            * p1.rotation(),
            p0.transformTo(p1.translation())
            + np.random.normal(size=3, scale=ODOM_TRANSLATION_NOISE_AMOUNT) * 0.1,
        )
        for p0, p1 in zip(true_poses[:-1], true_poses[1:])
    ]

    downs = [
        gtsam.Unit3(
            (
                gtsam.Rot3.Expmap(np.random.normal(size=3, scale=DOWN_NOISE_AMOUNT))
                * p.rotation()
            )
            .matrix()
            .T
            @ np.array([0, 0, -1])
        )
        for p in true_poses
    ]

    state_estimator = StateEstimator(
        history=50, use_3d=True
    )

    xs = []
    ys = []
    zs = []
    rxs = []
    rys = []
    rzs = []

    state_estimator.add_measurements(gps=gps[0], odom=None, down=downs[0])

    for _gps, _down, _odom in zip(
        gps[1:], downs[1:], odoms
    ):
        t0 = time.time()
        state_estimator.add_measurements(_gps, _odom, _down)
        estimate = state_estimator.optimize()
        current_pose = estimate.atPose3(estimate.size() - 1)
        xs.append(current_pose.translation()[0])
        ys.append(current_pose.translation()[1])
        zs.append(current_pose.translation()[2])
        rxs.append(
            current_pose.translation() + current_pose.rotation().matrix()[:, 0] * 3
        )
        rys.append(
            current_pose.translation() + current_pose.rotation().matrix()[:, 1] * 3
        )
        rzs.append(
            current_pose.translation() + current_pose.rotation().matrix()[:, 2] * 3
        )
        print("Time taken: ", time.time() - t0)

    # t0 = time.time()
    # result = gtsam.LevenbergMarquardtOptimizer(self.graph, roll_values(result)).optimize()
    # print("Time taken: ", time.time() - t0)

    # plt.plot([pt[0] for pt in gps], [pt[1] for pt in gps])
    ax = plt.axes(projection="3d")
    ax.plot3D(
        [pt.translation()[0] for pt in true_poses],
        [pt.translation()[1] for pt in true_poses],
        [pt.translation()[2] for pt in true_poses],
    )
    ax.plot3D(xs, ys, zs)
    ax.axis("equal")
    for vec, color in ((rxs, "red"), (rys, "green"), (rzs, "blue")):
        xlines = np.array([xs[::10], [vec[i][0] for i in range(0, len(xs), 10)]])
        ylines = np.array([ys[::10], [vec[i][1] for i in range(0, len(xs), 10)]])
        zlines = np.array([zs[::10], [vec[i][2] for i in range(0, len(xs), 10)]])
        for line in zip(xlines.T, ylines.T, zlines.T):
            ax.plot3D(*line, color=color)
    plt.show()


if __name__ == "__main__":
    test_state_estimator()
    test_state_estimator_3d()

    # print(state_estimator.graph)
    # print(state_estimator.initial_values)
