"""The swing leg controller class."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import math

import numpy as np
from typing import Any, Mapping, Sequence, Tuple

from lib.env.locomotion.agents.whole_body_controller import gait_generator as gait_generator_lib
from lib.env.locomotion.agents.whole_body_controller import leg_controller

# The position correction coefficients in Raibert's formula.
_KP = np.array([0.01, 0.01, 0.01]) * 3
# At the end of swing, we leave a small clearance to prevent unexpected foot
# collision.
_FOOT_CLEARANCE_M = 0.01


def _gen_parabola(phase: float, start: float, mid: float, end: float) -> float:
    """Gets a point on a parabola y = a x^2 + b x + c.

  The Parabola is determined by three points (0, start), (0.5, mid), (1, end) in
  the plane.

  Args:
    phase: Normalized to [0, 1]. A point on the x-axis of the parabola.
    start: The y value at x == 0.
    mid: The y value at x == 0.5.
    end: The y value at x == 1.

  Returns:
    The y value at x == phase.
  """
    mid_phase = 0.5
    delta_1 = mid - start
    delta_2 = end - start
    delta_3 = mid_phase ** 2 - mid_phase
    coef_a = (delta_1 - delta_2 * mid_phase) / delta_3
    coef_b = (delta_2 * mid_phase ** 2 - delta_1) / delta_3
    coef_c = start

    return coef_a * phase ** 2 + coef_b * phase + coef_c


def _gen_swing_foot_trajectory(input_phase: float, start_pos: Sequence[float],
                               end_pos: Sequence[float]) -> Tuple[float]:
    """Generates the swing trajectory using a parabola.

  Args:
    input_phase: the swing/stance phase value between [0, 1].
    start_pos: The foot's position at the beginning of swing cycle.
    end_pos: The foot's desired position at the end of swing cycle.

  Returns:
    The desired foot position at the current phase.
  """
    # We augment the swing speed using the below formula. For the first half of
    # the swing cycle, the swing leg moves faster and finishes 80% of the full
    # swing trajectory. The rest 20% of trajectory takes another half swing
    # cycle. Intuitely, we want to move the swing foot quickly to the target
    # landing location and stay above the ground, in this way the control is more
    # robust to perturbations to the body that may cause the swing foot to drop
    # onto the ground earlier than expected. This is a common practice similar
    # to the MIT cheetah and Marc Raibert's original controllers.
    phase = input_phase
    if input_phase <= 0.5:
        phase = 0.8 * math.sin(input_phase * math.pi)
    else:
        phase = 0.8 + (input_phase - 0.5) * 0.4

    x = (1 - phase) * start_pos[0] + phase * end_pos[0]
    y = (1 - phase) * start_pos[1] + phase * end_pos[1]
    max_clearance = 0.1
    mid = max(end_pos[2], start_pos[2]) + max_clearance
    z = _gen_parabola(phase, start_pos[2], mid, end_pos[2])

    # PyType detects the wrong return type here.
    return (x, y, z)  # pytype: disable=bad-return-type


class RaibertSwingLegController(leg_controller.LegController):
    """Controls the swing leg position using Raibert's formula.

  For details, please refer to chapter 2 in "Legged robbots that balance" by
  Marc Raibert. The key idea is to stablize the swing foot's location based on
  the CoM moving speed.

  """

    def __init__(
            self,
            robot: Any,
            gait_generator: Any,
            state_estimator: Any,
            desired_speed: Tuple[float, float],
            desired_twisting_speed: float,
            desired_height: float,
            foot_clearance: float,
    ):
        """Initializes the class.

    Args:
      robot: A robot instance.
      gait_generator: Generates the stance/swing pattern.
      state_estimator: Estiamtes the CoM speeds.
      desired_speed: Behavior parameters. X-Y speed.
      desired_twisting_speed: Behavior control parameters.
      desired_height: Desired standing height.
      foot_clearance: The foot clearance on the ground at the end of the swing
        cycle.
    """
        self._robot = robot
        self._state_estimator = state_estimator
        self._gait_generator = gait_generator
        self._last_leg_state = gait_generator.desired_leg_state
        self.desired_speed = np.array((desired_speed[0], desired_speed[1], 0))
        self.desired_twisting_speed = desired_twisting_speed
        self._desired_height = np.array((0, 0, desired_height - foot_clearance))

        self._joint_angles = None
        self._phase_switch_foot_local_position = None
        self.reset(0)

    def reset(self, current_time: float) -> None:
        """Called during the start of a swing cycle.

    Args:
      current_time: The wall time in seconds.
    """
        del current_time
        self._last_leg_state = self._gait_generator.desired_leg_state
        self._phase_switch_foot_local_position = (
            self._robot.GetFootPositionsInBaseFrame())
        self._joint_angles = {}

    def update(self, current_time: float) -> None:
        """Called at each control step.

    Args:
      current_time: The wall time in seconds.
    """
        del current_time
        new_leg_state = self._gait_generator.desired_leg_state

        # Detects phase switch for each leg so we can remember the feet position at
        # the beginning of the swing phase.
        for leg_id, state in enumerate(new_leg_state):
            if (state == gait_generator_lib.LegState.SWING
                    and state != self._last_leg_state[leg_id]):
                self._phase_switch_foot_local_position[leg_id] = (
                    self._robot.GetFootPositionsInBaseFrame()[leg_id])

        self._last_leg_state = copy.deepcopy(new_leg_state)

    def get_action(self) -> Mapping[Any, Any]:
        com_velocity = self._state_estimator.com_velocity_body_frame
        com_velocity = np.array((com_velocity[0], com_velocity[1], 0))

        _, _, yaw_dot = self._robot.GetBaseRollPitchYawRate()
        hip_positions = self._robot.GetHipPositionsInBaseFrame()

        for leg_id, leg_state in enumerate(self._gait_generator.leg_state):
            if leg_state in (gait_generator_lib.LegState.STANCE,
                             gait_generator_lib.LegState.EARLY_CONTACT):
                continue

            # For now we did not consider the body pitch/roll and all calculation is
            # in the body frame. TODO(b/143378213): Calculate the foot_target_position
            # in world frame and then project back to calculate the joint angles.
            hip_offset = hip_positions[leg_id]
            twisting_vector = np.array((-hip_offset[1], hip_offset[0], 0))
            hip_horizontal_velocity = com_velocity + yaw_dot * twisting_vector
            # print("Leg: {}, ComVel: {}, Yaw_dot: {}".format(leg_id, com_velocity,
            #                                                 yaw_dot))
            # print(hip_horizontal_velocity)
            target_hip_horizontal_velocity = (
                    self.desired_speed + self.desired_twisting_speed * twisting_vector)
            foot_target_position = (
                                           hip_horizontal_velocity *
                                           self._gait_generator.stance_duration[leg_id] / 2 - _KP *
                                           (target_hip_horizontal_velocity - hip_horizontal_velocity)
                                   ) - self._desired_height + np.array((hip_offset[0], hip_offset[1], 0))
            foot_position = _gen_swing_foot_trajectory(
                self._gait_generator.normalized_phase[leg_id],
                self._phase_switch_foot_local_position[leg_id], foot_target_position)
            joint_ids, joint_angles = (
                self._robot.ComputeMotorAnglesFromFootLocalPosition(
                    leg_id, foot_position))
            # Update the stored joint angles as needed.
            for joint_id, joint_angle in zip(joint_ids, joint_angles):
                self._joint_angles[joint_id] = (joint_angle, leg_id)

        action = {}
        kps = self._robot.GetMotorPositionGains()
        kds = self._robot.GetMotorVelocityGains()
        for joint_id, joint_angle_leg_id in self._joint_angles.items():
            leg_id = joint_angle_leg_id[1]
            if self._gait_generator.desired_leg_state[
                leg_id] == gait_generator_lib.LegState.SWING:
                # This is a hybrid action for PD control.
                action[joint_id] = (joint_angle_leg_id[0], kps[joint_id], 0,
                                    kds[joint_id], 0)

        return action
