defaults:
  - env: mjx_dmc
  - platform: torch
  - experiment_overrides: default
  - _self_

hyperparameters:
  # env and run settings (mostly don't touch)
  total_time_steps: 50_000_000
  normalize_env: true
  max_episode_steps: 1000
  eval_interval: 2
  num_eval: 20

  # optimization settings (seem very stable)
  lr: 3e-4
  temperature_lagragian_lr: 3e-4
  anneal_lr: false
  max_grad_norm: 0.5
  polyak: 1.0 # maybe ablate ?
  env_action_clip_value: 0.999
  action_clip_value: 0.999
  # problem discount settings (need tuning)
  gamma: 0.99
  lmbda: 0.95
  lmbda_min: 0.50 # irrelevant if no exploration noise is added

  # batch settings (need tuning for MJX humanoid)
  num_steps: 128
  num_mini_batches: 128
  num_envs: 1024
  num_epochs: 4
  use_lax_scan: true

  # exploration settings (currently not touched)
  exploration_noise_max: 1.0
  exploration_noise_min: 1.0
  exploration_base_envs: 0

  # critic architecture settings (need to be increased for MJX humanoid)
  critic_hidden_dim: 512
  actor_hidden_dim: 512
  vmin: ${env.vmin}
  vmax: ${env.vmax}
  num_bins: 151
  hl_gauss: true
  use_critic_norm: true
  num_critic_encoder_layers: 2
  num_critic_head_layers: 2
  num_critic_pred_layers: 2
  use_simplical_embedding: False
  use_critic_skip: False

  # actor architecture settings (seem stable)
  use_actor_norm: true
  num_actor_layers: 3
  actor_min_std: 0.0
  use_actor_skip: False

  # actor & critic loss settings (seem remarkably stable)
  ## kl settings
  kl_start: 0.01
  kl_bound: 0.1 # switched to tighter bounds for MJX
  reduce_kl: true
  reverse_kl: false # previous default "false"
  kl_action_rep: 4
  update_kl_lagrangian: true
  actor_kl_clip_mode: "clipped" # "full", "clipped", "kl_relu_clipped", "kl_bound_clipped", "value"
  ## entropy settings
  ent_start: 0.01
  ent_target_mult: 4.0
  update_entropy_lagrangian: true
  ## auxiliary loss settings
  aux_loss_mult: 1.0
  # torso COM logging (MJX only)
  log_torso_com: false
  log_torso_com_num_envs: 30
  log_torso_com_stride: 1

  # DIME diffusion settings
  diffusion:
    name: "dis"
    diff_steps: 8
    init_std: 2.5
    friction: 1.0
    use_friction_mlp: false
    friction_mlp_hidden: 64
    friction_mlp_layers: 2
    friction_num_time_hid: 32
    friction_num_time_out: 16
    friction_mlp_use_obs: true
    integrator: "EM"

    learn_forward: true
    learn_backward: false
    learn_prior: false
    learn_betas: false
    learn_friction: true
    learn_mass_matrix: false

    dt: 0.125
    learn_dt: true
    per_step_dt: true
    per_dim_friction: true
    use_step_size_scheduler: false

    score_model:
      use_path_gradient: false
      use_target_score: false
      num_layers: 4
      num_hid: 256
      num_time_hid: 32
      num_time_out: 16
      outer_clip: 1e4
      inner_clip: 1e2
      weight_init: 1e-8
      bias_init: 0.
      layer_norm: false
      layer_norm_type: "LayerNorm"

    dt_schedule:
      _target_: src.networks.diffusion.schedulers.get_cosine_schedule
      total_steps: ${hyperparameters.diffusion.diff_steps}
      min: 0.001
      s: 0.008
      pow: 2

  # Each coefficient will log metrics as: eval/episode_return_ode_XXX where XXX is coef*100
  # ode_coefs: [0.5, 1.0, 2.0]
  # ode_coefs: [1.0]
  ode_coefs: []
measure_burnin: 3

name: "reppo-dime-debug-1"
seed: 0
num_seeds: 1
tune: false
checkpoint_dir: "./checkpoints"  # Directory to save checkpoints
checkpoint_path: null  # Path to load checkpoint from (for resuming training)
save_checkpoint_interval: 0  # Save checkpoint every N steps (0 = no periodic saving)
save_final_checkpoint: false  # Save checkpoint at the end of training
wandb_upload_checkpoints: false  # Upload checkpoints to wandb as artifacts

# Evaluation-only mode parameters
eval_only: false  # Set to true to run evaluation only (requires checkpoint_path)
eval_episodes: 10  # Number of episodes to run in eval-only mode

# Trajectory tracking parameters
save_trajectories: false  # Save end-effector trajectories during evaluation
trajectory_dir: "./trajectories"  # Directory to save trajectory data and plots
plot_trajectories: false  # Generate visualization plots (set to false to only save data)

num_trials: 1
tags: ["experimental"]
wandb:
  mode: "online" # set to online to activate wandb
  entity: ""
  project: "dime_${env.name}"
  project_suffix: ""

hydra:
  job:
    chdir: True
