# @package _global_

defaults:
  - ppo

algo:
  _target_: phys_anim.agents.amp.AMP
  config:
    # Setup discriminator structure
    discriminator:
      _target_: phys_anim.agents.models.discriminator.JointDiscMLP
      _recursive_: False
      config:
        initializer: default
        units: [1024, 512]
        discriminator_obs_historical_steps: ${env.config.discriminator_obs_historical_steps}
        activation: relu
        normalize_obs: ${algo.config.normalize_obs}
        obs_clamp_value: ${algo.config.obs_clamp_value}
        use_layer_norm: False
        extra_inputs: null

    discriminator_optimizer:
      _target_: torch.optim.Adam
      _recursive_: False
      lr: 1e-4
      betas: [0.9, 0.999]

    discriminator_lr_scheduler: null

    # AMP parameters
    task_reward_w: 0.0
    use_rand_action_masks: True

    discriminator_weight_decay: 0.0001
    discriminator_logit_weight_decay: 0.01
    discriminator_batch_size: 4096
    discriminator_reward_w: 2.0
    discriminator_grad_penalty: 5
    discriminator_replay_keep_prob: 0.01
    discriminator_replay_size: 200000
    num_discriminator_mini_epochs: 3
