"""Gait pattern planning module."""

from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import os
import inspect

currentdir = os.path.dirname(os.path.abspath(
    inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(os.path.dirname(currentdir))
os.sys.path.insert(0, parentdir)

import logging
import math

import numpy as np
from typing import Any, Sequence
from runtime.utils.mpc.lib import gait_generator


LAIKAGO_TROTTING = (
    gait_generator.LegState.SWING,
    gait_generator.LegState.STANCE,
    gait_generator.LegState.STANCE,
    gait_generator.LegState.SWING,
)

_NOMINAL_STANCE_DURATION = (0.3, 0.3, 0.3, 0.3)
_NOMINAL_DUTY_FACTOR = (0.5, 0.5, 0.5, 0.5)
_NOMINAL_CONTACT_DETECTION_PHASE = 0.1


class OpenloopGaitGenerator(gait_generator.GaitGenerator):
    """Generates openloop gaits for quadruped robots.

  A flexible open-loop gait generator. Each leg has its own cycle and duty
  factor. And the state of each leg alternates between stance and swing. One can
  easily formuate a set of common quadruped gaits like trotting, pacing,
  pronking, bounding, etc by tweaking the input parameters.
  """

    def __init__(
            self,
            robot: Any,
            stance_duration: Sequence[float] = _NOMINAL_STANCE_DURATION,
            duty_factor: Sequence[float] = _NOMINAL_DUTY_FACTOR,
            initial_leg_state: Sequence[gait_generator.LegState] = LAIKAGO_TROTTING,
            initial_leg_phase: Sequence[float] = (0, 0, 0, 0),
            contact_detection_phase_threshold:
            float = _NOMINAL_CONTACT_DETECTION_PHASE,
    ):
        """Initializes the class.

    Args:
      robot: A quadruped robot that at least implements the GetFootContacts API
        and num_legs property.
      stance_duration: The desired stance duration.
      duty_factor: The ratio  stance_duration / total_gait_cycle.
      initial_leg_state: The desired initial swing/stance state of legs indexed
        by their id.
      initial_leg_phase: The desired initial phase [0, 1] of the legs within the
        full swing + stance cycle.
      contact_detection_phase_threshold: Updates the state of each leg based on
        contact info, when the current normalized phase is greater than this
        threshold. This is essential to remove false positives in contact
        detection when phase switches. For example, a swing foot at at the
        beginning of the gait cycle might be still on the ground.
    """
        self._robot = robot
        self._stance_duration = stance_duration
        self._duty_factor = duty_factor
        self._swing_duration = np.array(stance_duration) / np.array(
            duty_factor) - np.array(stance_duration)
        if len(initial_leg_phase) != self._robot.num_legs:
            raise ValueError(
                "The number of leg phases should be the same as number of legs.")
        self._initial_leg_phase = initial_leg_phase
        if len(initial_leg_state) != self._robot.num_legs:
            raise ValueError(
                "The number of leg states should be the same of number of legs.")
        self._initial_leg_state = initial_leg_state
        self._next_leg_state = []
        # The ratio in cycle is duty factor if initial state of the leg is STANCE,
        # and 1 - duty_factory if the initial state of the leg is SWING.
        self._initial_state_ratio_in_cycle = []
        for state, duty in zip(initial_leg_state, duty_factor):
            if state == gait_generator.LegState.SWING:
                self._initial_state_ratio_in_cycle.append(1 - duty)
                self._next_leg_state.append(gait_generator.LegState.STANCE)
            else:
                self._initial_state_ratio_in_cycle.append(duty)
                self._next_leg_state.append(gait_generator.LegState.SWING)

        self._contact_detection_phase_threshold = contact_detection_phase_threshold

        # The normalized phase within swing or stance duration.
        self._normalized_phase = None
        self._leg_state = None
        self._desired_leg_state = None

        self.reset(0)

    def reset(self, current_time):
        # The normalized phase within swing or stance duration.
        self._normalized_phase = np.zeros(self._robot.num_legs)
        self._leg_state = list(self._initial_leg_state)
        self._desired_leg_state = list(self._initial_leg_state)

    @property
    def desired_leg_state(self) -> Sequence[gait_generator.LegState]:
        """The desired leg SWING/STANCE states.

    Returns:
      The SWING/STANCE states for all legs.

    """
        return self._desired_leg_state

    @property
    def leg_state(self) -> Sequence[gait_generator.LegState]:
        """The leg state after considering contact with ground.

    Returns:
      The actual state of each leg after accounting for contacts.
    """
        return self._leg_state

    @property
    def swing_duration(self) -> Sequence[float]:
        return self._swing_duration

    @property
    def stance_duration(self) -> Sequence[float]:
        return self._stance_duration

    @property
    def normalized_phase(self) -> Sequence[float]:
        """The phase within the current swing or stance cycle.

    Reflects the leg's phase within the curren swing or stance stage. For
    example, at the end of the current swing duration, the phase will
    be set to 1 for all swing legs. Same for stance legs.

    Returns:
      Normalized leg phase for all legs.

    """
        return self._normalized_phase

    def update(self, current_time):
        contact_state = self._robot.GetFootContacts()
        for leg_id in range(self._robot.num_legs):
            # Here is the explanation behind this logic: We use the phase within the
            # full swing/stance cycle to determine if a swing/stance switch occurs
            # for a leg. The threshold value is the "initial_state_ratio_in_cycle" as
            # explained before. If the current phase is less than the initial state
            # ratio, the leg is either in the initial state or has switched back after
            # one or more full cycles.
            full_cycle_period = (self._stance_duration[leg_id] /
                                 self._duty_factor[leg_id])
            # To account for the non-zero initial phase, we offset the time duration
            # with the effect time contribution from the initial leg phase.
            augmented_time = current_time + self._initial_leg_phase[
                leg_id] * full_cycle_period
            phase_in_full_cycle = math.fmod(augmented_time,
                                            full_cycle_period) / full_cycle_period
            ratio = self._initial_state_ratio_in_cycle[leg_id]
            if phase_in_full_cycle < ratio:
                self._desired_leg_state[leg_id] = self._initial_leg_state[leg_id]
                self._normalized_phase[leg_id] = phase_in_full_cycle / ratio
            else:
                # A phase switch happens for this leg.
                self._desired_leg_state[leg_id] = self._next_leg_state[leg_id]
                self._normalized_phase[leg_id] = (phase_in_full_cycle -
                                                  ratio) / (1 - ratio)

            self._leg_state[leg_id] = self._desired_leg_state[leg_id]

            # No contact detection at the beginning of each SWING/STANCE phase.
            if (self._normalized_phase[leg_id] <
                    self._contact_detection_phase_threshold):
                continue

            if (self._leg_state[leg_id] == gait_generator.LegState.SWING
                    and contact_state[leg_id]):
                logging.info("early touch down detected.")
                self._leg_state[leg_id] = gait_generator.LegState.EARLY_CONTACT
            if (self._leg_state[leg_id] == gait_generator.LegState.STANCE
                    and not contact_state[leg_id]):
                logging.info("lost contact detected.")
                self._leg_state[leg_id] = gait_generator.LegState.LOSE_CONTACT
