---
algorithm: &algo zero_mb
name: zero_mb
dynamics_name: magw

precision: 32

# model path: root_dir/model_name/name
# tensorboard path: root_dir/model_name/logs
# the following names are just examples; they will be re-specified in the entry point
root_dir: *algo
model_name: *algo

routine:
    algorithm: *algo

    MAX_STEPS: 5e7
    n_steps: &nsteps 50
    LOG_PERIOD: 1e5
    EVAL_PERIOD: null

    RECORD_VIDEO: True
    N_EVAL_EPISODES: 1
    size: [256, 256]

    n_imaginary_runs: 3
    n_imaginary_envs: 64
    n_imaginary_steps: 10
    imaginary_rollout: sim
    switch_model_at_every_step: False

env:
    env_name: &env_name magw-staghunt
    n_runners: &nrunners 1
    n_envs: &nenvs 256
    max_episode_steps: 50
    reward_scale: .1
    population_size: 1
    timeout_done: &td False

    uid2aid: [0, 1]

agent: {}

monitor:
    use_tensorboard: True

strategy:
    algorithm: *algo
    train_loop: {}

model:
    aid: 0
    gamma: &gamma .99

    policy:
        nn_id: policy
        units_list: [64, 64]
        w_init: orthogonal
        activation: tanh
        norm: null
        out_scale: .01
        rnn_type: &prnn null
        rnn_units: 64
    value:
        nn_id: value
        units_list: [64, 64]
        w_init: orthogonal
        activation: tanh
        norm: null
        rnn_type: &vrnn null
        rnn_units: 64

loss:
    # hyperparams for value target and advantage
    target_type: gae
    c_clip: 1
    rho_clip: 1
    adv_type: vtrace
    norm_adv: True

    prnn_bptt: 10
    vrnn_bptt: 10

    # hyperparams for policy optimization
    pg_type: ppo
    ppo_clip_range: .2
    kl_type: reverse
    kl_coef: 0
    policy_sample_mask: True

    # hyperparams for value learning
    value_loss: clip
    value_clip_range: .2
    value_sample_mask: False

    stats:
        gamma: *gamma
        lam: &lam .95
        pg_coef: 1
        entropy_coef: .01
        value_coef: 1

trainer:
    algorithm: *algo
    aid: 0
    n_runners: *nrunners
    n_envs: *nenvs
    n_epochs: &nepochs 4
    n_imaginary_epochs: 4
    n_mbs: &nmbs 1
    n_steps: *nsteps         # BPTT length

    theta_opt:
        opt_name: adam
        lr: 1e-3
        clip_norm: .5
        eps: 1e-5

actor:
    update_obs_rms_at_execution: False
    update_obs_at_execution: False
    rms:
        obs_names: [obs, global_state]
        obs_normalized_axis: [0, 1, 2]
        reward_normalized_axis: [0, 1, 2]
        normalize_obs: False
        obs_clip: 10
        normalize_reward: False
        update_reward_rms_in_time: False
        gamma: *gamma

buffer:
    type: ac

    n_runners: *nrunners
    n_envs: *nenvs
    n_steps: *nsteps
    queue_size: 2
    timeout_done: *td

    gamma: *gamma
    lam: *lam

    sample_keys: 
        - obs
        - action
        - value
        - reward
        - discount
        - reset
        - mu_logprob
        - mu_logits
        - state_reset
        - state
