#!/bin/bash
export PYTHONPATH=/apdcephfs_cq12/share_302080740/user/raytseng/research/Auden-refactor-online/Auden:$PYTHONPATH

STEP=400k
VARIANT=1M

# # ----------------train from scratch using CrossEntropy loss------------------
pretrained_model_checkpoint=/apdcephfs_cq12/share_302080740/user/raytseng/data/ckpt/averaged_iter460000_avg10.pt
CUDA_VISIBLE_DEVICES=2,3 torchrun --nproc_per_node=2 \
        --master-port=29512 \
        train.py \
        exp_dir="/apdcephfs_cq12/share_302080740/user/raytseng/research/exp/speaker_id/voxceleb2_baseline_4_5e-2_bsz2560_MHAP" \
        ++model.config.pooling=mhap \
        model.config.loss_type=ce \
        model.id2label_json=configs/voxceleb2/id2label_voxceleb2.json \
        data.valid_sets='[/apdcephfs_cq12/share_302080740/user/raytseng/data/VoxCeleb2/manifest/voxceleb2_valid.jsonl.gz]' \
        data.max_duration=1280 \
        data.use_infinite_dataset=true \
        trainer.lr_steps_per_epoch=5700 \
        trainer.use_fp16=false \
        trainer.initialization.checkpoint=$pretrained_model_checkpoint \
        trainer.freeze_modules='[encoder_embed, encoder]' \
        trainer.valid_interval=500 \

        


# ----------------train from scratch using AAM-Softmax loss------------------
# export CUDA_VISIBLE_DEVICES=0,1,2,3
# torchrun --nproc_per_node=4 \
#         --master-port=29501 \
#         train.py \
#         exp_dir=exp/voxceleb2_scratch_AAMS \
#         model.id2label_json=configs/voxceleb2/id2label_voxceleb2.json \
#         model.config.loss_type=aams \
#         +model.config.margin=0.2 \
#         +model.config.scale=64 \
#         data.max_duration=200 \
#         data.use_infinite_dataset=true \
#         trainer.use_fp16=false \
#         trainer.lr_steps_per_epoch=10000 \


# # # # # ----------------ASR init------------------
# pretrained_checkpoint=/apdcephfs_cq10/share_1603164/user/yiwenyshao/independent/auden/egs/asr/exp/auden_zh_r3large_full_8gpu/averaged_iter1412000_avg10.pt
# # ASR is large model, set config_preset=large
# export CUDA_VISIBLE_DEVICES=4,5,6,7
# torchrun --nproc_per_node=4 \
#         --master-port=29502 \
#         train.py \
#         exp_dir=exp/voxceleb2_asr_init \
#         model.id2label_json=configs/voxceleb2/id2label_voxceleb2.json \
#         model.config.loss_type=ce \
#         +model.config_preset=large \
#         data.max_duration=300 \
#         data.use_infinite_dataset=true \
#         trainer.use_fp16=false \
#         trainer.lr_steps_per_epoch=10000 \
#         trainer.base_lr=0.0045 \
#         trainer.initialization.checkpoint=$pretrained_checkpoint \
