from .reach_avoid.pendulum_constraint import PendulumConstraint
from .reach_avoid.hopper_avoid_ceiling import HopperAvoidCeiling
from .reach_avoid.half_cheetah_avoid import HalfCheetahAvoid

from .wrappers import TransformObservation, AppendRegionOneHot

from functools import partial
import jax.numpy as jnp

def transform_observation(mean, variance, obs):
    return (obs - mean) / variance

def get_env(config):
    if config["EXP_NAME"] == 'PendulumConstraint':
        trans = partial(transform_observation, jnp.array([0., 0., 0.]), jnp.array([1., 1., 1.]))
        env = PendulumConstraint()
        env = TransformObservation(env, trans)
        # env = AppendRegionOneHot(env)
    elif config["EXP_NAME"] == 'HopperAvoidCeiling':
        vec1 = jnp.zeros(12, dtype=jnp.float32)
        vec1 = vec1.at[0].set(1.)
        # vec1 = vec1.at[-1].set(400.)
        vec2 = jnp.ones(12, dtype=jnp.float32)
        # vec2 = vec2.at[-1].set(400.)
        trans = partial(transform_observation, vec1, vec2)
        env = HopperAvoidCeiling()
        env = TransformObservation(env, trans)
        env = AppendRegionOneHot(env)
    elif config["EXP_NAME"] == 'HalfCheetahAvoid':
        vec1 = jnp.zeros(18, dtype=jnp.float32)
        vec1 = vec1.at[0].set(2.5)
        vec2 = jnp.ones(18, dtype=jnp.float32)
        vec2 = vec2.at[0].set(3.)
        trans = partial(transform_observation, vec1, vec2)
        env = HalfCheetahAvoid()
        env = TransformObservation(env, trans)
        env = AppendRegionOneHot(env)
    elif config["EXP_NAME"] == 'F16Avoid':
        from .reach_avoid.F16_avoid import F16Avoid
        vec1 = jnp.zeros(24, dtype=jnp.float32)
        # vec1 = vec1.at[-1].set(400.)
        vec2 = jnp.ones(24, dtype=jnp.float32)
        # vec2 = vec2.at[-1].set(400.)
        trans = partial(transform_observation, vec1, vec2)
        env = F16Avoid()
        env = TransformObservation(env, trans)
        env = AppendRegionOneHot(env)
    elif config["EXP_NAME"] == 'PointGoal':
        from .reach_avoid.point_goal_avoid import PointGoalJax
        vec1 = jnp.zeros(44, dtype=jnp.float32)
        # vec1 = vec1.at[-1].set(400.)
        vec2 = jnp.ones(44, dtype=jnp.float32)
        # vec2 = vec2.at[-1].set(400.)
        trans = partial(transform_observation, vec1, vec2)
        env = PointGoalJax()
        env = TransformObservation(env, trans)
        # env = AppendRegionOneHot(env)
    else:
        raise Exception("No Given Environment")
    return env
