from stable_baselines3.common.env_util import make_vec_env

from wrapper import *

def init_halfcheetah_env_trairl(env_name, self, joints_status):
    env = gym.make("HalfCheetah-v5", render_mode="rgb_array")
    env = DisabledHalfCheetah(env, self.reward_net, joints_status, self.encoders[env_name])

    return env

def init_halfcheetah_env_transfer(env_name, self):
    env = gym.make("HalfCheetah-v5", render_mode="rgb_array")
    env = CustomReward(env, self.reward_net, self.target_encoder)

    return env

def init_ant_env_trairl(env_name, self, joints_status):
    env = gym.make("Ant-v5", render_mode="rgb_array", terminate_when_unhealthy=False, include_cfrc_ext_in_observation=False)
    env = DisabledAnt(env, self.reward_net, joints_status, self.encoders[env_name])

    return env

def init_ant_env(env_name, self, **kwargs):
    env = gym.make("Ant-v5", render_mode="rgb_array", **kwargs)
    env = CustomReward(env, self.reward_net, self.target_encoder)

    return env