# @package _global_

scratch:
  resolution: 512
  train_batch_size: 8
  num_train_workers: 15
  num_frames: 8
  max_num_objects: 5
  base_lr: 5.0e-6
  vision_lr: 3.0e-06
  phases_per_epoch: 1
  num_epochs: 100

dataset:
  # PATHS to Dataset
  folder: /cluster/projects/bwanggroup/jma/data_sam2_3dmed2/npz_train/RECIST_train_npz # PATH to Med NPZ folder
  multiplier: 2

# Video transforms
vos:
  train_transforms:
    - _target_: training.dataset.transforms.ComposeAPI
      transforms:
        - _target_: training.dataset.transforms.RandomHorizontalFlip
          consistent_transform: True
        - _target_: training.dataset.transforms.RandomVerticalFlip
          consistent_transform: True
        - _target_: training.dataset.transforms.RandomAffine
          degrees: 25
          shear: 20
          scale: [0.7, 1.4]
          image_interpolation: bilinear
          consistent_transform: True
        - _target_: training.dataset.transforms.RandomAffine
          degrees: 5
          shear: 5
          scale: [0.95, 1.05]
          translate: [0.05, 0.05]
          image_interpolation: bilinear
          consistent_transform: False
          p: 0.1
        - _target_: training.dataset.transforms.RandomResizeAPI
          sizes: ${scratch.resolution}
          square: true
          consistent_transform: True
        - _target_: training.dataset.transforms.RandomGaussianNoise
          consistent_transform: True
          p: 0.1
        - _target_: training.dataset.transforms.RandomGaussianBlur
          consistent_transform: True
          kernel_size: 5
          sigma: [0.5, 1.0]
          p: 0.2
        - _target_: training.dataset.transforms.ColorJitter
          consistent_transform: True
          brightness: 0.2
          contrast: 0.2
          saturation: 0.03
          hue: null
        - _target_: training.dataset.transforms.ColorJitter
          consistent_transform: False
          brightness: 0.1
          contrast: 0.05
          saturation: 0.05
          hue: null
        - _target_: training.dataset.transforms.ToTensorAPI
        - _target_: training.dataset.transforms.NormalizeAPI
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]

trainer:
  _target_: training.trainer.Trainer
  mode: train_only
  max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
  accelerator: cuda
  seed_value: 123

  model:
    _target_: training.model.efficienttam.EfficientTAMTrain
    image_encoder:
      _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
      scalp: 0
      trunk:
        _target_: sam2.modeling.backbones.vitdet.ViT
        patch_size: 16
        embed_dim: 192
        depth: 12
        num_heads: 3
        mlp_ratio: 4.0
        qkv_bias: true
        drop_path_rate: 0.0
        use_rel_pos: false
        window_size: 14
        window_block_indexes: [0, 1, 3, 4, 6, 7, 9, 10]
      neck:
        _target_: sam2.modeling.backbones.image_encoder.ViTDetNeck
        position_encoding:
          _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
          num_pos_feats: 256
          normalize: true
          scale: null
          temperature: 10000
        d_model: 256
        backbone_channel_list: [192,]
        neck_norm: LN

    memory_attention:
      _target_: sam2.modeling.memory_attention.MemoryAttention
      d_model: 256
      pos_enc_at_input: true
      layer:
        _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
        activation: relu
        dim_feedforward: 2048
        dropout: 0.1
        pos_enc_at_attn: false
        self_attention:
          _target_: sam2.modeling.sam.transformer.RoPEAttention
          rope_theta: 10000.0
          feat_sizes: [32, 32]
          embedding_dim: 256
          num_heads: 1
          downsample_rate: 1
          dropout: 0.1
        d_model: 256
        pos_enc_at_cross_attn_keys: true
        pos_enc_at_cross_attn_queries: false
        cross_attention:
          _target_: sam2.modeling.sam.transformer.RoPEAttention
          rope_theta: 10000.0
          feat_sizes: [32, 32]
          rope_k_repeat: True
          embedding_dim: 256
          num_heads: 1
          downsample_rate: 1
          dropout: 0.1
          kv_in_dim: 64
      num_layers: 4

    memory_encoder:
        _target_: sam2.modeling.memory_encoder.MemoryEncoder
        out_dim: 64
        position_encoding:
          _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
          num_pos_feats: 64
          normalize: true
          scale: null
          temperature: 10000
        mask_downsampler:
          _target_: sam2.modeling.memory_encoder.MaskDownSampler
          kernel_size: 3
          stride: 2
          padding: 1
        fuser:
          _target_: sam2.modeling.memory_encoder.Fuser
          layer:
            _target_: sam2.modeling.memory_encoder.CXBlock
            dim: 256
            kernel_size: 7
            padding: 3
            layer_scale_init_value: 1e-6
            use_dwconv: True  # depth-wise convs
          num_layers: 2

    num_maskmem: 7
    image_size: ${scratch.resolution}
    # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
    # SAM decoder
    sigmoid_scale_for_mem_enc: 20.0
    sigmoid_bias_for_mem_enc: -10.0
    use_mask_input_as_output_without_sam: true
    # Memory
    directly_add_no_mem_embed: true
    use_high_res_features_in_sam: false
    # output 3 masks on the first click on initial conditioning frames
    multimask_output_in_sam: true
    # SAM heads
    iou_prediction_use_sigmoid: True
    # cross-attend to object pointers from other frames in the ViT encoder
    use_obj_ptrs_in_encoder: true
    add_tpos_enc_to_obj_ptrs: false
    only_obj_ptrs_in_the_past_for_eval: true
    # object occlusion prediction
    pred_obj_scores: true
    pred_obj_scores_mlp: true
    fixed_no_obj_ptr: true
    # multimask tracking settings
    multimask_output_for_tracking: true
    use_multimask_token_for_obj_ptr: true
    multimask_min_pt_num: 0
    multimask_max_pt_num: 1
    use_mlp_for_obj_ptr_proj: true
    # Compilation flag
    # compile_image_encoder: False

    ####### Training specific params #######
    # box/point input and corrections
    prob_to_use_pt_input_for_train: 0.5
    prob_to_use_pt_input_for_eval: 0.0
    prob_to_use_box_input_for_train: 0.5  # 0.5*0.5 = 0.25 prob to use box instead of points
    prob_to_use_box_input_for_eval: 0.0
    prob_to_sample_from_gt_for_train: 0.1  # with a small prob, sampling correction points from GT mask instead of prediction errors
    num_frames_to_correct_for_train: 2  # iteratively sample on random 1~2 frames (always include the first frame)
    num_frames_to_correct_for_eval: 1  # only iteratively sample on first frame
    rand_frames_to_correct_for_train: True  # random #init-cond-frame ~ 2
    add_all_frames_to_correct_as_cond: True  # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
    # maximum 2 initial conditioning frames
    num_init_cond_frames_for_train: 2
    rand_init_cond_frames_for_train: True  # random 1~2
    num_correction_pt_per_frame: 7
    use_act_ckpt_iterative_pt_sampling: false
    
    num_init_cond_frames_for_eval: 1  # only mask on the first frame
    forward_backbone_per_frame_for_eval: True
    

  data:
    train:
      _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
      phases_per_epoch: ${scratch.phases_per_epoch}
      batch_sizes:
        - ${scratch.train_batch_size}
      datasets:
        # Video Datasets
        - _target_: training.dataset.utils.RepeatFactorWrapper
          dataset:
            _target_: training.dataset.utils.ConcatDataset
            datasets:
            # CT
            - _target_: training.dataset.vos_dataset.VOSDataset
              transforms: ${vos.train_transforms}
              training: true
              video_dataset:
                _target_: training.dataset.vos_raw_dataset.NPZRawDataset
                folder: /cluster/projects/bwanggroup/jma/data_sam2_3dmed2/npz_train/RECIST_train_npz
              sampler:
                _target_: training.dataset.vos_sampler.RandomUniformSampler
                num_frames: ${scratch.num_frames}
                max_num_objects: ${scratch.max_num_objects}
              multiplier: 1
            


      shuffle: True
      num_workers: ${scratch.num_train_workers}
      pin_memory: True
      drop_last: True
      collate_fn:
        _target_: training.utils.data_utils.collate_fn
        _partial_: true
        dict_key: all

  optim:
    amp:
      enabled: True
      amp_dtype: bfloat16

    optimizer:
      _target_: torch.optim.AdamW

    gradient_clip:
      _target_: training.optimizer.GradientClipper
      max_norm: 0.1
      norm_type: 2

    param_group_modifiers:
      - _target_: training.optimizer.layer_decay_param_modifier
        _partial_: True
        layer_decay_value: 0.9
        apply_to: 'image_encoder.trunk'
        overrides:
          - pattern: '*pos_embed*'
            value: 1.0

    options:
      lr:
        - scheduler:
            _target_: fvcore.common.param_scheduler.CosineParamScheduler
            start_value: ${scratch.base_lr}
            end_value: ${divide:${scratch.base_lr},10}
        - scheduler:
            _target_: fvcore.common.param_scheduler.CosineParamScheduler
            start_value: ${scratch.vision_lr}
            end_value: ${divide:${scratch.vision_lr},10}
          param_names:
            - 'image_encoder.*'
      weight_decay:
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.1
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.0
          param_names:
            - '*bias*'
          module_cls_names: ['torch.nn.LayerNorm']

  loss:
    all:
      _target_: training.loss_fns.MultiStepMultiMasksAndIous
      weight_dict:
        loss_mask: 20
        loss_dice: 1
        loss_iou: 1
        loss_class: 1
      supervise_all_iou: true
      iou_use_l1_loss: true
      pred_obj_scores: true
      focal_gamma_obj_score: 0.0
      focal_alpha_obj_score: -1.0

  distributed:
    backend: nccl
    find_unused_parameters: True

  logging:
    tensorboard_writer:
      _target_: training.utils.logger.make_tensorboard_logger
      log_dir:  ${launcher.experiment_log_dir}/tensorboard
      flush_secs: 120
      should_log: True
    log_dir: ${launcher.experiment_log_dir}/logs
    log_freq: 10

  # initialize from a SAM 2 checkpoint
  checkpoint:
    save_dir: ${launcher.experiment_log_dir}/checkpoints
    save_freq: 10 # 0 only last checkpoint is saved.
    model_weight_initializer:
      _partial_: True
      _target_: training.utils.checkpoint_utils.load_state_dict_into_model
      strict: True
      ignore_unexpected_keys: null
      ignore_missing_keys: null

      state_dict:
        _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
        checkpoint_path: /cluster/projects/bwanggroup/jma/data_sam2_3dmed2/checkpoints/efficienttam_ti_512x512.pt # PATH to SAM 2.1 checkpoint
        ckpt_state_dict_keys: ['model']

launcher:
  num_nodes: 1
  gpus_per_node: 4
  experiment_log_dir: /cluster/projects/bwanggroup/jma/data_sam2_3dmed2/exp_log/RECIST-EffTiny # Path to log directory, defaults to ./sam2_logs/${config_name}

# SLURM args if running on a cluster
submitit:
  partition: null
  account: null
  qos: null
  cpus_per_task: 15
  use_cluster: false
  timeout_hour: 24
  name: null
  port_range: [10000, 65000]

