from collections import OrderedDict

import numpy as np
from gym import spaces

from . import VecEnv


class DummyVecEnv(VecEnv):
  """
    Creates a simple vectorized wrapper for multiple environments

    :param env_fns: ([Gym Environment]) the list of environments to vectorize
    """

  def __init__(self, env_fns):
    self.envs = [fn() for fn in env_fns]
    env = self.envs[0]
    VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
    shapes, dtypes = {}, {}
    self.keys = []
    obs_space = env.observation_space

    self.dummy_env = env
    if isinstance(obs_space, spaces.Dict):
      assert isinstance(obs_space.spaces, OrderedDict)
      subspaces = obs_space.spaces
      if env.compute_reward is not None:
        self.compute_reward = env.compute_reward
      if hasattr(env, 'goal_extraction_function') and env.goal_extraction_function is not None:
        self.goal_extraction_function = env.goal_extraction_function
    else:
      subspaces = {None: obs_space}

    for key, box in subspaces.items():
      shapes[key] = box.shape
      dtypes[key] = box.dtype
      self.keys.append(key)

    self.buf_obs = {k: np.zeros((self.num_envs, ) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys}
    self.buf_dones = np.zeros((self.num_envs, ), dtype=np.bool)
    self.buf_rews = np.zeros((self.num_envs, ), dtype=np.float32)
    self.buf_infos = [{} for _ in range(self.num_envs)]
    self.actions = None

  def step_async(self, actions):
    self.actions = actions

  def step_wait(self):
    for env_idx in range(self.num_envs):
      obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] =\
          self.envs[env_idx].step(self.actions[env_idx])
      if self.buf_dones[env_idx]:
        obs = self.envs[env_idx].reset()
      self._save_obs(env_idx, obs)
    if self.keys == [None]:
      return (np.copy(self._obs_from_buf()), np.copy(self.buf_rews), np.copy(self.buf_dones), self.buf_infos.copy())
    else:
      return ({k: np.copy(v)
               for k, v in self._obs_from_buf().items()}, np.copy(self.buf_rews), np.copy(self.buf_dones),
              self.buf_infos.copy())

  def reset(self):
    for env_idx in range(self.num_envs):
      obs = self.envs[env_idx].reset()
      self._save_obs(env_idx, obs)
    if self.keys == [None]:
      return np.copy(self._obs_from_buf())
    else:
      return {k: np.copy(v) for k, v in self._obs_from_buf().items()}

  def close(self):
    return

  def get_images(self):
    return [env.render(mode='rgb_array') for env in self.envs]

  def render(self, *args, **kwargs):
    if self.num_envs == 1:
      return self.envs[0].render(*args, **kwargs)
    else:
      return super().render(*args, **kwargs)

  def _save_obs(self, env_idx, obs):
    for key in self.keys:
      if key is None:
        self.buf_obs[key][env_idx] = obs
      else:
        self.buf_obs[key][env_idx] = obs[key]

  def _obs_from_buf(self):
    if self.keys == [None]:
      return self.buf_obs[None]
    else:
      return self.buf_obs

  def get_attr(self, attr_name, indices=None):
      """Return attribute from vectorized environment (see base class)."""
      target_envs = self._get_target_envs(indices)
      return [getattr(env_i, attr_name) for env_i in target_envs]

  def set_attr(self, attr_name, value, indices=None):
      """Set attribute inside vectorized environments (see base class)."""
      target_envs = self._get_target_envs(indices)
      for env_i in target_envs:
          setattr(env_i, attr_name, value)

  def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
      """Call instance methods of vectorized environments."""
      target_envs = self._get_target_envs(indices)
      return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]

  def _get_target_envs(self, indices):
      indices = self._get_indices(indices)
      return [self.envs[i] for i in indices]