# @package _global_

scratch:
  resolution: 512
  train_video_batch_size: 4
  num_train_workers: 15
  num_frames: 12
  max_num_objects: 3
  base_lr: 5.0e-5
  vision_lr: 3.0e-05
  phases_per_epoch: 1
  num_epochs: 1000

dataset:
  # PATHS to Dataset
  folder: # PATH to Med NPZ folder
  multiplier: 1

# Video transforms
vos:
  train_transforms:
    - _target_: training.dataset.transforms.ComposeAPI
      transforms:
        - _target_: training.dataset.transforms.RandomHorizontalFlip
          consistent_transform: True
        - _target_: training.dataset.transforms.RandomAffine
          degrees: 25
          shear: 20
          image_interpolation: bilinear
          consistent_transform: True
        - _target_: training.dataset.transforms.RandomResizeAPI
          sizes: ${scratch.resolution}
          square: true
          consistent_transform: True
        - _target_: training.dataset.transforms.RandomGaussianNoise
        - _target_: training.dataset.transforms.RandomGaussianBlur
        - _target_: training.dataset.transforms.RandomMosaicVideoAPI
          prob: 0.15
          grid_h: 2
          grid_w: 2
          use_random_hflip: False
        - _target_: training.dataset.transforms.ToTensorAPI
        - _target_: training.dataset.transforms.NormalizeAPI
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
  val_transforms:
    - _target_: training.dataset.transforms.ComposeAPI
      transforms:
        - _target_: training.dataset.transforms.RandomResizeAPI
          sizes: ${scratch.resolution}
          square: true
          consistent_transform: True
        - _target_: training.dataset.transforms.ToTensorAPI
        - _target_: training.dataset.transforms.NormalizeAPI
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]

trainer:
  _target_: training.trainer.Trainer
  mode: train_only
  max_epochs: ${scratch.num_epochs}
  accelerator: cuda
  seed_value: 123

  model:
    _target_: training.model.sam2_base.SAM2Train
    image_encoder:
      _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
      scalp: 1
      trunk:
        _target_: sam2.modeling.backbones.hieradet.Hiera
        embed_dim: 96
        num_heads: 1
        stages: [1, 2, 7, 2]
        global_att_blocks: [5, 7, 9]
        window_pos_embed_bkg_spatial_size: [7, 7]
        use_adapter: False
        adapter_bottleneck_dim: 64
      neck:
        _target_: sam2.modeling.backbones.image_encoder.FpnNeck
        position_encoding:
          _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
          num_pos_feats: 256
          normalize: true
          scale: null
          temperature: 10000
        d_model: 256
        backbone_channel_list: [768, 384, 192, 96]
        fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
        fpn_interp_model: nearest

    memory_attention:
      _target_: sam2.modeling.memory_attention.MemoryAttention
      d_model: 256
      pos_enc_at_input: true
      layer:
        _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
        activation: relu
        dim_feedforward: 2048
        dropout: 0.1
        pos_enc_at_attn: false
        self_attention:
          _target_: sam2.modeling.sam.transformer.RoPEAttention
          rope_theta: 10000.0
          feat_sizes: [32, 32]
          embedding_dim: 256
          num_heads: 1
          downsample_rate: 1
          dropout: 0.1
        d_model: 256
        pos_enc_at_cross_attn_keys: true
        pos_enc_at_cross_attn_queries: false
        cross_attention:
          _target_: sam2.modeling.sam.transformer.RoPEAttention
          rope_theta: 10000.0
          feat_sizes: [32, 32]
          rope_k_repeat: True
          embedding_dim: 256
          num_heads: 1
          downsample_rate: 1
          dropout: 0.1
          kv_in_dim: 64
      num_layers: 4

    memory_encoder:
      _target_: sam2.modeling.memory_encoder.MemoryEncoder
      out_dim: 64
      position_encoding:
        _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
        num_pos_feats: 64
        normalize: true
        scale: null
        temperature: 10000
      mask_downsampler:
        _target_: sam2.modeling.memory_encoder.MaskDownSampler
        kernel_size: 3
        stride: 2
        padding: 1
      fuser:
        _target_: sam2.modeling.memory_encoder.Fuser
        layer:
          _target_: sam2.modeling.memory_encoder.CXBlock
          dim: 256
          kernel_size: 7
          padding: 3
          layer_scale_init_value: 1e-6
          use_dwconv: True # depth-wise convs
        num_layers: 2

    num_maskmem: 7
    image_size: ${scratch.resolution}
    # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
    # SAM decoder
    freeze_image_encoder: True
    sigmoid_scale_for_mem_enc: 20.0
    sigmoid_bias_for_mem_enc: -10.0
    use_mask_input_as_output_without_sam: true
    # Memory
    directly_add_no_mem_embed: true
    no_obj_embed_spatial: true
    # use high-resolution feature map in the SAM mask decoder
    use_high_res_features_in_sam: true
    # output 3 masks on the first click on initial conditioning frames
    multimask_output_in_sam: true
    # SAM heads
    iou_prediction_use_sigmoid: True
    # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
    use_obj_ptrs_in_encoder: true
    add_tpos_enc_to_obj_ptrs: true
    proj_tpos_enc_in_obj_ptrs: true
    use_signed_tpos_enc_to_obj_ptrs: true
    only_obj_ptrs_in_the_past_for_eval: true
    # object occlusion prediction
    pred_obj_scores: true
    pred_obj_scores_mlp: true
    fixed_no_obj_ptr: true
    # multimask tracking settings
    multimask_output_for_tracking: true
    use_multimask_token_for_obj_ptr: true
    multimask_min_pt_num: 0
    multimask_max_pt_num: 1
    use_mlp_for_obj_ptr_proj: true
    # Compilation flag
    # compile_image_encoder: False

    ####### Training specific params #######
    # box/point input and corrections
    prob_to_use_pt_input_for_train: 0.5
    prob_to_use_pt_input_for_eval: 0.0
    prob_to_use_box_input_for_train: 1.0
    prob_to_use_box_input_for_eval: 0.0
    prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
    num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
    num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
    rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
    add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
    # maximum 2 initial conditioning frames
    num_init_cond_frames_for_train: 2
    rand_init_cond_frames_for_train: True # random 1~2
    num_correction_pt_per_frame: 7
    use_act_ckpt_iterative_pt_sampling: false

    num_init_cond_frames_for_eval: 1 # only mask on the first frame
    forward_backbone_per_frame_for_eval: True

  data:
    train:
      _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
      phases_per_epoch: ${scratch.phases_per_epoch}
      batch_sizes:
        - ${scratch.train_video_batch_size}
      datasets:
        - _target_: training.dataset.vos_dataset.VOSDataset
          transforms: ${vos.train_transforms}
          training: false
          video_dataset:
            _target_: training.dataset.vos_raw_dataset.NPZRawDataset
            folder: /home/lthpc/Datasets/Task10_Colon_test_v3/Training
          sampler:
            _target_: training.dataset.vos_sampler.RandomUniformSampler
            num_frames: ${scratch.num_frames}
            max_num_objects: ${scratch.max_num_objects}
          multiplier: 1

      shuffle: True
      num_workers: ${scratch.num_train_workers}
      pin_memory: True
      drop_last: True
      collate_fn:
        _target_: training.utils.data_utils.collate_fn
        _partial_: true
        dict_key: all_baseline

    val:
      _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
      phases_per_epoch: ${scratch.phases_per_epoch}
      batch_sizes:
        - 1
      datasets:
        - _target_: training.dataset.utils.RepeatFactorWrapper
          dataset:
            _target_: training.dataset.utils.ConcatDataset
            datasets:
              # PNG
              - _target_: training.dataset.vos_dataset.VOSDataset
                transforms: ${vos.val_transforms}
                training: true
                video_dataset:
                  _target_: training.dataset.vos_raw_dataset.PNGRawDataset
                  img_folder: /home/lthpc/Datasets/Task02_Heart_png_sift_test/Test/image
                  gt_folder: /home/lthpc/Datasets/Task02_Heart_png_sift_test/Test/mask
                sampler:
                  _target_: training.dataset.vos_sampler.RandomUniformSampler
                  num_frames: ${scratch.num_frames}
                  max_num_objects: ${scratch.max_num_objects}
                multiplier: 1
      shuffle: True
      num_workers: ${scratch.num_train_workers}
      pin_memory: True
      drop_last: True
      collate_fn:
        _target_: training.utils.data_utils.collate_fn
        _partial_: true
        dict_key: all_baseline

  optim:
    amp:
      enabled: True
      amp_dtype: bfloat16

    optimizer:
      _target_: torch.optim.AdamW

    gradient_clip:
      _target_: training.optimizer.GradientClipper
      max_norm: 0.1
      norm_type: 2

    param_group_modifiers:
      - _target_: training.optimizer.layer_decay_param_modifier
        _partial_: True
        layer_decay_value: 0.9
        apply_to: "image_encoder.trunk"
        overrides:
          - pattern: "*pos_embed*"
            value: 1.0

    options:
      lr:
        - scheduler:
            _target_: fvcore.common.param_scheduler.CosineParamScheduler
            start_value: ${scratch.base_lr}
            end_value: ${divide:${scratch.base_lr},10}
        - scheduler:
            _target_: fvcore.common.param_scheduler.CosineParamScheduler
            start_value: ${scratch.vision_lr}
            end_value: ${divide:${scratch.vision_lr},10}
          param_names:
            - "image_encoder.*"
      weight_decay:
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.1
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.0
          param_names:
            - "*bias*"
          module_cls_names: ["torch.nn.LayerNorm"]

  loss:
    all_baseline:
      _target_: training.loss_fns.MultiStepMultiMasksAndIous
      weight_dict:
        loss_mask: 20
        loss_dice: 1
        loss_iou: 1
        loss_class: 1
      supervise_all_iou: true
      iou_use_l1_loss: true
      pred_obj_scores: true
      focal_gamma_obj_score: 0.0
      focal_alpha_obj_score: -1.0

  distributed:
    backend: nccl #  gloo or nccl
    find_unused_parameters: True

  logging:
    tensorboard_writer:
      _target_: training.utils.logger.make_tensorboard_logger
      log_dir: ${launcher.experiment_log_dir}/tensorboard
      flush_secs: 120
      should_log: True
    log_dir: ${launcher.experiment_log_dir}/logs
    log_freq: 10

  # initialize from a SAM 2 checkpoint
  checkpoint:
    save_dir: ${launcher.experiment_log_dir}/checkpoints
    save_freq: 50 # 0 only last checkpoint is saved.
    model_weight_initializer:
      _partial_: True
      _target_: training.utils.checkpoint_utils.load_state_dict_into_model
      strict: False
      ignore_unexpected_keys: null
      ignore_missing_keys: null

      state_dict:
        _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
        checkpoint_path: checkpoints/sam2.1_hiera_tiny.pt # PATH to SAM 2.1 checkpoint
        ckpt_state_dict_keys: ["model"]

launcher:
  num_nodes: 1
  gpus_per_node: 4
  experiment_log_dir: exp_log # Path to log directory, defaults to ./sam2_logs/${config_name}

# SLURM args if running on a cluster
submitit:
  partition: gpu_bwanggroup
  account: null
  qos: null
  cpus_per_task: 10
  use_cluster: false
  timeout_hour: 24
  name: null
  port_range: [10000, 65000]
