# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Simple sensors related to the robot."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import typing
import scipy.stats as stats

from robots import minitaur_pose_utils
from build_envs.sensors import sensor

_ARRAY = typing.Iterable[float]
_FLOAT_OR_ARRAY = typing.Union[float, _ARRAY]
_DATATYPE_LIST = typing.Iterable[typing.Any]

# Desired direction
class DesiredDirection(sensor.BoxSpaceSensor):
    """A sensor that desired direction (cos(Y), sin(Y)) of a robot."""

    def __init__(self,
                 lower_bound: _FLOAT_OR_ARRAY = -1.0,
                 upper_bound: _FLOAT_OR_ARRAY = 1.0,
                 name: typing.Text = "Desired_direction",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs Desired_direction.

    Args:
      lower_bound: the lower bound of the yaw of the robot.
      upper_bound: the upper bound of the yaw of the robot.
      name: the name of the sensor.
      dtype: data type of sensor value.
    """
        super(DesiredDirection, self).__init__(
            name=name,
            shape=(2,),  # cos(Y),sin(Y)
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        default_yaw = self._robot.GetDesiredDirection()
        return default_yaw

# Desired turning direction
class DesiredTurningDirection(sensor.BoxSpaceSensor):
    """A sensor that desired turning direction of a robot."""

    def __init__(self,
                 lower_bound: _FLOAT_OR_ARRAY = -1.0,
                 upper_bound: _FLOAT_OR_ARRAY = 1.0,
                 name: typing.Text = "Desired_turning_direction",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs Desired_turning_direction.

    Args:
      lower_bound: the lower bound of the turning of the robot.
      upper_bound: the upper bound of the turning of the robot.
      name: the name of the sensor.
      dtype: data type of sensor value.
    """
        super(DesiredTurningDirection, self).__init__(
            name=name,
            shape=(1,),  # [-1,0,1]
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        default_turning = self._robot.GetDesiredTurningDirection()
        return default_turning

# Gravity vector
class IMUSensor(sensor.BoxSpaceSensor):
    """An IMU sensor that reads orientations and angular velocities."""

    def __init__(self,
                 channels: typing.Iterable[typing.Text] = None,
                 noisy_reading: bool = True,
                 lower_bound: _FLOAT_OR_ARRAY = None,
                 upper_bound: _FLOAT_OR_ARRAY = None,
                 name: typing.Text = "IMU",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs IMUSensor.

    It generates separate IMU value channels, e.g. IMU_R, IMU_P, IMU_dR, ...

    Args:
      channels: value channels wants to subscribe. A upper letter represents
        orientation and a lower letter represents angular velocity. (e.g. ['R',
        'P', 'Y', 'dR', 'dP', 'dY'] or ['R', 'P', 'dR', 'dP'])
      noisy_reading: whether values are true observations
      lower_bound: the lower bound IMU values
        (default: [-2pi, -2pi, -2000pi, -2000pi])
      upper_bound: the lower bound IMU values
        (default: [2pi, 2pi, 2000pi, 2000pi])
      name: the name of the sensor
      dtype: data type of sensor value
    """
        # self._channels = channels if channels else ["R", "P","Y", "dR", "dP","dY"]
        self._channels = channels if channels else ["R", "P", "Y"]
        self._num_channels = len(self._channels)
        self._noisy_reading = noisy_reading

        # Compute the default lower and upper bounds
        if lower_bound is None and upper_bound is None:
            lower_bound = []
            upper_bound = []
            for channel in self._channels:
                if channel in ["R", "P", "Y"]:
                    lower_bound.append(-2.0 * np.pi)
                    upper_bound.append(2.0 * np.pi)
                elif channel in ["Rcos", "Rsin", "Pcos", "Psin", "Ycos", "Ysin"]:
                    lower_bound.append(-1.)
                    upper_bound.append(1.)
                elif channel in ["dR", "dP", "dY"]:
                    lower_bound.append(-2000.0 * np.pi)
                    upper_bound.append(2000.0 * np.pi)

        super(IMUSensor, self).__init__(
            name=name,
            shape=(self._num_channels,),
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            dtype=dtype)

        # Compute the observation_datatype
        datatype = [("{}_{}".format(name, channel), self._dtype)
                    for channel in self._channels]

        self._datatype = datatype

    def get_channels(self) -> typing.Iterable[typing.Text]:
        return self._channels

    def get_num_channels(self) -> int:
        return self._num_channels

    def get_observation_datatype(self) -> _DATATYPE_LIST:
        """Returns box-shape data type."""
        return self._datatype

    def _get_observation(self) -> _ARRAY:
        if self._noisy_reading:
            rpy = self._robot.GetBaseRollPitchYaw()
            drpy = self._robot.GetBaseRollPitchYawRate()
        else:
            rpy = self._robot.GetTrueBaseRollPitchYaw()
            drpy = self._robot.GetTrueBaseRollPitchYawRate()

        assert len(rpy) >= 3, rpy
        assert len(drpy) >= 3, drpy

        observations = np.zeros(self._num_channels)
        for i, channel in enumerate(self._channels):
            if channel == "R":
                observations[i] = rpy[0]
            if channel == "Rcos":
                observations[i] = np.cos(rpy[0])
            if channel == "Rsin":
                observations[i] = np.sin(rpy[0])
            if channel == "P":
                observations[i] = rpy[1]
            if channel == "Pcos":
                observations[i] = np.cos(rpy[1])
            if channel == "Psin":
                observations[i] = np.sin(rpy[1])
            if channel == "Y":
                observations[i] = rpy[2]
            if channel == "Ycos":
                observations[i] = np.cos(rpy[2])
            if channel == "Ysin":
                observations[i] = np.sin(rpy[2])
            if channel == "dR":
                observations[i] = drpy[0]
            if channel == "dP":
                observations[i] = drpy[1]
            if channel == "dY":
                observations[i] = drpy[2]
        return observations

# Base angular velocity
class BaseAngularVelocitySensor(sensor.BoxSpaceSensor):
    """A sensor that reads angular velocities."""

    def __init__(self,
                 channels: typing.Iterable[typing.Text] = None,
                 noisy_reading: bool = True,
                 lower_bound: _FLOAT_OR_ARRAY = None,
                 upper_bound: _FLOAT_OR_ARRAY = None,
                 name: typing.Text = "AngularVelocities",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs IMUSensor.

    It generates separate IMU value channels, e.g. IMU_R, IMU_P, IMU_dR, ...

    Args:
      channels: value channels wants to subscribe. A upper letter represents
        orientation and a lower letter represents angular velocity. (e.g. ['R',
        'P', 'Y', 'dR', 'dP', 'dY'] or ['R', 'P', 'dR', 'dP'])
      noisy_reading: whether values are true observations
      lower_bound: the lower bound IMU values
        (default: [-2pi, -2pi, -2000pi, -2000pi])
      upper_bound: the lower bound IMU values
        (default: [2pi, 2pi, 2000pi, 2000pi])
      name: the name of the sensor
      dtype: data type of sensor value
    """
        # self._channels = channels if channels else ["R", "P","Y", "dR", "dP","dY"]
        self._channels = channels if channels else [ "dR", "dP","dY"]
        self._num_channels = len(self._channels)
        self._noisy_reading = noisy_reading

        # Compute the default lower and upper bounds
        if lower_bound is None and upper_bound is None:
            lower_bound = []
            upper_bound = []
            for channel in self._channels:
                if channel in ["R", "P", "Y"]:
                    lower_bound.append(-2.0 * np.pi)
                    upper_bound.append(2.0 * np.pi)
                elif channel in ["Rcos", "Rsin", "Pcos", "Psin", "Ycos", "Ysin"]:
                    lower_bound.append(-1.)
                    upper_bound.append(1.)
                elif channel in ["dR", "dP", "dY"]:
                    lower_bound.append(-2000.0 * np.pi)
                    upper_bound.append(2000.0 * np.pi)

        super(BaseAngularVelocitySensor, self).__init__(
            name=name,
            shape=(self._num_channels,),
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            dtype=dtype)

        # Compute the observation_datatype
        datatype = [("{}_{}".format(name, channel), self._dtype)
                    for channel in self._channels]

        self._datatype = datatype

    def get_channels(self) -> typing.Iterable[typing.Text]:
        return self._channels

    def get_num_channels(self) -> int:
        return self._num_channels

    def get_observation_datatype(self) -> _DATATYPE_LIST:
        """Returns box-shape data type."""
        return self._datatype

    def _get_observation(self) -> _ARRAY:
        if self._noisy_reading:
            rpy = self._robot.GetBaseRollPitchYaw()
            drpy = self._robot.GetBaseRollPitchYawRate()
        else:
            rpy = self._robot.GetTrueBaseRollPitchYaw()
            drpy = self._robot.GetTrueBaseRollPitchYawRate()

        assert len(rpy) >= 3, rpy
        assert len(drpy) >= 3, drpy

        observations = np.zeros(self._num_channels)
        for i, channel in enumerate(self._channels):
            if channel == "R":
                observations[i] = rpy[0]
            if channel == "Rcos":
                observations[i] = np.cos(rpy[0])
            if channel == "Rsin":
                observations[i] = np.sin(rpy[0])
            if channel == "P":
                observations[i] = rpy[1]
            if channel == "Pcos":
                observations[i] = np.cos(rpy[1])
            if channel == "Psin":
                observations[i] = np.sin(rpy[1])
            if channel == "Y":
                observations[i] = rpy[2]
            if channel == "Ycos":
                observations[i] = np.cos(rpy[2])
            if channel == "Ysin":
                observations[i] = np.sin(rpy[2])
            if channel == "dR":
                observations[i] = drpy[0]
            if channel == "dP":
                observations[i] = drpy[1]
            if channel == "dY":
                observations[i] = drpy[2]
        return observations

# Base linear velocity
class BaseLinearVelocitySensor(sensor.BoxSpaceSensor):
    """A sensor that reads the base linear velocity of the robot."""

    def __init__(self,
                 lower_bound: _FLOAT_OR_ARRAY = -2.0,
                 upper_bound: _FLOAT_OR_ARRAY = 2.0,
                 name: typing.Text = "LinearVelocity",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs BasePositionSensor.

    Args:
      lower_bound: the lower bound of the base position of the robot.
      upper_bound: the upper bound of the base position of the robot.
      name: the name of the sensor
      dtype: data type of sensor value
    """
        super(BaseLinearVelocitySensor, self).__init__(
            name=name,
            shape=(3,),  # dx, dy, dz
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        velocity,_ = self._robot.GetBaseVelocity()
        return velocity

# Joint position
class JointPositionSensor(sensor.BoxSpaceSensor):
    """A sensor that reads motor angles from the robot."""

    def __init__(self,
                 num_motors: int,
                 noisy_reading: bool = True,
                 observe_sine_cosine: bool = False,
                 lower_bound: _FLOAT_OR_ARRAY = -np.pi,
                 upper_bound: _FLOAT_OR_ARRAY = np.pi,
                 name: typing.Text = "JointPosition",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs MotorAngleSensor.

    Args:
      num_motors: the number of motors in the robot
      noisy_reading: whether values are true observations
      observe_sine_cosine: whether to convert readings to sine/cosine values for
        continuity
      lower_bound: the lower bound of the motor angle
      upper_bound: the upper bound of the motor angle
      name: the name of the sensor
      dtype: data type of sensor value
    """
        self._num_motors = num_motors
        self._noisy_reading = noisy_reading
        self._observe_sine_cosine = observe_sine_cosine

        if observe_sine_cosine:
            super(JointPositionSensor, self).__init__(
                name=name,
                shape=(self._num_motors * 2,),
                lower_bound=-np.ones(self._num_motors * 2),
                upper_bound=np.ones(self._num_motors * 2),
                dtype=dtype)
        else:
            super(JointPositionSensor, self).__init__(
                name=name,
                shape=(self._num_motors,),
                lower_bound=lower_bound,
                upper_bound=upper_bound,
                dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        if self._noisy_reading:
            motor_angles = self._robot.GetMotorAngles()
        else:
            motor_angles = self._robot.GetTrueMotorAngles()

        if self._observe_sine_cosine:
            return np.hstack((np.cos(motor_angles), np.sin(motor_angles)))
        else:
            return motor_angles

# Joint velocity
class JointVelocitySensor(sensor.BoxSpaceSensor):
    """A sensor that reads motor angles from the robot."""

    def __init__(self,
                 num_motors: int,
                 noisy_reading: bool = True,
                 observe_sine_cosine: bool = False,
                 lower_bound: _FLOAT_OR_ARRAY = -10.0,
                 upper_bound: _FLOAT_OR_ARRAY = 10.0,
                 name: typing.Text = "JointVelocity",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs MotorAngleSensor.

    Args:
      num_motors: the number of motors in the robot
      noisy_reading: whether values are true observations
      observe_sine_cosine: whether to convert readings to sine/cosine values for
        continuity
      lower_bound: the lower bound of the motor angle
      upper_bound: the upper bound of the motor angle
      name: the name of the sensor
      dtype: data type of sensor value
    """
        self._num_motors = num_motors
        self._noisy_reading = noisy_reading
        self._observe_sine_cosine = observe_sine_cosine

        if observe_sine_cosine:
            super(JointVelocitySensor, self).__init__(
                name=name,
                shape=(self._num_motors * 2 ,),
                lower_bound=-np.ones(self._num_motors * 2),
                upper_bound=np.ones(self._num_motors * 2),
                dtype=dtype)
        else:
            super(JointVelocitySensor, self).__init__(
                name=name,
                shape=(self._num_motors,),
                lower_bound=lower_bound,
                upper_bound=upper_bound,
                dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        if self._noisy_reading:
            motor_velocities = self._robot.GetMotorVelocities()
        else:
            motor_velocities = self._robot.GetTrueMotorVelocities()

        if self._observe_sine_cosine:
            return np.hstack((np.cos(motor_velocities), np.sin(motor_velocities)))
        else:
            return motor_velocities

# FTG phases
class FTGPhasesSensor(sensor.BoxSpaceSensor):
    """A sensor that reads the FTG phases of the robot."""

    def __init__(self,
                 lower_bound: _FLOAT_OR_ARRAY = -1.0,
                 upper_bound: _FLOAT_OR_ARRAY = 1.0,
                 name: typing.Text = "FTGPhases",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs FTGPhasesSensor.

    Args:
      lower_bound: the lower bound of the FTG phases of the robot.
      upper_bound: the upper bound of the FTG phases of the robot.
      name: the name of the sensor
      dtype: data type of sensor value
    """
        super(FTGPhasesSensor, self).__init__(
            name=name,
            shape=(8,),
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        return self._robot.GetFTGPhases()

# FTG frequencies
class FTGFrequenciesSensor(sensor.BoxSpaceSensor):
    """A sensor that reads the FTG frequencies of the robot."""

    def __init__(self,
                 lower_bound: _FLOAT_OR_ARRAY = -0.5,
                 upper_bound: _FLOAT_OR_ARRAY = 0.5,
                 name: typing.Text = "FTGFrequencies",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs FTGFrequenciesSensor.

    Args:
      lower_bound: the lower bound of the FTG frequencies of the robot.
      upper_bound: the upper bound of the FTG frequencies of the robot.
      name: the name of the sensor
      dtype: data type of sensor value
    """
        super(FTGFrequenciesSensor, self).__init__(
            name=name,
            shape=(4,),
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        return self._robot.GetFTGFrequencies()

# Base frequency
class BaseFrequencySensor(sensor.BoxSpaceSensor):
    """A sensor that reads the BaseFrequency of the robot."""

    def __init__(self,
                 lower_bound: _FLOAT_OR_ARRAY = -0.0,
                 upper_bound: _FLOAT_OR_ARRAY = 2.0,
                 name: typing.Text = "BaseFrequency",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs BaseFrequencySensor.

    Args:
      lower_bound: the lower bound of the BaseFrequency of the robot.
      upper_bound: the upper bound of the BaseFrequency of the robot.
      name: the name of the sensor
      dtype: data type of sensor value
    """
        super(BaseFrequencySensor, self).__init__(
            name=name,
            shape=(1,),
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        return self._robot.GetBaseFrequency()

# Joint position error history
class JointPositionErrorHistorySensor(sensor.BoxSpaceSensor):
    """A sensor that reads JointPositionErrorHistory from the robot."""

    def __init__(self,
                 num_motors: int,
                 noisy_reading: bool = True,
                 observe_sine_cosine: bool = False,
                 lower_bound: _FLOAT_OR_ARRAY = -np.pi,
                 upper_bound: _FLOAT_OR_ARRAY = np.pi,
                 name: typing.Text = "JointPositionErrorHistory",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs JointPositionErrorHistory.

    Args:
      num_motors: the number of motors in the robot
      noisy_reading: whether values are true observations
      observe_sine_cosine: whether to convert readings to sine/cosine values for
        continuity
      lower_bound: the lower bound of the motor angle
      upper_bound: the upper bound of the motor angle
      name: the name of the sensor
      dtype: data type of sensor value
    """
        self._num_motors = num_motors
        self._noisy_reading = noisy_reading
        self._observe_sine_cosine = observe_sine_cosine

        if observe_sine_cosine:
            super(JointPositionErrorHistorySensor, self).__init__(
                name=name,
                shape=(self._num_motors * 2,),
                lower_bound=-np.ones(self._num_motors * 2),
                upper_bound=np.ones(self._num_motors * 2),
                dtype=dtype)
        else:
            super(JointPositionErrorHistorySensor, self).__init__(
                name=name,
                shape=(self._num_motors,),
                lower_bound=lower_bound,
                upper_bound=upper_bound,
                dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        # if self._noisy_reading:
        motor_angles_error_history = self._robot.GetJointPositionErrorHistory()
        # else:
        #     motor_angles = self._robot.GetTrueMotorAngles()

        # if self._observe_sine_cosine:
        #     return np.hstack((np.cos(motor_angles), np.sin(motor_angles)))
        # else:
        return motor_angles_error_history

# Joint velocity history
class JointVelocityHistorySensor(sensor.BoxSpaceSensor):
    """A sensor that reads motor angles from the robot."""

    def __init__(self,
                 num_motors: int,
                 noisy_reading: bool = True,
                 observe_sine_cosine: bool = False,
                 lower_bound: _FLOAT_OR_ARRAY = -10.0,
                 upper_bound: _FLOAT_OR_ARRAY = 10.0,
                 name: typing.Text = "JointVelocityHistory",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs MotorAngleSensor.

    Args:
      num_motors: the number of motors in the robot
      noisy_reading: whether values are true observations
      observe_sine_cosine: whether to convert readings to sine/cosine values for
        continuity
      lower_bound: the lower bound of the motor angle
      upper_bound: the upper bound of the motor angle
      name: the name of the sensor
      dtype: data type of sensor value
    """
        self._num_motors = num_motors
        self._noisy_reading = noisy_reading
        self._observe_sine_cosine = observe_sine_cosine

        if observe_sine_cosine:
            super(JointVelocityHistorySensor, self).__init__(
                name=name,
                shape=(self._num_motors * 2 ,),
                lower_bound=-np.ones(self._num_motors * 2),
                upper_bound=np.ones(self._num_motors * 2),
                dtype=dtype)
        else:
            super(JointVelocityHistorySensor, self).__init__(
                name=name,
                shape=(self._num_motors,),
                lower_bound=lower_bound,
                upper_bound=upper_bound,
                dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        if self._noisy_reading:
            motor_velocities_history = self._robot.GetJointVelocityHistory()
        else:
            motor_velocities_history = self._robot.GetJointVelocityHistory()

        # if self._observe_sine_cosine:
        #     return np.hstack((np.cos(motor_velocities), np.sin(motor_velocities)))
        # else:
        return motor_velocities_history

# Foot target history
class FootTargetHistorySensor(sensor.BoxSpaceSensor):
    """A sensor that reads the FootTargetHistory of the robot."""

    def __init__(self,
                 num_motors: int,
                 lower_bound: _FLOAT_OR_ARRAY = -1.0,
                 upper_bound: _FLOAT_OR_ARRAY = 1.0,
                 name: typing.Text = "FootTargetHistory",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs FootTargetHistorySensor.

    Args:
      lower_bound: the lower bound of the FootTargetHistory of the robot.
      upper_bound: the upper bound of the FootTargetHistory of the robot.
      name: the name of the sensor
      dtype: data type of sensor value
    """
        self._num_motors = num_motors
        super(FootTargetHistorySensor, self).__init__(
            name=name,
            shape=(self._num_motors,),
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        return self._robot.GetFootTargetHistory()

# Qdot target history
class QdotTargetHistorySensor(sensor.BoxSpaceSensor):
    """A sensor that reads the QdotTargetHistory of the robot."""

    def __init__(self,
                 num_motors: int,
                 lower_bound: _FLOAT_OR_ARRAY = -5.0,
                 upper_bound: _FLOAT_OR_ARRAY = 5.0,
                 name: typing.Text = "QdotTargetHistory",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs QdotTargetHistorySensor.

    Args:
      lower_bound: the lower bound of the QdotTargetHistory of the robot.
      upper_bound: the upper bound of the QdotTargetHistory of the robot.
      name: the name of the sensor
      dtype: data type of sensor value
    """
        self._num_motors = num_motors
        super(QdotTargetHistorySensor, self).__init__(
            name=name,
            shape=(self._num_motors,),
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        return self._robot.GetQdotTargetHistory()

# Torque target history
class TorqueTargetHistorySensor(sensor.BoxSpaceSensor):
    """A sensor that reads the TorqueTargetHistory of the robot."""

    def __init__(self,
                 num_motors: int,
                 lower_bound: _FLOAT_OR_ARRAY = -2.0,
                 upper_bound: _FLOAT_OR_ARRAY = 2.0,
                 name: typing.Text = "TorqueTargetHistory",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs TorqueTargetHistorySensor.

    Args:
      lower_bound: the lower bound of the TorqueTargetHistory of the robot.
      upper_bound: the upper bound of the TorqueTargetHistory of the robot.
      name: the name of the sensor
      dtype: data type of sensor value
    """
        self._num_motors = num_motors
        super(TorqueTargetHistorySensor, self).__init__(
            name=name,
            shape=(self._num_motors,),
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        return self._robot.GetTorqueTargetHistory()

class PrivilegedInformationSensor(sensor.BoxSpaceSensor):
    """the information goes to MLP"""

    def __init__(self,
                 lower_bound: _FLOAT_OR_ARRAY = -100,
                 upper_bound: _FLOAT_OR_ARRAY = 100,
                 name: typing.Text = "PrivilegedInfoSensor",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:

        self._channels = [
                          "foot_contact_forces",
                          "foot_contact_states",
                          "thigh_contact_states",
                          "hip_contact_states",
                          "friction_coefficients",
                          "external_force"]
        self.scanner_info = None
        self.position = None
        self.scanner_target = None
        self.num_channels = 4 * 5 + 3 + 26
        lower_bound = [float(-1e10)] * self.num_channels
        upper_bound = [float(1e10)] * self.num_channels

        self._pybullet_client = None

        super(PrivilegedInformationSensor, self).__init__(
            name=name,
            shape=(self.num_channels,),  #
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            dtype=dtype)

    """A sensor that reads environment."""


    def on_reset(self, env):
        del env

    #
    # def get_terrain_normal(self):
    #     return self.set_height_scan()[2]
    #
    # def get_height_scan(self):
    #     """9 points around each foot"""
    #     # return [s * self._robot.terrain_scale for s in self.set_height_scan()[1]]
    #     return self.set_height_scan()[1]

    def one_hot(self, _list):
        return [1 if l else 0 for l in _list]

    def get_foot_contact_forces(self):
        return self._robot.GetFootContactForces()

    def get_foot_contact_states(self):
        return self.one_hot(self._robot.GetA1legsContacts()["foot_contacts"])

    def get_shank_contact_states(self):
        return self.one_hot(self._robot.GetA1legsContacts()["shank_contacts"])

    def get_thigh_contact_states(self):
        return self.one_hot(self._robot.GetA1legsContacts()["thigh_contacts"])

    def get_external_force(self):
        return self._robot.GetExternalForce()

    def get_friction_coefficient(self):
        return self._robot.GetFootFriction()

    def get_physical_parameters(self):
        return self._robot.GetPhysicalParameters()

    # def height_map_conveter(self, x, y):
    #     # return int((x + 10) * 26), int((y + 10) * 26)
    #     scale = 1 / self._robot.terrain_scale
    #     map_x = int((x * scale) + self._robot.height_map.shape[0] / 2)
    #     map_y = int((y * scale) + self._robot.height_map.shape[1] / 2)
    #     if map_x < 0 or map_x >= self._robot.height_map.shape[0] or\
    #         map_y < 0 or map_y >= self._robot.height_map.shape[1]:
    #         map_x = map_y = 0
    #         self.safe = False
    #     return map_x, map_y


    # def set_height_scan(self):
    #     # height_scanner = []
    #     scanner_target = []
    #     scanner_info = []
    #     self_position = []
    #     body_location = self._robot.GetBasePosition()
    #     # print(self.height_map_conveter(body_location[0], body_location[1]))
    #     for f in self._robot.GetA1FootPositionsInWorldFrame():
    #         foot_center_x, foot_center_y = f
    #         self_position.append(foot_center_x)
    #         self_position.append(foot_center_y)
    #         self_position.append(self._robot.height_map[self.height_map_conveter(foot_center_x, foot_center_y)])
    #
    #         gap_angle = 2 * np.pi / 9
    #         for i in range(9):
    #             i_x = foot_center_x + np.sin(gap_angle * i) * 0.1
    #             i_y = foot_center_y + np.cos(gap_angle * i) * 0.1
    #             scanner_target.append(self.height_map_conveter(i_x, i_y))
    #             scanner_info.append(self._robot.height_map[scanner_target[i]])
    #             # plot
    #             # if self._pybullet_client is not None:
    #             #     self._pybullet_client.resetBasePositionAndOrientation(
    #             #         self.spot_list[i], [i_x, i_y, self.height_map_conveter(i_x, i_y)], [0, 0, 0, 1])
    #             # print()
    #     self.scanner_info = scanner_info
    #     self.foot_position_height = self_position
    #     self.scanner_target = scanner_target
    #     return [scanner_target, scanner_info, self_position]

    def _get_observation(self) -> _ARRAY:
        return np.squeeze(np.array([
                         self.get_foot_contact_forces()+
                         self.get_foot_contact_states()+
                         self.get_shank_contact_states()+
                         self.get_thigh_contact_states()+
                         self.get_friction_coefficient()+
                         self.get_external_force()+
                         self.get_physical_parameters()

        ]))


class HeightSensor(sensor.BoxSpaceSensor):
    """A sensor that reads the height around a robot."""

    def __init__(self,
                 lower_bound: _FLOAT_OR_ARRAY = None,
                 upper_bound: _FLOAT_OR_ARRAY = None,
                 name: typing.Text = "HeightSensor",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs HeightSensor.

    Args:
      lower_bound: the lower bound of the height around the robot.
      upper_bound: the upper bound of the height around the robot.
      name: the name of the sensor.
      dtype: data type of sensor value.
    """
        self.shape = (1,20,20)
        self.height_local_map = np.zeros(shape=self.shape)
        self.score = np.zeros(shape=(self.shape[1], self.shape[2]))
        self.map_score = 0

        self.num_channels = 36
        self.foot_score = [0]*self.num_channels
        self.foot_surrounding_height = [0]*self.num_channels
        self.scan_dots_location = [0]*self.num_channels

        self.score_length = 3
        self.score_width = 3

        self.is_safe = True
        # self.num_channels = self.shape[0]*self.shape[1]*self.shape[2]
        lower_bound = [float(-1e10)] * self.num_channels
        upper_bound = [float(1e10)] * self.num_channels
        super(HeightSensor, self).__init__(
            name=name,
            shape=(self.num_channels,),  #
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            dtype=dtype)
        # super(HeightSensor, self).__init__(
        #     name=name,
        #     shape= (self.shape[0]*self.shape[1]*self.shape[2],),  # x, y, orientation
        #     lower_bound=np.array([-1.]*self.shape[0]*self.shape[1]*self.shape[2]),
        #     upper_bound=np.array([1.]*self.shape[0]*self.shape[1]*self.shape[2]),
        #     dtype=dtype)

    def on_reset(self, env):
        # self.map_score = self.heightmap2score(env.height_map, env.height_map.shape[0], env.height_map.shape[1])
        del env
        self.is_safe = True
        self.height_local_map = np.zeros(shape=self.shape)
        self.score = np.zeros(shape=(self.shape[1],self.shape[2]))
        self.foot_score = [0]*self.num_channels
        self.foot_surrounding_height = [0]*self.num_channels

    def height_map_converter(self, height_map, scale, x, y):
        if x != x:
            self.is_safe = False
        else:
            map_x = int((x * scale) + height_map.shape[0] / 2)
            map_y = int((y * scale) + height_map.shape[1] / 2)
            if map_x < 0 or map_x >= height_map.shape[0] or map_y < 0 or map_y >= height_map.shape[1]:
                map_x = map_y = 0
                self.is_safe = False

            return map_x, map_y

    def on_step(self, env):
        # --------------------------------------
        # body_location = self._robot.GetBasePosition()
        # map_scale = 0.02
        # start_x = body_location[0] - self.shape[1]//2 * map_scale
        # start_y = body_location[1] - self.shape[2]//2 * map_scale
        # body_matrix = []
        # # self.scan_dots_location = []
        # self.body_height = env.height_map[self.height_map_converter(env.height_map, 1/env.terrain_scale,
        #                                                       body_location[0], body_location[1])]
        # for _x in range(self.shape[1]):
        #     x_pos = start_x + _x * map_scale
        #     for _y in range(self.shape[2]):
        #         y_pos = start_y + _y * map_scale
        #         hei = env.height_map[self.height_map_converter(env.height_map, 1/env.terrain_scale, x_pos, y_pos)]
        #         # if _x%5==0 and _y%5==0:
        #         #     self.scan_dots_location.append([x_pos, y_pos, hei])
        #         body_matrix.append(hei)
        # body_matrix = np.reshape(body_matrix,newshape=(self.shape[1],self.shape[2]))
        # self.height_local_map[0] = np.subtract(body_matrix, self.body_height)
        #-----------------------------------------------------------
        # self.height_local_map[0] = body_matrix
        # print(self.height_local_map[1])
        # input()
        # self.on_process()

        # foothold = []
        # k = 0
        # for f in self._robot.GetFootPositionsInBaseFrame():
        #     x_ind = int(np.clip((f[0] + self.shape[1]//2 * map_scale)//map_scale, a_min=0, a_max=self.shape[1]-1))
        #     y_ind = int(np.clip((f[1] + self.shape[2]//2 * map_scale)//map_scale, a_min=0, a_max=self.shape[2]-1))
        #     foot_terrain_height = self.height_local_map[0][x_ind, y_ind]
        #     foothold.append(self.score[x_ind, y_ind])
        #     for i in range(-1,2):
        #         x_ind = np.clip(x_ind + i, a_min=0, a_max=self.shape[1] - 1)
        #         for j in range(-1,2):
        #             y_ind = np.clip(y_ind+j,a_min=0,a_max=self.shape[2]-1)
        #             self.foot_surrounding_height[k]=self.height_local_map[0][x_ind, y_ind]-foot_terrain_height
        #             # self.foot_surrounding_height[k] = self.height_local_map[0][x_ind, y_ind]
        #             self.foot_score[k]=self.score[x_ind, y_ind]
        #             k+=1

        # foothold = []
        # k = 0
        # self.scan_dots_location = []
        # for f in self._robot.GetFootPositionsInBaseFrame():
        #     f_x = f[0] + self.shape[1]//2 * map_scale
        #     f_y = f[1] + self.shape[2]//2 * map_scale
        #
        #     x_ind = int(np.clip((f[0] + self.shape[1]//2 * map_scale)//map_scale, a_min=0, a_max=self.shape[1]-1))
        #     y_ind = int(np.clip((f[1] + self.shape[2]//2 * map_scale)//map_scale, a_min=0, a_max=self.shape[2]-1))
        #     foot_terrain_height = self.height_local_map[0][x_ind, y_ind]
        #     foothold.append(self.score[x_ind, y_ind])
        #
        #     for i in [-map_scale,0,map_scale]:
        #         f_x_ = f_x + i
        #         x_ind = int(np.clip(f_x_ // map_scale, a_min=0, a_max=self.shape[1] - 1))
        #         for j in [-map_scale, 0, map_scale]:
        #             f_y_ = f_y + j
        #             y_ind = int(np.clip(f_y_ // map_scale, a_min=0, a_max=self.shape[2] - 1))
        #             self.foot_surrounding_height[k]=self.height_local_map[0][x_ind, y_ind]-foot_terrain_height
        #             # self.foot_surrounding_height[k] = self.height_local_map[0][x_ind, y_ind]
        #             self.scan_dots_location.append([f_x_,f_y_,self.foot_surrounding_height[k]])
        #             self.foot_score[k]=self.score[x_ind, y_ind]
        #             k+=1

        # env.foot_score = foothold
        # env.foot_surrounding_height = self.foot_surrounding_height

        # # ---------------------------------
        # scanner_target = []
        self.scan_dots_location = []
        # scanner_info = []
        # self_position = []
        body_location = self._robot.GetBasePosition()
        # score-----------------------
        body_location_ind = self.height_map_converter(env.height_map,1/env.terrain_scale,body_location[0], body_location[1])

        # print(self.height_map_conveter(body_location[0], body_location[1]))
        for i in range(self.shape[1]):
            body_center_x = body_location[0] + (-self.shape[1]//2 + i)*env.terrain_scale*4
            for j in range(self.shape[2]):
                body_center_y = body_location[1] + (-self.shape[2] // 2 + j) * env.terrain_scale*4
                map_height = self._robot.height_map[self.height_map_converter(env.height_map,
                                                                               1/env.terrain_scale,
                                                                               body_center_x,
                                                                               body_center_y)]
                self.height_local_map[0, i, j] = map_height
                # if i%2==0 and j%2==0:
                #     self.scan_dots_location.append([body_center_x, body_center_y, map_height])

        # score-----------------------
        confidence_interval = self.get_confidence_bound()
        self.score = self.heightmap2score(self.height_local_map[0],self.shape[1],self.shape[2], confidence_interval)

        # score-----------------------
        foot_score = []

        for ind,f in enumerate(self._robot.GetA1FootPositionsInWorldFrame()):
            foot_center_x, foot_center_y = f
            # self_position.append(foot_center_x)
            # self_position.append(foot_center_y)
            foothlod_height = self._robot.height_map[self.height_map_converter(env.height_map,
                                                                               1/env.terrain_scale,
                                                                               foot_center_x,
                                                                               foot_center_y)]
            gap_angle = 2 * np.pi / 9
            for i in range(9):
                i_x = foot_center_x + np.sin(gap_angle * i) * 0.1
                i_y = foot_center_y + np.cos(gap_angle * i) * 0.1
                scanner_target=self.height_map_converter(env.height_map, 1/env.terrain_scale, i_x, i_y)
                fsh = self._robot.height_map[scanner_target]
                self.foot_surrounding_height[ind*9+i]=fsh - foothlod_height
                self.scan_dots_location.append([i_x, i_y, fsh])

            # score-----------------------
            for m in range(self.score_length):
                foot_x = foot_center_x + 4*(-self.score_length / 2 + m) * env.terrain_scale
                for n in range(self.score_width):
                    foot_y = foot_center_y + 4*(-self.score_width / 2 + n) * env.terrain_scale
                    foot_ind = self.height_map_converter(env.height_map,
                                                          1 / env.terrain_scale,
                                                          foot_x,
                                                          foot_y)
                    foot_ind_InBodyFrame = np.clip(foot_ind[0]-body_location_ind[0],a_min=0,a_max=self.shape[1]-1), \
                                           np.clip(foot_ind[1] - body_location_ind[1], a_min=0, a_max=self.shape[2]-1)
                    foot_score.append(self.score[foot_ind_InBodyFrame])
                    # self.scan_dots_location.append([foot_x, foot_y,
                    #                                 self.height_local_map[0,foot_ind_InBodyFrame[0],foot_ind_InBodyFrame[1]]])
        #
        # print("---------")
        # for ii in range(4):
        #     for jj in range(9):
        #         print("self.foot_surrounding_height:",self.foot_surrounding_height[ii * 9 + jj])
        # input()
        env.foot_surrounding_height = self.foot_surrounding_height
        # score-----------------------
        env.foot_score_info = [foot_score,self.score_length,self.score_width]


    def get_confidence_bound(self):
        mu = np.squeeze(self.height_local_map[0])
        sigma = 0.01
        # lower, upper = mu - 2 * sigma, mu
        # X_lower = stats.truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
        X = np.random.normal(loc=mu, scale=sigma, size=np.shape(mu))
        X_lower = np.where(X>mu,-X,X)
        # X_lower = np.clip(X, a_min=mu - 2 * sigma, a_max=mu)

        X = np.random.normal(loc=mu, scale=sigma, size=np.shape(mu))
        X_upper = np.where(X < mu, -X, X)
        # X_upper = np.clip(X, a_min=mu, a_max=mu+2*sigma)
        # print(X_lower<mu)
        # input()
        return X_upper - X_lower

    def heightmap2score(self,height, x_size, y_size,  confidence_interval):
        # score wieght
        # w_sd=60.
        # w_sl=30.
        # w_max=10.
        # w_min=10.

        w_sd = 5
        w_sl = 2
        w_max = 1
        w_min = 1
        w_ci = 1

        score = np.zeros((x_size, y_size))
        for i in range(x_size):
            for j in range(y_size):
                if i >= 1 and j >= 1 and i < x_size - 1 and j < y_size - 1:
                    slope, sum = 0., 0.
                    hmax = height[i][j]
                    hmin = height[i][j]
                    tmpmap = np.zeros(9)
                    k = 0
                    for m in [i - 1, i, i + 1]:
                        for n in [j - 1, j, j + 1]:
                            # // slope
                            if not (m == i and n == j):
                                slope = abs(
                                    (height[i][j] - height[m][n]) / np.sqrt(pow((i - m), 2) + pow((j - n), 2))) + slope
                            # //max
                            if height[m][n] > hmax:
                                hmax = height[m][n]
                            # //min
                            if height[m][n] < hmin:
                                hmin = height[m][n]
                            #   sum
                            sum = sum + height[m][n]
                            tmpmap[k] = height[m][n]  # 暂存周围区域的高度值
                            k += 1

                    # //标准差
                    mean = sum / 9.
                    sd_sum = 0.
                    for l in range(9):
                        sd_sum = sd_sum + pow((tmpmap[l] - mean), 2)
                    sd = np.sqrt(sd_sum / 9.)

                    # //平均斜率
                    slope = slope / 8.
                    # //traversability map
                    score[i][j] = w_sd * sd + w_sl * slope + w_max * (hmax - height[i][j]) + w_min * (
                                height[i][j] - hmin)

                elif i == 0 or j == 0 or i == x_size - 1 or j == y_size - 1:
                    score[i][j] = 0.

                # if i==1 and j==10:
                #     print(hmax)
                #     print(hmin)
                #     print(sd)
                #     print(slope)
                #     print(height[i][j])
                #     print(score[i][j])

        score += w_ci*confidence_interval
        score = np.clip(score, a_min=0,a_max=1)
        return score

    def render_map(self):
        import matplotlib.pylab as plt
        from mpl_toolkits.mplot3d import Axes3D

        fig = plt.figure()
        ax = Axes3D(fig)
        X = np.arange(0, self.shape[1], 1)
        Y = np.arange(0, self.shape[2], 1)
        X, Y = np.meshgrid(X, Y)
        map = np.squeeze(self.score)
        ax.plot_surface(X, Y, map, rstride=1, cstride=1)
        plt.contourf(X, Y,map, 20)
        plt.show()

    def _get_observation(self) -> _ARRAY:
        # self.render_map()
        # height_local_map = np.reshape(self.height_local_map, newshape=(self.shape[0] * self.shape[1] * self.shape[2]))
        # return height_local_map
        return self.foot_surrounding_height

    # def get_shape(self):
    #     return self.shape
    #
    # def isSafe(self):
    #     return self.safe
    #     # for f in self.foot_position_height:
    #
    # def reset(self):
    #     self.safe = True
    #
    # def height_map_conveter(self, x, y):
    #     # return int((x + 10) * 26), int((y + 10) * 26)
    #     scale = 1 / self._robot.terrain_scale
    #     map_x = int((x * scale) + self._robot.height_map.shape[0] / 2)
    #     map_y = int((y * scale) + self._robot.height_map.shape[1] / 2)
    #     if map_x < 0 or map_x >= self._robot.height_map.shape[0] or\
    #         map_y < 0 or map_y >= self._robot.height_map.shape[1]:
    #         map_x = map_y = 0
    #         self.safe = False
    #     return map_x, map_y
    #
    # def get_local_height_map(self):
    #     body_location = self._robot.GetBasePosition()
    #     com_map_id = self.height_map_conveter(body_location[0], body_location[1])
    #     # for i in range(-(self.shape[1])//2,(self.shape[1])//2):
    #     #     for j in range(-(self.shape[2]) // 2, (self.shape[2]) // 2):
    #     for i in range(self.shape[1]):
    #         for j in range(self.shape[2]):
    #             self.height_local_map[1,i,j] = self._robot.height_map[com_map_id[0]-(self.shape[1])//2+i,
    #                                                                 com_map_id[1]-(self.shape[2])//2+j]
    #     # return height_local_map
    #
    # def get_confidence_bound(self):
    #     mu = np.squeeze(self.height_local_map[1])
    #     sigma = 0.01
    #     # lower, upper = mu - 2 * sigma, mu
    #     # X_lower = stats.truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
    #     X = np.random.normal(loc=mu, scale=sigma, size=np.shape(mu))
    #     X_lower = np.where(X>mu,-X,X)
    #     # X_lower = np.clip(X, a_min=mu - 2 * sigma, a_max=mu)
    #
    #     X = np.random.normal(loc=mu, scale=sigma, size=np.shape(mu))
    #     X_upper = np.where(X < mu, -X, X)
    #     # X_upper = np.clip(X, a_min=mu, a_max=mu+2*sigma)
    #     # print(X_lower<mu)
    #     # input()
    #     return X_lower, X_upper
    #
    # def _get_observation(self) -> _ARRAY:
    #     self.get_local_height_map()
    #     X_lower, X_upper = self.get_confidence_bound()
    #     self.height_local_map[0] = X_lower
    #     self.height_local_map[2] = X_upper
    #     height_map = np.reshape(self.height_local_map,newshape=(self.shape[0]*self.shape[1]*self.shape[2]))
    #     return height_map


class MinitaurLegPoseSensor(sensor.BoxSpaceSensor):
    """A sensor that reads leg_pose from the Minitaur robot."""

    def __init__(self,
                 num_motors: int,
                 noisy_reading: bool = True,
                 observe_sine_cosine: bool = False,
                 lower_bound: _FLOAT_OR_ARRAY = -np.pi,
                 upper_bound: _FLOAT_OR_ARRAY = np.pi,
                 name: typing.Text = "MinitaurLegPose",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs MinitaurLegPoseSensor.

    Args:
      num_motors: the number of motors in the robot
      noisy_reading: whether values are true observations
      observe_sine_cosine: whether to convert readings to sine/cosine values for
        continuity
      lower_bound: the lower bound of the motor angle
      upper_bound: the upper bound of the motor angle
      name: the name of the sensor
      dtype: data type of sensor value
    """
        self._num_motors = num_motors
        self._noisy_reading = noisy_reading
        self._observe_sine_cosine = observe_sine_cosine

        if observe_sine_cosine:
            super(MinitaurLegPoseSensor, self).__init__(
                name=name,
                shape=(self._num_motors * 2,),
                lower_bound=-np.ones(self._num_motors * 2),
                upper_bound=np.ones(self._num_motors * 2),
                dtype=dtype)
        else:
            super(MinitaurLegPoseSensor, self).__init__(
                name=name,
                shape=(self._num_motors,),
                lower_bound=lower_bound,
                upper_bound=upper_bound,
                dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        motor_angles = (
            self._robot.GetMotorAngles()
            if self._noisy_reading else self._robot.GetTrueMotorAngles())
        leg_pose = minitaur_pose_utils.motor_angles_to_leg_pose(motor_angles)
        if self._observe_sine_cosine:
            return np.hstack((np.cos(leg_pose), np.sin(leg_pose)))
        else:
            return leg_pose


class BaseDisplacementSensor(sensor.BoxSpaceSensor):
    """A sensor that reads displacement of robot base."""

    def __init__(self,
                 lower_bound: _FLOAT_OR_ARRAY = -0.1,
                 upper_bound: _FLOAT_OR_ARRAY = 0.1,
                 convert_to_local_frame: bool = False,
                 name: typing.Text = "BaseDisplacement",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs BaseDisplacementSensor.

    Args:
      lower_bound: the lower bound of the base displacement
      upper_bound: the upper bound of the base displacement
      convert_to_local_frame: whether to project dx, dy to local frame based on
        robot's current yaw angle. (Note that it's a projection onto 2D plane,
        and the roll, pitch of the robot is not considered.)
      name: the name of the sensor
      dtype: data type of sensor value
    """

        self._channels = ["x", "y", "z"]
        self._num_channels = len(self._channels)

        super(BaseDisplacementSensor, self).__init__(
            name=name,
            shape=(self._num_channels,),
            lower_bound=np.array([lower_bound] * 3),
            upper_bound=np.array([upper_bound] * 3),
            dtype=dtype)

        datatype = [("{}_{}".format(name, channel), self._dtype)
                    for channel in self._channels]
        self._datatype = datatype
        self._convert_to_local_frame = convert_to_local_frame

        self._last_yaw = 0
        self._last_base_position = np.zeros(3)
        self._current_yaw = 0
        self._current_base_position = np.zeros(3)

    def get_channels(self) -> typing.Iterable[typing.Text]:
        """Returns channels (displacement in x, y, z direction)."""
        return self._channels

    def get_num_channels(self) -> int:
        """Returns number of channels."""
        return self._num_channels

    def get_observation_datatype(self) -> _DATATYPE_LIST:
        """See base class."""
        return self._datatype

    def _get_observation(self) -> _ARRAY:
        """See base class."""
        dx, dy, dz = self._current_base_position - self._last_base_position
        if self._convert_to_local_frame:
            dx_local = np.cos(self._last_yaw) * dx + np.sin(self._last_yaw) * dy
            dy_local = -np.sin(self._last_yaw) * dx + np.cos(self._last_yaw) * dy
            return np.array([dx_local, dy_local, dz])
        else:
            return np.array([dx, dy, dz])

    def on_reset(self, env):
        """See base class."""
        self._current_base_position = np.array(self._robot.GetBasePosition())
        self._last_base_position = np.array(self._robot.GetBasePosition())
        self._current_yaw = self._robot.GetBaseRollPitchYaw()[2]
        self._last_yaw = self._robot.GetBaseRollPitchYaw()[2]

    def on_step(self, env):
        """See base class."""
        self._last_base_position = self._current_base_position
        self._current_base_position = np.array(self._robot.GetBasePosition())
        self._last_yaw = self._current_yaw
        self._current_yaw = self._robot.GetBaseRollPitchYaw()[2]


class BasePositionSensor(sensor.BoxSpaceSensor):
    """A sensor that reads the base position of the Minitaur robot."""

    def __init__(self,
                 lower_bound: _FLOAT_OR_ARRAY = -100,
                 upper_bound: _FLOAT_OR_ARRAY = 100,
                 name: typing.Text = "BasePosition",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs BasePositionSensor.

    Args:
      lower_bound: the lower bound of the base position of the robot.
      upper_bound: the upper bound of the base position of the robot.
      name: the name of the sensor
      dtype: data type of sensor value
    """
        super(BasePositionSensor, self).__init__(
            name=name,
            shape=(3,),  # x, y, z
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        return self._robot.GetBasePosition()


class PoseSensor(sensor.BoxSpaceSensor):
    """A sensor that reads the (x, y, theta) of a robot."""

    def __init__(self,
                 lower_bound: _FLOAT_OR_ARRAY = -100,
                 upper_bound: _FLOAT_OR_ARRAY = 100,
                 name: typing.Text = "PoseSensor",
                 dtype: typing.Type[typing.Any] = np.float64) -> None:
        """Constructs PoseSensor.

    Args:
      lower_bound: the lower bound of the pose of the robot.
      upper_bound: the upper bound of the pose of the robot.
      name: the name of the sensor.
      dtype: data type of sensor value.
    """
        super(PoseSensor, self).__init__(
            name=name,
            shape=(3,),  # x, y, orientation
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            dtype=dtype)

    def _get_observation(self) -> _ARRAY:
        return np.concatenate((self._robot.GetBasePosition()[:2],
                               (self._robot.GetTrueBaseRollPitchYaw()[2],)))
