defaults:
  - env: mjx_dmc
  - platform: torch
  - experiment_overrides: default
  - _self_

hyperparameters:
  # env and run settings (mostly don't touch)
  total_time_steps: 50_000_000
  normalize_env: true
  max_episode_steps: 1000
  eval_interval: 2
  num_eval: 20

  # optimization settings (seem very stable)
  lr: 3e-4
  temperature_lr: 3e-4
  temperature_lr_mult: 1.0
  lagrangian_lr: 3e-4
  lagrangian_lr_mult: 1.0
  anneal_lr: false
  max_grad_norm: 0.5
  polyak: 1.0 # maybe ablate ?
  env_action_clip_value: 0.999
  action_clip_value: 0.999
  train_mode: "reparam" # "WPO"
  disable_wpo_fisher_preconditioning: false
  disable_temperature: false

  # problem discount settings (need tuning)
  gamma: 0.99
  lmbda: 0.95
  lmbda_min: 0.50 # irrelevant if no exploration noise is added

  # batch settings (need tuning for MJX humanoid)
  num_steps: 128
  num_mini_batches: 128
  num_envs: 1024
  num_epochs: 4
  use_lax_scan: true

  # exploration settings (currently not touched)
  exploration_noise_max: 1.0
  exploration_noise_min: 1.0
  exploration_base_envs: 0

  # critic architecture settings (need to be increased for MJX humanoid)
  critic_hidden_dim: 512
  actor_hidden_dim: 512
  vmin: ${env.vmin}
  vmax: ${env.vmax}
  num_bins: 151
  hl_gauss: true
  use_critic_norm: true
  num_critic_encoder_layers: 2
  num_critic_head_layers: 2
  num_critic_pred_layers: 2
  use_simplical_embedding: False
  use_critic_skip: False

  # actor architecture settings (seem stable)
  use_actor_norm: true
  num_actor_layers: 3
  actor_min_std: 0.0
  use_actor_skip: False

  # actor & critic loss settings (seem remarkably stable)
  ## kl settings
  kl_start: 0.01
  kl_bound: 0.1 # switched to tighter bounds for MJX
  reduce_kl: true
  reverse_kl: false # previous default "false"
  update_kl_lagrangian: true
  actor_kl_clip_mode: "clipped" # "full", "clipped", "kl_relu_clipped", "kl_bound_clipped", "value"
  ## entropy settings
  ent_start: 0.01
  ent_target_mult: 0.5
  update_entropy_lagrangian: true
  ## auxiliary loss settings
  aux_loss_mult: 1.0
  # torso COM logging (MJX only)
  log_torso_com: false
  log_torso_com_num_envs: 30
  log_torso_com_stride: 1

measure_burnin: 3

name: "reppo"
seed: 0
num_seeds: 1
tune: false
checkpoint_dir: "./checkpoints"  # Directory to save checkpoints
checkpoint_path: null  # Path to load checkpoint from (for resuming training)
save_checkpoint_interval: 0  # Save checkpoint every N steps (0 = no periodic saving)
save_final_checkpoint: false  # Save checkpoint at the end of training
wandb_upload_checkpoints: false  # Upload checkpoints to wandb as artifacts

# Evaluation-only mode parameters
eval_only: false  # Set to true to run evaluation only (requires checkpoint_path)
eval_episodes: 10  # Number of episodes to run in eval-only mode

# Trajectory tracking parameters
save_trajectories: false  # Save end-effector trajectories during evaluation
trajectory_dir: "./trajectories"  # Directory to save trajectory data and plots
plot_trajectories: false  # Generate visualization plots (set to false to only save data)

num_trials: 1
tags: ["experimental"]
wandb:
  mode: "online" # set to online to activate wandb
  entity: ""
  project: "dime_${env.name}"
  project_suffix: ""

hydra:
  job:
    chdir: True
