# @package _global_

scratch:
  resolution: 1024
  train_batch_size: 1
  num_train_workers: 10
  num_frames: 8
  max_num_objects: 3
  base_lr: 5.0e-6
  vision_lr: 3.0e-06
  phases_per_epoch: 1
  num_epochs: 40

dataset:
  # PATHS to Dataset
  img_folder: null # PATH to MOSE JPEGImages folder
  gt_folder: null  # PATH to MOSE Annotations folder
  file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
  multiplier: 2

# Video transforms
vos:
  train_transforms:
    - _target_: training.dataset.transforms.ComposeAPI
      transforms:
        - _target_: training.dataset.transforms.RandomHorizontalFlip
          consistent_transform: True
        - _target_: training.dataset.transforms.RandomAffine
          degrees: 25
          shear: 20
          image_interpolation: bilinear
          consistent_transform: True
        - _target_: training.dataset.transforms.RandomResizeAPI
          sizes: ${scratch.resolution}
          square: true
          consistent_transform: True
        - _target_: training.dataset.transforms.ColorJitter
          consistent_transform: True
          brightness: 0.1
          contrast: 0.03
          saturation: 0.03
          hue: null
        - _target_: training.dataset.transforms.RandomGrayscale
          p: 0.05
          consistent_transform: True
        - _target_: training.dataset.transforms.ColorJitter
          consistent_transform: False
          brightness: 0.1
          contrast: 0.05
          saturation: 0.05
          hue: null
        - _target_: training.dataset.transforms.ToTensorAPI
        - _target_: training.dataset.transforms.NormalizeAPI
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]

trainer:
  _target_: training.trainer.Trainer
  mode: train_only
  max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
  accelerator: cuda
  seed_value: 123

  model:
    _target_: training.model.sam2.SAM2Train
    image_encoder:
      _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
      scalp: 1
      trunk:
        _target_: sam2.modeling.backbones.hieradet.Hiera
        embed_dim: 112
        num_heads: 2
        drop_path_rate: 0.1
      neck:
        _target_: sam2.modeling.backbones.image_encoder.FpnNeck
        position_encoding:
          _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
          num_pos_feats: 256
          normalize: true
          scale: null
          temperature: 10000
        d_model: 256
        backbone_channel_list: [896, 448, 224, 112]
        fpn_top_down_levels: [2, 3]  # output level 0 and 1 directly use the backbone features
        fpn_interp_model: nearest

    memory_attention:
      _target_: sam2.modeling.memory_attention.MemoryAttention
      d_model: 256
      pos_enc_at_input: true
      layer:
        _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
        activation: relu
        dim_feedforward: 2048
        dropout: 0.1
        pos_enc_at_attn: false
        self_attention:
          _target_: sam2.modeling.sam.transformer.RoPEAttention
          rope_theta: 10000.0
          feat_sizes: [64, 64]
          embedding_dim: 256
          num_heads: 1
          downsample_rate: 1
          dropout: 0.1
        d_model: 256
        pos_enc_at_cross_attn_keys: true
        pos_enc_at_cross_attn_queries: false
        cross_attention:
          _target_: sam2.modeling.sam.transformer.RoPEAttention
          rope_theta: 10000.0
          feat_sizes: [64, 64]
          rope_k_repeat: True
          embedding_dim: 256
          num_heads: 1
          downsample_rate: 1
          dropout: 0.1
          kv_in_dim: 64
      num_layers: 4

    memory_encoder:
        _target_: sam2.modeling.memory_encoder.MemoryEncoder
        out_dim: 64
        position_encoding:
          _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
          num_pos_feats: 64
          normalize: true
          scale: null
          temperature: 10000
        mask_downsampler:
          _target_: sam2.modeling.memory_encoder.MaskDownSampler
          kernel_size: 3
          stride: 2
          padding: 1
        fuser:
          _target_: sam2.modeling.memory_encoder.Fuser
          layer:
            _target_: sam2.modeling.memory_encoder.CXBlock
            dim: 256
            kernel_size: 7
            padding: 3
            layer_scale_init_value: 1e-6
            use_dwconv: True  # depth-wise convs
          num_layers: 2

    num_maskmem: 7
    image_size: ${scratch.resolution}
    # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
    sigmoid_scale_for_mem_enc: 20.0
    sigmoid_bias_for_mem_enc: -10.0
    use_mask_input_as_output_without_sam: true
    # Memory
    directly_add_no_mem_embed: true
    no_obj_embed_spatial: true
    # use high-resolution feature map in the SAM mask decoder
    use_high_res_features_in_sam: true
    # output 3 masks on the first click on initial conditioning frames
    multimask_output_in_sam: true
    # SAM heads
    iou_prediction_use_sigmoid: True
    # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
    use_obj_ptrs_in_encoder: true
    add_tpos_enc_to_obj_ptrs: true
    proj_tpos_enc_in_obj_ptrs: true
    use_signed_tpos_enc_to_obj_ptrs: true
    only_obj_ptrs_in_the_past_for_eval: true
    # object occlusion prediction
    pred_obj_scores: true
    pred_obj_scores_mlp: true
    fixed_no_obj_ptr: true
    # multimask tracking settings
    multimask_output_for_tracking: true
    use_multimask_token_for_obj_ptr: true
    multimask_min_pt_num: 0
    multimask_max_pt_num: 1
    use_mlp_for_obj_ptr_proj: true
    # Compilation flag
    # compile_image_encoder: False

    ####### Training specific params #######
    # box/point input and corrections
    prob_to_use_pt_input_for_train: 0.5
    prob_to_use_pt_input_for_eval: 0.0
    prob_to_use_box_input_for_train: 0.5  # 0.5*0.5 = 0.25 prob to use box instead of points
    prob_to_use_box_input_for_eval: 0.0
    prob_to_sample_from_gt_for_train: 0.1  # with a small prob, sampling correction points from GT mask instead of prediction errors
    num_frames_to_correct_for_train: 2  # iteratively sample on random 1~2 frames (always include the first frame)
    num_frames_to_correct_for_eval: 1  # only iteratively sample on first frame
    rand_frames_to_correct_for_train: True  # random #init-cond-frame ~ 2
    add_all_frames_to_correct_as_cond: True  # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
    # maximum 2 initial conditioning frames
    num_init_cond_frames_for_train: 2
    rand_init_cond_frames_for_train: True  # random 1~2
    num_correction_pt_per_frame: 7
    use_act_ckpt_iterative_pt_sampling: false
    

    
    num_init_cond_frames_for_eval: 1  # only mask on the first frame
    forward_backbone_per_frame_for_eval: True
    

data:
  train:
    _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset 
    phases_per_epoch: ${phases_per_epoch} # Chunks a single epoch into smaller phases
    batch_sizes: # List of batch sizes corresponding to each dataset
    - ${bs1} # Batch size of dataset 1
    datasets:
    # SA1B as an example of an image dataset
    - _target_: training.dataset.vos_dataset.VOSDataset
      training: true
      video_dataset:
        _target_: training.dataset.vos_raw_dataset.SA1BRawDataset
        img_folder: ${path_to_img_folder}
        gt_folder: ${path_to_gt_folder}
        file_list_txt: ${path_to_train_filelist} # Optional
      sampler:
        _target_: training.dataset.vos_sampler.RandomUniformSampler
        num_frames: 1
        max_num_objects: ${max_num_objects_per_image}
      transforms: ${image_transforms}
    shuffle: True
    num_workers: ${num_train_workers}
    pin_memory: True
    drop_last: True
    collate_fn:
    _target_: training.utils.data_utils.collate_fn
    _partial_: true
    dict_key: all

  optim:
    amp:
      enabled: True
      amp_dtype: bfloat16

    optimizer:
      _target_: torch.optim.AdamW

    gradient_clip:
      _target_: training.optimizer.GradientClipper
      max_norm: 0.1
      norm_type: 2

    param_group_modifiers:
      - _target_: training.optimizer.layer_decay_param_modifier
        _partial_: True
        layer_decay_value: 0.9
        apply_to: 'image_encoder.trunk'
        overrides:
          - pattern: '*pos_embed*'
            value: 1.0

    options:
      lr:
        - scheduler:
            _target_: fvcore.common.param_scheduler.CosineParamScheduler
            start_value: ${scratch.base_lr}
            end_value: ${divide:${scratch.base_lr},10}
        - scheduler:
            _target_: fvcore.common.param_scheduler.CosineParamScheduler
            start_value: ${scratch.vision_lr}
            end_value: ${divide:${scratch.vision_lr},10}
          param_names:
            - 'image_encoder.*'
      weight_decay:
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.1
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.0
          param_names:
            - '*bias*'
          module_cls_names: ['torch.nn.LayerNorm']

  loss:
    all:
      _target_: training.loss_fns.MultiStepMultiMasksAndIous
      weight_dict:
        loss_mask: 20
        loss_dice: 1
        loss_iou: 1
        loss_class: 1
      supervise_all_iou: true
      iou_use_l1_loss: true
      pred_obj_scores: true
      focal_gamma_obj_score: 0.0
      focal_alpha_obj_score: -1.0

  distributed:
    backend: nccl
    find_unused_parameters: True

  logging:
    tensorboard_writer:
      _target_: training.utils.logger.make_tensorboard_logger
      log_dir:  ${launcher.experiment_log_dir}/tensorboard
      flush_secs: 120
      should_log: True
    log_dir: ${launcher.experiment_log_dir}/logs
    log_freq: 10

  # initialize from a SAM 2 checkpoint
  checkpoint:
    save_dir: ${launcher.experiment_log_dir}/checkpoints
    save_freq: 0 # 0 only last checkpoint is saved.
    model_weight_initializer:
      _partial_: True
      _target_: training.utils.checkpoint_utils.load_state_dict_into_model
      strict: True
      ignore_unexpected_keys: null
      ignore_missing_keys: null

      state_dict:
        _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
        checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
        ckpt_state_dict_keys: ['model']

launcher:
  num_nodes: 1
  gpus_per_node: 8
  experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}

# SLURM args if running on a cluster
submitit:
  partition: null
  account: null
  qos: null
  cpus_per_task: 10
  use_cluster: false
  timeout_hour: 24
  name: null
  port_range: [10000, 65000]

