# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: disable=attribute-error
"""Real robot interface of A1 robot."""

from absl import logging
import math
import re
import numpy as np
import time

from locomotion.robots import laikago_pose_utils
from locomotion.robots import a1
from locomotion.robots import a1_robot_velocity_estimator
from locomotion.robots import minitaur
from locomotion.robots import robot_config
from locomotion.envs import locomotion_gym_config
from robot_interface import RobotInterface  # pytype: disable=import-error

NUM_MOTORS = 12
NUM_LEGS = 4
MOTOR_NAMES = [
    "FR_hip_joint",
    "FR_upper_joint",
    "FR_lower_joint",
    "FL_hip_joint",
    "FL_upper_joint",
    "FL_lower_joint",
    "RR_hip_joint",
    "RR_upper_joint",
    "RR_lower_joint",
    "RL_hip_joint",
    "RL_upper_joint",
    "RL_lower_joint",
]
INIT_RACK_POSITION = [0, 0, 1]
INIT_POSITION = [0, 0, 0.48]
JOINT_DIRECTIONS = np.ones(12)
HIP_JOINT_OFFSET = 0.0
UPPER_LEG_JOINT_OFFSET = 0.0
KNEE_JOINT_OFFSET = 0.0
DOFS_PER_LEG = 3
JOINT_OFFSETS = np.array(
    [HIP_JOINT_OFFSET, UPPER_LEG_JOINT_OFFSET, KNEE_JOINT_OFFSET] * 4)
PI = math.pi

MAX_MOTOR_ANGLE_CHANGE_PER_STEP = 0.2
_DEFAULT_HIP_POSITIONS = (
    (0.17, -0.135, 0),
    (0.17, 0.13, 0),
    (-0.195, -0.135, 0),
    (-0.195, 0.13, 0),
)

ABDUCTION_P_GAIN = 100.0
ABDUCTION_D_GAIN = 1.0
HIP_P_GAIN = 100.0
HIP_D_GAIN = 2.0
KNEE_P_GAIN = 100.0
KNEE_D_GAIN = 2.0

COMMAND_CHANNEL_NAME = 'LCM_Low_Cmd'
STATE_CHANNEL_NAME = 'LCM_Low_State'

# Bases on the readings from Laikago's default pose.
INIT_MOTOR_ANGLES = np.array([
    laikago_pose_utils.LAIKAGO_DEFAULT_ABDUCTION_ANGLE,
    laikago_pose_utils.LAIKAGO_DEFAULT_HIP_ANGLE,
    laikago_pose_utils.LAIKAGO_DEFAULT_KNEE_ANGLE
] * NUM_LEGS)

HIP_NAME_PATTERN = re.compile(r"\w+_hip_\w+")
UPPER_NAME_PATTERN = re.compile(r"\w+_upper_\w+")
LOWER_NAME_PATTERN = re.compile(r"\w+_lower_\w+")
TOE_NAME_PATTERN = re.compile(r"\w+_toe\d*")
IMU_NAME_PATTERN = re.compile(r"imu\d*")

URDF_FILENAME = "a1/a1.urdf"

_BODY_B_FIELD_NUMBER = 2
_LINK_A_FIELD_NUMBER = 3


class A1Robot(a1.A1):
  """Interface for real A1 robot."""
  MPC_BODY_MASS = 108 / 9.8
  MPC_BODY_INERTIA = np.array((0.24, 0, 0, 0, 0.80, 0, 0, 0, 1.00))

  MPC_BODY_HEIGHT = 0.24
  ACTION_CONFIG = [
      locomotion_gym_config.ScalarField(name="FR_hip_motor",
                                        upper_bound=0.802851455917,
                                        lower_bound=-0.802851455917),
      locomotion_gym_config.ScalarField(name="FR_upper_joint",
                                        upper_bound=4.18879020479,
                                        lower_bound=-1.0471975512),
      locomotion_gym_config.ScalarField(name="FR_lower_joint",
                                        upper_bound=-0.916297857297,
                                        lower_bound=-2.69653369433),
      locomotion_gym_config.ScalarField(name="FL_hip_motor",
                                        upper_bound=0.802851455917,
                                        lower_bound=-0.802851455917),
      locomotion_gym_config.ScalarField(name="FL_upper_joint",
                                        upper_bound=4.18879020479,
                                        lower_bound=-1.0471975512),
      locomotion_gym_config.ScalarField(name="FL_lower_joint",
                                        upper_bound=-0.916297857297,
                                        lower_bound=-2.69653369433),
      locomotion_gym_config.ScalarField(name="RR_hip_motor",
                                        upper_bound=0.802851455917,
                                        lower_bound=-0.802851455917),
      locomotion_gym_config.ScalarField(name="RR_upper_joint",
                                        upper_bound=4.18879020479,
                                        lower_bound=-1.0471975512),
      locomotion_gym_config.ScalarField(name="RR_lower_joint",
                                        upper_bound=-0.916297857297,
                                        lower_bound=-2.69653369433),
      locomotion_gym_config.ScalarField(name="RL_hip_motor",
                                        upper_bound=0.802851455917,
                                        lower_bound=-0.802851455917),
      locomotion_gym_config.ScalarField(name="RL_upper_joint",
                                        upper_bound=4.18879020479,
                                        lower_bound=-1.0471975512),
      locomotion_gym_config.ScalarField(name="RL_lower_joint",
                                        upper_bound=-0.916297857297,
                                        lower_bound=-2.69653369433),
  ]

  def __init__(self, pybullet_client, time_step=0.001, **kwargs):
    """Initializes the robot class."""
    # Initialize pd gain vector
    self.motor_kps = np.array([ABDUCTION_P_GAIN, HIP_P_GAIN, KNEE_P_GAIN] * 4)
    self.motor_kds = np.array([ABDUCTION_D_GAIN, HIP_D_GAIN, KNEE_D_GAIN] * 4)
    self._pybullet_client = pybullet_client
    self.time_step = time_step

    # Robot state variables
    self._init_complete = False
    self._base_orientation = None
    self._raw_state = None
    self._last_raw_state = None
    self._motor_angles = np.zeros(12)
    self._motor_velocities = np.zeros(12)
    self._joint_states = None
    self._last_reset_time = time.time()
    self._velocity_estimator = a1_robot_velocity_estimator.VelocityEstimator(
        self)

    # Initiate UDP for robot state and actions
    self._robot_interface = RobotInterface()
    self._robot_interface.send_command(np.zeros(60, dtype=np.float32))

    kwargs['on_rack'] = True
    super(A1Robot, self).__init__(pybullet_client,
                                  time_step=time_step,
                                  **kwargs)
    self._init_complete = True
    self._last_step_time_wall = 0

  def ReceiveObservation(self):
    """Receives observation from robot.

    Synchronous ReceiveObservation is not supported in A1,
    so changging it to noop instead.
    """
    state = self._robot_interface.receive_observation()
    self._raw_state = state
    # Convert quaternion from wxyz to xyzw, which is default for Pybullet.
    q = state.imu.quaternion
    self._base_orientation = np.array([q[1], q[2], q[3], q[0]])
    rpy = state.imu.rpy
    self._base_orientation_rpy = np.array([rpy[0], rpy[1], rpy[2]])
    self._motor_angles = np.array([motor.q for motor in state.motorState[:12]])
    self._motor_velocities = np.array(
        [motor.dq for motor in state.motorState[:12]])
    self._joint_states = np.array(
        list(zip(self._motor_angles, self._motor_velocities)))
    if self._init_complete:
      # self._SetRobotStateInSim(self._motor_angles, self._motor_velocities)
      self._velocity_estimator.update(self._raw_state)

  def _SetRobotStateInSim(self, motor_angles, motor_velocities):
    self._pybullet_client.resetBasePositionAndOrientation(
        self.quadruped, self.GetBasePosition(), self.GetBaseOrientation())
    for i, motor_id in enumerate(self._motor_id_list):
      self._pybullet_client.resetJointState(self.quadruped, motor_id,
                                            motor_angles[i],
                                            motor_velocities[i])

  def GetTrueMotorAngles(self):
    return self._motor_angles.copy()

  def GetMotorAngles(self):
    return minitaur.MapToMinusPiToPi(self._motor_angles).copy()

  def GetMotorVelocities(self):
    return self._motor_velocities.copy()

  def GetBasePosition(self):
    return self._pybullet_client.getBasePositionAndOrientation(
        self.quadruped)[0]

  def GetBaseRollPitchYaw(self):
    return self._base_orientation_rpy

  def GetTrueBaseRollPitchYaw(self):
    return self._base_orientation_rpy

  def GetBaseRollPitchYawRate(self):
    return self.GetTrueBaseRollPitchYawRate()

  def GetTrueBaseRollPitchYawRate(self):
    return np.array(self._raw_state.imu.gyroscope).copy()

  def GetBaseVelocity(self):
    return self._velocity_estimator.estimated_velocity.copy()

  def GetFootContacts(self):
    return np.array(self._raw_state.footForce) > 20

  def GetTimeSinceReset(self):
    return time.time() - self._last_reset_time

  def GetBaseOrientation(self):
    return self._base_orientation.copy()

  @property
  def motor_velocities(self):
    return self._motor_velocities.copy()

  def ApplyAction(self, motor_commands, motor_control_mode=None):
    """Clips and then apply the motor commands using the motor model.

    Args:
      motor_commands: np.array. Can be motor angles, torques, hybrid commands,
        or motor pwms (for Minitaur only).
      motor_control_mode: A MotorControlMode enum.
    """
    if motor_control_mode is None:
      motor_control_mode = self._motor_control_mode

    command = np.zeros(60, dtype=np.float32)
    if motor_control_mode == robot_config.MotorControlMode.POSITION:
      for motor_id in range(NUM_MOTORS):
        command[motor_id * 5] = motor_commands[motor_id]
        command[motor_id * 5 + 1] = self.motor_kps[motor_id]
        command[motor_id * 5 + 3] = self.motor_kds[motor_id]
    elif motor_control_mode == robot_config.MotorControlMode.TORQUE:
      for motor_id in range(NUM_MOTORS):
        command[motor_id * 5 + 4] = motor_commands[motor_id]
    elif motor_control_mode == robot_config.MotorControlMode.HYBRID:
      command = np.array(motor_commands, dtype=np.float32)
    else:
      raise ValueError('Unknown motor control mode for A1 robot: {}.'.format(
          motor_control_mode))

    self._robot_interface.send_command(command)

  def Reset(self, reload_urdf=True, default_motor_angles=None, reset_time=3.0):
    """Reset the robot to default motor angles."""
    super(A1Robot, self).Reset(reload_urdf=reload_urdf,
                               default_motor_angles=default_motor_angles,
                               reset_time=-1)
    logging.warning(
        "About to reset the robot, make sure the robot is hang-up.")

    # M: this should be put at the beginning. then super.Reset() equals to the following...
    if default_motor_angles is None:
      default_motor_angles = a1.INIT_MOTOR_ANGLES

    current_motor_angles = self.GetMotorAngles()

    # Stand up in 1.5 seconds, and keep the behavior in this way.
    if reset_time < 0:
      standup_time = -1
    else:
      standup_time = max(reset_time, 2.0)

    print('\nReal Robot Resetting...')
    print(f'- Current Angle: {current_motor_angles}\n- Target Angle: {default_motor_angles}\n- Reset Time: {standup_time}\n')

    # M: change reset time to standup_time
    for t in np.arange(0, standup_time, self.time_step * self._action_repeat):
    # for t in np.arange(0, reset_time, self.time_step * self._action_repeat):

      blend_ratio = min(t / standup_time, 1)
      action = blend_ratio * default_motor_angles + (
          1 - blend_ratio) * current_motor_angles
      self.Step(action, robot_config.MotorControlMode.POSITION)
      # M: should sleep in the StepInternal function
      # time.sleep(self.time_step * self._action_repeat)

    if self._enable_action_filter:
      self._ResetActionFilter()

    self._velocity_estimator.reset()
    self._state_action_counter = 0
    self._step_counter = 0
    self._last_reset_time = time.time()

  def Terminate(self):
    self._is_alive = False

  def _StepInternal(self, action, motor_control_mode=None):
    self._last_step_time_wall = time.time()
    self.ApplyAction(action, motor_control_mode)
    # M: sleep here for real robot to imitate simulation
    self._Nap()
    self.ReceiveObservation()
    self._state_action_counter += 1

  def _Nap(self):
    """Sleep for the remainder of self.time_step."""
    now = time.time()
    sleep_time = self.time_step - (now - self._last_step_time_wall)
    if sleep_time >= 0:
      time.sleep(sleep_time)

