command:
  - python3.10
  - ${program}
  - ${args_no_boolean_flags}

entity: 
method: grid
name: RL_M_OMD_linear_quadratic
program: mfax/algos/rl/algos/timed_m_omd.py

parameters:
  # --- logging ---
  debug:
    value: false
  log:
    value: true
  save:
    value: false
  wandb_project:
    value: mfax
  wandb_team:
    value: 
  wandb_group:
    value: mfax

  # --- environment and offline dataset ---
  task:
    value: linear_quadratic
  state_type:
    value: indices
  discount_factor:
    value: 0.99
  normalize_obs:
    value: true
  normalize_states:
    value: true
  partially_observable:
    value: true
  common_noise:
    value: true

  # --- m_omd hyperparameters ---
  algo:
    value: rl_m_omd
  seed:
    value: 0
  num_envs:
    value: 8
  num_agents_per_env:
    values: [8, 128, 1024]
  batch_size:
    value: 2048
  replay_buffer_capacity:
    value: 300000
  min_buffer_to_learn:
    value: 10000
  learn_every:
    value: 8
  epsilon_decay_duration_pct:
    value: 0.5
  epsilon_power:
    value: 1.0
  epsilon_start:
    value: 1.0
  epsilon_end:
    value: 0.1
  reset_replay_buffer_on_update:
    value: true
  q_net_type:
    value: discrete
  activation:
    value: relu
  update_target_every:
    value: 200
  lr:
    values: [0.0001, 0.001, 0.01]
  anneal_lr:
    value: true
  max_grad_norm:
    value: 1.0
  loss:
    value: mse
  huber_loss_parameter:
    value: 1.0
  
  # --- Munchausen parameters ---
  tau:
    values: [0.05, 5, 10]
  alpha:
    values: [0.9, 0.95, 0.99]
  with_munchausen:
    value: true

  # --- logging frequencies ---
  num_iterations:
    value: 200
  num_updates_per_iteration:
    values: [50, 100, 200]
  eval_frequency:
    value: 20
