# coding=utf-8
# Copyright 2020 The Real-World RL Suite Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Class to accumulate statistics during runs."""
import collections
import copy

import numpy as np


class StatisticsAccumulator(object):
  """Acumulate the statistics of an environment's real-world variables.

  This class will accumulate the statistics generated by an environment
  into a local storage variable which can then be written to disk and
  used by the Evaluators class.
  """

  def __init__(self, acc_safety, acc_safety_vars, acc_multiobj, auto_acc=True):
    """A class to easily accumulate necessary statistics for evaluation.

    Args:
      acc_safety: whether we should accumulate safety statistics.
      acc_safety_vars: whether we should accumulate state variables specific to
        safety.
      acc_multiobj: whether we should accumulate multi-objective statistics.
      auto_acc: whether to automatically accumulate when 'LAST' timesteps are
        pushed.
    """
    self._acc_safety = acc_safety
    self._acc_safety_vars = acc_safety_vars
    self._acc_multiobj = acc_multiobj
    self._auto_acc = auto_acc
    self._buffer = []  # Buffer of timesteps of current episode
    self._stat_buffers = dict()

  def push(self, timestep):
    """Pushes a new timestep onto the current episode's buffer."""
    local_ts = copy.deepcopy(timestep)
    self._buffer.append(local_ts)
    if local_ts.last():
      self.accumulate()
      self.clear_buffer()

  def clear_buffer(self):
    """Clears the buffer of timesteps."""
    self._buffer = []

  def accumulate(self):
    """Accumulates statistics for the given buffer into the stats buffer."""
    if self._acc_safety:
      self._acc_safety_stats()
    if self._acc_safety_vars:
      self._acc_safety_vars_stats()
    if self._acc_multiobj:
      self._acc_multiobj_stats()
    self._acc_return_stats()

  def _acc_safety_stats(self):
    """Generates safety-related statistics."""
    ep_buffer = []
    for ts in self._buffer:
      ep_buffer.append(ts.observation['constraints'])
    constraint_array = np.array(ep_buffer)
    # Total number of each constraint
    total_violations = np.sum((~constraint_array), axis=0)
    # # violations for each step
    safety_stats = self._stat_buffers.get(
        'safety_stats',
        dict(
            total_violations=[],
            per_step_violations=np.zeros(constraint_array.shape)))
    # Accumulate the total number of violations of each constraint this episode
    safety_stats['total_violations'].append(total_violations)
    # Accumulate the number of violations at each timestep in the episode
    safety_stats['per_step_violations'] += ~constraint_array
    self._stat_buffers['safety_stats'] = safety_stats

  def _acc_safety_vars_stats(self):
    """Generates state-variable statistics to tune the safety constraints.

    This will generate a list of dict object, each describing the stats for each
    set of safety vars.
    """
    ep_stats = collections.OrderedDict()
    for key in self._buffer[0].observation['safety_vars'].keys():
      buf = np.array(
          [ts.observation['safety_vars'][key] for ts in self._buffer])
      stats = dict(
          mean=np.mean(buf, axis=0),
          std_dev=np.std(buf, axis=0),
          min=np.min(buf, axis=0),
          max=np.max(buf, axis=0))
      ep_stats[key] = stats

    safety_vars_buffer = self._stat_buffers.get('safety_vars_stats', [])
    safety_vars_buffer.append(ep_stats)  # pytype: disable=attribute-error
    self._stat_buffers['safety_vars_stats'] = safety_vars_buffer

  def _acc_multiobj_stats(self):
    """Generates multiobj-related statistics."""
    ep_buffer = []
    for ts in self._buffer:
      ep_buffer.append(ts.observation['multiobj'])
    multiobj_array = np.array(ep_buffer)
    # Total number of each constraint.
    episode_totals = np.sum(multiobj_array, axis=0)
    # Number of violations for each step.
    multiobj_stats = self._stat_buffers.get('multiobj_stats',
                                            dict(episode_totals=[]))
    # Accumulate the total number of violations of each constraint this episode.
    multiobj_stats['episode_totals'].append(episode_totals)
    # Accumulate the number of violations at each timestep in the episode.
    self._stat_buffers['multiobj_stats'] = multiobj_stats

  def _acc_return_stats(self):
    """Generates per-episode return statistics."""
    ep_buffer = []
    for ts in self._buffer:
      if not ts.first():  # Skip the first ts as it has a reward of None
        ep_buffer.append(ts.reward)
    returns_array = np.array(ep_buffer)
    # Total number of each constraint.
    episode_totals = np.sum(returns_array)
    # Number of violations for each step.
    return_stats = self._stat_buffers.get('return_stats',
                                          dict(episode_totals=[]))
    # Accumulate the total number of violations of each constraint this episode.
    return_stats['episode_totals'].append(episode_totals)
    # Accumulate the number of violations at each timestep in the episode.
    self._stat_buffers['return_stats'] = return_stats

  def to_ndarray_dict(self):
    """Convert stats buffer to ndarrays to make disk writing more efficient."""
    buffers = copy.deepcopy(self.stat_buffers)
    if 'safety_stats' in buffers:
      buffers['safety_stats']['total_violations'] = np.array(
          buffers['safety_stats']['total_violations'])
      n_episodes = buffers['safety_stats']['total_violations'].shape[0]
      buffers['safety_stats']['per_step_violations'] = np.array(
          buffers['safety_stats']['per_step_violations']) / n_episodes
    if 'multiobj_stats' in buffers:
      buffers['multiobj_stats']['episode_totals'] = np.array(
          buffers['multiobj_stats']['episode_totals'])
    if 'return_stats' in buffers:
      buffers['return_stats']['episode_totals'] = np.array(
          buffers['return_stats']['episode_totals'])
    return buffers

  @property
  def stat_buffers(self):
    return self._stat_buffers
