defaults:
  - env: mjx_dmc
  - platform: torch
  - experiment_overrides: dmerl_default
  - _self_

hyperparameters:
  # env and run settings (mostly don't touch)
  total_time_steps: 50_000_000
  normalize_env: true
  max_episode_steps: 1000
  eval_interval: 2
  num_eval: 20
  log_torso_com: false
  log_torso_com_num_envs: 30
  log_torso_com_stride: 1

  # optimization settings (seem very stable)
  lr: 1e-4
  lr_decay_factor: 1.
  weight_decay: 0.0
  use_current_critic_for_actor_samples: true
  temperature_lr: 1e-4
  lagrangian_lr: 1e-4
  temp_lagrangian_optim: "adam"
  temp_lagrangian_adam_gamma1: 0.9
  temp_lagrangian_adam_gamma2: 0.999
  use_temp_lagrangian_ema_optim: false
  use_temp_lagrangian_post_adam_ema: false
  temp_lagrangian_ema_decay: 0.99
  project_unit_ball: false
  project_only_if_exceeds: false
  anneal_lr: false
  max_grad_norm: 0.5
  polyak: 1.0 # maybe ablate ?
  env_action_clip_value: 0.999
  action_clip_value: 0.999
  tanh_transform: false

  # problem discount settings (need tuning)
  gamma: 0.99
  lmbda: 0.95
  lmbda_min: 0.50 # irrelevant if no exploration noise is added

  # batch settings (need tuning for MJX humanoid)
  num_steps: 128
  num_collection_step_factor: 1
  num_mini_batches: 16
  num_envs: 1024
  num_epochs: 4
  use_lax_scan: true

  # exploration settings (currently not touched)
  exploration_noise_max: 1.0
  exploration_noise_min: 1.0
  exploration_base_envs: 0

  # critic architecture settings (need to be increased for MJX humanoid)
  critic_hidden_dim: 512
  actor_hidden_dim: 512
  vmin: ${env.vmin}
  vmax: ${env.vmax}
  num_bins: 151
  hl_gauss: true
  use_critic_norm: true
  num_critic_encoder_layers: 2
  num_critic_head_layers: 2
  num_critic_pred_layers: 2
  use_simplical_embedding: False
  use_critic_skip: False

  # actor architecture settings (seem stable)
  use_actor_norm: true
  num_actor_layers: 3
  actor_min_std: 0.0
  use_actor_skip: False
  train_mode: "reparam"

  # actor & critic loss settings (seem remarkably stable)
  ## kl settings
  kl_start: 0.3
  kl_bound: 0.1 # switched to tighter bounds for MJX
  kl_bound_fisher_precond: true
  remove_fisher_precond: true
  reduce_kl: true
  reverse_kl: false # previous default "false"
  use_W2_kl: false
  kl_action_rep: 4
  update_kl_lagrangian: true
  use_augmented_lagrangian_dual: false
  augmented_lagrangian_kl_coef: 1.0
  actor_kl_clip_mode: "clipped" # "full", "clipped", "kl_relu_clipped", "kl_bound_clipped", "value"
  ## entropy settings
  ent_start: 0.01
  ent_target_mult: 3.
  update_entropy_lagrangian: true
  augmented_lagrangian_entropy_coef: 1.0
  # optional: exponentially decay temperature instead of learning it
  use_temperature_decay: false
  temperature_decay_start: ${hyperparameters.ent_start}
  temperature_decay_end: 10e-6
  temperature_decay_steps: ${hyperparameters.total_time_steps}
  # optionally learn temperature/kl multipliers with a tiny MLP
  use_temp_lagrangian_mlp: false
  temp_lagrangian_hidden: 32
  # scale the effective learning rates for temperature and KL lagrangian updates
  temperature_lr_mult: 1.0
  lagrangian_lr_mult: 1.0
  ## auxiliary loss settings
  aux_loss_mult: 0.15
  aux_loss_alpha: 0.95

  # DIME diffusion settings
  diffusion:
    name: "dis"
    diff_steps: 8
    init_std: 2.5
    friction: 1.0
    use_friction_mlp: false
    friction_mlp_hidden: 128
    friction_mlp_layers: 3
    friction_num_time_hid: 32
    friction_num_time_out: 16
    friction_mlp_use_obs: true
    integrator: "EM"

    learn_forward: true
    learn_backward: false
    learn_prior: false
    learn_betas: false
    learn_friction: false
    learn_mass_matrix: false

    dt: 0.125
    learn_dt: false
    per_step_dt: false
    per_dim_friction: true
    use_step_size_scheduler: false

    score_model:
      use_path_gradient: false
      use_target_score: false
      num_layers: 4
      num_hid: 256
      num_time_hid: 32
      num_time_out: 16
      outer_clip: 1e4
      inner_clip: 1e2
      weight_init: 1e-8
      bias_init: 0.
      layer_norm: false
      layer_norm_type: "LayerNorm"
      langevin_param: false

    dt_schedule:
      _target_: src.networks.diffusion.schedulers.get_cosine_schedule
      total_steps: ${hyperparameters.diffusion.diff_steps}
      min: 0.001
      s: 0.008
      pow: 2
    # dt_schedule:
    #   _target_: src.networks.diffusion.schedulers.get_constant_schedule
    # dt_schedule:
    #   _target_: src.networks.diffusion.schedulers.get_linear_schedule
    #   total_steps: ${hyperparameters.diffusion.diff_steps}
    #   min: 0.01

  # Each coefficient will log metrics as: eval/episode_return_ode_XXX where XXX is coef*100
  # ode_coefs: [0.5, 1.0, 2.0]
  # ode_coefs: [1.0]
  ode_coefs: []
measure_burnin: 3

name: "reppo-dmerl-debug-1"
seed: 0
num_seeds: 1
tune: false
checkpoint_dir: "./checkpoints"  # Directory to save checkpoints
checkpoint_path: null  # Path to load checkpoint from (for resuming training)
save_checkpoint_interval: 0  # Save checkpoint every N steps (0 = no periodic saving)
save_final_checkpoint: false  # Save checkpoint at the end of training
wandb_upload_checkpoints: false  # Upload checkpoints to wandb as artifacts

# Evaluation-only mode parameters
eval_only: false  # Set to true to run evaluation only (requires checkpoint_path)
eval_episodes: 10  # Number of episodes to run in eval-only mode

# Trajectory tracking parameters
save_trajectories: false  # Save end-effector trajectories during evaluation
trajectory_dir: "./trajectories"  # Directory to save trajectory data and plots
plot_trajectories: false  # Generate visualization plots (set to false to only save data)

num_trials: 1
tags: ["experimental"]
wandb:
  mode: "online" # set to online to activate wandb
  entity: ""
  project: "dime_${env.name}"
  project_suffix: ""

hydra:
  job:
    chdir: True
