name: opv2v_mpda_protocol_pyramid_m1m3
root_dir: "dataset/OPV2V/train"
validate_dir: "dataset/OPV2V/validate"
test_dir: "dataset/OPV2V/test"

yaml_parser: "load_general_params"

train_setting:
  train_params:
    batch_size: 1
    epoches: 10
    eval_freq: 1
    save_freq: 1
    max_cav: 5

  optimizer:
    core_method: Adam
    lr: 0.001
    args:
      eps: 1e-10
      weight_decay: 1e-4

  lr_scheduler:
    core_method: multistep #step, multistep and Exponential support
    gamma: 0.1
    step_size: [10, 50, 80]



comm_range: 70
input_source: ['lidar']
label_type: 'lidar'
cav_lidar_range: &cav_lidar [-102.4, -51.2, -3, 102.4, 51.2, 1]

# add_data_extension: ['bev_visibility.png']

heter:
  assignment_path: "opencood/logs/heter_modality_assign/opv2v_4modality.json" 
  ego_modality: &ego_modality "m1&m3"
  mapping_dict: &mapping_dict
    m1: m1
    m2: m1
    m3: m3
    m4: m3
  lidar_channels_dict:
    m3: 32
  modality_setting:
    m1:
      sensor_type: &sensor_type_m1 'lidar'
      core_method: &core_method_m1 "point_pillar"

      # lidar requires preprocess
      preprocess:
        # options: BasePreprocessor, VoxelPreprocessor, BevPreprocessor
        core_method: 'SpVoxelPreprocessor'
        args:
          voxel_size: &voxel_size_m1 [0.4, 0.4, 4]
          max_points_per_voxel: 32
          max_voxel_train: 32000
          max_voxel_test: 70000
        # lidar range for each individual cav.
        cav_lidar_range: *cav_lidar
    m3:
      sensor_type: &sensor_type_m3 'lidar'
      core_method: &core_method_m3 "second"

      # lidar requires preprocess
      preprocess:
        # options: BasePreprocessor, VoxelPreprocessor, BevPreprocessor
        core_method: 'SpVoxelPreprocessor'
        args:
          voxel_size: &voxel_size_m3 [0.1, 0.1, 0.1]
          max_points_per_voxel: 5
          max_voxel_train: 32000
          max_voxel_test: 70000
        # lidar range for each individual cav.
        cav_lidar_range: *cav_lidar


fusion:
  # core_method: 'intermediateheterpair'
  core_method: 'intermediatehetercontrastive'
  dataset: 'opv2v'
  args: 
    proj_first: false
    grid_conf: None # place-holder
    data_aug_conf: None # place-holder

# data_augment: # no use in intermediate fusion
#   - NAME: random_world_flip
#     ALONG_AXIS_LIST: [ 'x' ]

#   - NAME: random_world_rotation
#     WORLD_ROT_ANGLE: [ -0.78539816, 0.78539816 ]

#   - NAME: random_world_scaling
#     WORLD_SCALE_RANGE: [ 0.95, 1.05 ]

preprocess:
  # options: BasePreprocessor, VoxelPreprocessor, BevPreprocessor
  core_method: 'SpVoxelPreprocessor'
  args:
    voxel_size: [0.4, 0.4, 4] # useful
    max_points_per_voxel: 1 # useless
    max_voxel_train: 1 # useless
    max_voxel_test: 1 # useless
  # lidar range for each individual cav.
  cav_lidar_range: *cav_lidar

# anchor box related
postprocess:
  core_method: 'VoxelPostprocessorContrastive' # VoxelPostprocessor, BevPostprocessor supported
  # core_method: 'VoxelPostprocessor'
  gt_range: *cav_lidar
  anchor_args:
    cav_lidar_range: *cav_lidar
    l: 3.9
    w: 1.6
    h: 1.56
    r: &anchor_yaw [0, 90]
    feature_stride: 2
    num: &anchor_num 2
  target_args:
    pos_threshold: 0.6
    neg_threshold: 0.45
    score_threshold: 0.2
  order: 'hwl' # hwl or lwh
  max_num: 150 # maximum number of objects in a single frame. use this number to make sure different frames has the same dimension in the same batch
  nms_thresh: 0.15
  dir_args: &dir_args
    dir_offset: 0.7853
    num_bins: 2
    anchor_yaw: *anchor_yaw

# model related
model:
  core_method: heter_mpda_protocol
  args:
    supervise_single: true
    mapping_dict: *mapping_dict
    lidar_range: *cav_lidar

    pub:
      protocol_feat_dim: &protocol_dim 64
      granularity_H: &gra_H 0.8
      granularity_W: &gra_W 0.8

    m1:
      model_dir: stage0_m1_collab
      core_method: *core_method_m1
      sensor_type: *sensor_type_m1
      local_dim: &local_dim_m1 64

      encoder_args:
        voxel_size: *voxel_size_m1
        lidar_range: *cav_lidar
        pillar_vfe:
          use_norm: true
          with_distance: false
          use_absolute_xyz: true
          num_filters: [64]
        point_pillar_scatter:
          num_features: 64

      backbone_args:
        layer_nums: [3]
        layer_strides: [2]
        num_filters: [64]
  
      aligner_args:
        core_method: identity

      adapter_args: 
        local_dim: *local_dim_m1
        local_range: *cav_lidar
        protocol_parameters:
          protocol_dim: *protocol_dim
          # Compute from: H_uni: 128, W_uni: 256; unify_range: [-102.4, -51.2, -3, 102.4, 51.2, 1]
          granularity_H: *gra_H
          granularity_W: *gra_W
        resizer:
          # input_channel:  256
          # output_channel: &input_dim_res 256
          wg_att:
            # input_dim: 256
            # mlp_dim: 256
            window_size: 8
            dim_head: 16
            drop_out: 0.1
            depth: 1
          residual:
            # input_dim: *input_dim_res
            depth: 2
        cdt:
          input_dim: *local_dim_m1
          window_size: 8
          dim_head: 16
          heads: 8
          depth: 1
        
      fusion_net:
        method: pyramid
        args: 
          resnext: true
          layer_nums: [3, 5, 8]
          layer_strides: [1, 2, 2]
          num_filters: [64, 128, 256]
          upsample_strides: [1, 2, 4]
          num_upsample_filter: [128, 128, 128]
          anchor_number: *anchor_num

          shrink_header: 
            kernal_size: [ 3 ]
            stride: [ 1 ]
            padding: [ 1 ]
            dim: [ 256 ]
            input_dim: 384 # 128 * 3

      in_head: 256

      anchor_number: *anchor_num
      dir_args: *dir_args
    
    m3:
      model_dir: stage0_m3_collab
      core_method: *core_method_m3
      sensor_type: *sensor_type_m3
      local_dim: &local_dim_m3 64

      encoder_args:
        voxel_size: *voxel_size_m3
        lidar_range: *cav_lidar
        mean_vfe:
          num_point_features: 4
        spconv:
          num_features_in: 4
          num_features_out: 64
        map2bev:
          feature_num: 128

      backbone_args:
        layer_nums: [3]
        layer_strides: [1]
        num_filters: [64]
        inplanes: 128
  
      aligner_args:
        core_method: identity

      adapter_args: 
        local_dim: *local_dim_m3
        local_range: *cav_lidar
        protocol_parameters:
          protocol_dim: *protocol_dim
          # Compute from: H_uni: 128, W_uni: 256; unify_range: [-102.4, -51.2, -3, 102.4, 51.2, 1]
          granularity_H: *gra_H
          granularity_W: *gra_W

        resizer:
          # input_channel:  256
          # output_channel: &input_dim_res 256
          wg_att:
            # input_dim: *protocol_dim
            # mlp_dim: *protocol_dim
            window_size: 8
            dim_head: 16
            drop_out: 0.1
            depth: 1
          residual:
            # input_dim: *protocol_dim
            depth: 2
        cdt:
          input_dim: *protocol_dim
          window_size: 8
          dim_head: 16
          heads: 8
          depth: 1

      fusion_net: 
        method: pyramid
        args:  
          resnext: true
          layer_nums: [3, 5, 8]
          layer_strides: [1, 2, 2]
          num_filters: [64, 128, 256]
          upsample_strides: [1, 2, 4]
          num_upsample_filter: [128, 128, 128]
          anchor_number: *anchor_num

          shrink_header: 
            kernal_size: [ 3 ]
            stride: [ 1 ]
            padding: [ 1 ]
            dim: [ 256 ]
            input_dim: 384 # 128 * 3

      in_head: 256

      anchor_number: *anchor_num
      dir_args: *dir_args


loss:
  core_method: mpda_loss_protocol
  args:
    contrastive:
      tau: 0.1
      max_voxel: 30

    pos_cls_weight: 2.0
    cls:
      type: 'SigmoidFocalLoss'
      alpha: 0.25
      gamma: 2.0
      weight: 1.0
    reg:
      type: 'WeightedSmoothL1Loss'
      sigma: 3.0
      codewise: true
      weight: 2.0
    dir:
      type: 'WeightedSoftmaxClassificationLoss'
      weight: 0.2
      args: *dir_args
    depth:
      weight: 1.0
    pyramid:
      relative_downsample: [1, 2, 4]
      weight: [0.4, 0.2, 0.1]

