from gym.vector.utils import (
    create_shared_memory,
    create_empty_array,
    write_to_shared_memory,
    read_from_shared_memory,
    concatenate,
    CloudpickleWrapper,
    clear_mpi_env_vars,
)

import sys


def _worker_shared_memory_2(
    index, env_fn, pipe, parent_pipe, shared_memory, error_queue
):
    assert shared_memory is not None
    env = env_fn()
    observation_space = env.observation_space
    parent_pipe.close()
    try:
        while True:
            command, data = pipe.recv()
            if command == "reset":
                observation = env.reset()
                write_to_shared_memory(
                    index, observation, shared_memory, observation_space
                )
                pipe.send((None, True))
            elif command == "step":
                observation, reward, done, info = env.step(data)
                if done:
                    observation = env.reset()
                write_to_shared_memory(
                    index, observation, shared_memory, observation_space
                )
                pipe.send(((None, reward, done, info), True))
            elif command == "seed":
                env.seed(data)
                pipe.send((None, True))
            elif command == "close":
                pipe.send((None, True))
                break
            elif command == "_check_observation_space":
                pipe.send((data == observation_space, True))
            elif command == "sim.get_state":
                state = env.sim.get_state().flatten()
                pipe.send((state, True))
            elif command == "sim.set_state":
                env.sim.set_state_from_flattened(data)
                pipe.send((None, True))
            else:
                raise RuntimeError(
                    "Received unknown command `{0}`. Must "
                    "be one of {`reset`, `step`, `seed`, `close`, "
                    "`_check_observation_space`}.".format(command)
                )
    except (KeyboardInterrupt, Exception):
        error_queue.put((index,) + sys.exc_info()[:2])
        pipe.send((None, False))
    finally:
        env.close()


def sim_get_state(self, data=None):
    self._assert_is_running()
    for pipe in self.parent_pipes:
        pipe.send(("sim.get_state", data))
    states, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
    self._raise_if_errors(successes)
    return states


def sim_get_state_sync(self, data=None):
    states, successes = zip(
        *[(env.sim.get_state().flatten(), True) for env in self.envs]
    )
    return states


def sim_set_state(self, states):
    self._assert_is_running()
    for pipe, state in zip(self.parent_pipes, states):
        pipe.send(("sim.set_state", state))
    _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
    self._raise_if_errors(successes)
    return True


def sim_set_state_sync(self, states):
    for env, state in zip(self.envs, states):
        env.sim.set_state_from_flattened(state)
    return True


import functools as ft
import sys
from gym.vector import SyncVectorEnv

from envs.mujoco.ant_env import AntEnv


def make_model(contextualized_make_env, num_envs):
    model = SyncVectorEnv(
        [contextualized_make_env for _ in range(num_envs)],
        # worker=_worker_shared_memory_2,
    )
    model.sim_get_state = ft.partial(sim_get_state_sync, model)
    model.sim_set_state = ft.partial(sim_set_state_sync, model)

    return model
