import time

import cloudpickle
import elements
import numpy as np
import portal


class Driver:

  def __init__(self, make_env_fns, parallel=True, **kwargs):
    assert len(make_env_fns) >= 1
    self.parallel = parallel
    self.kwargs = kwargs
    self.length = len(make_env_fns)
    if parallel:
      import multiprocessing as mp
      context = mp.get_context()
      self.pipes, pipes = zip(*[context.Pipe() for _ in range(self.length)])
      self.stop = context.Event()
      fns = [cloudpickle.dumps(fn) for fn in make_env_fns]
      self.procs = [
          portal.Process(self._env_server, self.stop, i, pipe, fn, start=True)
          for i, (fn, pipe) in enumerate(zip(fns, pipes))]
      self.pipes[0].send(('act_space',))
      self.act_space = self._receive(self.pipes[0])
    else:
      self.envs = [fn() for fn in make_env_fns]
      self.act_space = self.envs[0].act_space
    self.callbacks = []
    self.acts = None
    self.carry = None
    self.reset()

  def reset(self, init_policy=None):
    self.acts = {
        k: np.zeros((self.length,) + v.shape, v.dtype)
        for k, v in self.act_space.items()}
    self.acts['reset'] = np.ones(self.length, bool)
    self.carry = init_policy and init_policy(self.length)
    self.reset_delay_variables()

  def reset_delay_variables(self):
    try:
      hasattr(self.envs[0], 'maximum_delay')
      if self.parallel:
        self.pipes[0].send(('maximum_delay',))
        maximum_delay = self._receive(self.pipes[0])
      else:
        maximum_delay = self.envs[0].maximum_delay
      self._prev_mask = [np.zeros(maximum_delay + 1) for _ in range(len(self.envs))]
      self._prev_masks = [[] for _ in range(len(self.envs))]
      self._all_outs = []
      self._obs = []
      self._state_time = -1
    except:
      pass
    if self.carry is not None:
      self._state = self.carry, {**self.carry[1], 'logit' : self.carry[1]['stoch']}
    else: self._state = None, None

  def close(self):
    if self.parallel:
      [proc.kill() for proc in self.procs]
    else:
      [env.close() for env in self.envs]

  def on_step(self, callback):
    self.callbacks.append(callback)

  def __call__(self, policy, steps=0, episodes=0):
    step, episode = 0, 0
    while step < steps or episode < episodes:
      step, episode = self._step(policy, step, episode)

  def delayed_call(self, imagine, observe, policy, steps=0, episodes=0):
    step, episode = 0, 0
    while step < steps or episode < episodes:
      step, episode = self._delayed_step(imagine, observe, policy, step, episode)

  def _step(self, policy, step, episode):
    acts = self.acts
    assert all(len(x) == self.length for x in acts.values())
    assert all(isinstance(v, np.ndarray) for v in acts.values())
    acts = [{k: v[i] for k, v in acts.items()} for i in range(self.length)]
    if self.parallel:
      [pipe.send(('step', act)) for pipe, act in zip(self.pipes, acts)]
      obs = [self._receive(pipe) for pipe in self.pipes]
    else:
      obs = [env.step(act) for env, act in zip(self.envs, acts)]
    obs = {k: np.stack([x[k] for x in obs]) for k in obs[0].keys()}
    logs = {k: v for k, v in obs.items() if k.startswith('log/')}
    obs = {k: v for k, v in obs.items() if not k.startswith('log/')}
    assert all(len(x) == self.length for x in obs.values()), obs
    self.carry, acts, outs = policy(self.carry, obs, **self.kwargs)
    assert all(k not in acts for k in outs), (
        list(outs.keys()), list(acts.keys()))
    if obs['is_last'].any():
      mask = ~obs['is_last']
      acts = {k: self._mask(v, mask) for k, v in acts.items()}
    self.acts = {**acts, 'reset': obs['is_last'].copy()}
    trans = {**obs, **acts, **outs, **logs}
    for i in range(self.length):
      trn = elements.tree.map(lambda x: x[i], trans)
      [fn(trn, i, **self.kwargs) for fn in self.callbacks]
    step += len(obs['is_first'])
    episode += obs['is_last'].sum()
    return step, episode

  def _delayed_step(self, imagine, observe, policy, step, episode):
    acts = self.acts
    assert all(len(x) == self.length for x in acts.values())
    assert all(isinstance(v, np.ndarray) for v in acts.values())
    acts = [{k: v[i] for k, v in acts.items() if not k.startswith('log/')} for i in range(self.length)]

    if self.parallel:
      [pipe.send(('act', act)) for pipe, act in zip(self.pipes, acts)]
      [self._receive(pipe) for pipe in self.pipes]
    else:
      [env.act(act) for env, act in zip(self.envs, acts)]

    action_history = self.get_action_history()
    # check action history shape here k: len, max_d, features

    prev_obs = self.get_latest_obs()
    # check prev_obs shape here k: envs, features

    obs_received = []
    obs = self.get_next_available_obs()
    while obs is not None:
      obs_received.append(obs)
      obs = self.get_next_available_obs()
    # check obs_received here list of t <= max_d and each: dict with shapes envs, features

    obs_history = self.get_obs_histroy()
    # check obs_history shape here list of t <= max_d and each: dict with shapes envs, features

    env_time = self.get_env_time()
    obs_window_size = self.get_obs_window_size()
    maximum_delay = self.get_maximum_delay()

    for i in range(len(self.envs)):
      self._prev_mask[i] = np.concatenate([self._prev_mask[i][1:].copy(), np.array([0])])
      for obs in obs_received:
        offset = env_time - obs['log/obs_env_time'][i] + 1
        self._prev_mask[i][-offset] = 1
      self._prev_masks[i].append(self._prev_mask[i].copy())

    states = [self._state]
    for i, o_i in enumerate(obs_history):
      if i == 0:
        t1 = self._state_time
      else:
        t1 = obs_history[i-1]['log/obs_env_time'][0]
      t2 = o_i['log/obs_env_time'][0]
      prev_actions = {k: v[:, t1+1:t2] for k, v in action_history.items()}
      for j in range(t2-t1-1):
        prev_action = {k: v[:, j:j+1] for k, v in prev_actions.items()}
        carry, out, feat = imagine(prev_action, states[-1][0])
        state = (carry, feat)
        states.append(state)
      prev_action = {k: v[:, t2] for k, v in action_history.items()}
      o_i = {k: v for k, v in o_i.items() if not k.startswith('log/')}
      carry, out, feat = observe(o_i, prev_action, states[-1][0])
      state = (carry, feat)
      states.append(state)
    if env_time >= obs_window_size and len(states) > 1:
      self._state = states[1]
      self._state_time = env_time - obs_window_size
    if len(obs_history):
      t = obs_history[-1]['log/obs_env_time'][0]
    else:
      t = max(-1, env_time - obs_window_size - 1)
    next_actions = {k: v[:, t+1:] for k, v in action_history.items()}

    carry, acts, out, _ = policy(next_actions, states[-1], **self.kwargs)
    self._all_outs.append(out)

    cur_obs = self.get_current_obs()

    if cur_obs['is_last'].any():
      mask = ~cur_obs['is_last']
      acts = {k: self._mask(v, mask) for k, v in acts.items()}
    self.acts = {**acts, 'reset': cur_obs['is_last'].copy()}
    acts = self.acts

    if prev_obs is not None:
      if len(obs_received) == 0:
        obs_received = [prev_obs.copy()]
      else:
        for i in range(len(obs_received)):
          if prev_obs['log/obs_env_time'][0] < obs_received[i]['log/obs_env_time'][0]:
            obs_received = obs_received[:i] + [prev_obs.copy()] + obs_received[i:]
            break
          if i == len(obs_received) - 1:
            obs_received.append(prev_obs.copy())
    
    if env_time >= maximum_delay:
      obs_to_store = {k: v for k, v in obs_history[0].items()}
      step += self.store(obs_to_store, acts)

    if cur_obs['is_last'].any():
      # storing remaining observations
      obs_buffer = self.get_obs_buffer()
      # check obs_buffer shape here list of t <= max_d and each: dict with shapes envs, features
      remaining_obs = obs_history[1:] + obs_buffer
      remaining_obs = sorted(remaining_obs, key=lambda x: x['log/obs_env_time'][0])
      for obs in remaining_obs:
        step += self.store(obs, acts)
        episode += 1
      self.reset_delay_variables()
    return step, episode
      
  def get_action_history(self):
    if self.parallel:
      [pipe.send(('action_history',)) for pipe in self.pipes]
      action_history = [self._receive(pipe) for pipe in self.pipes]
    else:
      action_history = [env.get_action_history() for env in self.envs]
    action_history = {k: np.array([([ac[i][k] for i in range(len(ac))]) for ac in action_history]) for k in action_history[0][0]}
    return action_history
  
  def get_obs_histroy(self):
    if self.parallel:
      [pipe.send(('observed_history',)) for pipe in self.pipes]
      obs_history = [self._receive(pipe) for pipe in self.pipes]
    else:
      obs_history = [env.get_observed_history() for env in self.envs]
    obs_history = [{k: np.stack([obs_history[e][t][k] for e in range(len(obs_history))]) for k in obs_history[0][0]} for t in range(len(obs_history[0]))]
    return obs_history

  def get_next_available_obs(self):
    if self.parallel:
      [pipe.send(('next_available_obs',)) for pipe in self.pipes]
      obs = [self._receive(pipe) for pipe in self.pipes]
    else:
      obs = [env.get_next_available_obs() for env in self.envs]
    if None in obs:
      return None
    obs = {k: np.stack([x[k] for x in obs]) for k in obs[0].keys()}
    return obs
  
  def get_current_obs(self):
    if self.parallel:
      [pipe.send(('get_current_obs',)) for pipe in self.pipes]
      cur_obs = [self._receive(pipe) for pipe in self.pipes]
    else:
      cur_obs = [env.get_current_obs() for env in self.envs]
    cur_obs = {k: np.stack([x[k] for x in cur_obs]) for k in cur_obs[0].keys()}
    return cur_obs

  def get_obs_buffer(self):
    if self.parallel:
      [pipe.send(('obs_buffer',)) for pipe in self.pipes]
      obs_buffer = [self._receive(pipe) for pipe in self.pipes]
    else:
      obs_buffer = [env.obs_buffer for env in self.envs]
    obs_buffer = [[obs.time_step for obs in env_buffer] for env_buffer in obs_buffer]
    obs_buffer = [{k: np.stack([obs_buffer[e][t][k] for e in range(len(obs_buffer))]) for k in obs_buffer[0][0]} for t in range(len(obs_buffer[0]))]
    return obs_buffer
  
  def get_latest_obs(self):
    if self.parallel:
      [pipe.send(('latest_obs',)) for pipe in self.pipes]
      prev_obs = [self._receive(pipe) for pipe in self.pipes]
    else:
      prev_obs = [env.get_latest_obs() for env in self.envs]
    if None in prev_obs:
      return None
    prev_obs = {k: np.stack([x[k] for x in prev_obs]) for k in prev_obs[0].keys()}
    return prev_obs

  def get_env_time(self):
    if self.parallel:
      self.pipes[0].send(('env_time',))
      env_time = self._receive(self.pipes[0])
    else:
      env_time = self.envs[0].env_time
    return env_time

  def get_obs_window_size(self):
    if self.parallel:
      self.pipes[0].send(('obs_window_size',))
      obs_window_size = self._receive(self.pipes[0])
    else:
      obs_window_size = self.envs[0].obs_window_size
    return obs_window_size

  def get_maximum_delay(self):
    if self.parallel:
      self.pipes[0].send(('maximum_delay',))
      maximum_delay = self._receive(self.pipes[0])
    else:
      maximum_delay = self.envs[0].maximum_delay
    return maximum_delay
  
  def store(self, obs, act):
    maximum_delay = self.get_maximum_delay()
    env_time = self.get_env_time()
    action_history = self.get_action_history()
    act['log/act_time'] = np.array([env_time])
    action_history = {k: np.append(v, act[k][np.newaxis], axis=1) for k, v in action_history.items()}
    for i in range(len(self.envs)):
      t1 = obs['log/obs_env_time'][i]
      t2 = obs['log/obs_receive_time'][i]
      assert t1 <= t2
      acts = {k: v[i][t1+1] for k, v in action_history.items()} # a_t
      next_actions = {k + '_next': v[i][t1+1:t2+1] for k, v in action_history.items()} # from a_t to a_t+d-1
      d = next_actions[list(next_actions.keys())[0]].shape[0]
      for k, v in next_actions.items():
        pds = maximum_delay - v.shape[0]
        assert len(v.shape) in [1, 2]
        if len(v.shape) == 1: # for action['reset']
          next_actions[k] = np.pad(v, pad_width=(0, pds), mode='constant')
        else: # for action['action']
          next_actions[k] = np.pad(v, pad_width=((0, pds), (0, 0)), mode='constant')
      obs_i = {k: v[i] for k, v in obs.items()}
      mask = np.pad(np.ones(d), pad_width=(0, maximum_delay - d), mode='constant')
      outs_i = {k: v[i] for k, v in self._all_outs[t1].items()}
      logs_i = {k: v for k, v in obs_i.items() if k.startswith('log/')}
      obs_i = {k: v for k, v in obs_i.items() if not k.startswith('log/')}
      next_actions = {k: v for k, v in next_actions.items() if not k.startswith('log/') and k!= 'reset_next'}
      acts = {k: v for k, v in acts.items() if not k.startswith('log/') and k != 'reset'}
      s = []
      trn = {**obs_i, **acts, **outs_i, **logs_i,
             **next_actions,
             'delay_mask': mask.astype(bool),
             'prev_mask': self._prev_masks[i][t1].astype(bool),}
      [fn(trn, i, **self.kwargs) for fn in self.callbacks]
    return len(self.envs)

  def _mask(self, value, mask):
    while mask.ndim < value.ndim:
      mask = mask[..., None]
    return value * mask.astype(value.dtype)

  def _receive(self, pipe):
    try:
      msg, arg = pipe.recv()
      if msg == 'error':
        raise RuntimeError(arg)
      assert msg == 'result'
      return arg
    except Exception:
      print('Terminating workers due to an exception.')
      [proc.kill() for proc in self.procs]
      raise

  @staticmethod
  def _env_server(stop, envid, pipe, ctor):
    try:
      ctor = cloudpickle.loads(ctor)
      env = ctor()
      while not stop.is_set():
        if not pipe.poll(0.1):
          time.sleep(0.1)
          continue
        try:
          msg, *args = pipe.recv()
        except EOFError:
          return
        if msg == 'step':
          assert len(args) == 1
          act = args[0]
          obs = env.step(act)
          pipe.send(('result', obs))
        elif msg == 'obs_space':
          assert len(args) == 0
          pipe.send(('result', env.obs_space))
        elif msg == 'act_space':
          assert len(args) == 0
          pipe.send(('result', env.act_space))
        # delay related
        elif msg == 'act':
          assert len(args) == 1
          act = args[0]
          env.act(act)
          pipe.send(('result', None))
        elif msg == 'action_history':
          assert len(args) == 0
          act_history = env.get_action_history()
          pipe.send(('result', act_history))
        elif msg == 'latest_obs':
          assert len(args) == 0
          latest_obs = env.get_latest_obs()
          pipe.send(('result', latest_obs))
        elif msg == 'get_current_obs':
          assert len(args) == 0
          cur_obs = env.get_current_obs()
          pipe.send(('result', cur_obs))
        elif msg == 'next_available_obs':
          assert len(args) == 0
          next_available_obs = env.get_next_available_obs()
          pipe.send(('result', next_available_obs))
        elif msg == 'observed_history':
          assert len(args) <= 1
          if len(args):
            obs_window_size = args[0]
            obs_history = env.get_observed_history(obs_window_size)
          else:
            obs_history = env.get_observed_history()
          pipe.send(('result', obs_history))
        elif msg == 'env_time':
          assert len(args) == 0
          env_time = env.env_time
          pipe.send(('result', env_time))
        elif msg == 'obs_window_size':
          assert len(args) == 0
          obs_window_size = env.obs_window_size
          pipe.send(('result', obs_window_size))
        elif msg == 'maximum_delay':
          assert len(args) == 0
          maximum_delay = env.maximum_delay
          pipe.send(('result', maximum_delay))
        elif msg == 'obs_buffer':
          assert len(args) == 0
          obs_buffer = env.obs_buffer
          pipe.send(('result', obs_buffer))
        else:
          raise ValueError(f'Invalid message {msg}')
    except ConnectionResetError:
      print('Connection to driver lost')
    except Exception as e:
      pipe.send(('error', e))
      raise
    finally:
      try:
        env.close()
      except Exception:
        pass
      pipe.close()
