import numpy as np
from multiprocessing import Process, Pipe
from functools import partial


class ParallelEnvWrapper:
    def __init__(self, make_env, env_num, seed=None):
        self.env_num = env_num
        self.env_fns = [partial(make_env) for _ in range(env_num)]
        self.seed = seed

        self.parent_conns, self.child_conns = zip(*[Pipe() for _ in range(env_num)])
        self.ps = [
            Process(target=worker, args=(child_conn, parent_conn, CloudpickleWrapper(env_fn), seed + i))
            for i, ((parent_conn, child_conn), env_fn) in enumerate(
                zip(zip(self.parent_conns, self.child_conns), self.env_fns))
        ]

        for p in self.ps:
            p.daemon = True
            p.start()

        for conn in self.child_conns:
            conn.close()

        self.parent_conns[0].send(('get_properties', None))
        self.agent_num, self.action_dim, self.obs_dim = self.parent_conns[0].recv()

    def reset(self):
        for conn in self.parent_conns:
            conn.send(('reset', None))

        obs_list = [conn.recv() for conn in self.parent_conns]
        return np.stack(obs_list, axis=0)  # [env_num, agent_num, obs_dim]

    def step(self, actions):
        """
        Args:
            actions: [env_num, agent_num, action_dim]
        Returns:
            obs: [env_num, agent_num, obs_dim]
            rewards: [env_num,]
            dones: [env_num,]
            infos: list of dict
        """
        for conn, action in zip(self.parent_conns, actions):
            conn.send(('step', action))

        results = [conn.recv() for conn in self.parent_conns]
        obs_list, reward_list, terminated_list, truncated_list, info_list = zip(*results)

        return (
            np.stack(obs_list, axis=0),
            np.array(reward_list),
            np.array(terminated_list),
            np.array(truncated_list),
            list(info_list)
        )

    def close(self):
        for conn in self.parent_conns:
            conn.send(('close', None))
        for p in self.ps:
            p.join()


def worker(remote, parent_remote, env_fn_wrapper, seed):
    parent_remote.close()
    env = env_fn_wrapper.x()
    env.seed = seed

    agent_num = env.agent_num
    action_dim = env.action_dim
    obs_dim = env.obs_dim

    while True:
        cmd, data = remote.recv()
        if cmd == 'step':
            action = data
            obs, reward, terminated, truncated, info = env.step(action)

            if terminated or truncated:
                reset_obs = env.reset()
                remote.send((reset_obs, reward, terminated, truncated, info))
            else:
                remote.send((obs, reward, terminated, truncated, info))

        elif cmd == 'reset':
            obs = env.reset()
            remote.send(obs)

        elif cmd == 'get_properties':
            remote.send((agent_num, action_dim, obs_dim))

        elif cmd == 'close':
            env.close()
            remote.close()
            break

        else:
            raise NotImplementedError


class CloudpickleWrapper:
    def __init__(self, x):
        self.x = x

    def __getstate__(self):
        import cloudpickle
        return cloudpickle.dumps(self.x)

    def __setstate__(self, ob):
        import pickle
        self.x = pickle.loads(ob)