# @package _global_

defaults:
  - /benchmark/ovmm: gaze
  - /habitat_baselines: habitat_baselines_rl_config_base
  - /habitat_baselines/rl/policy/obs_transforms:
    - resize_shortest_edge_base
    - center_cropper_base
  - /habitat/simulator/sim_sensors@habitat_baselines.eval.extra_sim_sensors.third_rgb_sensor: third_rgb_sensor
  - _self_

habitat_baselines:
  verbose: False
  trainer_name: "ddppo"
  torch_gpu_id: 0
  tensorboard_dir: "tb"
  video_dir: "video_dir"
  video_fps: 30
  test_episode_count: -1
  eval_ckpt_path_dir: "data/new_checkpoints"
  # 26 environments will just barely be below 16gb.
  # 20 environments will just barely be below 11gb.
  num_environments: 20
  checkpoint_folder: "data/new_checkpoints"
  num_updates: -1
  total_num_steps: 4.0e8
  log_interval: 10
  num_checkpoints: 20
  # Force PyTorch to be single threaded as
  # this improves performance considerably
  force_torch_single_threaded: True
  eval_keys_to_include_in_name: ["reward", "force", "success"]

  eval:
    video_option: ["disk"]

  rl:
    policy:
      name: "PointNavResNetPolicy"
      action_distribution_type: "categorical"
      no_downscaling: True
      action_dist:
        use_log_std: True
        clamp_std: True
        std_init: -1.0
        use_std_param: True
        max_log_std: 2
      obs_transforms:
        resize_shortest_edge:
          type: ResizeShortestEdge
          size: 120
          channels_last: True
          trans_keys:
            - head_rgb
            - head_depth
            - head_panoptic
            - object_segmentation
            - receptacle_segmentation
            - start_recep_segmentation
            - goal_recep_segmentation
            - ovmm_nav_goal_segmentation
          semantic_keys:
            - head_panoptic
            - object_segmentation
            - receptacle_segmentation
            - start_recep_segmentation
            - goal_recep_segmentation
            - ovmm_nav_goal_segmentation
        center_cropper:
          type: CenterCropper
          height: 160
          width: 120
          channels_last: True
          trans_keys:
            - head_rgb
            - head_depth
            - head_panoptic
            - object_segmentation
            - receptacle_segmentation
            - start_recep_segmentation
            - goal_recep_segmentation
            - ovmm_nav_goal_segmentation

    ppo:
      # ppo params
      clip_param: 0.2
      ppo_epoch: 2
      num_mini_batch: 2
      value_loss_coef: 0.5
      entropy_coef: 0.001
      lr: 3e-4
      eps: 1e-5
      max_grad_norm: 0.2
      num_steps: 128
      use_gae: True
      gamma: 0.99
      tau: 0.95
      use_linear_clip_decay: False
      use_linear_lr_decay: False
      reward_window_size: 50

      use_normalized_advantage: False

      hidden_size: 512

      # Use double buffered sampling, typically helps
      # when environment time is similar or larger than
      # policy inference time during rollout generation
      use_double_buffered_sampler: False

    ddppo:
      sync_frac: 0.6
      # The PyTorch distributed backend to use
      distrib_backend: NCCL
      # Visual encoder backbone
      pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth
      # Initialize with pretrained weights
      pretrained: False
      # Initialize just the visual encoder backbone with pretrained weights
      pretrained_encoder: False
      # Whether the visual encoder backbone will be trained.
      train_encoder: True
      # Whether to reset the critic linear layer
      reset_critic: True

      # Model parameters
      backbone: resnet18
      rnn_type: LSTM
      num_recurrent_layers: 2
