# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import inspect
import gin

# from mpi4py import MPI

currentdir = os.path.dirname(
    os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(os.path.dirname(currentdir))
os.sys.path.insert(0, parentdir)

from a2perf.domains.quadruped_locomotion.motion_imitation.envs import \
  locomotion_gym_env
from a2perf.domains.quadruped_locomotion.motion_imitation.envs import \
  locomotion_gym_config
from a2perf.domains.quadruped_locomotion.motion_imitation.envs.env_wrappers import \
  imitation_wrapper_env
from a2perf.domains.quadruped_locomotion.motion_imitation.envs.env_wrappers import \
  observation_dictionary_to_array_wrapper
from a2perf.domains.quadruped_locomotion.motion_imitation.envs.env_wrappers import \
  observation_dictionary_to_array_wrapper as obs_dict_to_array_wrapper
from a2perf.domains.quadruped_locomotion.motion_imitation.envs.env_wrappers import \
  trajectory_generator_wrapper_env
from a2perf.domains.quadruped_locomotion.motion_imitation.envs.env_wrappers import \
  simple_openloop
from a2perf.domains.quadruped_locomotion.motion_imitation.envs.env_wrappers import \
  simple_forward_task
from a2perf.domains.quadruped_locomotion.motion_imitation.envs.env_wrappers import \
  imitation_task
from a2perf.domains.quadruped_locomotion.motion_imitation.envs.env_wrappers import \
  default_task

from a2perf.domains.quadruped_locomotion.motion_imitation.envs.sensors import \
  environment_sensors
from a2perf.domains.quadruped_locomotion.motion_imitation.envs.sensors import \
  sensor_wrappers
from motion_imitation.envs.sensors import robot_sensors
from motion_imitation.envs.utilities import \
  controllable_env_randomizer_from_config
from motion_imitation.robots import laikago
from motion_imitation.robots import a1
from motion_imitation.robots import robot_config

ENABLE_ENV_RANDOMIZER = True


def build_laikago_env(motor_control_mode, enable_rendering):
  sim_params = locomotion_gym_config.SimulationParameters()
  sim_params.enable_rendering = enable_rendering
  sim_params.motor_control_mode = motor_control_mode
  sim_params.reset_time = 2
  sim_params.num_action_repeat = 10
  sim_params.enable_action_interpolation = False
  sim_params.enable_action_filter = False
  sim_params.enable_clip_motor_commands = False

  gym_config = locomotion_gym_config.LocomotionGymConfig(
      simulation_parameters=sim_params)

  robot_class = laikago.Laikago

  sensors = [
      robot_sensors.MotorAngleSensor(num_motors=laikago.NUM_MOTORS),
      robot_sensors.IMUSensor(),
      environment_sensors.LastActionSensor(num_actions=laikago.NUM_MOTORS)
  ]

  task = default_task.DefaultTask()

  env = locomotion_gym_env.LocomotionGymEnv(gym_config=gym_config,
                                            robot_class=robot_class,
                                            robot_sensors=sensors, task=task)

  # env = observation_dictionary_to_array_wrapper.ObservationDictionaryToArrayWrapper(env)
  # env = trajectory_generator_wrapper_env.TrajectoryGeneratorWrapperEnv(env,
  #                                                                     trajectory_generator=simple_openloop.LaikagoPoseOffsetGenerator(action_limit=laikago.UPPER_BOUND))

  return env


@gin.configurable
def build_imitation_env(motion_files, mode, enable_rendering,
    num_parallel_envs,
    enable_randomizer=None,
    robot_class=laikago.Laikago,
    trajectory_generator=simple_openloop.LaikagoPoseOffsetGenerator(
        action_limit=laikago.UPPER_BOUND)):
  assert len(motion_files) > 0

  if enable_randomizer is None:
    enable_randomizer = ENABLE_ENV_RANDOMIZER and (mode != "test")

  curriculum_episode_length_start = 20
  curriculum_episode_length_end = 600

  sim_params = locomotion_gym_config.SimulationParameters()
  sim_params.enable_rendering = enable_rendering
  sim_params.allow_knee_contact = True
  sim_params.motor_control_mode = robot_config.MotorControlMode.POSITION

  gym_config = locomotion_gym_config.LocomotionGymConfig(
      simulation_parameters=sim_params)

  sensors = [
      sensor_wrappers.HistoricSensorWrapper(
          wrapped_sensor=robot_sensors.MotorAngleSensor(
              num_motors=laikago.NUM_MOTORS), num_history=3),
      sensor_wrappers.HistoricSensorWrapper(
          wrapped_sensor=robot_sensors.IMUSensor(), num_history=3),
      sensor_wrappers.HistoricSensorWrapper(
          wrapped_sensor=environment_sensors.LastActionSensor(
              num_actions=laikago.NUM_MOTORS), num_history=3)
  ]

  task = imitation_task.ImitationTask(ref_motion_filenames=motion_files,
                                      enable_cycle_sync=True,
                                      tar_frame_steps=[1, 2, 10, 30],
                                      ref_state_init_prob=0.9,
                                      warmup_time=0.25)

  randomizers = []
  if enable_randomizer:
    randomizer = controllable_env_randomizer_from_config.ControllableEnvRandomizerFromConfig(
        verbose=False)
    randomizers.append(randomizer)

  env = locomotion_gym_env.LocomotionGymEnv(gym_config=gym_config,
                                            robot_class=robot_class,
                                            env_randomizers=randomizers,
                                            robot_sensors=sensors, task=task)

  env = observation_dictionary_to_array_wrapper.ObservationDictionaryToArrayWrapper(
      env)
  env = trajectory_generator_wrapper_env.TrajectoryGeneratorWrapperEnv(env,
                                                                       trajectory_generator=trajectory_generator)

  if mode == "test":
    curriculum_episode_length_start = curriculum_episode_length_end

  env = imitation_wrapper_env.ImitationWrapperEnv(env,
                                                  episode_length_start=curriculum_episode_length_start,
                                                  episode_length_end=curriculum_episode_length_end,
                                                  curriculum_steps=30000000,
                                                  num_parallel_envs=num_parallel_envs)
  return env


def build_regular_env(robot_class,
    motor_control_mode,
    enable_rendering=False,
    on_rack=False,
    action_limit=(0.75, 0.75, 0.75),
    wrap_trajectory_generator=True):
  sim_params = locomotion_gym_config.SimulationParameters()
  sim_params.enable_rendering = enable_rendering
  sim_params.motor_control_mode = motor_control_mode
  sim_params.reset_time = 2
  sim_params.num_action_repeat = 10
  sim_params.enable_action_interpolation = False
  sim_params.enable_action_filter = False
  sim_params.enable_clip_motor_commands = False
  sim_params.robot_on_rack = on_rack

  gym_config = locomotion_gym_config.LocomotionGymConfig(
      simulation_parameters=sim_params)

  sensors = [
      robot_sensors.BaseDisplacementSensor(),
      robot_sensors.IMUSensor(),
      robot_sensors.MotorAngleSensor(num_motors=a1.NUM_MOTORS),
  ]

  task = simple_forward_task.SimpleForwardTask()

  env = locomotion_gym_env.LocomotionGymEnv(gym_config=gym_config,
                                            robot_class=robot_class,
                                            robot_sensors=sensors,
                                            task=task)

  env = obs_dict_to_array_wrapper.ObservationDictionaryToArrayWrapper(
      env)
  if (motor_control_mode
      == robot_config.MotorControlMode.POSITION) and wrap_trajectory_generator:
    if robot_class == laikago.Laikago:
      env = trajectory_generator_wrapper_env.TrajectoryGeneratorWrapperEnv(
          env,
          trajectory_generator=simple_openloop.LaikagoPoseOffsetGenerator(
              action_limit=action_limit))
    elif robot_class == a1.A1:
      env = trajectory_generator_wrapper_env.TrajectoryGeneratorWrapperEnv(
          env,
          trajectory_generator=simple_openloop.LaikagoPoseOffsetGenerator(
              action_limit=action_limit))
  return env
