# @package _global_

defaults:
  - /habitat: habitat_config_base
  - /habitat_baselines: habitat_baselines_rl_config_base
  - /habitat_baselines/rl/policy/obs_transforms:
    - resize_shortest_edge_base
    - center_cropper_base
  - _self_


habitat_baselines:
  rl:
    policy:
        name: "PointNavResNetPolicy"
        no_downscaling: True
        action_distribution_type: "gaussian"
        action_dist:
            use_log_std: True
            clamp_std: True
            std_init: -1.0
            max_std: 1
            use_std_param: True
        obs_transforms:
          resize_shortest_edge:
            type: ResizeShortestEdge
            size: 120
            channels_last: True
            trans_keys:
              - head_rgb
              - head_depth
              - object_segmentation
              - start_recep_segmentation
              - head_panoptic
              - goal_recep_segmentation
              - ovmm_nav_goal_segmentation
            semantic_keys:
              - head_panoptic
              - object_segmentation
              - start_recep_segmentation
              - goal_recep_segmentation
              - ovmm_nav_goal_segmentation
          center_cropper:
            type: CenterCropper
            height: 160
            width: 120
            channels_last: True
            trans_keys:
              - head_rgb
              - head_depth
              - object_segmentation
              - start_recep_segmentation
              - head_panoptic
              - goal_recep_segmentation
              - ovmm_nav_goal_segmentation


    ppo:
      # ppo params
      clip_param: 0.2
      ppo_epoch: 2
      num_mini_batch: 2
      value_loss_coef: 0.5
      entropy_coef: 0.001
      lr: 3e-4
      eps: 1e-5
      max_grad_norm: 0.2
      num_steps: 128
      use_gae: True
      gamma: 0.99
      tau: 0.95
      use_linear_clip_decay: False
      use_linear_lr_decay: False
      reward_window_size: 50

      use_normalized_advantage: False

      hidden_size: 512

      # Use double buffered sampling, typically helps
      # when environment time is similar or larger than
      # policy inference time during rollout generation
      use_double_buffered_sampler: False

    ddppo:
      sync_frac: 0.6
      # The PyTorch distributed backend to use
      distrib_backend: NCCL
      # Visual encoder backbone
      pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth
      # Initialize with pretrained weights
      pretrained: False
      # Initialize just the visual encoder backbone with pretrained weights
      pretrained_encoder: False
      # Whether the visual encoder backbone will be trained.
      train_encoder: True
      # Whether to reset the critic linear layer
      reset_critic: True

      # Model parameters
      backbone: resnet18
      rnn_type: LSTM
      num_recurrent_layers: 2

# unused
habitat:
  simulator:
    agents:
      main_agent:
        articulated_agent_type: "StretchRobot"