from examples.mujoco.mujoco_env import MujocoEnvFactory, MujocoEnvObsRmsPersistence
from gymnasium import Env

from tianshou.env import BaseVectorEnv
from tianshou.highlevel.env import VectorEnvType, EnvMode


def make_mujoco_env(
    task: str,
    seed: int,
    num_train_envs: int,
    num_test_envs: int,
    obs_norm: bool,
    create_watch_env: bool = False,
    *args,
    **kwargs
) -> tuple[Env, BaseVectorEnv, BaseVectorEnv, BaseVectorEnv]:
    """Wrapper function for Mujoco env.

    If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.

    :return: a tuple of (single env, training envs, test envs).
    """
    envs = MujocoEnvFactory(task, seed, seed + num_train_envs, obs_norm=obs_norm).create_envs(
        num_train_envs,
        num_test_envs,
    )

    #Apparently we need to remove envpool to have render
    watch_env = None
    if create_watch_env:
        watch_env_factory = NoEnvPoolMujocoEnvFactory(task, seed, seed + num_train_envs, obs_norm=obs_norm)
        watch_env = watch_env_factory.create_env(EnvMode.WATCH)

        if obs_norm:
            watch_env.set_obs_rms(envs.train_envs.get_obs_rms())

    return envs.env, envs.train_envs, envs.test_envs, watch_env


class NoEnvPoolMujocoEnvFactory(MujocoEnvFactory):
    def __init__(
        self,
        task: str,
        train_seed: int,
        test_seed: int,
        obs_norm: bool = True,
        venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
    ) -> None:
        super(MujocoEnvFactory, self).__init__(
            task=task,
            train_seed=train_seed,
            test_seed=test_seed,
            venv_type=venv_type,
            envpool_factory=None,
        )
        self.obs_norm = obs_norm