#!/bin/sh
set -x
swin_yaml=swin_base__800ep/customized_simmim_finetune__swin_base__img224_window7__800ep.yaml
vit_yaml=vit_base__800ep/customized_simmim_finetune__vit_base__img224__800ep.yaml
wo_wd_swin_yaml=swin_base__800ep/customized_wo_wd_simmim_finetune__swin_base__img224_window7__800ep.yaml
wo_wd_swin_small_yaml=swin_small/customized_simmim_finetune__swin_small__img224_window7.yaml
resnet_yaml=resnet18.yaml

root_path=/root-path/
runs=()

runs[0]=path1
for var in ${runs[@]}; do
    bs=128
    if [[ $var == *"vit"* ]]; then
        config=${vit_yaml}
    elif [[ $var == *"resnet18"* ]]; then
        config=${resnet_yaml}
        bs=64
    else
        config=${wo_wd_swin_small_yaml}
    fi

    if [[ $var == *"_sup"* ]]; then
        lrm=0.5
    elif [[ $var == *"_siam_"* ]]; then
        lrm=1.0
    else
        lrm=5.0
    fi

    prt_args=${root_path}${var}    
    tid=9
    fntid=8
    lrmstr="${lrm/./"p"}"  

    python -m torch.distributed.launch --master_port=$1 --nproc_per_node=$2 main_finetune.py \
        --cfg configs/$config \
        --data-path /st1/dataset/imagenet1k/raw-data/ \
        --batch-size $bs --output ./ --tag $3_FT/${var}/lrmul${lrmstr}_wd0p0/t${tid}data_and_t0model \
        --task-id $tid \
        --weight-decay 0.0 \
        --lr-multiplier $lrm \
        --pretrained "${prt_args}_0.pth"

    python -m torch.distributed.launch --master_port=$1 --nproc_per_node=$2 main_finetune.py \
        --cfg configs/$config \
        --data-path /st1/dataset/imagenet1k/raw-data/ \
        --batch-size $bs --output ./ --tag $3_LP/${var}/lrmul${lrmstr}_wd0p0/t${tid}data_and_t${tid}model \
        --task-id $tid \
        --weight-decay 0.0 \
        --lr-multiplier $lrm \
        --linear-probe \
        --pretrained "${prt_args}_${tid}.pth"

    python -m torch.distributed.launch --master_port=$1 --nproc_per_node=$2 main_finetune.py \
        --cfg configs/$config \
        --data-path /st1/dataset/imagenet1k/raw-data/ \
        --batch-size $bs --output ./ --tag $3_FT/${var}/lrmul${lrmstr}_wd0p0/t${tid}data_and_t${fntid}model \
        --task-id $tid \
        --weight-decay 0.0 \
        --lr-multiplier $lrm \
        --pretrained "${prt_args}_${fntid}.pth"

    python -m torch.distributed.launch --master_port=$1 --nproc_per_node=$2 main_finetune.py \
        --cfg configs/$config \
        --data-path /st1/dataset/imagenet1k/raw-data/ \
        --batch-size $bs --output ./ --tag $3_LP/${var}/lrmul${lrmstr}_wd0p0/t${tid}data_and_t${fntid}model \
        --task-id $tid \
        --weight-decay 0.0 \
        --lr-multiplier $lrm \
        --linear-probe \
        --pretrained "${prt_args}_${fntid}.pth"
done