"""Action modes for H1."""
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING

import warnings
from abc import abstractmethod, ABC
from typing import Optional

import numpy as np
from gymnasium import spaces
from mojo import Mojo

from bigym.const import TOLERANCE_ANGULAR
from bigym.utils.physics_utils import (
    is_target_reached,
)

if TYPE_CHECKING:
    from bigym.robots import Robot


class TargetStateNotReachedWarning(Warning):
    """Warning raised when the target state is not reached within the maximum steps."""

    pass


class PelvisDof(Enum):
    """Set of floating base DOFs."""

    X = "pelvis_x"
    Y = "pelvis_y"
    Z = "pelvis_z"
    RZ = "pelvis_rz"


DEFAULT_DOFS = [PelvisDof.X, PelvisDof.Y, PelvisDof.RZ]


class ActionMode(ABC):
    """Base action mode class used for controlling H1."""

    def __init__(
        self,
        floating_base: bool = True,
        floating_dofs: Optional[list[PelvisDof]] = None,
        wrist_dof: bool = True,
    ):
        """Init.

        :param floating_base: If True, then legs are frozen, and the robot base
            controlled by positional actuators.
            If False, then user has full control of legs (i.e. for whole-body control).
        :param floating_dofs: Set of floating DOFs. By default, it is: [X, Y, RZ].
        :param wrist_dof: If set to True, an additional wrist joint and actuator
            will be added to control the orientation of the grippers.
        """
        self._floating_base = floating_base
        self._floating_dofs = DEFAULT_DOFS if floating_dofs is None else floating_dofs
        self._wrist_dof = wrist_dof

        # Will be assigned later
        self._mojo: Optional[Mojo] = None
        self._robot: Optional[Robot] = None

    def bind_robot(self, robot: Robot, mojo: Mojo):
        """Bind action mode to robot."""
        self._robot = robot
        self._mojo = mojo

    @property
    def floating_base(self) -> bool:
        """Is floating base enabled."""
        return self._floating_base

    @property
    def floating_dofs(self) -> list[PelvisDof]:
        """Set of floating DOFs."""
        return self._floating_dofs

    @property
    def wrist_dof(self) -> bool:
        """Is wrist DOF added."""
        return self._wrist_dof

    @abstractmethod
    def action_space(
        self, action_scale: float, seed: Optional[int] = None
    ) -> spaces.Box:
        """The action space for this action mode."""
        pass

    @abstractmethod
    def step(self, action: np.ndarray):
        """Apply the control command and step the physics.

        Note: This function has the responsibility of calling `mujoco.mj_step`.

        :param action: The entire action passed to the action mode.
        """
        pass


class TorqueActionMode(ActionMode):
    """Control all joints through torque control.

    Enables the user to control joints using torque values.

    Notes:
        - Grippers are controlled in positional mode.
        - Joints of the 'floating_base' are always controlled in delta position mode.
    """

    def action_space(
        self, action_scale: float, seed: Optional[int] = None
    ) -> spaces.Box:
        """See base."""
        bounds = []
        if self.floating_base:
            action_bounds = self._robot.floating_base.get_action_bounds()
            action_bounds = [np.array(b) * action_scale for b in action_bounds]
            bounds.extend(action_bounds)
        for actuator in self._robot.limb_actuators:
            action_bounds = np.array(actuator.ctrlrange)
            bounds.append(action_bounds)
        for _ in self._robot.gripper_actuators:
            action_bounds = np.array(self._robot.get_gripper_control_range())
            bounds.append(action_bounds)
        bounds = np.array(bounds).copy().astype(np.float32)
        low, high = bounds.T
        return spaces.Box(
            low=low,
            high=high,
            dtype=np.float32,
            seed=seed,
        )

    def step(self, action: np.ndarray):
        """See base."""
        if self.floating_base:
            base_action = action[: self._robot.floating_base.dof_amount]
            action = action[self._robot.floating_base.dof_amount :]
            self._robot.floating_base.set_control(base_action)
        for i, actuator in enumerate(self._robot.limb_actuators):
            self._mojo.physics.bind(actuator).ctrl = action[i]
        gripper_actions = action[-len(self._robot.gripper_actuators) :]
        for action, actuator in zip(gripper_actions, self._robot.gripper_actuators):
            self._robot.set_gripper_control(actuator, action)
        self._mojo.step()


class JointPositionActionMode(ActionMode):
    """Control all joints through joint position.

    Allows to control joint positions, supporting both absolute and delta positions.
    For absolute control, set 'absolute' to True. If the floating base is enabled,
    only delta position control is applied to it.

    Notes:
        - `block_until_reached` does not guarantee reaching the target position because
          the target position could be unreachable due to collisions.
        - Joints of the `floating_base` are always controlled in delta position mode.
    """

    MAX_STEPS = 200

    def __init__(
        self,
        absolute: bool = False,
        block_until_reached: bool = False,
        floating_base: bool = True,
        floating_dofs: list[PelvisDof] = None,
        wrist_dof: bool = True,
    ):
        """See base.

        :param absolute: Use absolute or delta joint positions.
        :param block_until_reached: Continue stepping until the target
            position is reached or the step threshold is exceeded.
        """
        super().__init__(
            floating_base=floating_base,
            floating_dofs=floating_dofs,
            wrist_dof=wrist_dof,
        )
        self.absolute = absolute
        self.block_until_reached = block_until_reached

    def action_space(
        self, action_scale: float, seed: Optional[int] = None
    ) -> spaces.Box:
        """See base."""
        bounds = []
        if self.floating_base:
            action_bounds = self._robot.floating_base.get_action_bounds()
            action_bounds = [np.array(b) * action_scale for b in action_bounds]
            bounds.extend(action_bounds)
        for actuator in self._robot.limb_actuators:
            action_bounds = np.array(
                self._robot.get_limb_control_range(actuator, self.absolute)
            )
            action_bounds *= 1 if self.absolute else action_scale
            bounds.append(action_bounds)
        for _ in self._robot.gripper_actuators:
            action_bounds = np.array(self._robot.get_gripper_control_range())
            bounds.append(action_bounds)
        bounds = np.array(bounds).copy().astype(np.float32)
        low, high = bounds.T
        return spaces.Box(
            low=low,
            high=high,
            dtype=np.float32,
            seed=seed,
        )

    def step(self, action: np.ndarray):
        """See base."""
        if self.floating_base:
            base_action = action[: self._robot.floating_base.dof_amount]
            action = action[self._robot.floating_base.dof_amount :]
            self._robot.floating_base.set_control(base_action)
        for i, actuator in enumerate(self._robot.limb_actuators):
            actuator = self._mojo.physics.bind(actuator)
            actuator.ctrl = action[i] if self.absolute else actuator.ctrl + action[i]
        gripper_actions = action[-len(self._robot.gripper_actuators) :]
        for action, actuator in zip(gripper_actions, self._robot.gripper_actuators):
            self._robot.set_gripper_control(actuator, action)
        if self.block_until_reached:
            self._step_until_reached()
        else:
            self._mojo.step()

    def _step_until_reached(self):
        """Step physics until the target position is reached."""
        steps_counter = 0
        while True:
            self._mojo.step()
            steps_counter += 1
            if self._is_target_state_reached() or steps_counter >= self.MAX_STEPS:
                if steps_counter >= self.MAX_STEPS:
                    warnings.warn(
                        f"Failed to reach target state in " f"{self.MAX_STEPS} steps!",
                        TargetStateNotReachedWarning,
                    )
                break

    def _is_target_state_reached(self):
        if self.floating_base:
            if not self._robot.floating_base.is_target_reached:
                return False
        for actuator in self._robot.limb_actuators:
            if not is_target_reached(actuator, self._mojo.physics, TOLERANCE_ANGULAR):
                return False
        return True
