import numpy as np
from struct import dataclass
from rlnav.robot.tasks.task import BaseNavTask, TaskConfig


def pose_distance(
    position: np.ndarray,
    orientation: np.ndarray,
    goal_position: np.ndarray,
    goal_orientation: np.ndarray,
    orientation_weight: float = 1.0,
):
    # Compute orientation distance
    q1 = orientation / np.linalg.norm(orientation, axis=-1, keepdims=True)
    q2 = goal_orientation / np.linalg.norm(goal_orientation, axis=-1, keepdims=True)
    d_quat = 2 * np.arccos(np.abs(np.sum(q1 * q2, axis=-1)))

    # Compute position distance
    d_pos = np.linalg.norm(position - goal_position, axis=-1)

    return d_pos + orientation_weight * d_quat


@dataclass
class CircuitTaskConfig(TaskConfig):
    reach_threshold: float = 0.2
    goal_topk: int = 25


class CircuitTask(BaseNavTask):
    def __init__(self, config: CircuitTaskConfig, goal_file: str):
        super().__init__(config)
        self.goal_data = np.load(goal_file)

    def sample_goal(self, observation: dict):
        """
        Sample a new goal configuration based on the current observation.
        """

        # Find the distance to each point in the dataset, and sample randomly from the top K
        topk = self.config.goal_topk
        goal_positions = self.goal_data["data/position"]
        goal_orientations = self.goal_data["data/quaternion"]

        distances = pose_distance(
            position, orientation, goal_positions, goal_orientations
        )

        # Select the top K closest points
        best_idcs = np.argpartition(distances, topk)[:topk]

        # Compute a probability distribution over the top K points
        probs = np.exp(-distances[best_idcs])
        probs /= np.sum(probs)

        # Select a base index from the top K points
        assert best_idcs.shape == (topk,), f"best_idcs.shape is {best_idcs.shape}"
        base_idx = int(np.random.choice(best_idcs, p=probs))

        # The goal index is the base index plus an exponentially-distributed offset
        goal_idx = (base_idx + int(np.random.exponential() * 20)) % len(
            self.goal_data["data/position"]
        )
        assert isinstance(goal_idx, int), f"goal_idx is {goal_idx} ({type(goal_idx)})"

        return {
            "base_idx": base_idx,
            "goal_idx": goal_idx,
            "position": self.goal_data["data/position"][goal_idx],
            "orientation": self.goal_data["data/quaternion"][goal_idx],
            "image": self.goal_data["data/image"][goal_idx],
        }

    def compute_step(
        self,
        observation: dict,
        action: np.ndarray,
        next_observation: dict,
        goal: dict,
        num_steps: int,
    ) -> Tuple[float, bool, dict]:
        """
        Compute information from a step with the task.
        """

        # Compute distance to goal
        position = next_observation["position"]
        orientation = next_observation["orientation"]
        goal_position = goal["position"]
        goal_orientation = goal["orientation"]
        distance_to_goal = pose_distance(
            position, orientation, goal_position, goal_orientation
        )

        # Compute reward, termination, etc. from the base task
        reward, truncated, terminated, info = self.compute_step(
            observation, action, next_observation, goal, num_steps
        )

        reached = distance_to_goal < self.config.reach_threshold

        # 0-1 reward based on reaching the goal
        reward = 0.0 if reached else -1.0
        truncated = truncated or False
        terminated = terminated or reached

        info["reached_goal"] = reached

        return (
            reward,
            truncated,
            terminated,
            info,
        )
    
    def reset(self, observation: dict) -> Optional[dict]:
        """
        Reset the task to a new episode.
        """

        return None

    def data_format(self, obs_format: dict):
        goal, info = super().data_format(obs_format)

        return {
            **goal,
            "base_idx": tf.TensorSpec((), tf.int32, name="base_idx"),
            "goal_idx": tf.TensorSpec((), tf.int32, name="goal_idx"),
            "position": tf.TensorSpec((3,), tf.float64, name="position"),
            "orientation": tf.TensorSpec((4,), tf.float64, name="orientation"),
            "image": obs_format["image"],
        }, {
            **info,
            "reached_goal": tf.TensorSpec((), tf.bool, name="reached_goal"),
        }


if __name__ == "__main__":
    import os
    import matplotlib.pyplot as plt

    task = TrainingTask(
        os.path.join(os.path.dirname(__file__), "../../../data/goal_loop.pkl.npz")
    )

    robot_position = np.array([-4.0, 4.0, 0.0])
    robot_orientation = np.array([0.0, 0.0, 0.0, 1.0])

    # Get yaw from quat
    def _yaw(quat):
        return np.arctan2(
            2.0 * (quat[3] * quat[2] + quat[0] * quat[1]),
            1.0 - 2.0 * (quat[1] ** 2 + quat[2] ** 2),
        )

    robot_yaw = _yaw(robot_orientation)

    for _ in range(10):
        task.select_goal_idx(robot_position, robot_orientation)

        goal = task.get_goal()
        goal_position = goal["position"]
        goal_orientation = goal["orientation"]
        goal_yaw = _yaw(goal_orientation)

        fig, axs = plt.subplot_mosaic(
            [
                ["A", "A"],
                ["B", "C"],
            ]
        )
        axs["A"].axis("equal")
        axs["A"].plot(
            task.goal_data["data/position"][:, 0],
            task.goal_data["data/position"][:, 1],
            ".",
            label="path",
        )
        axs["A"].scatter(
            task.goal_data["data/position"][0, 0],
            task.goal_data["data/position"][0, 1],
            marker="o",
            c="yellow",
            s=100,
            label="begin",
        )
        axs["A"].scatter(
            task.goal_data["data/position"][-1, 0],
            task.goal_data["data/position"][-1, 1],
            marker="o",
            c="pink",
            s=100,
            label="end",
        )
        axs["A"].scatter(
            robot_position[0],
            robot_position[1],
            marker="o",
            c="r",
            s=100,
            zorder=10,
            label="robot",
        )
        axs["A"].quiver(
            robot_position[0],
            robot_position[1],
            np.cos(robot_yaw),
            np.sin(robot_yaw),
            color="r",
        )
        axs["A"].scatter(
            task.goal_data["data/position"][goal["base_idx"], 0],
            task.goal_data["data/position"][goal["base_idx"], 1],
            marker="o",
            c="g",
            s=100,
            zorder=0,
            label="goal base",
        )
        axs["A"].scatter(
            goal_position[0],
            goal_position[1],
            marker="x",
            c="g",
            s=100,
            zorder=10,
            label="goal",
        )
        axs["A"].quiver(
            goal_position[0],
            goal_position[1],
            np.cos(goal_yaw),
            np.sin(goal_yaw),
            color="g",
        )
        axs["A"].legend()
        axs["B"].imshow(task.goal_data["data/image"][goal["base_idx"]])
        axs["C"].imshow(goal["image"])
        plt.show()
