python train.py --multiprocessing_distributed \
    --train_data_path /path/to/avsbench/train/ \
    --train_pseudo_gt_path /path/to/train/pseudo_gt_masks \
    --test_data_path /path/to/avsbench/val/ \
    --test_gt_path /path/to/avsbench/val/gt_masks \
    --experiment_name wsavs_avsbench \
    --model 'wsavs' \
    --weight_msmil 1. \
    --weight_pixel 1. \
    --imgnet_type resnet50 --audnet_type resnet50 \
    --trainset 'avsbench' \
    --testset 'avsbench' \
    --epochs 20 \
    --batch_size 64 \
    --init_lr 0.0001 \
    --dropout_img 0.9 --dropout_aud 0