from .continuous_wrapper import *
from .base_wrapper import *
from .vecenv import VecEnv
from .subproc_vecenv import SubProcVecEnv


def wrap_continuous_env(env, obs_norm, reward_scale):
    env = RewardShift(env, reward_scale)
    if obs_norm:
        return NormObs(env)
    return env


def get_env(env_id, env_param):
    env = gym.make(env_id)
    if str(env.__class__.__name__).find('TimeLimit') >= 0:
        env = TimeLimitAugment(env)
    env = BaseWrapper(env)
    if "rew_norm" in env_param:
        env = NormRet(env, **env_param["rew_norm"])
        del env_param["rew_norm"]

    ob_space = env.observation_space
    if len(ob_space.shape) == 3:
        env = wrap_deepmind(env, **env_param)
    else:
        env = wrap_continuous_env(env, **env_param)

    if isinstance(env.action_space, gym.spaces.Box):
        return NormAct(env)
    return env


def get_single_env(env_id, env_param):
    env = gym.make(env_id)
    if str(env.__class__.__name__).find('TimeLimit') >= 0:
        env = TimeLimitAugment(env)
    env = BaseWrapper(env)

    ob_space = env.observation_space
    if len(ob_space.shape) == 3:
        env = wrap_deepmind(env, **env_param)

    if "reward_scale" in env_param:
        env = RewardShift(env, env_param["reward_scale"])

    if isinstance(env.action_space, gym.spaces.Box):
        return NormAct(env)
    return env


def get_vec_env(env_id, env_param, vec_env_nums):
    vec_env = VecEnv(
        vec_env_nums, get_single_env,
        [env_id, env_param])

    if "obs_norm" in env_param and env_param["obs_norm"]:
        vec_env = NormObs(vec_env)
    return vec_env


def get_subprocvec_env(env_id, env_param, vec_env_nums, proc_nums):
    vec_env = SubProcVecEnv(
        proc_nums, vec_env_nums, get_single_env,
        [env_id, env_param])

    if "obs_norm" in env_param and env_param["obs_norm"]:
        vec_env = NormObs(vec_env)
    return vec_env
