import numpy as np


# Wrapper for batching environments together
class EnvBatcher():
    def __init__(self, env_fn, n):
        self.n = n
        self.envs = [env_fn() for _ in range(n)]
        self.dones = [True] * n

    def reset(self):
        observations = [env.reset() for env in self.envs]
        self.dones = [False] * self.n
        return np.array(observations)

    def render(self, *args, idx=None, **kwargs):
        if idx is None:
            return [env.render(*args, **kwargs) for env in self.envs]
        else:
            return self.envs[idx].render(*args, **kwargs)

    def step_single(self, action, idx):
        obs, reward, done, infos = self.envs[idx].step(action)
        self.dones[idx] = done or self.dones[idx]
        return obs, 0 if self.dones[idx] else reward, self.dones[idx], {}

    def step(self, actions):
        # Done mask to blank out observations and zero rewards for previously terminated environments
        done_mask = np.nonzero(np.array(self.dones))
        observations, rewards, dones, infos = zip(*[env.step(action) for env, action in zip(self.envs, actions)])

        # Env should remain terminated if previously terminated
        dones = [d or prev_d for d, prev_d in zip(dones, self.dones)]
        self.dones = dones

        observations = np.array(observations)
        rewards = np.array(rewards)
        dones = np.array(dones)

        observations[done_mask] = 0
        rewards[done_mask] = 0
        return observations, rewards, dones, infos

    def close(self):
        [env.close() for env in self.envs]
