from typing import Dict, Optional, SupportsFloat, Tuple

import gymnasium as gym
import numpy as np
from gymnasium.envs.mujoco.mujoco_env import MujocoEnv


class PDController(gym.Wrapper):
    """
    PD controller for MuJoCo environments, the action space is in torque space,
    it allows to control the robot in joint space.

    See https://gymnasium.farama.org/environments/mujoco
    for env descriptions.

    :param env: Gym environment
    :param n_joints: Number of joints
    :param kp: Proportional gain
    :param kd: Derivative gain (velocity gain)
    :param skip_pd: Skip PD for env that already provide a position interface.
    """

    mujoco_env: MujocoEnv

    def __init__(self, env: gym.Env, n_joints: int, kp: float = 1, kd: float = 0.1, skip_pd: bool = False):
        super().__init__(env)
        self.kp = kp
        self.kd = kd
        self.n_joints = n_joints
        self.last_qpos = np.zeros(self.n_joints)
        self.last_qvel = np.zeros(self.n_joints)
        self.skip_pd = skip_pd
        if not skip_pd:
            mujoco_env = env.unwrapped
            assert isinstance(mujoco_env, MujocoEnv)
            self.mujoco_env = mujoco_env

    @property
    def current_joint_pos(self) -> np.ndarray:
        return self.mujoco_env.data.qpos[-self.n_joints :]

    @property
    def current_joint_vel(self) -> np.ndarray:
        return self.mujoco_env.data.qvel[-self.n_joints :]

    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, dict]:
        obs, _ = self.env.reset(seed=seed, options=options)
        if not self.skip_pd:
            self.last_qpos = self.current_joint_pos.copy()
            self.last_qvel = self.current_joint_vel.copy()
        return obs, {}

    def step(self, action: np.ndarray) -> Tuple[np.ndarray, SupportsFloat, bool, bool, dict]:  # type: ignore[override]
        if not self.skip_pd:
            desired_qpos = action
            qpos_err = desired_qpos - self.last_qpos
            # desired qvel is zero
            qvel_err = -self.last_qvel
            action = self.kp * qpos_err + self.kd * qvel_err
            # Clip to correct action range
            action = np.clip(action, -1.0, 1.0)
        obs, reward, terminated, truncated, info = self.env.step(action)
        if not self.skip_pd:
            self.last_qpos = self.current_joint_pos.copy()
            self.last_qvel = self.current_joint_vel.copy()

        return obs, reward, terminated, truncated, info
