from rl_training.wrappers import (
    ClipActionRewardIRL,
    VecEnvRewardIRL,
    Braxgymnax_localWrapper,
    PenalizeActions,
    NormalizeVecReward,
)

import jax.numpy as jnp



BRAX_CONFIG = {
    "LR": 3e-4,
    "NUM_ENVS": 2048,
    "NUM_STEPS": 10,
    "NUM_UPDATES": 5e7 // 2048 // 10,
    "ORIG_NUM_UPDATES": 5e7 // 2048 // 10,
    "TOTAL_TIMESTEPS": 5e7,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 32,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.0,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "tanh",
    "ENV_NAME": "brax",
    "ANNEAL_LR": False,
    "NORMALIZE_OBS": True,
    "DEBUG": False,
    "DISCRETE": False,
}



def get_env(env_name, backend="positional", indeces=None, normalize=False):
    print("Creating env", env_name)
    if not is_brax_env(env_name):
        raise NotImplementedError("Only brax is implemented")
    env, env_params = Braxgymnax_localWrapper(env_name, backend=backend), None
    env = ClipActionRewardIRL(env)
    config = BRAX_CONFIG.copy()
    config["ENV_NAME"] = env_name

    if indeces is not None:
        print("Penalizing actions", indeces)
        env = PenalizeActions(env, indeces=jnp.array(indeces))
    env = VecEnvRewardIRL(env)

    if normalize:
        env = NormalizeVecReward(env, config["GAMMA"])

    return env, env_params, config



def is_brax_env(env_name):
    return (
        env_name == "hopper"
        or env_name == "halfcheetah"
        or env_name == "humanoid"
        or env_name == "ant"
        or env_name == "walker2d"
    )

