seed: 0
name: "train"
# logdir is set to empty string so it can be updated by the main script
logdir: ""

device:
    cuda: true
    cuda_deterministic: true
    torch_threads: 4

replay:
    capacity: 1000000
    offload: False

train:
    n_rollout_threads: 16
    parallel_rollout: true
    num_env_steps: 2e7
    train_ratio: 128
    imagination_steps: 16
    prefill_steps: 5000
    batch_size: 16
    batch_length: 64
    burn_in_length: 0
    stoch_dyn_scale: 0.8
    stoch_rep_scale: 0.2
    free_bits: 1.0
    obs_scale: 1.0
    reward_scale: 1.0
    cont_scale: 1.0
    act_mask_scale: 1.0
    entropy_coef: 1e-2
    target_update_tau: 1.0
    gamma: 0.99
    gae_lambda: 0.95
    share_actors: false
    share_critics: false
    # ppo specific parameters
    ppo_epochs: 5
    clip_param: 0.2
    # optimizer configs for each model, use Adam as default
    optim:
        world_model:
            lr: 1e-4
            eps: 1e-5
            use_max_grad_norm: true
            max_grad_norm: 1000.0
        critic:
            lr: 3e-5
            eps: 1e-8
            use_max_grad_norm: true
            max_grad_norm: 100.0
        actor:
            lr: 3e-5
            eps: 1e-8
            use_max_grad_norm: true
            max_grad_norm: 100.0
    checkpoint:
        save_interval: 50000
        from_checkpoint: ""

use_eval: true
eval:
    parallel_rollout: true
    n_rollout_threads: 4
    eval_interval: 10000
    eval_episode_num: 32

logging:
    log_interval: 5000
    rewards_reduce: "mean"
    terminal_filter: '^((?!timer).)*$'
    use_wandb: true
    wandb_filter: '.*'
    wandb_config:
        project: "DMAWM"
        job_type: "dmawm"
        entity: ""
        notes: ""
    log_keys_sum: '^$'
    log_keys_avg: '^$'
    log_keys_max: '^$'
    timer: true

world_model:
    encoder:
        act: SiLU
        hidden_dim: 1024
        hidden_layers: 2
        use_layernorm: true
        use_symlog: false
        # cnn
        kernel: 4
        depth: 32
        mults: [2, 3, 4, 4]
    rssm:
        deterministic_dim: 512
        stochastic_dim: 32
        classes: 32
        unimix: 0.01
        hidden_dim: 1024
        act: SiLU
        use_layernorm: true
        obs_layers: 2
        mlp:
            act: SiLU
            hidden_dim: 1024
            hidden_layers: 1
            use_layernorm: true
        rnn:
            act: SiLU
            use_layernorm: true
        use_img_stoch_transformer: true
        img_stoch_transformer:
            num_layers: 3
            nhead: 8
            activation: gelu
            norm_first: true
        obs_predictor:
            act: SiLU
            hidden_dim: 1024
            hidden_layers: 2
            use_layernorm: true
            output: "symlog_mse"
            # cnn
            depth: 32
            mults: [4, 4, 3, 2]
            kernel: 4
        global_agent_embedding_transformer:
            num_layers: 3
            nhead: 8
            activation: gelu
            norm_first: true
        reward_predictor:
            act: SiLU
            hidden_dim: 1024
            hidden_layers: 2
            use_layernorm: true
            out_scale: 0.0
            output: "symexp_twohot"
            enable_feat_grad: true
        use_cont_predictor: false
        cont_predictor:
            act: SiLU
            hidden_dim: 1024
            hidden_layers: 2
            use_layernorm: true
            output: "binary"
            enable_feat_grad: true
        use_act_mask_predictor: false
        act_mask_predictor:
            act: SiLU
            hidden_dim: 1024
            hidden_layers: 2
            use_layernorm: true
            output: "binary"
            enable_feat_grad: true

actor:
    act: SiLU
    hidden_dim: 256
    hidden_layers: 2
    use_layernorm: true
    out_scale: 0.01

critic:
    act: SiLU
    hidden_dim: 256
    hidden_layers: 2
    use_layernorm: true
    out_scale: 0.0
    output: "symexp_twohot"
