from typing import Any, Dict, Optional, Type, Union

from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.ppo.policies import (
    ActorCriticCnnPolicy,
    ActorCriticPolicy,
)
from stable_baselines3.ppo.ppo import PPO
import gym

from ...utils.pre_conditions import (
    CNNGridWorldActionPreConditionWrapper,
    GridWorldActionPreConditionWrapper,
)
from .policy import ActorCriticPolicyPreConditionsWrapper


class PPOPreConditions(PPO):
    def __init__(
        self,
        policy: Union[str, Type[ActorCriticPolicy]],
        env: Union[GymEnv, str],
        policy_kwargs: Optional[Dict[str, Any]] = None,
        wrapper_class: Optional[Type[
            gym.core.Wrapper
        ]] = None,
        **kwargs
    ):

        policy = ActorCriticPolicyPreConditionsWrapper(policy)

        if wrapper_class is None:
            wrapper_class = (
                CNNGridWorldActionPreConditionWrapper
                if issubclass(policy, ActorCriticCnnPolicy)
                else GridWorldActionPreConditionWrapper
            )
        
        if isinstance(env, VecEnv):
            env.envs = [wrapper_class(e) for e in env.envs]
            single_env = env.envs[0]
        else:
            env = wrapper_class(env)
            single_env = env

        if policy_kwargs is None:
            policy_kwargs = {}
        policy_kwargs.setdefault("env", single_env)

        super().__init__(policy, env, policy_kwargs=policy_kwargs, **kwargs)
