# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions to manipulate environment."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

import numpy as np
from gymnasium import spaces

from a2perf.domains.quadruped_locomotion.motion_imitation.envs import gym_spaces


def flatten_observations(observation_dict, observation_excluded=()):
  """Flattens the observation dictionary to an array.

  If observation_excluded is passed in, it will still return a dictionary,
  which includes all the (key, observation_dict[key]) in observation_excluded,
  and ('other': the flattened array).

  Args:
    observation_dict: A dictionary of all the observations.
    observation_excluded: A list/tuple of all the keys of the observations to be
      ignored during flattening.

  Returns:
    An array or a dictionary of observations based on whether
      observation_excluded is empty.
  """
  if not isinstance(observation_excluded, (list, tuple)):
    observation_excluded = [observation_excluded]
  observations = []
  for key, value in observation_dict.items():
    if key not in observation_excluded:
      observations.append(np.asarray(value).flatten())
  flat_observations = np.concatenate(observations)
  if not observation_excluded:
    return flat_observations
  else:
    observation_dict_after_flatten = {"other": flat_observations}
    for key in observation_excluded:
      observation_dict_after_flatten[key] = observation_dict[key]
    return collections.OrderedDict(
        sorted(list(observation_dict_after_flatten.items())))


def flatten_observation_spaces(observation_spaces, observation_excluded=()):
  """Flattens the dictionary observation spaces to gym.spaces.Box.

  If observation_excluded is passed in, it will still return a dictionary,
  which includes all the (key, observation_spaces[key]) in observation_excluded,
  and ('other': the flattened Box space).

  Args:
    observation_spaces: A dictionary of all the observation spaces.
    observation_excluded: A list/tuple of all the keys of the observations to be
      ignored during flattening.

  Returns:
    A box space or a dictionary of observation spaces based on whether
      observation_excluded is empty.
  """
  if not isinstance(observation_excluded, (list, tuple)):
    observation_excluded = [observation_excluded]
  lower_bound = []
  upper_bound = []
  for key, value in observation_spaces.spaces.items():
    if key not in observation_excluded:
      lower_bound.append(np.asarray(value.low).flatten())
      upper_bound.append(np.asarray(value.high).flatten())
  lower_bound = np.concatenate(lower_bound)
  upper_bound = np.concatenate(upper_bound)
  observation_space = gym_spaces.Box(
      np.array(lower_bound), np.array(upper_bound), dtype=np.float32)
  if not observation_excluded:
    return observation_space
  else:
    observation_spaces_after_flatten = {"other": observation_space}
    for key in observation_excluded:
      observation_spaces_after_flatten[key] = observation_spaces[key]
    return spaces.Dict(observation_spaces_after_flatten)
