from __future__ import annotations
from typing import Dict, List, Union, Tuple

import numpy as np
import pybullet as pb

from .utils import change_frame


class SimObject:
    """Class for a simulated object in the environment.
    """

    def __init__(
        self,
        name: str,
        position: np.ndarray = np.array([0, 0, 0]),
        orientation: np.ndarray = np.array([0, 0, 0, 1]),
        scale: float = 1,
    ) -> None:
        """Initializes the SimObject class.

        Args:
            name: name of the object
            position: position of the object in the world frame
            orientation: quaternion orientation of the object in the world frame
            scale: scale of the object
        """
        self.name = name
        self._position = position
        self._orientation = orientation
        self.scale = scale

        self.parent_object = None
        self.child_objects: Dict[str, SimObject] = {}

    def add_child_object(
        self,
        child_object: SimObject,
    ) -> None:
        self.child_objects[child_object.name] = child_object
        child_object.parent_object = self

    def convert_pose_to_local_frame(
            self, world_position: np.ndarray,
            world_orientation: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Converts a pose from the world frame to the object's local frame.
        """
        local_pose = change_frame(
            (world_position, world_orientation),
            (self.get_position(), self.get_orientation()))
        return local_pose

    def convert_position_to_local_frame(
            self, world_position: np.ndarray) -> np.ndarray:
        """Converts a vector from the world frame to the object's local frame.
        """
        local_position, _ = self.convert_pose_to_local_frame(
            world_position, np.array([0, 0, 0, 1]))
        return local_position

    def convert_orientation_to_local_frame(
            self, world_orientation: np.ndarray) -> np.ndarray:
        """Converts a quaternion orientation from the world frame to the object's local frame.
        """
        _, local_orientation = self.convert_pose_to_local_frame(
            np.array([0, 0, 0]), world_orientation)
        return local_orientation

    def convert_pose_to_world_frame(
            self, local_position: np.ndarray,
            local_orientation: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Converts a pose from the object's local frame to the world frame.
        """
        pose = change_frame((local_position, local_orientation),
                            (np.array([0, 0, 0]), np.array([0, 0, 0, 1])),
                            (self.get_position(), self.get_orientation()))
        return pose

    def convert_position_to_world_frame(
            self, local_position: np.ndarray) -> np.ndarray:
        """Converts a vector from the object's local frame to the world frame.
        """
        world_position, _ = self.convert_pose_to_world_frame(
            local_position, np.array([0, 0, 0, 1]))
        return world_position

    def convert_orientation_to_world_frame(
            self, local_orientation: np.ndarray) -> np.ndarray:
        """Converts a quaternion orientation from the object's local frame to the world frame.
        """
        _, world_orientation = self.convert_pose_to_world_frame(
            np.array([0, 0, 0]), local_orientation)
        return world_orientation

    def get_pose(self) -> Tuple[np.ndarray, np.ndarray]:
        """Returns the position of the object.
        """
        return self._position, self._orientation

    def get_position(self) -> np.ndarray:
        """Returns the position of the object.
        """
        return self.get_pose()[0]

    def get_orientation(self) -> np.ndarray:
        """Returns the orientation of the object.
        """
        return self.get_pose()[1]

    def set_pose(self,
                 position: np.ndarray,
                 orientation: np.ndarray,
                 recursive: bool = True) -> None:
        """Set the position and orientation of the object.
        If recursive is True, the pose of all child objects will be updated as well.
        """
        if recursive:
            for child_object in self.child_objects.values():
                child_local_pose = self.convert_pose_to_local_frame(
                    *child_object.get_pose())
                child_object.set_pose(*change_frame(
                    child_local_pose,
                    (np.array([0, 0, 0]), np.array([0, 0, 0, 1])),
                    (position, orientation)),
                                      recursive=True)
        self._position = position
        self._orientation = orientation

    def set_position(self,
                     position: np.ndarray,
                     recursive: bool = True) -> None:
        """Set the position of the object.
        """
        self.set_pose(position, self.get_orientation(), recursive)

    def set_orientation(self,
                        orientation: np.ndarray,
                        recursive: bool = True) -> None:
        """Set the orientation of the object.
        """
        self.set_pose(self.get_position(), orientation, recursive)

    def visualize(self) -> None:
        """Visualizes the frame of the object.
        """
        axis_length = 0.1
        origin = self.convert_position_to_world_frame(np.array([0, 0, 0]))
        x_axis = self.convert_position_to_world_frame(
            np.array([axis_length, 0, 0]))
        y_axis = self.convert_position_to_world_frame(
            np.array([0, axis_length, 0]))
        z_axis = self.convert_position_to_world_frame(
            np.array([0, 0, axis_length]))
        pb.addUserDebugLine(origin, x_axis, [1, 0, 0], 2, 0)
        pb.addUserDebugLine(origin, y_axis, [0, 1, 0], 2, 0)
        pb.addUserDebugLine(origin, z_axis, [0, 0, 1], 2, 0)


class RigidObject(SimObject):
    """Class for a rigid object in the environment.
    """

    def __init__(
        self,
        name: str,
        visual_shape_id: int,
        collision_shape_id: int,
        position: np.ndarray = np.array([0, 0, 0]),
        orientation: np.ndarray = np.array([0, 0, 0, 1]),
        scale: float = 1,
        mass: float = 0,
    ) -> None:
        """Initializes the RigidObject class.

        Args:
            name: name of the object
            visual_shape_id: visual shape ID of the object
            collision_shape_id: collision shape ID of the object
            position: position of the object in the world frame
            orientation: quaternion orientation of the object in the world frame
            scale: scale of the object
            mass: mass of the object
        """
        super().__init__(name, position, orientation, scale)

        self.id = pb.createMultiBody(
            baseMass=mass,
            baseCollisionShapeIndex=collision_shape_id,
            baseVisualShapeIndex=visual_shape_id,
            basePosition=position,
            baseOrientation=orientation)

    def get_pose(self) -> Tuple[np.ndarray, np.ndarray]:
        """Returns the position of the object.
        """
        self._position, self._orientation = pb.getBasePositionAndOrientation(
            self.id)
        return super().get_pose()

    def is_collision(self, object: SimObject = None) -> bool:
        """Checks if the object is colliding with another object.
        """
        if object is None:
            return any(
                len(pb.getClosestPoints(self.id, i, 0)) > 0
                for i in range(pb.getNumBodies()) if i != self.id)
        return len(pb.getClosestPoints(self.id, object.id, 0)) > 0

    def set_pose(self,
                 position: np.ndarray,
                 orientation: np.ndarray,
                 recursive: bool = True) -> None:
        """Sets the position and orientation of the object.
        """
        pb.resetBasePositionAndOrientation(self.id, position, orientation)
        super().set_pose(position, orientation, recursive)

    def set_position(self, position: np.ndarray) -> None:
        """Sets the position of the object.
        """
        self.set_pose(position, self.get_orientation())

    def set_orientation(self, orientation: np.ndarray) -> None:
        """Sets the orientation of the object.
        """
        self.set_pose(self.get_position(), orientation)
