import argparse
import os
import sys
import time

import jax
import jax.numpy as jnp
import numpy as np
import optax

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '.', '.'))
sys.path.append(project_root)
from Craftax.craftax.craftax_env import make_craftax_env_from_name

import wandb
from typing import NamedTuple

from flax.training import orbax_utils
from flax.training.train_state import TrainState
from orbax.checkpoint import (
    PyTreeCheckpointer,
    CheckpointManagerOptions,
    CheckpointManager,
)

from logz.batch_logging import batch_log, create_log_dict
from models.actor_critic import (
    ActorCritic,
    ActorCriticConv,
    ActorCriticConvSymbolicCraftax,
    ActorCriticMask,
    ActorCriticConvMask
)
from models.icm import ICMEncoder, ICMForward, ICMInverse
from wrappers import (
    LogWrapper,
    OptimisticResetVecEnvWrapper,
    BatchEnvWrapper,
    AutoResetEnvWrapper,
)

from Craftax.craftax.craftax_classic.stateinfo import state_info
from gymnax.environments import spaces


# import distrax
# from Craftax.craftax.craftax_classic.renderer import render_craftax_pixels
# Code adapted from the original implementation made by Chris Lu
# Original code located at https://github.com/luchris429/purejaxrl


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    action_mask: jnp.ndarray
    epsilon: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    next_obs: jnp.ndarray
    info: jnp.ndarray
    feasible_action: jnp.ndarray


def make_train(config):
    config["NUM_UPDATES"] = (
            config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
            config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )

    env = make_craftax_env_from_name(
        config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"]
    )
    env_params = env.default_params

    env = LogWrapper(env)

    if config["GOAL_TYPE"] in ["SGRL","SGRL_nopri"]:
        from Craftax.craftax.craftax_classic.text_goal import GoalandMask
        pixels_goal_generator = GoalandMask(
            num_envs=config["NUM_ENVS"],
            goal_type=config["GOAL_TYPE"],
            lm_name=config["LM_NAME"],
            env_name=config["ENV_NAME"],
            alg_name=config["ALG_NAME"],
        )

    elif config["GOAL_TYPE"] == "LM":
        from Craftax.craftax.craftax_classic.text_goal import LlmGoal
        pixels_goal_generator = LlmGoal(
            lm_name=config["LM_NAME"],
            env_name=config["ENV_NAME"],
            num_envs=config["NUM_ENVS"],
            alg_name=config["ALG_NAME"],
            goal_type=config["GOAL_TYPE"]
        )


    if not config["USE_GOAL"]:
        OBS_SHAPE = env.observation_space(env_params)
    else:
        OBS_SHAPE = spaces.Box(
            0.0,
            1.0,
            (63 * 63 * 3 + 384,),
            dtype=jnp.float32,
        )

    if config["USE_OPTIMISTIC_RESETS"]:
        env = OptimisticResetVecEnvWrapper(
            env,
            num_envs=config["NUM_ENVS"],
            reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], config["NUM_ENVS"]),
        )
    else:
        env = AutoResetEnvWrapper(env)
        env = BatchEnvWrapper(env, num_envs=config["NUM_ENVS"])

    def linear_schedule(count):
        frac = (
                1.0
                - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
                / config["NUM_UPDATES"]
        )
        return config["LR"] * frac

    def train(rng):
        # INIT NETWORK
        if config["USE_GOAL"]:
            if config["GOAL_TYPE"] in ["SGRL","SGRL_nopri"]:
                network = ActorCriticMask(
                    env.action_space(env_params).n,
                    env.observation_space(env_params).shape, config["LAYER_SIZE"]
                )
            else:
                network = ActorCriticConvSymbolicCraftax(
                    env.action_space(env_params).n,
                    env.observation_space(env_params).shape, config["LAYER_SIZE"]
                )
        elif "Pixels" in config["ENV_NAME"]:
            if config["GOAL_TYPE"] in ["SGRL","SGRL_nopri"]:
                network = ActorCriticConvMask(
                    env.action_space(env_params).n, config["LAYER_SIZE"]
                )
            else:
                network = ActorCriticConv(
                    env.action_space(env_params).n, config["LAYER_SIZE"]
                )
        elif "Symbolic" in config["ENV_NAME"]:
            network = ActorCritic(env.action_space(env_params).n, config["LAYER_SIZE"])

        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros((1, *OBS_SHAPE.shape))
        epsilon = jnp.zeros(config["NUM_ENVS"])
        if config["GOAL_TYPE"] in ["SGRL","SGRL_nopri"]:
            init_m = jnp.zeros((1, env.action_space(env_params).n))
            network_params = network.init(_rng, init_x, init_m, epsilon)
        else:
            network_params = network.init(_rng, init_x)
        b1 = 0.9
        b2 = 0.999
        wd = 0
        if config["L2"]:
            b1 = config["B1"]
            b2 = config["B2"]
            wd = config["WD"]
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.add_decayed_weights(weight_decay=wd),
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                # optax.adam(learning_rate=linear_schedule, eps=1e-5, b1=b1, b2=b2)
                optax.adamw(learning_rate=linear_schedule, eps=1e-5, b1=b1, b2=b2)
            )
        else:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.add_decayed_weights(weight_decay=wd),
                # optax.adam(config["LR"], eps=1e-5, b1=b1, b2=b2),
                optax.adamw(config["LR"], eps=1e-5, b1=b1, b2=b2),
            )
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )

        # Exploration state
        ex_state = {
            "icm_encoder": None,
            "icm_forward": None,
            "icm_inverse": None,
            "e3b_matrix": None,
        }

        # INIT ENV
        rng, _rng = jax.random.split(rng)
        obsv, env_state = env.reset(_rng, env_params)
        action_mask = None
        action = 0

        if config["GOAL_TYPE"] in ["SGRL","SGRL_nopri"]:
            view_arr = jax.vmap(state_info, in_axes=(0))(
                env_state.env_state)
            (goal_encoder, action_mask) = jax.experimental.io_callback(
                pixels_goal_generator.get_pixels_goals,
                (jnp.zeros((config["NUM_ENVS"], 384), dtype=jnp.float32),
                 jnp.zeros((config["NUM_ENVS"], env.action_space(env_params).n))),
                view_arr,
                obsv.shape[0],
            )

        elif config["USE_GOAL"]:
            view_arr = jax.vmap(state_info, in_axes=(0))(env_state.env_state)
            goal_encoder = jax.experimental.io_callback(
                pixels_goal_generator.get_pixels_goals,
                (jnp.zeros((config["NUM_ENVS"], 384), dtype=jnp.float32)),
                view_arr,
                obsv.shape[0])
        if config["USE_GOAL"]:
            image_obsv = obsv.reshape((config["NUM_ENVS"], -1))
            obsv = jnp.concatenate([image_obsv, goal_encoder], axis=-1)

        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                (
                    train_state,
                    env_state,
                    last_obs,
                    ex_state,
                    rng,
                    update_step,
                    action_mask,
                    epsilon
                ) = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                if config["GOAL_TYPE"] in ["SGRL","SGRL_nopri"]:

                    def get_epsilon():
                        epsilon_type = config["EPSILON_TYPE"]
                        if epsilon_type == "linearann":
                            epsilon = config["MIN_EPSILON"] + (config["EPSILON"] - config["MIN_EPSILON"]) * (1 - update_step / config["NUM_UPDATES"])
                            epsilon = 1 - epsilon
                        elif epsilon_type == "expann":
                            epsilon = config["MIN_EPSILON"]  + (config["EPSILON"] - config["MIN_EPSILON"] ) * jnp.exp(-update_step / config["DECAY_STEPS"])
                            epsilon = 1 - epsilon
                        elif epsilon_type == "3stagelinear":
                            condition = update_step < 0.4 * config["NUM_UPDATES"]
                            y = jnp.where(
                                condition,
                                1 - (1 / (0.4 * config["NUM_UPDATES"]))* update_step,
                                (1 / (0.4 * config["NUM_UPDATES"]))* update_step - 1,
                            )
                            epsilon = jnp.maximum(config["MIN_EPSILON"], y)
                        elif epsilon_type == "3stagecos":
                            # cos1-0-1-1
                            condition = update_step < 0.8 * config["NUM_UPDATES"]
                            y = jnp.where(
                                condition,
                                0.5 + 0.5 * jnp.cos(2 * jnp.pi * (update_step / (0.8 * config["NUM_UPDATES"]))),
                                1.0
                            )
                            epsilon = jnp.maximum(config["MIN_EPSILON"], y)
                        return epsilon
                    epsilon_v = jnp.minimum(config["EPSILON"], get_epsilon())
                    epsilon = jnp.full(config["NUM_ENVS"], epsilon_v)

                    pi, value = network.apply(train_state.params, last_obs, action_mask, epsilon)
                else:
                    pi, value = network.apply(train_state.params, last_obs)

                action = pi.sample(seed=_rng)
                if config["GOAL_TYPE"] in ["SGRL","SGRL_nopri"]:
                    action_expanded = action[:, None]
                    feasible_action = jnp.take_along_axis(action_mask, action_expanded, axis=1).squeeze()
                else:
                    feasible_action=None
                log_prob = pi.log_prob(action)

                # STEP ENV
                rng, _rng = jax.random.split(rng)

                obsv, env_state, reward, done, info = env.step(
                    _rng, env_state, action, env_params
                )

                if config["GOAL_TYPE"] in ["SGRL","SGRL_nopri"]:
                    view_arr = jax.vmap(state_info, in_axes=(0))(
                        env_state.env_state)
                    (goal_encoder, action_mask) = jax.experimental.io_callback(
                        pixels_goal_generator.get_pixels_goals,
                        (jnp.zeros((config["NUM_ENVS"], 384), dtype=jnp.float32),
                         jnp.zeros((config["NUM_ENVS"], 17), dtype=jnp.float32)),
                        view_arr,
                        obsv.shape[0])
                elif config["USE_GOAL"]:
                    view_arr = jax.vmap(state_info, in_axes=(0))(
                        env_state.env_state)
                    (goal_encoder) = jax.experimental.io_callback(
                        pixels_goal_generator.get_pixels_goals,
                        (jnp.zeros((config["NUM_ENVS"], 384), dtype=jnp.float32)),
                        view_arr,
                        obsv.shape[0])
                if config["USE_GOAL"]:
                    image_obsv = obsv.reshape((config["NUM_ENVS"], -1))
                    obsv = jnp.concatenate([image_obsv, goal_encoder], axis=-1)

                transition = Transition(
                    done=done,
                    action=action,
                    action_mask=action_mask,
                    epsilon=epsilon,
                    value=value,
                    reward=reward,
                    log_prob=log_prob,
                    obs=last_obs,
                    next_obs=obsv,
                    info=info,
                    feasible_action=feasible_action,
                )
                runner_state = (
                    train_state,
                    env_state,
                    obsv,
                    ex_state,
                    rng,
                    update_step,
                    action_mask,
                    epsilon
                )

                return runner_state, transition

            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # CALCULATE ADVANTAGE
            (
                train_state,
                env_state,
                last_obs,
                ex_state,
                rng,
                update_step,
                action_mask,
                epsilon
            ) = runner_state
            if config["GOAL_TYPE"] in ["SGRL","SGRL_nopri"]:
                _, last_val = network.apply(train_state.params, last_obs, action_mask, epsilon)
            else:
                _, last_val = network.apply(train_state.params, last_obs)

            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    delta = reward + config["GAMMA"] * next_value * (1 - done) - value
                    gae = (
                            delta
                            + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                    )
                    return (gae, value), gae

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info

                    # Policy/value network
                    def _loss_fn(params, traj_batch, gae, targets):
                        # RERUN NETWORK
                        if config["GOAL_TYPE"] in ["SGRL","SGRL_nopri"]:
                            pi, value = network.apply(params, traj_batch.obs, traj_batch.action_mask,
                                                      traj_batch.epsilon)
                        else:
                            pi, value = network.apply(params, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)

                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                                value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = (
                                0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
                        )

                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor1 = ratio * gae
                        loss_actor2 = (
                                jnp.clip(
                                    ratio,
                                    1.0 - config["CLIP_EPS"],
                                    1.0 + config["CLIP_EPS"],
                                )
                                * gae
                        )
                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                        loss_actor = loss_actor.mean()
                        entropy = pi.entropy().mean()

                        total_loss = (
                                loss_actor
                                + config["VF_COEF"] * value_loss
                                - config["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy)

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grads = grad_fn(
                        train_state.params, traj_batch, advantages, targets
                    )
                    train_state = train_state.apply_gradients(grads=grads)

                    losses = (total_loss, 0)
                    return train_state, losses

                (
                    train_state,
                    traj_batch,
                    advantages,
                    targets,
                    rng,
                ) = update_state
                rng, _rng = jax.random.split(rng)
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                        batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
                ), "batch size must be equal to number of steps * number of envs"
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)
                batch = jax.tree.map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree.map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                minibatches = jax.tree.map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                train_state, losses = jax.lax.scan(
                    _update_minbatch, train_state, minibatches
                )
                update_state = (
                    train_state,
                    traj_batch,
                    advantages,
                    targets,
                    rng,
                )
                return update_state, losses

            update_state = (
                train_state,
                traj_batch,
                advantages,
                targets,
                rng,
            )
            update_state, loss_info = jax.lax.scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            )

            train_state = update_state[0]
            metric = jax.tree.map(
                lambda x: (x * traj_batch.info["returned_episode"]).sum()
                          / traj_batch.info["returned_episode"].sum(),
                traj_batch.info,
            )

            rng = update_state[-1]

            # UPDATE EXPLORATION STATE
            def _update_ex_epoch(update_state, unused):
                def _update_ex_minbatch(ex_state, traj_batch):
                    def _inverse_loss_fn(
                            icm_encoder_params, icm_inverse_params, traj_batch
                    ):
                        latent_obs = ex_state["icm_encoder"].apply_fn(
                            icm_encoder_params, traj_batch.obs
                        )
                        latent_next_obs = ex_state["icm_encoder"].apply_fn(
                            icm_encoder_params, traj_batch.next_obs
                        )

                        action_pred_logits = ex_state["icm_inverse"].apply_fn(
                            icm_inverse_params, latent_obs, latent_next_obs
                        )
                        true_action = jax.nn.one_hot(
                            traj_batch.action, num_classes=action_pred_logits.shape[-1]
                        )

                        bce = -jnp.mean(
                            jnp.sum(
                                action_pred_logits
                                * true_action
                                * (1 - traj_batch.done[:, None]),
                                axis=1,
                            )
                        )

                        return bce * config["ICM_INVERSE_LOSS_COEF"]

                    inverse_grad_fn = jax.value_and_grad(
                        _inverse_loss_fn,
                        has_aux=False,
                        argnums=(
                            0,
                            1,
                        ),
                    )
                    inverse_loss, grads = inverse_grad_fn(
                        ex_state["icm_encoder"].params,
                        ex_state["icm_inverse"].params,
                        traj_batch,
                    )
                    icm_encoder_grad, icm_inverse_grad = grads
                    ex_state["icm_encoder"] = ex_state["icm_encoder"].apply_gradients(
                        grads=icm_encoder_grad
                    )
                    ex_state["icm_inverse"] = ex_state["icm_inverse"].apply_gradients(
                        grads=icm_inverse_grad
                    )

                    def _forward_loss_fn(icm_forward_params, traj_batch):
                        latent_obs = ex_state["icm_encoder"].apply_fn(
                            ex_state["icm_encoder"].params, traj_batch.obs
                        )
                        latent_next_obs = ex_state["icm_encoder"].apply_fn(
                            ex_state["icm_encoder"].params, traj_batch.next_obs
                        )

                        latent_next_obs_pred = ex_state["icm_forward"].apply_fn(
                            icm_forward_params, latent_obs, traj_batch.action
                        )

                        error = (latent_next_obs - latent_next_obs_pred) * (
                                1 - traj_batch.done[:, None]
                        )
                        return (
                                jnp.square(error).mean() * config["ICM_FORWARD_LOSS_COEF"]
                        )

                    forward_grad_fn = jax.value_and_grad(
                        _forward_loss_fn, has_aux=False
                    )
                    forward_loss, icm_forward_grad = forward_grad_fn(
                        ex_state["icm_forward"].params, traj_batch
                    )
                    ex_state["icm_forward"] = ex_state["icm_forward"].apply_gradients(
                        grads=icm_forward_grad
                    )

                    losses = (inverse_loss, forward_loss)
                    return ex_state, losses

                (ex_state, traj_batch, rng) = update_state
                rng, _rng = jax.random.split(rng)
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                        batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
                ), "batch size must be equal to number of steps * number of envs"
                permutation = jax.random.permutation(_rng, batch_size)
                batch = jax.tree.map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), traj_batch
                )
                shuffled_batch = jax.tree.map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                minibatches = jax.tree.map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                ex_state, losses = jax.lax.scan(
                    _update_ex_minbatch, ex_state, minibatches
                )
                update_state = (ex_state, traj_batch, rng)
                return update_state, losses

            if config["GOAL_TYPE"] in ["SGRL","SGRL_nopri"]:
                metric["epsilon_mask"] = traj_batch.epsilon.mean()
                metric["feasible_action"] = jnp.sum(traj_batch.feasible_action) / (
                            traj_batch.feasible_action.shape[0] * traj_batch.feasible_action.shape[1])


            # wandb logging
            if config["DEBUG"] and config["USE_WANDB"]:
                def callback(metric, update_step):
                    to_log = create_log_dict(metric, config)
                    batch_log(update_step, to_log, config)

                jax.debug.callback(
                    callback,
                    metric,
                    update_step,
                )

            runner_state = (
                train_state,
                env_state,
                last_obs,
                ex_state,
                rng,
                update_step + 1,
                action_mask,
                epsilon
            )
            return runner_state, metric

        rng, _rng = jax.random.split(rng)
        runner_state = (
            train_state,
            env_state,
            obsv,
            ex_state,
            _rng,
            0,
            action_mask,
            epsilon
        )
        update_step_t1 = time.perf_counter()
        runner_state, metric = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )
        update_step_t2 = time.perf_counter()
        jax.debug.print(f"Update time: {update_step_t2 - update_step_t1}")

        # for _ in range(int(config["NUM_UPDATES"])):
        #     runner_state, metric = _update_step(runner_state, None)
        return {"runner_state": runner_state}  # , "info": metric}

    return train


def run_ppo(config):
    config = {k.upper(): v for k, v in config.__dict__.items()}

    if config["USE_WANDB"]:
        project = "Craftax-Classic-Pixels-v1" + "-" + str(int(config["TOTAL_TIMESTEPS"] // 1e6)) + "M"
        os.environ["WANDB_MODE"] = "offline"
        wandb.init(
            # project=config["WANDB_PROJECT"],
            project=project,
            entity=config["WANDB_ENTITY"],
            config=config,
            group=config["ALG_NAME"],
            name=config["ENV_NAME"]
                 + "-"
                 + config["ALG_NAME"]
        )

    rng = jax.random.PRNGKey(config["SEED"])
    rngs = jax.random.split(rng, config["NUM_REPEATS"])
    t0 = time.time()
    train_fn = make_train(config)
    results = []
    for rng in rngs:
        result = train_fn(rng)
        results.append(result)

    out = jax.tree_util.tree_map(lambda *xs: jnp.array(xs), *results)
    t1 = time.time()
    print("Time to run experiment", t1 - t0)
    print("SPS: ", config["TOTAL_TIMESTEPS"] / (t1 - t0))

    if config["USE_WANDB"]:

        def _save_network(rs_index, dir_name):
            train_states = out["runner_state"][rs_index]
            train_state = jax.tree.map(lambda x: x[0], train_states)
            orbax_checkpointer = PyTreeCheckpointer()
            options = CheckpointManagerOptions(max_to_keep=1, create=True)
            path = os.path.join(wandb.run.dir, dir_name)
            checkpoint_manager = CheckpointManager(path, orbax_checkpointer, options)
            print(f"saved runner state to {path}")
            save_args = orbax_utils.save_args_from_target(train_state)
            checkpoint_manager.save(
                config["TOTAL_TIMESTEPS"],
                train_state,
                save_kwargs={"save_args": save_args},
            )

        if config["SAVE_POLICY"]:
            _save_network(0, "policies")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_name", type=str, default="Craftax-Classic-Pixels-v1")
    parser.add_argument("--alg_name", type=str, default="PPO_LM_Bestgoal_reward")
    parser.add_argument("--num_envs", type=int, default=256)
    parser.add_argument("--total_timesteps", type=lambda x: int(float(x)), default=5e6)  # Allow scientific notation default=1e9
    parser.add_argument("--lr", type=float, default=7e-4)
    parser.add_argument("--num_steps", type=int, default=16)
    parser.add_argument("--update_epochs", type=int, default=4)
    parser.add_argument("--num_minibatches", type=int, default=8)
    parser.add_argument("--gamma", type=float, default=0.97)
    parser.add_argument("--gae_lambda", type=float, default=0.8)
    parser.add_argument("--clip_eps", type=float, default=0.14)
    parser.add_argument("--ent_coef", type=float, default=0.01)
    parser.add_argument("--vf_coef", type=float, default=0.5)
    parser.add_argument("--max_grad_norm", type=float, default=1.0)
    parser.add_argument("--activation", type=str, default="tanh")
    parser.add_argument("--anneal_lr", action=argparse.BooleanOptionalAction, default=True)
    parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=True)
    parser.add_argument("--jit", action=argparse.BooleanOptionalAction, default=True)
    parser.add_argument("--seed", type=int)
    parser.add_argument("--use_wandb", action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument("--save_policy", action="store_true")
    parser.add_argument("--num_repeats", type=int, default=1)
    parser.add_argument("--layer_size", type=int, default=512)
    parser.add_argument("--wandb_project", type=str)
    parser.add_argument("--wandb_entity", type=str)
    parser.add_argument(
        "--use_optimistic_resets", action=argparse.BooleanOptionalAction, default=True  ##True todo
    )
    parser.add_argument("--optimistic_reset_ratio", type=int, default=16)

    # EXPLORATION
    parser.add_argument("--exploration_update_epochs", type=int, default=4)
    # L2
    parser.add_argument("--l2", action="store_true")
    parser.add_argument("--b1", type=float, default=0.99)
    parser.add_argument("--b2", type=float, default=0.99)
    parser.add_argument("--wd", type=float, default=1e-3)
    # LM
    parser.add_argument("--lm_name", type=str, default="THUDM/GLM-4-9B-0414",
                        help="THUDM/glm-4-9b-chat, THUDM/GLM-4-9B-0414, deepseek-ai/DeepSeek-R1-Distill-Qwen-7B, Qwen/Qwen2.5-7B-Instruct, Qwen/Qwen3-8B")  # config["get_goal"]
    parser.add_argument("--goal_type", type=str, default="None",
                        help="None, LM, SGRL, SGRL_nopri")
    parser.add_argument("--sgt", type=int, default=1, help="short_goal_threshold")
    parser.add_argument("--lgt", type=int, default=1, help="long_goal_threshold")
    parser.add_argument("--epsilon", type=float, default=1)
    parser.add_argument("--min_epsilon", type=float, default=1e-6)
    parser.add_argument("--decay_steps", type=float, default=600.0)
    parser.add_argument("--use_goal", action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument("--epsilon_type", type=str, default="3stagecos", help="3stagecos, linearann,expann,3stagelinear")

    args, rest_args = parser.parse_known_args(sys.argv[1:])

    if rest_args:
        raise ValueError(f"Unknown args {rest_args}")

    if args.seed is None:
        args.seed = np.random.randint(2 ** 31)

    if args.jit:
        run_ppo(args)
    else:
        with jax.disable_jit():
            run_ppo(args)
