import functools
import time
from collections import defaultdict, deque
import heapq

import elements
import numpy as np

from embodied.delays import state_independent

class Wrapper:

  def __init__(self, env):
    self.env = env

  def __len__(self):
    return len(self.env)

  def __bool__(self):
    return bool(self.env)

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self.env, name)
    except AttributeError:
      raise ValueError(name)
    
  def __hasattr__(self, name):
    return hasattr(self.env, name)

class ObsInBuffer:
  
  def __init__(self,
              env_time: int,
              receive_time: int,
              time_step: dict):
    if 'log/obs_env_time' in time_step.keys() and 'log/obs_receive_time' in time_step.keys():
      assert env_time == time_step['log/obs_env_time'], "Something is wrong."
      assert receive_time == time_step['log/obs_receive_time'], "Something is wrong."
    self._env_time: int = env_time
    self._receive_time: int = receive_time
    self._time_step: dict = time_step
      
  @property
  def env_time(self) -> int:
    return self._env_time
  
  @property
  def receive_time(self) -> int:
    return self._receive_time
      
  @property
  def time_step(self) -> dict:
    return self._time_step
  
  def __lt__(self, other):
    if self._receive_time == other.receive_time:
      return self._env_time < other.env_time
    return self.receive_time < other.receive_time

class DelayedEnv(Wrapper):
  def __init__(self, env,
               dist: str,
               maximum_delay: int,
               obs_window_size: int,
               **dist_kwargs: dict):
    super().__init__(env)
    self._env_time: int = 0
    self._generator = getattr(state_independent, dist)(**dist_kwargs)
    self._maximum_delay: int = maximum_delay
    self._obs_window_size: int = obs_window_size
    self._obs_buffer: list[ObsInBuffer] = list[ObsInBuffer]()
    self._observed_history: list[ObsInBuffer] = list[ObsInBuffer]()
    self._act_buffer: list[dict] = list[dict]()
    self._done: bool = False
    self._latest_obs: dict = None
    self._current_obs: dict = None
    self._reset()
    
    self._obs_space: dict = self.env.obs_space.copy()
    self._obs_space['log/obs_env_time'] = elements.Space(np.int32, (), low=0, high=np.inf)
    self._obs_space['log/obs_receive_time'] = elements.Space(np.int32, (), low=0, high=np.inf)
    # self._obs_space['log/cur_time'] = elements.Space(np.int32, (), low=0, high=np.inf)

  @property
  def generator(self):
    return self._generator

  @functools.cached_property
  def obs_space(self) -> dict:
    return self._obs_space

  @functools.cached_property
  def act_space(self) -> dict:
    return self.env.act_space

  @functools.cached_property
  def act_buffer_space(self) -> dict:
    act_buffer_space: dict = self.env.act_space.copy()
    act_buffer_space['log/act_time'] = elements.Space(np.int32, (), low=-1, high=np.inf)
    return act_buffer_space
  
  @property
  def env_time(self) -> int:
    return self._env_time
  
  @property
  def maximum_delay(self) -> int:
    return self._maximum_delay
  
  @property
  def obs_window_size(self) -> int:
    return self._obs_window_size
  
  @property
  def obs_buffer(self) -> list[ObsInBuffer]:
    return self._obs_buffer
  
  def _act_wrapper(self, act: dict) -> dict:
    act['action'] = act['action'].reshape((-1))
    act["reset"] = np.asarray(np.bool_(act["reset"]))
    act["log/act_time"] = np.int32(self._env_time)
    return act

  def _obs_at_retrieving(self, obs: dict) -> dict:
    obs["log/cur_time"] = np.int32(self._env_time)
    return obs
  
  def _obs_at_acting(self, obs: dict, receive_time: int) -> dict:
    obs["log/obs_env_time"] = np.int32(self._env_time)
    obs['log/obs_receive_time'] = np.int32(receive_time)
    return obs

  def _reset(self):
    self._done = False
    self._obs_buffer = list[ObsInBuffer]()
    self._observed_history = list[ObsInBuffer]()
    self._act_buffer = list[dict]([{k:(np.zeros(v.shape) if k!='log/act_time' else np.int32(-1)) for k, v in self.act_buffer_space.items()}])
    self._env_time = 0
    self._latest_obs = None
    self._current_obs = None

  def step(self, action):
    return self.env.step(action)
    
  def act(self, action) -> None:
    if action['reset']:
      obs = self.env.step(action)
      self._reset()
      delay = self._generator.generate({**obs, 'time': self._env_time})
      receive_time = delay + self._env_time
      obs = self._obs_at_acting(obs, receive_time)
      obs = ObsInBuffer(self._env_time, receive_time, obs)
      self._current_obs = obs
      heapq.heappush(self._obs_buffer, obs)
      self._observed_history += [None]
      return None
    else:
      obs = self.env.step(action)
      self._act_buffer.append(self._act_wrapper(action))
      self._env_time += 1
      delay = self._generator.generate({**obs, 'time': self._env_time})
      receive_time = delay + self._env_time
      obs = self._obs_at_acting(obs, receive_time)
      self._done = obs['is_last']
      obs = ObsInBuffer(self._env_time, receive_time, obs)
      self._current_obs = obs
      heapq.heappush(self._obs_buffer, obs)
      self._observed_history += [None]
      return None

  def get_next_available_obs(self) -> None:
    if self.has_update():
      obs = heapq.heappop(self._obs_buffer)
      if obs.env_time >= len(self._observed_history):
        pass
      self._observed_history[obs.env_time] = obs
      if self._latest_obs is None or obs.env_time > self._latest_obs.env_time:
        self._latest_obs = obs
      return obs.time_step
    return None

  def has_update(self) -> bool:
    return (len(self._obs_buffer) > 0) and (self._obs_buffer[0].receive_time <= self._env_time)

  def get_latest_obs(self) -> dict:
    # latest obs that is read by get_next_available_obs
    if self._latest_obs is None: return None
    return self._latest_obs.time_step
  
  def get_latest_arrived_obs(self) -> dict:
    # latest obs that is received by the agent
    for obs in self._obs_buffer[::-1]:
      if obs.receive_time <= self._env_time:
        return obs.time_step
    return None
  
  def get_current_obs(self) -> dict:
    return self._current_obs.time_step
  
  def get_action_history(self) -> list[dict]:
    return self._act_buffer
  
  def get_observed_history(self, d=None) -> list[dict]:
    if d is None:
      d = self._obs_window_size
    return [obs.time_step for obs in self._observed_history[-d-1:] if obs is not None]
  
  @property
  def observed_history(self):
    return self._observed_history
  
class DelayedEnvWithExtendedObs(Wrapper):
  def __init__(self, env,
               dist: str,
               maximum_delay: int,
               privilaged_decoder: bool,
               include_actions: bool,
               include_masks: bool,
               **dist_kwargs: dict):
    super().__init__(env)
    
    self._generator = getattr(state_independent, dist)(**dist_kwargs)
    self._maximum_delay: int = maximum_delay

    self._obs_buffer: list[ObsInBuffer] = list[ObsInBuffer]()
    self._act_buffer: list[dict] = list[dict]()

    self._privilaged_decoder = privilaged_decoder
    self._include_actions = include_actions
    self._include_masks = include_masks

    self._obs_space: dict = self.obs_space.copy()
    self._obs_keys = [k for k in self.obs_space.keys()
                      if k not in ['reward', 'is_first', 'is_last', 'is_terminal'] and not k.startswith('log/')]
    
    if self._include_actions:
      self._obs_space['next_actions'] = elements.Space(
        self.act_space['action'].dtype,
        (int(np.prod(self.act_space['action'].shape)*self._maximum_delay),),
        low=np.repeat(self.act_space['action'].low, self._maximum_delay),
        high=np.repeat(self.act_space['action'].high, self._maximum_delay)
      )

    self._deterministic_delay = type(self._generator) is state_independent.Fixed

    if self._include_masks and not self._deterministic_delay:
      self._mask_key = 'arrived_mask'
      self._obs_space[self._mask_key] = elements.Space(np.int32, low=0, high=1)
      self._obs_keys.append(self._mask_key)

    if self._privilaged_decoder:
      encoder_keys = [k for k in self._obs_space.keys() if k not in ['reward', 'is_first', 'is_last', 'is_terminal'] and not k.startswith('log/')]
      self._obs_space.update({f'decoder_{k}': self._obs_space[k] for k in self._obs_keys})


    if not self._deterministic_delay:
      for k in self._obs_keys:
        self._obs_space[k] = elements.Space(
          self.obs_space[k].dtype,
          (int(np.prod(self.obs_space[k].shape)*(self._maximum_delay+1)),),
          low=np.repeat(self.obs_space[k].low, self._maximum_delay+1),
          high=np.repeat(self.obs_space[k].high, self._maximum_delay+1)
        )

    if self._privilaged_decoder:
      for k in encoder_keys:
        self._obs_space[f'encoder_{k}'] = self._obs_space.pop(k)

    self._reset()

  @property
  def obs_space(self) -> dict:
    return self._obs_space

  @functools.cached_property
  def act_space(self) -> dict:
    return self.env.act_space

  def _reset(self):
    self._done = False
    self._obs_buffer = list[ObsInBuffer]()
    for t in range(self._maximum_delay):
      obs = {k: np.zeros(v.shape) for k, v in self.env.obs_space.items() if k in self._obs_keys}
      self._obs_buffer.append(ObsInBuffer(t - self._maximum_delay, 0, obs))
    self._act_buffer = [np.zeros(self._act_space['action'].shape) for _ in range(self._maximum_delay)]
    self._env_time = 0

  def step(self, action):
    obs = self.env.step(action)
    if action['reset']:
      self._reset()
    else:
      self._act_buffer.append(action['action'])
      self._env_time += 1
    delay = self._generator.generate({**obs, 'time': self._env_time})
    receive_time = delay + self._env_time
    self._obs_buffer.append(ObsInBuffer(self._env_time, receive_time, obs))

    extended_obs = {}
    D = self._maximum_delay
    if self._include_actions:
      extended_obs['next_actions'] = np.concatenate(self._act_buffer[-D:])
    if self._deterministic_delay:
      for key in self._obs_keys:
        extended_obs[key] = self._obs_buffer[-D].time_step[key]
    else:
      obs_to_send = {k: [] for k in self._obs_keys}
      for o in self._obs_buffer[-D-1:]: 
        for key in self._obs_keys:
            if self._include_masks and key==self._mask_key:
              v = int(o.receive_time <= self._env_time)
              obs_to_send[key].append(np.atleast_1d(v))
              continue
            v = o.time_step[key] if o.receive_time <= self._env_time else np.zeros(self.env.obs_space[key].shape)
            obs_to_send[key].append(np.atleast_1d(v))
      obs_to_send = {k: np.concatenate(v) for k, v in obs_to_send.items()}
      for key in self._obs_keys:
        extended_obs[key] = obs_to_send[key]
    if self._privilaged_decoder:
      extended_obs = {f'encoder_{k}': v for k, v in extended_obs.items()}
      extended_obs.update({f'decoder_{k}': v for k, v in obs.items() if k in self._obs_keys})
    extended_obs['reward'] = obs['reward']
    extended_obs['is_first'] = obs['is_first']
    extended_obs['is_last'] = obs['is_last']
    extended_obs['is_terminal'] =  obs['is_terminal']
    return extended_obs

class DelayedEnvWithLatestObs(DelayedEnv):
  def __init__(self, env,
               dist: str,
               maximum_delay: int,
               **dist_kwargs: dict):
    super().__init__(env, dist, maximum_delay, maximum_delay, **dist_kwargs)

    self._obs_space = self.obs_space.copy()
    self._obs_space['collect_reward'] = self.obs_space['reward']

  @property
  def obs_space(self) -> dict:
    return self._obs_space
    
  def step(self, action):
    self.act(action)
    latest_obs = self.get_latest_arrived_obs()
    current_obs = self.get_current_obs()
    if latest_obs is None:
      latest_obs = {k: np.zeros(v.shape, v.dtype) for k, v in self.obs_space.items()}
    latest_obs['is_first'] = current_obs['is_first']
    latest_obs['collect_reward'] = current_obs['reward']
    latest_obs['is_last'] = current_obs['is_last']
    return latest_obs

class DelayedEnvWithWaitAction(Wrapper):
  def __init__(self, env,
               dist: str,
               **dist_kwargs: dict):
    super().__init__(env)
    self._generator = getattr(state_independent, dist)(**dist_kwargs)
    # change wait action hre manually for now
    self._wait_action = {k: np.zeros(v.shape, v.dtype) for k, v in self.act_space.items()}

    self._env_time = 0

  def step(self, action):
    if action['reset']:
      self._env_time = 0
    obs = self.env.step(action)
    if action['reset']: return obs
    reward = obs['reward']
    delay = self._generator.generate({**obs, 'time': self._env_time})
    for i in range(delay):
      if obs['is_last'] or obs['is_terminal']:
        break
      obs = self.env.step(self._wait_action)
      reward += obs['reward']
    obs['reward'] = reward
    self._env_time += 1
    return obs

class ActionNoise(Wrapper):
  def __init__(self, env, key='action', noise_fraction=0.1):
    super().__init__(env)
    self._noise_fraction = noise_fraction
    self._key = key

  def step(self, action):
    noise_scale = self._noise_fraction * (self.act_space[self._key].high - self.act_space[self._key].low) * 0.5
    noisy_action = action[self._key] + np.random.normal(0, noise_scale, size=action[self._key].shape)
    noisy_action = np.clip(noisy_action, self.act_space[self._key].low, self.act_space[self._key].high)
    noisy_action = np.asarray(noisy_action, dtype=self.act_space[self._key].dtype)
    return self.env.step({**action, self._key: noisy_action})
  
class TimeLimit(Wrapper):

  def __init__(self, env, duration, reset=True):
    super().__init__(env)
    self._duration = duration
    self._reset = reset
    self._step = 0
    self._done = False

  def step(self, action):
    if action['reset'] or self._done:
      self._step = 0
      self._done = False
      if self._reset:
        action.update(reset=True)
        return self.env.step(action)
      else:
        action.update(reset=False)
        obs = self.env.step(action)
        obs['is_first'] = True
        return obs
    self._step += 1
    obs = self.env.step(action)
    if self._duration and self._step >= self._duration:
      obs['is_last'] = True
    self._done = obs['is_last']
    return obs


class ActionRepeat(Wrapper):

  def __init__(self, env, repeat):
    super().__init__(env)
    self._repeat = repeat

  def step(self, action):
    if action['reset']:
      return self.env.step(action)
    reward = 0.0
    success = False
    for _ in range(self._repeat):
      obs = self.env.step(action)
      reward += obs['reward']
      if 'success' in obs.keys():
        success = success or obs['success']
      if obs['is_last'] or obs['is_terminal']:
        break
    obs['reward'] = np.float32(reward)
    if 'success' in obs.keys():
      obs['success'] = np.float32(success)
    return obs


class ClipAction(Wrapper):

  def __init__(self, env, key='action', low=-1, high=1):
    super().__init__(env)
    self._key = key
    self._low = low
    self._high = high

  def step(self, action):
    clipped = np.clip(action[self._key], self._low, self._high)
    return self.env.step({**action, self._key: clipped})


class NormalizeAction(Wrapper):

  def __init__(self, env, key='action'):
    super().__init__(env)
    self._key = key
    self._space = env.act_space[key]
    self._mask = np.isfinite(self._space.low) & np.isfinite(self._space.high)
    self._low = np.where(self._mask, self._space.low, -1)
    self._high = np.where(self._mask, self._space.high, 1)

  @functools.cached_property
  def act_space(self):
    low = np.where(self._mask, -np.ones_like(self._low), self._low)
    high = np.where(self._mask, np.ones_like(self._low), self._high)
    space = elements.Space(np.float32, self._space.shape, low, high)
    return {**self.env.act_space, self._key: space}

  def step(self, action):
    orig = (action[self._key] + 1) / 2 * (self._high - self._low) + self._low
    orig = np.where(self._mask, orig, action[self._key])
    return self.env.step({**action, self._key: orig})


# class ExpandScalars(Wrapper):
#
#   def __init__(self, env):
#     super().__init__(env)
#     self._obs_expanded = []
#     self._obs_space = {}
#     for key, space in self.env.obs_space.items():
#       if space.shape == () and key != 'reward' and not space.discrete:
#         space = elements.Space(space.dtype, (1,), space.low, space.high)
#         self._obs_expanded.append(key)
#       self._obs_space[key] = space
#     self._act_expanded = []
#     self._act_space = {}
#     for key, space in self.env.act_space.items():
#       if space.shape == () and not space.discrete:
#         space = elements.Space(space.dtype, (1,), space.low, space.high)
#         self._act_expanded.append(key)
#       self._act_space[key] = space
#
#   @functools.cached_property
#   def obs_space(self):
#     return self._obs_space
#
#   @functools.cached_property
#   def act_space(self):
#     return self._act_space
#
#   def step(self, action):
#     action = {
#         key: np.squeeze(value, 0) if key in self._act_expanded else value
#         for key, value in action.items()}
#     obs = self.env.step(action)
#     obs = {
#         key: np.expand_dims(value, 0) if key in self._obs_expanded else value
#         for key, value in obs.items()}
#     return obs
#
#
# class FlattenTwoDimObs(Wrapper):
#
#   def __init__(self, env):
#     super().__init__(env)
#     self._keys = []
#     self._obs_space = {}
#     for key, space in self.env.obs_space.items():
#       if len(space.shape) == 2:
#         space = elements.Space(
#             space.dtype,
#             (int(np.prod(space.shape)),),
#             space.low.flatten(),
#             space.high.flatten())
#         self._keys.append(key)
#       self._obs_space[key] = space
#
#   @functools.cached_property
#   def obs_space(self):
#     return self._obs_space
#
#   def step(self, action):
#     obs = self.env.step(action).copy()
#     for key in self._keys:
#       obs[key] = obs[key].flatten()
#     return obs
#
#
# class FlattenTwoDimActions(Wrapper):
#
#   def __init__(self, env):
#     super().__init__(env)
#     self._origs = {}
#     self._act_space = {}
#     for key, space in self.env.act_space.items():
#       if len(space.shape) == 2:
#         space = elements.Space(
#             space.dtype,
#             (int(np.prod(space.shape)),),
#             space.low.flatten(),
#             space.high.flatten())
#         self._origs[key] = space.shape
#       self._act_space[key] = space
#
#   @functools.cached_property
#   def act_space(self):
#     return self._act_space
#
#   def step(self, action):
#     action = action.copy()
#     for key, shape in self._origs.items():
#       action[key] = action[key].reshape(shape)
#     return self.env.step(action)


class UnifyDtypes(Wrapper):

  def __init__(self, env):
    super().__init__(env)
    self._obs_space, _, self._obs_outer = self._convert(env.obs_space)
    self._act_space, self._act_inner, _ = self._convert(env.act_space)

  @property
  def obs_space(self):
    return self._obs_space

  @property
  def act_space(self):
    return self._act_space

  def step(self, action):
    action = action.copy()
    for key, dtype in self._act_inner.items():
      action[key] = np.asarray(action[key], dtype)
    obs = self.env.step(action)
    for key, dtype in self._obs_outer.items():
      obs[key] = np.asarray(obs[key], dtype)
    return obs

  def _convert(self, spaces):
    results, befores, afters = {}, {}, {}
    for key, space in spaces.items():
      before = after = space.dtype
      if np.issubdtype(before, np.floating):
        after = np.float32
      elif np.issubdtype(before, np.uint8):
        after = np.uint8
      elif np.issubdtype(before, np.integer):
        after = np.int32
      befores[key] = before
      afters[key] = after
      results[key] = elements.Space(after, space.shape, space.low, space.high)
    return results, befores, afters


class CheckSpaces(Wrapper):

  def __init__(self, env):
    assert not (env.obs_space.keys() & env.act_space.keys()), (
        env.obs_space.keys(), env.act_space.keys())
    super().__init__(env)

  def step(self, action):
    for key, value in action.items():
      self._check(value, self.env.act_space[key], key)
    obs = self.env.step(action)
    for key, value in obs.items():
      self._check(value, self.env.obs_space[key], key)
    return obs

  def _check(self, value, space, key):
    if not isinstance(value, (
        np.ndarray, np.generic, list, tuple, int, float, bool)):
      raise TypeError(f'Invalid type {type(value)} for key {key}.')
    if value in space:
      return
    dtype = np.array(value).dtype
    shape = np.array(value).shape
    lowest, highest = np.min(value), np.max(value)
    raise ValueError(
        f"Value for '{key}' with dtype {dtype}, shape {shape}, "
        f"lowest {lowest}, highest {highest} is not in {space}.")


class DiscretizeAction(Wrapper):

  def __init__(self, env, key='action', bins=5):
    super().__init__(env)
    self._dims = np.squeeze(env.act_space[key].shape, 0).item()
    self._values = np.linspace(-1, 1, bins)
    self._key = key

  @functools.cached_property
  def act_space(self):
    space = elements.Space(np.int32, self._dims, 0, len(self._values))
    return {**self.env.act_space, self._key: space}

  def step(self, action):
    continuous = np.take(self._values, action[self._key])
    return self.env.step({**action, self._key: continuous})


class ResizeImage(Wrapper):

  def __init__(self, env, size=(64, 64)):
    super().__init__(env)
    self._size = size
    self._keys = [
        k for k, v in env.obs_space.items()
        if len(v.shape) > 1 and v.shape[:2] != size]
    print(f'Resizing keys {",".join(self._keys)} to {self._size}.')
    if self._keys:
      from PIL import Image
      self._Image = Image

  @functools.cached_property
  def obs_space(self):
    spaces = self.env.obs_space
    for key in self._keys:
      shape = self._size + spaces[key].shape[2:]
      spaces[key] = elements.Space(np.uint8, shape)
    return spaces

  def step(self, action):
    obs = self.env.step(action)
    for key in self._keys:
      obs[key] = self._resize(obs[key])
    return obs

  def _resize(self, image):
    image = self._Image.fromarray(image)
    image = image.resize(self._size, self._Image.NEAREST)
    image = np.array(image)
    return image


# class RenderImage(Wrapper):
#
#   def __init__(self, env, key='image'):
#     super().__init__(env)
#     self._key = key
#     self._shape = self.env.render().shape
#
#   @functools.cached_property
#   def obs_space(self):
#     spaces = self.env.obs_space
#     spaces[self._key] = elements.Space(np.uint8, self._shape)
#     return spaces
#
#   def step(self, action):
#     obs = self.env.step(action)
#     obs[self._key] = self.env.render()
#     return obs


class BackwardReturn(Wrapper):

  def __init__(self, env, horizon):
    super().__init__(env)
    self._discount = 1 - 1 / horizon
    self._bwreturn = 0.0

  @functools.cached_property
  def obs_space(self):
    return {
        **self.env.obs_space,
        'bwreturn': elements.Space(np.float32),
    }

  def step(self, action):
    obs = self.env.step(action)
    self._bwreturn *= (1 - obs['is_first']) * self._discount
    self._bwreturn += obs['reward']
    obs['bwreturn'] = np.float32(self._bwreturn)
    return obs


class AddObs(Wrapper):

  def __init__(self, env, key, value, space):
    super().__init__(env)
    self._key = key
    self._value = value
    self._space = space

  @functools.cached_property
  def obs_space(self):
    return {
        **self.env.obs_space,
        self._key: self._space,
    }

  def step(self, action):
    obs = self.env.step(action)
    obs[self._key] = self._value
    return obs


class RestartOnException(Wrapper):

  def __init__(
      self, ctor, exceptions=(Exception,), window=300, maxfails=2, wait=20):
    if not isinstance(exceptions, (tuple, list)):
        exceptions = [exceptions]
    self._ctor = ctor
    self._exceptions = tuple(exceptions)
    self._window = window
    self._maxfails = maxfails
    self._wait = wait
    self._last = time.time()
    self._fails = 0
    super().__init__(self._ctor())

  def step(self, action):
    try:
      return self.env.step(action)
    except self._exceptions as e:
      if time.time() > self._last + self._window:
        self._last = time.time()
        self._fails = 1
      else:
        self._fails += 1
      if self._fails > self._maxfails:
        raise RuntimeError('The env crashed too many times.')
      message = f'Restarting env after crash with {type(e).__name__}: {e}'
      print(message, flush=True)
      time.sleep(self._wait)
      self.env = self._ctor()
      action['reset'] = np.ones_like(action['reset'])
      return self.env.step(action)
