#!/bin/bash
export PYTHONPATH=/apdcephfs_cq12/share_302080740/user/raytseng/research/Auden-refactor-online/Auden:$PYTHONPATH

# # ----------------train from scratch------------------
# torchrun --nproc_per_node=1 \
#         --master_port=25679 \
#          train.py \
#          exp_dir=exp/auden_MTT_tag_fromscratch_debug \
#          model.id2label_json=configs/MTT_tag/id2label_MTT_tag.json \
#          trainer.valid_interval=100 \
#          trainer.base_lr=0.025 \
#          data.max_duration=480 \
#          data.use_infinite_dataset=True \
#          ++data.train_data_config='configs/MTT_tag/train_data_config_MTT_tag.yaml' \
#          ++data.valid_sets='["/apdcephfs_cq7/share_1297902/common/allenycwang/data/music_test_data/manifests/MTT_tag_test.jsonl.gz"]' \
#          ++model.num_encoder_layers='[2,2,3,4,3,2]' \
#          ++model.feedforward_dim='[512,768,1024,1536,1024,768]' \
#          ++model.encoder_dim='[192,256,384,512,384,256]' \
#          ++model.config.fuse_encoder=true \
#          ++data.label_field=instrument \

# ----------------finetune from audioset------------------
pretrained_model_checkpoint=/apdcephfs_cq12/share_302080740/user/raytseng/research/Auden-refactor-online/Auden/egs/audio_captioning/exp/CaptionStew_Full_masked_captioning_AU_TU_5e-3_fp16_bsz5120_shuffled/checkpoint-200000.pt
CUDA_VISIBLE_DEVICES=3 torchrun --nproc_per_node=1\
        --master_port=25683 \
         train.py \
         exp_dir=/apdcephfs_cq12/share_302080740/user/raytseng/research/exp/music/MTT_tag_CaptionStew_10M_captioningPa_TUAU_200k_4_5e-2_bsz1280_MeanPooling \
         model.id2label_json=configs/mtt_tag/id2label_mtt_tag.json \
         trainer.initialization.checkpoint=$pretrained_model_checkpoint \
         trainer.valid_interval=200 \
         trainer.base_lr=0.045 \
         data.use_infinite_dataset=true \
         trainer.lr_steps_per_epoch=2000 \
         data.max_duration=2048 \
         trainer.freeze_modules=['encoder_embed','encoder'] \
         ++data.train_data_config='configs/mtt_tag/train_data_config_mtt_tag.yaml' \
         ++data.valid_sets='["/apdcephfs_cq7/share_1297902/common/allenycwang/data/music_test_data/manifests/MTT_tag_test.jsonl.gz"]' \
         ++data.label_field=music_tag \
         model.config.is_multilabel=true \