from typing import Callable, Any, List
import pybullet as pb

from .sim_object import *
from .region import *


class Randomizer(SimObject):
    """Class for a randomizer in the environment."""

    def __init__(
        self,
        name: str,
        position: np.ndarray = np.array([0, 0, 0]),
        orientation: np.ndarray = np.array([0, 0, 0, 1]),
        scale: float = 1,
    ) -> None:
        """Initializes the Randomizer class.
        
        Args:
            name: name of the randomizer
            position: position of the randomizer in the world frame
            orientation: quaternion orientation of the randomizer in the world frame
            scale: scale of the randomizer
        """
        super().__init__(name, position, orientation, scale)

    def randomize(
        self,
        objects: List[SimObject],
        randomize_function: Callable[[SimObject], None],
        condition_function: Callable[[SimObject], bool] = None,
        sequencial_randomization: bool = False,
        max_retry: int = 1000,
    ) -> None:
        """Randomizes the objects in the environment.
        
        Args:
            objects: list of objects to randomize
            randomize_function: function to randomize the object
            condition_function: function to check if the object satisfies a condition
            sequencial_randomization: whether to randomize the objects sequentially
            max_retry: maximum number of retries
        """
        if sequencial_randomization:
            retry = 0
            for obj in objects:
                randomize_function(obj)
                while condition_function is not None and not condition_function(
                        obj):
                    if retry == max_retry:
                        raise RuntimeError(
                            f"Failed to randomize object {obj.name} after {max_retry} retries"
                        )
                    randomize_function(obj)
                    retry += 1
        else:
            for i in range(max_retry):
                for obj in objects:
                    randomize_function(obj)
                for obj in objects:
                    if condition_function is not None and not condition_function(
                            obj):
                        return
            raise RuntimeError(
                f"Failed to randomize objects after {max_retry} retries")


class PoseRandomizer(Randomizer):
    """Class for a pose randomizer in the environment."""

    def __init__(
            self,
            name: str,
            position: np.ndarray = np.array([0, 0, 0]),
            orientation: np.ndarray = np.array([0, 0, 0, 1]),
            scale: float = 1,
            position_min: np.ndarray = np.array([0, 0, 0]),
            position_max: np.ndarray = np.array([0, 0, 0]),
            euler_angle_min: np.ndarray = np.array([0, 0, 0]),
            euler_angle_max: np.ndarray = np.array([0, 0, 0]),
    ) -> None:
        super().__init__(name, position, orientation, scale)
        self.position_min = position_min
        self.position_max = position_max
        self.euler_angle_min = euler_angle_min
        self.euler_angle_max = euler_angle_max

    def get_random_position(self) -> np.ndarray:
        """Returns a random position in the world frame."""

        local_position = self.position_min + (
            self.position_max - self.position_min) * np.random.rand(3)
        return self.convert_position_to_world_frame(local_position)

    def get_random_orientation(self) -> np.ndarray:
        """Returns a random orientation in the world frame."""
        
        local_euler_angle = self.euler_angle_min + (
            self.euler_angle_max - self.euler_angle_min) * np.random.rand(3)
        local_orientation = pb.getQuaternionFromEuler(local_euler_angle)
        return self.convert_orientation_to_world_frame(local_orientation)

    def randomize_pose(
        self,
        objects: List[SimObject],
        condition_function: Callable[[SimObject], bool] = None,
        sequencial_randomization: bool = True,
        max_retry: int = 1000,
    ) -> None:
        """Randomizes the pose of the objects in the environment."""

        def pose_randomize_function(obj: SimObject) -> None:
            obj.set_pose(self.get_random_position(),
                         self.get_random_orientation())

        self.randomize(objects, pose_randomize_function, condition_function,
                       sequencial_randomization, max_retry)

    def randomize_position(
        self,
        objects: List[SimObject],
        condition_function: Callable[[SimObject], bool] = None,
        sequencial_randomization: bool = True,
        max_retry: int = 1000,
    ) -> None:
        """Randomizes the position of the objects in the environment."""

        def position_randomize_function(obj: SimObject) -> None:
            obj.set_position(self.get_random_position())

        self.randomize(objects, position_randomize_function,
                       condition_function, sequencial_randomization, max_retry)

    def randomize_orientation(
        self,
        objects: List[SimObject],
        condition_function: Callable[[SimObject], bool] = None,
        sequencial_randomization: bool = True,
        max_retry: int = 1000,
    ) -> None:
        """Randomizes the orientation of the objects in the environment."""

        def orientation_randomize_function(obj: SimObject) -> None:
            obj.set_orientation(self.get_random_orientation())

        self.randomize(objects, orientation_randomize_function,
                       condition_function, sequencial_randomization, max_retry)

    def visualize(
            self,
            rgb_color: np.ndarray = np.array([1, 1, 0]),
    ) -> None:
        """Draws the box region in the PyBullet client.
        """
        super().visualize()

        corner_points = [
            self.position_min,
            np.array([
                self.position_min[0], self.position_min[1],
                self.position_max[2]
            ]),
            np.array([
                self.position_min[0], self.position_max[1],
                self.position_min[2]
            ]),
            np.array([
                self.position_min[0], self.position_max[1],
                self.position_max[2]
            ]),
            np.array([
                self.position_max[0], self.position_min[1],
                self.position_min[2]
            ]),
            np.array([
                self.position_max[0], self.position_min[1],
                self.position_max[2]
            ]),
            np.array([
                self.position_max[0], self.position_max[1],
                self.position_min[2]
            ]),
            self.position_max,
        ]

        world_corner_points = [
            self.convert_position_to_world_frame(point)
            for point in corner_points
        ]

        for start, end in [(0, 1), (0, 2), (0, 4), (1, 3), (1, 5), (2, 3),
                           (2, 6), (3, 7), (4, 5), (4, 6), (5, 7), (6, 7)]:
            pb.addUserDebugLine(world_corner_points[start],
                                world_corner_points[end],
                                lineColorRGB=rgb_color,
                                lifeTime=0,
                                lineWidth=2)
