import numpy as np

try:
    import gymnasium as gym

    old_api = False

except Exception:
    import gym

    old_api = True


class SequentialMultiEnvWrapper(gym.Env):
    def __init__(self, env_fns, seeds):
        self.envs = [env_fn() for env_fn in env_fns]
        self.seeds = seeds
        self.num_envs = len(self.envs)

        self.action_space = gym.spaces.Box(
            low=self.envs[0]
            .action_space.low[None]
            .repeat(len(self.envs), axis=0),
            high=self.envs[0]
            .action_space.high[None]
            .repeat(len(self.envs), axis=0),
            shape=(len(self.envs), self.envs[0].action_space.shape[0]),
            dtype=self.envs[0].action_space.dtype,
        )
        self.observation_space = gym.spaces.Box(
            low=self.envs[0]
            .observation_space.low[None]
            .repeat(len(self.envs), axis=0),
            high=self.envs[0]
            .observation_space.high[None]
            .repeat(len(self.envs), axis=0),
            shape=(len(self.envs), self.envs[0].observation_space.shape[0]),
            dtype=self.envs[0].observation_space.dtype,
        )

    def _reset_idx(self, idx):
        if old_api:
            obs = self.envs[idx].reset()
        else:
            obs, _ = self.envs[idx].reset()
        return obs

    def reset_where_done(self, observations, dones):
        for j, done in enumerate(dones):
            if done:
                observations[j], dones[j] = self._reset_idx(j), False
        return observations, dones

    def generate_masks(self, dones, infos):

        masks = []
        for done, info in zip(dones, infos):

            if not done or "TimeLimit.truncated" in info:
                mask = 1.0
            else:
                mask = 0.0

            masks.append(mask)
        masks = np.array(masks)

        return masks

    def reset(self):
        obs = []
        for idx, env in enumerate(self.envs):
            if old_api:
                ob = env.reset()
            else:
                ob, _ = env.reset(seed=self.seeds[idx])
            obs.append(ob)
        return np.stack(obs)

    def step(self, actions):
        obs, rews, dones, infos = [], [], [], []
        for env, action in zip(self.envs, actions):
            if old_api:
                ob, reward, done, info = env.step(action)
            else:
                ob, reward, terminated, truncated, info = env.step(action)
                done = terminated or truncated

            obs.append(ob)
            rews.append(reward)
            dones.append(done)
            infos.append(info)
        return np.stack(obs), np.stack(rews), np.stack(dones), infos
