seed: 42


# Models are instantiated using skrl's model instantiator utility
# https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html
#
#             obs[0]                  obs[1]
#               │                       │
#    ┏━━━━━━━━━━▼━━━━━━━━━━┓    ┏━━━━━━━▼━━━━━━━━┓
#    ┃ features_extractor  ┃    ┃ proprioception ┃
#    ┡━━━━━━━━━━━━━━━━━━━━━┩    ┡━━━━━━━━━━━━━━━━┩
#    │ conv2d(32, 8, 4)    │    │ linear(16)     │
#    │ relu                │    │ elu            │
#    │ conv2d(64, 4, 2)    │    │ linear(8)      │
#    │ relu                │    │ elu            │
#    │ conv2d(64, 3, 1)    │    └───────┬────────┘
#    │ relu                │            │
#    │ flatten             │            │
#    │ linear(512)         │            │
#    │ tanh                │            │
#    │ linear(16)          │            │
#    │ tanh                │            │
#    └──────────┬──────────┘            |
#               │                       │
#               └─▶(concatenate)◀───────┘
#                        │
#                 ┏━━━━━━▼━━━━━┓
#                 ┃     net    ┃
#                 ┡━━━━━━━━━━━━┩
#                 │ linear(32) │
#                 │ elu        │
#                 │ linear(32) │
#                 │ elu        │
#                 └──────┬─────┘
# shared                 │
# .......................│.......................
# non-shared             │
#            ┏━━━━━━━━━━━▼━━━━━━━━━━━┓
#            ┃  policy|value output  ┃
#            ┡━━━━━━━━━━━━━━━━━━━━━━━┩
#            │ linear(num_actions|1) │
#            └───────────┬───────────┘
#                        ▼
models:
  separate: False
  policy:  # see multicategorical_model parameters
    class: MultiCategoricalMixin
    unnormalized_log_prob: True
    network:
      - name: features_extractor
        input: permute(OBSERVATIONS[0], (0, 3, 1, 2))  # PyTorch NHWC -> NCHW. Warning: don't permute for JAX since it expects NHWC
        layers:
          - conv2d: {out_channels: 32, kernel_size: 8, stride: 4, padding: 0}
          - conv2d: {out_channels: 64, kernel_size: 4, stride: 2, padding: 0}
          - conv2d: {out_channels: 64, kernel_size: 3, stride: 1, padding: 0}
          - flatten
          - linear: 512
          - linear: 16
        activations: [relu, relu, relu, null, tanh, tanh]
      - name: proprioception
        input: OBSERVATIONS[1]
        layers: [16, 8]
        activations: elu
      - name: net
        input: concatenate([features_extractor, proprioception])
        layers: [32, 32]
        activations: elu
    output: ACTIONS
  value:  # see deterministic_model parameters
    class: DeterministicMixin
    clip_actions: False
    network:
      - name: features_extractor
        input: permute(OBSERVATIONS[0], (0, 3, 1, 2))  # PyTorch NHWC -> NCHW. Warning: don't permute for JAX since it expects NHWC
        layers:
          - conv2d: {out_channels: 32, kernel_size: 8, stride: 4, padding: 0}
          - conv2d: {out_channels: 64, kernel_size: 4, stride: 2, padding: 0}
          - conv2d: {out_channels: 64, kernel_size: 3, stride: 1, padding: 0}
          - flatten
          - linear: 512
          - linear: 16
        activations: [relu, relu, relu, null, tanh, tanh]
      - name: proprioception
        input: OBSERVATIONS[1]
        layers: [16, 8]
        activations: elu
      - name: net
        input: concatenate([features_extractor, proprioception])
        layers: [32, 32]
        activations: elu
    output: ONE


# Rollout memory
# https://skrl.readthedocs.io/en/latest/api/memories/random.html
memory:
  class: RandomMemory
  memory_size: -1  # automatically determined (same as agent:rollouts)


# PPO agent configuration (field names are from PPO_DEFAULT_CONFIG)
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html
agent:
  class: PPO
  rollouts: 64
  learning_epochs: 4
  mini_batches: 32
  discount_factor: 0.99
  lambda: 0.95
  learning_rate: 1.0e-04
  learning_rate_scheduler: KLAdaptiveLR
  learning_rate_scheduler_kwargs:
    kl_threshold: 0.008
  state_preprocessor: null
  state_preprocessor_kwargs: null
  value_preprocessor: RunningStandardScaler
  value_preprocessor_kwargs: null
  random_timesteps: 0
  learning_starts: 0
  grad_norm_clip: 1.0
  ratio_clip: 0.2
  value_clip: 0.2
  clip_predicted_values: True
  entropy_loss_scale: 0.0
  value_loss_scale: 1.0
  kl_threshold: 0.0
  rewards_shaper_scale: 1.0
  time_limit_bootstrap: False
  mixed_precision: False
  # logging and checkpoint
  experiment:
    directory: "cartpole_camera_direct_tuple_multidiscrete"
    experiment_name: ""
    write_interval: auto
    checkpoint_interval: auto


# Sequential trainer
# https://skrl.readthedocs.io/en/latest/api/trainers/sequential.html
trainer:
  class: SequentialTrainer
  timesteps: 32000
  environment_info: log
