# @package __global__

defaults:
  - /env: zone_env
  - /alg: deep_ltl
  - /rl_alg: ppo

curriculum:
  _target_: jaxltl.deep_ltl.curriculum.zone_env_curriculum.make

curriculum_wrapper:
  episode_window: 512

model:
  _target_: jaxltl.deep_ltl.model.deep_ltl.DeepLTLModel
  _recursive_: false
  env_net:
    _target_: jaxltl.networks.mlp.MLP
    hidden_sizes: [128]
    out_size: 64
    activation: ${act:tanh}
  actor:
    _target_: jaxltl.deep_ltl.model.actor.continuous_actor.ContinuousActor
    hidden_sizes: [64, 64, 64]
    hidden_activation: ${act:relu}
    output_activation: ${act:tanh}
    state_dependent_std: true
  critic:
    _target_: jaxltl.networks.mlp.MLP
    hidden_sizes: [64, 64]
    activation: ${act:tanh}
  sequence:
    embedding_dim: 16
    deep_sets:
      _target_: jaxltl.networks.deep_sets.DeepSets
      hidden_sizes: [32]
      out_size: 16
      activation: ${act:relu}

rl_alg:
  total_timesteps: 5e6
  num_envs: 16
  num_steps: 4096
  num_minibatches: 32
  update_epochs: 10
  gamma: 0.998
  gae_lambda: 0.95
  clip_eps: 0.2
  ent_coef: 0.003
  vf_coef: 0.5
  lr: 3e-4
  max_grad_norm: 0.5
  anneal_lr: false
  adam_eps: 1e-8
