# Copyright 2017 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""A dm_env.Environment subclass for control-specific environments."""

import abc
import collections
import contextlib
import dm_env
from dm_env import specs
import numpy as np

FLAT_OBSERVATION_KEY = 'observations'


class Environment(dm_env.Environment):
  """Class for physics-based reinforcement learning environments."""

  def __init__(self,
               physics,
               task,
               time_limit=float('inf'),
               control_timestep=None,
               n_sub_steps=None,
               flat_observation=False):
    """Initializes a new `Environment`.
    Args:
      physics: Instance of `Physics`.
      task: Instance of `Task`.
      time_limit: Optional `int`, maximum time for each episode in seconds. By
        default this is set to infinite.
      control_timestep: Optional control time-step, in seconds.
      n_sub_steps: Optional number of physical time-steps in one control
        time-step, aka "action repeats". Can only be supplied if
        `control_timestep` is not specified.
      flat_observation: If True, observations will be flattened and concatenated
        into a single numpy array.
    Raises:
      ValueError: If both `n_sub_steps` and `control_timestep` are supplied.
    """
    self._task = task
    self._physics = physics
    self._flat_observation = flat_observation

    if n_sub_steps is not None and control_timestep is not None:
      raise ValueError('Both n_sub_steps and control_timestep were supplied.')
    elif n_sub_steps is not None:
      self._n_sub_steps = n_sub_steps
    elif control_timestep is not None:
      self._n_sub_steps = compute_n_steps(control_timestep,
                                          self._physics.timestep())
    else:
      self._n_sub_steps = 1

    if time_limit == float('inf'):
      self._step_limit = float('inf')
    else:
      self._step_limit = time_limit / (
          self._physics.timestep() * self._n_sub_steps)
    self._step_count = 0
    self._reset_next_step = True

  def reset(self):
    """Starts a new episode and returns the first `TimeStep`."""
    self._reset_next_step = False
    self._step_count = 0
    with self._physics.reset_context():
      self._task.initialize_episode(self._physics)

    observation = self._task.get_observation(self._physics)
    if self._flat_observation:
      observation = flatten_observation(observation)

    return dm_env.TimeStep(
        step_type=dm_env.StepType.FIRST,
        reward=None,
        discount=None,
        observation=observation)

  def step(self, action):
    """Updates the environment using the action and returns a `TimeStep`."""

    if self._reset_next_step:
      return self.reset()

    self._task.before_step(action, self._physics)
    for _ in range(self._n_sub_steps):
      self._physics.step()
    self._task.after_step(self._physics)

    reward = self._task.get_reward(self._physics)
    observation = self._task.get_observation(self._physics)
    if self._flat_observation:
      observation = flatten_observation(observation)

    self._step_count += 1
    if self._step_count >= self._step_limit:
      discount = 1.0
    else:
      discount = self._task.get_termination(self._physics)

    episode_over = discount is not None

    if episode_over:
      self._reset_next_step = True
      return dm_env.TimeStep(
          dm_env.StepType.LAST, reward, discount, observation)
    else:
      return dm_env.TimeStep(dm_env.StepType.MID, reward, 1.0, observation)

  def action_spec(self):
    """Returns the action specification for this environment."""
    return self._task.action_spec(self._physics)

  def step_spec(self):
    """May return a specification for the values returned by `step`."""
    return self._task.step_spec(self._physics)

  def observation_spec(self):
    """Returns the observation specification for this environment.
    Infers the spec from the observation, unless the Task implements the
    `observation_spec` method.
    Returns:
      An dict mapping observation name to `ArraySpec` containing observation
      shape and dtype.
    """
    try:
      return self._task.observation_spec(self._physics)
    except NotImplementedError:
      observation = self._task.get_observation(self._physics)
      if self._flat_observation:
        observation = flatten_observation(observation)
      return _spec_from_observation(observation)

  @property
  def physics(self):
    return self._physics

  @property
  def task(self):
    return self._task

  def control_timestep(self):
    """Returns the interval between agent actions in seconds."""
    return self.physics.timestep() * self._n_sub_steps


def compute_n_steps(control_timestep, physics_timestep, tolerance=1e-8):
  """Returns the number of physics timesteps in a single control timestep.
  Args:
    control_timestep: Control time-step, should be an integer multiple of the
      physics timestep.
    physics_timestep: The time-step of the physics simulation.
    tolerance: Optional tolerance value for checking if `physics_timestep`
      divides `control_timestep`.
  Returns:
    The number of physics timesteps in a single control timestep.
  Raises:
    ValueError: If `control_timestep` is smaller than `physics_timestep` or if
      `control_timestep` is not an integer multiple of `physics_timestep`.
  """
  if control_timestep < physics_timestep:
    raise ValueError(
        'Control timestep ({}) cannot be smaller than physics timestep ({}).'.
        format(control_timestep, physics_timestep))
  if abs((control_timestep / physics_timestep - round(
      control_timestep / physics_timestep))) > tolerance:
    raise ValueError(
        'Control timestep ({}) must be an integer multiple of physics timestep '
        '({})'.format(control_timestep, physics_timestep))
  return int(round(control_timestep / physics_timestep))


def _spec_from_observation(observation):
  result = collections.OrderedDict()
  for key, value in observation.items():
    result[key] = specs.Array(value.shape, value.dtype, name=key)
  return result

# Base class definitions for objects supplied to Environment.


class Physics(metaclass=abc.ABCMeta):
  """Simulates a physical environment."""

  @abc.abstractmethod
  def step(self, n_sub_steps=1):
    """Updates the simulation state.
    Args:
      n_sub_steps: Optional number of times to repeatedly update the simulation
        state. Defaults to 1.
    """

  @abc.abstractmethod
  def time(self):
    """Returns the elapsed simulation time in seconds."""

  @abc.abstractmethod
  def timestep(self):
    """Returns the simulation timestep."""

  def set_control(self, control):
    """Sets the control signal for the actuators."""
    raise NotImplementedError('set_control is not supported.')

  @contextlib.contextmanager
  def reset_context(self):
    """Context manager for resetting the simulation state.
    Sets the internal simulation to a default state when entering the block.
    ```python
    with physics.reset_context():
      # Set joint and object positions.
    physics.step()
    ```
    Yields:
      The `Physics` instance.
    """
    try:
      self.reset()
    except PhysicsError:
      pass
    yield self
    self.after_reset()

  @abc.abstractmethod
  def reset(self):
    """Resets internal variables of the physics simulation."""

  @abc.abstractmethod
  def after_reset(self):
    """Runs after resetting internal variables of the physics simulation."""

  def check_divergence(self):
    """Raises a `PhysicsError` if the simulation state is divergent.
    The default implementation is a no-op.
    """


class PhysicsError(RuntimeError):
  """Raised if the state of the physics simulation becomes divergent."""


class Task(metaclass=abc.ABCMeta):
  """Defines a task in a `control.Environment`."""

  @abc.abstractmethod
  def initialize_episode(self, physics):
    """Sets the state of the environment at the start of each episode.
    Called by `control.Environment` at the start of each episode *within*
    `physics.reset_context()` (see the documentation for `base.Physics`).
    Args:
      physics: Instance of `Physics`.
    """

  @abc.abstractmethod
  def before_step(self, action, physics):
    """Updates the task from the provided action.
    Called by `control.Environment` before stepping the physics engine.
    Args:
      action: numpy array or array-like action values, or a nested structure of
        such arrays. Should conform to the specification returned by
        `self.action_spec(physics)`.
      physics: Instance of `Physics`.
    """

  def after_step(self, physics):
    """Optional method to update the task after the physics engine has stepped.
    Called by `control.Environment` after stepping the physics engine and before
    `control.Environment` calls `get_observation, `get_reward` and
    `get_termination`.
    The default implementation is a no-op.
    Args:
      physics: Instance of `Physics`.
    """

  @abc.abstractmethod
  def action_spec(self, physics):
    """Returns a specification describing the valid actions for this task.
    Args:
      physics: Instance of `Physics`.
    Returns:
      A `BoundedArraySpec`, or a nested structure containing `BoundedArraySpec`s
      that describe the shapes, dtypes and elementwise lower and upper bounds
      for the action array(s) passed to `self.step`.
    """

  def step_spec(self, physics):
    """Returns a specification describing the time_step for this task.
    Args:
      physics: Instance of `Physics`.
    Returns:
      A `BoundedArraySpec`, or a nested structure containing `BoundedArraySpec`s
      that describe the shapes, dtypes and elementwise lower and upper bounds
      for the array(s) returned by `self.step`.
    """
    raise NotImplementedError()

  @abc.abstractmethod
  def get_observation(self, physics):
    """Returns an observation from the environment.
    Args:
      physics: Instance of `Physics`.
    """

  @abc.abstractmethod
  def get_reward(self, physics):
    """Returns a reward from the environment.
    Args:
      physics: Instance of `Physics`.
    """

  def get_termination(self, physics):
    """If the episode should end, returns a final discount, otherwise None."""

  def observation_spec(self, physics):
    """Optional method that returns the observation spec.
    If not implemented, the Environment infers the spec from the observation.
    Args:
      physics: Instance of `Physics`.
    Returns:
      A dict mapping observation name to `ArraySpec` containing observation
      shape and dtype.
    """
    raise NotImplementedError()


def flatten_observation(observation, output_key=FLAT_OBSERVATION_KEY):
  """Flattens multiple observation arrays into a single numpy array.
  Args:
    observation: A mutable mapping from observation names to numpy arrays.
    output_key: The key for the flattened observation array in the output.
  Returns:
    A mutable mapping of the same type as `observation`. This will contain a
    single key-value pair consisting of `output_key` and the flattened
    and concatenated observation array.
  Raises:
    ValueError: If `observation` is not a `collections.abc.MutableMapping`.
  """
  if not isinstance(observation, collections.abc.MutableMapping):
    raise ValueError('Can only flatten dict-like observations.')

  if isinstance(observation, collections.OrderedDict):
    keys = observation.keys()
  else:
    # Keep a consistent ordering for other mappings.
    keys = sorted(observation.keys())

  observation_arrays = [observation[key].ravel() for key in keys]
  return type(observation)([(output_key, np.concatenate(observation_arrays))])