---
# NOTE: root_dir and model_name will be specified to all configs in run.train.py
algorithm: &algo zero2
name: zero2
version: 0

precision: 32

n_agents: 2

# model path: root_dir/model_name/name
# tensorboard path: root_dir/model_name/logs
# the following names are just examples; they will be re-specified in the entry point
root_dir: *algo
model_name: *algo

controller:
    store_period: 1e6
    restart_runners_priod: null
    max_version_iterations: 1000
    max_steps_per_iteration: 1000
    initialize_rms: &irms False

parameter_server:
    root_dir: *algo
    model_name: *algo

    train_from_scratch_frac: 1
    online_frac: .2

    payoff:
        step_size: .1
        sampling_strategy: fsp

ray_config:
    runner:
        num_cpus: 1
    agent:
        num_gpus: 1

monitor: {}

runner:
    n_runners: &nrunners 10
    n_steps: &nsteps 20
    push_every_episode: False

env:
    env_name: &env_name overcooked-asymmetric_advantages
    max_episode_steps: 400
    n_envs: &nenvs 4
    dense_reward: False
    featurize: True
    uid2aid: [0, 1]
    # layout_params:
    #     rew_shaping_params:
    #         PLACEMENT_IN_POT_REW: 1
    #         DISH_PICKUP_REWARD: 1
    #         SOUP_PICKUP_REWARD: 1
    #         DISH_DISP_DISTANCE_REW: 0
    #         POT_DISTANCE_REW: 0
    #         SOUP_DISTANCE_REW: 0

agent: {}

strategy:
    train_loop:
        n_epochs: &nepochs 1
        n_mbs: &nmbs 1
        n_aux_epochs: &naepochs 9
        n_pi: &npi 16
        n_segs: &nsegs 16
        n_aux_mbs_per_seg: &nambs 2
        max_kl: 0            # early stop when max_kl is violated. 0 or null suggests unbound

        # (once reuse null)
        # "once" updates values at the end of each epoch
        # "reuse" updates values using value from train which is staler than once
        # null doesn't update values.
        value_update: null

model:
    rnn_type: &rnn null
    sample_size: &ss 16

    encoder: 
        nn_id: cnn_impala
        out_size: 256

    # rnn:
    #     nn_id: *rnn
    #     units: 256
    #     kernel_initializer: orthogonal
    #     use_ln: False

    policy:
        nn_id: policy
        units_list: [256]
        kernel_initializer: orthogonal
        activation: relu
        eval_act_temp: 1
        out_gain: .01

    value:
        nn_id: value
        units_list: [256]
        kernel_initializer: orthogonal
        activation: relu

loss:
    entropy_coef: .01
    value_loss: clip
    clip_range: .05
    value_coef: .5

trainer:
    algorithm: *algo
    display_var: False

    sample_size: *ss         # BPTT length
    store_state: True
    optimizer:
        opt_name: adam
        schedule_lr: False
        lr: 1e-3
        clip_norm: .1
        epsilon: 1.e-5
        weight_decay: 0

actor:
    algorithm: *algo

    rms:
        obs_names: []
        normalize_obs: *irms
        normalize_reward: False
        obs_normalized_axis: [0, 1]
        reward_normalized_axis: [0, 1]
        update_reward_rms_in_time: True
        gamma: &gamma .99

buffer:
    type: ac
    use_dataset: False

    # PPO configs
    adv_type: gae     # nae or gae
    gamma: *gamma
    lam: .98
    n_runners: *nrunners
    n_envs: *nenvs
    n_steps: *nsteps
    n_epochs: *nepochs
    n_mbs: *nmbs        # number of minibatches
    fragment_size: null
    rnn_type: *rnn
    sample_size: *ss
    norm_adv: minibatch

    # PPG configs
    n_pi: *npi
    n_segs: *nsegs
    n_aux_mbs_per_seg: *nambs
    n_aux_epochs: *naepochs

    # mini-batch size = n_runners * n_envs * epslen / n_mbs
    sample_keys:
        - obs
        - action
        - value
        - traj_ret
        - advantage
        - logpi
        - mask
        - h
        - c

    # aux_compute_keys: *keys

    # aux_sample_keys: *keys
