#!/bin/bash
# export PYTHONPATH=/apdcephfs_cq10/share_1603164/user/yiwenyshao/independent/auden:/apdcephfs_cq10/share_1603164/user/yiwenyshao/lhotse:/apdcephfs_cq7/share_1297902/common/allenycwang/Auden:$PYTHONPATH
export PYTHONPATH=/apdcephfs_cq10/share_1603164/user/yiwenyshao/lhotse:/apdcephfs_cq7/share_1297902/common/allenycwang/Auden:$PYTHONPATH

export CUDA_VISIBLE_DEVICES=0
# ----------------train from scratch------------------
# torchrun --nproc_per_node=1 \
#         --master_port=25769 \
#          train.py \
#          exp_dir=exp/auden_mtt_top50_from_scratch \
#          model.id2label_json=configs/mtt_tag/id2label_mtt_tag.json \
#          model.config.is_multilabel=true \
#          trainer.valid_interval=100 \
#          trainer.base_lr=0.015 \
#          data.max_duration=480 \
#          data.use_infinite_dataset=False \
#          ++data.train_data_config='configs/mtt_tag/train_data_config_mtt_tag.yaml' \
#          ++data.valid_sets='["/apdcephfs_cq7/share_1297902/common/allenycwang/data/music_test_data/manifests/MTT_tag_valid.jsonl.gz"]' \
#          ++model.num_encoder_layers='[2,2,3,4,3,2]' \
#          ++model.feedforward_dim='[512,768,1024,1536,1024,768]' \
#          ++model.encoder_dim='[192,256,384,512,384,256]' \
#          ++model.config.fuse_encoder=true \
#          ++data.label_field=music_tag \

# ----------------finetune from audioset------------------
pretrained_model_checkpoint=/apdcephfs_cq10/share_1603164/user/yiwenyshao/independent/auden/egs/audio_tag/exp/audioset_bucket_2M_orig/averaged_epoch30_avg10.pt
torchrun --nproc_per_node=1 \
        --master_port=25756 \
         train.py \
         exp_dir=exp/auden_mtt_top50_audiosetinit_2 \
         model.id2label_json=configs/mtt_tag/id2label_mtt_tag.json \
         model.config.is_multilabel=true \
         trainer.initialization.checkpoint=$pretrained_model_checkpoint \
         trainer.initialization.init_modules='[encoder_embed, encoder]' \
         trainer.valid_interval=100 \
         trainer.base_lr=0.0045 \
         trainer.lr_batches=100000 \
         trainer.lr_epochs=100 \
         data.max_duration=480 \
         data.use_infinite_dataset=False \
         ++data.train_data_config='configs/mtt_tag/train_data_config_mtt_tag.yaml' \
         ++data.valid_sets='["/apdcephfs_cq7/share_1297902/common/allenycwang/data/music_test_data/manifests/MTT_tag_valid.jsonl.gz"]' \
         ++trainer.freeze_modules='[encoder_embed, encoder]' \
         ++model.config.num_encoder_layers='[2,2,3,4,3,2]' \
         ++model.config.feedforward_dim='[512,768,1024,1536,1024,768]' \
         ++model.config.encoder_dim='[192,256,384,512,384,256]' \
         ++model.config.fuse_encoder=true \
         ++data.label_field=music_tag \

