import numpy as np
from copy import deepcopy

from gym.vector.vector_env import VectorEnvWrapper
from gym.vector.utils import concatenate
from gym.vector.sync_vector_env import SyncVectorEnv


class VecSyncWrapper(SyncVectorEnv):
    """ Debugging purpose """

    def __init__(self, **kwargs):
        SyncVectorEnv.__init__(self,
                               env_fns=kwargs["env_fns"],
                               observation_space=kwargs["env_fns"][0]().observation_space,
                               action_space=kwargs["env_fns"][0]().action_space,
                               copy=kwargs.get("copy", True))
        self._if_respawn = kwargs["syncEnv_if_respawn"]

    def execute(self, fn_name: str, args: list = None):
        """ Alternati, max_episode_steps=ve of execute API in AsynchronousEnv """
        res = list()
        for i in range(len(self.envs)):
            if fn_name == "set_seed":
                _res = self.envs[i].seed(seed=args[i])
            else:
                raise ValueError
            res.append(_res)
        return res, True

    def step(self, actions):
        self.step_async(actions)
        return self.step_wait()

    def step_wait(self):
        observations, infos = [], []
        for i, (env, action) in enumerate(zip(self.envs, self._actions)):
            observation, self._rewards[i], self._dones[i], info = env.step(action)
            if self._if_respawn:
                # if self._if_respawn or not self._if_eval:
                if self._dones[i]:
                    observation = env.reset()
            observations.append(observation)
            infos.append(info)
        if "TimeLimit.truncated" in infos:
            del infos["TimeLimit.truncated"]
        self.observations = concatenate(observations, self.observations, self.single_observation_space)
        return (
            deepcopy(self.observations) if self.copy else self.observations, np.copy(self._rewards),
            np.copy(self._dones),
            infos)

    def reset(self, **kwargs):
        self.reset_async()
        return self.reset_wait()

    def set_seed(self, seed_list: list):
        for _seed in seed_list:
            self.execute(fn_name="set_seed", args=seed_list)


def create_vector_environment(make_env, args: dict, if_train=True):
    env = VecSyncWrapper(env_fns=[lambda: make_env() for _ in range(args["num_envs"])], args=args,
                         syncEnv_if_respawn=if_train)
    env = VecWrapper(env=env, max_episode_steps=args["max_episode_steps"], if_pomdp=args["mjc_if_pomdp"])
    seed_list = [i + args["env_seed"] for i in range(args["num_envs"])]
    env.set_seed(seed_list=seed_list)  # Seed set for RecSim needs to be done externally
    env.reset()
    return env


class VecWrapper(VectorEnvWrapper):
    def __init__(self, env, max_episode_steps, if_pomdp=False):
        super().__init__(env=env)
        self._ts = 0
        self._if_pomdp = if_pomdp
        self._max_episode_steps = max_episode_steps

    def step_wait(self):
        o, r, d, i = self.env.step_wait()
        d[self._ts >= self._max_episode_steps] = True
        self._ts += 1
        if self._if_pomdp:
            o = self.render()
        return o, r, d, i

    def reset_wait(self):
        self._ts = 0
        o = self.env.reset_wait()
        if self._if_pomdp:
            o = self.render()
        return o

    def render(self, mode='human'):
        frames = list()
        for env in self.env.envs:
            _frame = env.render(mode="rgb_array")
            frames.append(_frame)
        return np.stack(frames)
