import multiprocessing as mp
import numpy as np
from baselines.common.vec_env.vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
import ctypes
from baselines import logger

from baselines.common.vec_env.util import dict_to_obs, obs_space_info, obs_to_dict

_NP_TO_CT = {np.float32: ctypes.c_float,
             np.int32: ctypes.c_int32,
             np.int8: ctypes.c_int8,
             np.uint8: ctypes.c_char,
             np.bool_: ctypes.c_bool,
             np.bool_: ctypes.c_bool}


class ShmemVecEnv(VecEnv):
    """
    Optimized version of SubprocVecEnv that uses shared variables to communicate observations.
    """

    def __init__(self, env_fns, spaces=None, context='spawn'):
        """
        If you don't specify observation_space, we'll have to create a dummy
        environment to get it.
        """
        ctx = mp.get_context(context)
        if spaces:
            observation_space, action_space = spaces
        else:
            logger.log('Creating dummy env object to get spaces')
            with logger.scoped_configure(format_strs=[]):
                dummy = env_fns[0]()
                observation_space, action_space = dummy.observation_space, dummy.action_space
                dummy.close()
                del dummy
        VecEnv.__init__(self, len(env_fns), observation_space, action_space)
        self.obs_keys, self.obs_shapes, self.obs_dtypes = obs_space_info(observation_space)
        self.obs_bufs = [
            {k: ctx.Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys}
            for _ in env_fns]
        self.parent_pipes = []
        self.procs = []
        self.locks = [ctx.Lock() for _ in env_fns]  # Create a lock for each environment
        with clear_mpi_env_vars():
            for env_fn, obs_buf, lock in zip(env_fns, self.obs_bufs, self.locks):
                wrapped_fn = CloudpickleWrapper(env_fn)
                parent_pipe, child_pipe = ctx.Pipe()
                proc = ctx.Process(target=_subproc_worker,
                            args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes, self.obs_dtypes, self.obs_keys, lock))
                proc.daemon = True
                self.procs.append(proc)
                self.parent_pipes.append(parent_pipe)
                proc.start()
                child_pipe.close()
        self.waiting_step = False
        self.waiting_cost2Env = False
        self.waiting_update_monitor = False

        self.viewer = None

    def reset(self):
        if self.waiting_step:
            logger.warn('Called reset() while waiting for the step to complete')
            self.step_wait()
        for pipe in self.parent_pipes:
            pipe.send(('reset', None))
        return self._decode_obses([pipe.recv() for pipe in self.parent_pipes])

    def step_async(self, actions):
        assert len(actions) == len(self.parent_pipes)
        for pipe, act in zip(self.parent_pipes, actions):
            pipe.send(('step', act))
        self.waiting_step = True

    def step_wait(self):
        outs = [pipe.recv() for pipe in self.parent_pipes]
        self.waiting_step = False
        obs, rews, dones, infos = zip(*outs)
        return self._decode_obses(obs), np.array(rews), np.array(dones), infos

    def talk2Env_async(self, data):
        assert len(data) == len(self.parent_pipes)
        for pipe, d in zip(self.parent_pipes, data):
            pipe.send(('talk2Env', d))
        self.waiting_step = True

    def talk2Env_wait(self):
        outs = [pipe.recv() for pipe in self.parent_pipes]  # pipe.recv() is a blocking call
        self.waiting_step = False
        return outs
    
    def update_monitor_async(self, data):
        obs, reward, done, infos = data
        observations = [{key: obs[key][index] for key in obs.keys()} for index in range(len(obs['robot_node']))]
        assert len(observations) == len(self.parent_pipes)
        assert len(reward) == len(self.parent_pipes)
        assert len(done) == len(self.parent_pipes)
        assert len(infos) == len(self.parent_pipes)
        for pipe, o, r, d, i in zip(self.parent_pipes, observations, reward, done, infos):
            pipe.send(('update_monitor', (o, r, d, i)))
        self.waiting_update_monitor = True
    

    def update_monitor_wait(self):
        outs = [pipe.recv() for pipe in self.parent_pipes]  # pipe.recv() is a blocking call
        self.waiting_update_monitor = False
        obs, rews, dones, infos = zip(*outs)
        
        return self._decode_obses(obs), np.array(rews), np.array(dones), infos

    # def cost2Env_async(self, cost):
    #     assert len(cost) == len(self.parent_pipes)
    #     for pipe, c in zip(self.parent_pipes, cost):
    #         pipe.send(('cost2Env', c))
    #     self.waiting_cost2Env = True
    

    # def cost2Env_wait(self):
    #     outs = [pipe.recv() for pipe in self.parent_pipes]  # pipe.recv() is a blocking call
    #     self.waiting_cost2Env = False
    #     return np.array(outs)

    def close_extras(self):
        if self.waiting_step:
            self.step_wait()
        for pipe in self.parent_pipes:
            pipe.send(('close', None))
        for pipe in self.parent_pipes:
            pipe.recv()
            pipe.close()
        for proc in self.procs:
            proc.join()

    def get_images(self, mode='human'):
        for pipe in self.parent_pipes:
            pipe.send(('render', None))
        return [pipe.recv() for pipe in self.parent_pipes]

    def _decode_obses(self, obs):
        result = {}
        for k in self.obs_keys:
            bufs = [b[k] for b in self.obs_bufs]
            o = [np.frombuffer(b.get_obj(), dtype=self.obs_dtypes[k]).reshape(self.obs_shapes[k]) for b in bufs]
            result[k] = np.array(o)
        return dict_to_obs(result)


def _subproc_worker(pipe, parent_pipe, env_fn_wrapper, obs_bufs, obs_shapes, obs_dtypes, keys, lock):
    """
    Control a single environment instance using IPC and
    shared memory.
    """
    def _write_obs(maybe_dict_obs):
        flatdict = obs_to_dict(maybe_dict_obs)
        with lock:
            for k in keys:
                dst = obs_bufs[k].get_obj()
                dst_np = np.frombuffer(dst, dtype=obs_dtypes[k]).reshape(obs_shapes[k])  # pylint: disable=W0212
                np.copyto(dst_np, flatdict[k])

    env = env_fn_wrapper.x()
    parent_pipe.close()
    try:
        while True:
            cmd, data = pipe.recv()
            if cmd == 'reset':
                env.reset_state()
                pipe.send(_write_obs(env.reset()))
            elif cmd == 'step':
                obs, reward, done, info = env.step(data)
                if done:
                    env.needs_reset = True
                    obs = env.reset()
                pipe.send((_write_obs(obs), reward, done, info))
            elif cmd == 'render':
                pipe.send(env.render(mode='rgb_array'))
            elif cmd == 'close':
                pipe.send(None)
                break
            elif cmd == 'talk2Env':
                pipe.send(env.talk2Env(data))
            elif cmd == 'update_monitor':
                obs, reward, done, info = env.update_monitor((data[0], data[1], data[2], data[3]))
                pipe.send((_write_obs(obs), reward, done, info))
            else:
                raise RuntimeError('Got unrecognized cmd %s' % cmd)
    except KeyboardInterrupt:
        print('ShmemVecEnv worker: got KeyboardInterrupt')
    finally:
        env.close()