import logging
from typing import Any

import gymnasium as gym
from gymnasium import Env

from tianshou.env import BaseVectorEnv
from tianshou.highlevel.env import (
    EnvFactoryRegistered,
    EnvPoolFactory,
    VectorEnvType, EnvMode,
)

import env_.wrapper as wrapper_registry

envpool_is_available = False # No envpool for mo_gymnasium
# try:
#     import envpool
# except ImportError:
#     envpool_is_available = False
#     envpool = None

log = logging.getLogger(__name__)


def make_mo_gymnasium_env(
    task: str,
    seed: int,
    training_num: int,
    test_num: int,
    create_watch_env: bool = False,
    *args,
    **make_kwargs
) -> tuple[Env, BaseVectorEnv, BaseVectorEnv, BaseVectorEnv]:
    """Wrapper function for mo-gymnasium env.

    :return: a tuple of (single env, training envs, test envs).
    """
    env_factory = MOGymnasiumEnvFactory(task, seed, seed + training_num, **make_kwargs)
    envs = env_factory.create_envs(
        training_num,
        test_num,
        create_watch_env
    )

    return envs.env, envs.train_envs, envs.test_envs, envs.watch_env


class MOGymnasiumEnvFactory(EnvFactoryRegistered):
    def __init__(
        self,
        task: str,
        train_seed: int,
        test_seed: int,
        venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
        wrappers=None,
        **make_kwargs: Any,
    ) -> None:
        self._disable_timelimit = make_kwargs.pop('disable_timelimit', False)
        if len(make_kwargs.get('params', [])) > 0:
            make_kwargs = make_kwargs['params'][0]
        else:
            make_kwargs.pop('params', [])
        if wrappers is None:
            wrappers = []
        self.wrappers = wrappers
        super().__init__(
            task=task,
            train_seed=train_seed,
            test_seed=test_seed,
            venv_type=venv_type,
            envpool_factory=EnvPoolFactory() if envpool_is_available else None,
            **make_kwargs
        )

    def create_env(self, mode: EnvMode) -> Env:
        """Creates a single environment for the given mode.

        :param mode: the mode
        :return: an environment
        """
        wrapped = super().create_env(mode)
        for wrapper_info in self.wrappers:
            cls, kwargs = wrapper_info['class'], wrapper_info['kwargs']
            wrapper = getattr(wrapper_registry, cls)
            wrapped = wrapper(wrapped, **kwargs)

        if self._disable_timelimit:
            wrapped = self.remove_time_limit(wrapped)
        return wrapped

    def remove_time_limit(self, env):
        while isinstance(env, gym.Wrapper):
            if isinstance(env, gym.wrappers.TimeLimit):
                env = env.env  # skip TimeLimit
            else:
                env.env = self.remove_time_limit(env.env)
                return env
        return env
