# coding=utf-8

"""Wrapper around a Gym environment to add curiosity reward."""

from __future__ import absolute_import
from __future__ import division

from __future__ import print_function

from episodic_curiosity import episodic_memory
from episodic_curiosity import oracle
from rlb.utils import RewardForwardFilter, RunningMeanStd, SimpleWeightedMovingScalarMeanStd
from third_party.baselines import logger
from third_party.baselines.common.vec_env import VecEnv
from third_party.baselines.common.vec_env import VecEnvWrapper
import gin
import gym
import numpy as np
import cv2

from episodic_curiosity.curiosity_env_wrapper import *

del CuriosityEnvWrapper


@gin.configurable
class RLBEnvWrapper(VecEnvWrapper):
  """Environment wrapper that adds additional curiosity reward."""

  def __init__(self,
               vec_env,
               vec_episodic_memory,
               observation_embedding_fn,
               intrinsic_reward_fn,
               rlb_image_shape,
               target_image_shape,
               exploration_reward = 'rlb',
               scale_task_reward = 1.0,
               scale_surrogate_reward = None,
               exploration_reward_min_step = 0,
               ir_normalize_type=0,
               ir_clip_low=None,
               name='',
               ):
    logger.info('RLBEnvWrapper args: {}'.format(locals()))
    if exploration_reward == 'rlb':
      if len(vec_episodic_memory) != vec_env.num_envs:
        raise ValueError('Each env must have a unique episodic memory.')

    if target_image_shape is None:
      target_image_shape = rlb_image_shape

    if self._should_process_observation(vec_env.observation_space.shape):
      observation_space_shape = target_image_shape[:]
      observation_space = gym.spaces.Box(
          low=0, high=255, shape=observation_space_shape, dtype=np.float)
    else:
      observation_space = vec_env.observation_space

    VecEnvWrapper.__init__(self, vec_env, observation_space=observation_space)

    self._vec_episodic_memory = vec_episodic_memory
    self._observation_embedding_fn = observation_embedding_fn
    self._intrinsic_reward_fn = intrinsic_reward_fn
    self._rlb_image_shape = rlb_image_shape
    self._target_image_shape = target_image_shape

    self._exploration_reward = exploration_reward
    self._scale_task_reward = scale_task_reward
    self._scale_surrogate_reward = scale_surrogate_reward
    self._exploration_reward_min_step = exploration_reward_min_step

    # Oracle reward.
    self._oracles = [oracle.OracleExplorationReward()
                     for _ in range(self.venv.num_envs)]

    self._ir_normalize_type = ir_normalize_type
    if self._ir_normalize_type == 0:
      pass
    elif self._ir_normalize_type == 1:
      ir_normalize_gamma = 0.99
      self._irff = RewardForwardFilter(ir_normalize_gamma)
      self._irff_rms = RunningMeanStd()
    elif self._ir_normalize_type == 2:
      self._ir_rms = RunningMeanStd()
    elif self._ir_normalize_type == 3:
      self._ir_rms = SimpleWeightedMovingScalarMeanStd(alpha=0.0001)
    else:
      assert False

    self._ir_clip_low = ir_clip_low

    self._name = name

    # Cumulative task reward over an episode.
    self._episode_task_reward = [0.0] * self.venv.num_envs
    self._episode_bonus_reward = [0.0] * self.venv.num_envs

    # Stats on the task and exploration reward.
    self._stats_task_reward = MovingAverage(capacity=100)
    self._stats_bonus_reward = MovingAverage(capacity=100)

    # Total number of steps so far per environment.
    self._step_count = 0

    # Observers are notified each time a new time step is generated by the
    # environment.
    self._observers = []

    self._bonus_reward_raw_history = [[] for _ in range(self.venv.num_envs)]
    self._bonus_reward_history = [[] for _ in range(self.venv.num_envs)]

  def _should_process_observation(self, obs_shape):
    # Only post-process observations that look like an image.
    return len(obs_shape) >= 3

  def add_observer(self, observer):
    self._observers.append(observer)

  def _preprocess_observation(self, observations):
    if (not self._should_process_observation(observations[0].shape)) or self._rlb_image_shape is None:
      return observations

    return np.array(
        [resize_observation(obs, self._rlb_image_shape, None)
         for obs in observations])

  def _postprocess_observation(self, observations):
    if (not self._should_process_observation(observations[0].shape)) or self._target_image_shape is None:
      return observations

    return np.array(
        [resize_observation(obs, self._target_image_shape, None)
         for obs in observations])

  def _compute_rlb_rewards(self, observations, infos, dones):
    # Computes the surrogate reward.
    # This extra reward is set to 0 when the episode is finished.
    if infos[0].get('frame') is not None:
      frames = np.array([info['frame'] for info in infos])
    else:
      frames = observations
    embedded_observations = self._observation_embedding_fn(frames)
    memory_set = [
        self._vec_episodic_memory[k].get_data()
        for k in range(self.venv.num_envs)
    ]

    bonus_rewards = self._intrinsic_reward_fn(memory_set, embedded_observations)

    # Updates the episodic memory of every environment.
    for k in range(self.venv.num_envs):
      # If we've reached the end of the episode, resets the memory
      # and always adds the first state of the new episode to the memory.
      if dones[k]:
        self._vec_episodic_memory[k].reset()
        self._vec_episodic_memory[k].add(embedded_observations[k], infos[k])
        continue

      if True:
        self._vec_episodic_memory[k].add(embedded_observations[k], infos[k])
    return bonus_rewards

  def _compute_oracle_reward(self, infos, dones):
    bonus_rewards = [
        self._oracles[k].update_position(infos[k]['position'])
        for k in range(self.venv.num_envs)]
    bonus_rewards = np.array(bonus_rewards)

    for k in range(self.venv.num_envs):
      if dones[k]:
        self._oracles[k].reset()

    return bonus_rewards

  def step_async(self, actions):
    self._last_actions = actions
    VecEnvWrapper.step_async(self, actions)

  def step_wait(self):
    """Overrides VecEnvWrapper.step_wait."""
    observations, rewards, dones, infos = self.venv.step_wait()
    observations = self._preprocess_observation(observations)

    # Hacky.
    updated = False
    for observer in self._observers:
      if hasattr(observer, 'on_new_observation'):
        observer.on_new_observation(observations, rewards, dones, infos)
      if hasattr(observer, 'on_new_observation2'):
        with logger.ProfileKV('on_new_observation2'):
          updated = observer.on_new_observation2(observations, rewards, dones, infos, self._last_actions) or updated

    self._step_count += 1

    if (self._step_count % 1000) == 0:
      print('RLBEnvWrapper [{}]: step={} task_reward={} bonus_reward={} scale_bonus={}'.format(
          self._name,
          self._step_count,
          self._stats_task_reward.mean(),
          self._stats_bonus_reward.mean(),
          self._scale_surrogate_reward))

    for i in range(self.venv.num_envs):
      infos[i]['task_reward'] = rewards[i]
      infos[i]['task_observation'] = observations[i]

    # Exploration bonus.
    if self._exploration_reward == 'rlb':
      bonus_rewards_raw = self._compute_rlb_rewards(
          observations, infos, dones)
      bonus_rewards_raw = np.nan_to_num(bonus_rewards_raw)

      if self._ir_normalize_type == 1:
        irffs = self._irff.update(bonus_rewards_raw)
        self._irff_rms.update(irffs.ravel())
        bonus_rewards = bonus_rewards_raw / np.sqrt(self._irff_rms.var)
      elif self._ir_normalize_type in [2, 3]:
        self._ir_rms.update(bonus_rewards_raw.ravel())
        bonus_rewards = bonus_rewards_raw / np.sqrt(self._ir_rms.var)
        #logger.info('self._ir_rms.var: {}'.format(self._ir_rms.var))
      else:
        assert self._ir_normalize_type == 0
        bonus_rewards = bonus_rewards_raw

      if self._ir_clip_low is not None:
        bonus_rewards = np.clip(bonus_rewards, a_min=self._ir_clip_low, a_max=None)

    elif self._exploration_reward == 'oracle':
      with logger.ProfileKV('ir_oracle'):
        bonus_rewards_raw = self._compute_oracle_reward(infos, dones)
      bonus_rewards = bonus_rewards_raw
    elif self._exploration_reward == 'none':
      bonus_rewards_raw = np.zeros(self.venv.num_envs)
      bonus_rewards = bonus_rewards_raw
    else:
      raise ValueError('Unknown exploration reward: {}'.format(
          self._exploration_reward))

    for i in range(self.venv.num_envs):
      infos[i]['bonus_reward_raw'] = bonus_rewards_raw[i]
      infos[i]['bonus_reward'] = bonus_rewards[i]

    # Combined rewards.
    scale_surrogate_reward = self._scale_surrogate_reward
    if self._step_count < self._exploration_reward_min_step:
      # This can be used for online training during the first N steps,
      # the R network is totally random and the surrogate reward has no
      # meaning.
      scale_surrogate_reward = 0.0
    postprocessed_rewards = (self._scale_task_reward * rewards +
                             scale_surrogate_reward * bonus_rewards)

    # Update the statistics.
    for i in range(self.venv.num_envs):
      self._episode_task_reward[i] += rewards[i]
      self._episode_bonus_reward[i] += bonus_rewards[i]
      self._bonus_reward_raw_history[i].append(bonus_rewards_raw[i])
      self._bonus_reward_history[i].append(bonus_rewards[i])
      if dones[i]:
        self._stats_task_reward.add(self._episode_task_reward[i])
        self._stats_bonus_reward.add(self._episode_bonus_reward[i])
        self._episode_task_reward[i] = 0.0
        self._episode_bonus_reward[i] = 0.0
        infos[i]['bonus_reward_raw_history'] = self._bonus_reward_raw_history[i]
        infos[i]['bonus_reward_history'] = self._bonus_reward_history[i]
        self._bonus_reward_raw_history[i] = []
        self._bonus_reward_history[i] = []
      else:
        assert not updated, 'Model updated in the middle of the current episode. _step_count: {}'.format(self._step_count)

    # Post-processing on the observation. Note that the reward could be used
    # as an input to the agent. For simplicity we add it as a separate channel.
    postprocessed_observations = self._postprocess_observation(observations)
    assert postprocessed_observations.dtype == np.uint8

    return postprocessed_observations, postprocessed_rewards, dones, infos

  def get_episodic_memory(self, k):
    """Returns the episodic memory for the k-th environment."""
    return self._vec_episodic_memory[k]

  def reset(self):
    """Overrides VecEnvWrapper.reset."""
    observations = self.venv.reset()
    observations = self._preprocess_observation(observations)
    postprocessed_observations = self._postprocess_observation(observations)

    # Clears the episodic memory of every environment.
    if self._vec_episodic_memory is not None:
      for memory in self._vec_episodic_memory:
        memory.reset()

    for i in range(self.venv.num_envs):
      self._bonus_reward_raw_history[i] = []
      self._bonus_reward_history[i] = []

    return postprocessed_observations

