project_name: v2i_fusion # This is used for logging into wandb
name: HM-ViT

dataset: &dataset 'V2XSim' # V2XReal or V2XSim
root_dir: "/media/user/Dataset/V2X-Sim-Reformat/train"
validate_dir: "/media/user/Dataset/V2X-Sim-Reformat/test"
dataset_mode: 'v2i' # 'ic' or 'v2i'
img_format: 'npy' # 'jpeg' or 'npy'
modality:
  rsu_lidar: true
  rsu_camera: false
  cav_lidar: false
  cav_camera: true
ego_disconnected_rate: 0
agent_disconnected_rate: 0

yaml_parser: "load_bev_params"
train_params:
  batch_size: &batch_size 1
  epoches: &epoches 80
  eval_freq: 4
  save_freq: 1
  max_cav: &max_cav 3

fusion:
  core_method: 'IntermediateFusionDatasetV2XSim' # IntermediateFusionDatasetV2XReal, IntermediateFusionDatasetV2XSim
  args: []

# preprocess-related
preprocess:
  # options: BasePreprocessor, BevPreprocessor
  core_method: 'BevPreprocessor'
  args:
    res: &res 0.15 # discretization resolusion
    downsample_rate: &downsample_rate 4 # pixor downsample ratio
    bgr2rgb: true
    resize_x: &image_width 512
    resize_y: &image_height 512
    mean: [ 0.485, 0.456, 0.406 ]
    std: [ 0.229, 0.224, 0.225 ]
    voxel_size: &voxel_size [0.4, 0.4, 30]
  # lidar range for each individual cav.
  cav_lidar_range: &cav_lidar [ -38.4, -38.4, -10, 38.4, 38.4, 2 ]

data_augment: []

# anchor box related
postprocess:
  core_method: 'BevPostprocessor' # BevPostprocessor supported
  nms_thresh: 0.15
  anchor_args:
    cav_lidar_range: *cav_lidar
    res: *res
    downsample_rate: *downsample_rate # pixor downsample ratio
  target_args:
    score_threshold: 0.5
  order: 'lwh' # hwl or lwh
  max_num: 300 # maximum number of objects in a single frame. use this number to make sure different frames has the same dimension in the same batch

# model related
model:
  core_method: hm_vit # rsu_modality_intermediate_plainfuse, cam_intermediate_fuse_single, pixor_intermediate
  args:
    dataset: *dataset
    max_cav: *max_cav
    camera_encoder:
      num_layers: 34
      pretrained: true
      image_width: *image_width
      image_height: *image_height
      id_pick: [1, 2, 3]

    use_bn: True
    decode: False
    point_pillar_scatter:
      num_features: 64

    fax:
      dim: [ 128, 128, 128 ] # b, d, h w from resenet -> b 256 h w [128, 128, 128]
      middle: [ 2, 2, 2 ] # middle conv
      bev_embedding:
        sigma: 1.0
        bev_height: 256
        bev_width: 256
        h_meters: 100
        w_meters: 100
        offset: 0.0
        upsample_scales: [ 2, 4, 8 ]

      cross_view: #cross_view attention
        image_height: *image_height
        image_width: *image_width
        no_image_features: False
        skip: True
        heads: [ 4, 4, 4 ]
        dim_head: [ 32, 32, 32 ]
        qkv_bias: True

      cross_view_swap:
        rel_pos_emb: False
        q_win_size: [ [ 16, 16 ], [ 16, 16 ], [ 32, 32 ] ]
        feat_win_size: [ [ 8, 8 ], [ 8, 8 ], [ 16, 16 ] ]
        bev_embedding_flag: [ true, false, false ]

      self_attn:
        dim_head: 32
        dropout: 0.1 #0.1
        window_size: 32

    residual_vq:
      input_dim: 128
      accept_image_fmap: True
      codebook_size_ls: [128, 128, 128]
      num_quantizers: 3

    img_decoder:
      input_dim: 128
      num_layer: 2
      num_ch_dec: [128, 128]

    fusion_net:
      input_dim: 128
      mlp_dim: 256
      agent_size: *max_cav # TODO: *max_cav
      window_size: 8
      dim_head: 32
      drop_out: 0.1 #0.1
      depth: 3
      mask: false

    sttf: &sttf
      resolution: 0.390625 # m/pixel
      downsample_rate: 8
      use_roi_mask: true

    decoder:
      input_dim: 128
      num_layer: 1
      num_ch_dec: [64]

    head_dim:
      64

loss:
  core_method: segment_focal_loss
  args:
    gamma: 1.0 # 2.0
    seg_weight: 2.0
    cmt_weight: 5.0


optimizer:
  core_method: AdamW
  lr: 2e-4 #0.0002
  args:
    eps: 1e-10 #1e-10
    weight_decay: 1e-2

lr_scheduler:
  core_method: cosineannealwarm #step, multistep, cosineannealwarm and Exponential support
  epoches: *epoches
  warmup_lr: 2e-5
  warmup_epoches: 10
  lr_min: 5e-6
