# bash ./bash/student_cifar100_3.sh

teacher_model=vgg13
# teacher_model=resnet32x4
# teacher_model=wrn_40_2
# teacher_model=resnet56
# teacher_model=resnet110
# teacher_model=ResNet50

student_model=vgg8
# student_model=MobileNetV2
# student_model=resnet8x4
# student_model=resnet20
# student_model=resnet32
# student_model=wrn_40_1
# student_model=wrn_16_2
# student_model=ShuffleV1
# student_model=ShuffleV2

path=JPEG1_lr_0.1_alpha_20.0_lambda_0.5
q_table_epcoh=20
GPU_ID=0

same_arch_groups=(
    "resnet20 resnet56 resnet110 ResNet50 resnet32 resnet32x4 resnet8x4"
    "wrn_40_2 wrn_40_1 wrn_16_2"
    "vgg13 vgg8"
    "MobileNetV2"
    "ShuffleV1 ShuffleV2")
get_arch() {
    for group in "${same_arch_groups[@]}"; do
        if [[ $group =~ $1 ]]; then
            echo "$group"
            return
        fi
    done
    echo "unknown"
}

teacher_arch=$(get_arch "$teacher_model")
student_arch=$(get_arch "$student_model")

if [[ "$teacher_arch" == "$student_arch" ]]; then
    echo "same architecture"
    alpha=1.01
else
    echo "different architecture"
    alpha=1.5
fi
echo "Teacher: $teacher_model, Student: $student_model, Alpha: $alpha"


# rkd itrd crd
for trial in {1..3}; do
    CUDA_VISIBLE_DEVICES="${GPU_ID}" python3.9 train_student_cifar100.py --trial ${trial} --JPEG_enable --train_mode \
                            --model_t ${teacher_model} --model_s ${student_model} --q_table_epoch ${q_table_epcoh} \
                            --distill rkd -a 0 -b 1 \
                            --base_path "./save/cifar100/teacher/$teacher_model/$path/trial_1" 

    CUDA_VISIBLE_DEVICES="${GPU_ID}" python3.9 train_student_cifar100.py --trial ${trial} --JPEG_enable --train_mode \
                            --model_t ${teacher_model} --model_s ${student_model} --q_table_epoch ${q_table_epcoh} \
                            --distill itrd -a 0 -b 1 --alpha_it ${alpha} \
                            --base_path "./save/cifar100/teacher/$teacher_model/$path/trial_1" 

    CUDA_VISIBLE_DEVICES="${GPU_ID}" python3.9 train_student_cifar100.py --trial ${trial} --JPEG_enable --train_mode \
                            --model_t ${teacher_model} --model_s ${student_model} --q_table_epoch ${q_table_epcoh} \
                            --distill crd -a 0 -b 0.8 \
                            --base_path "./save/cifar100/teacher/$teacher_model/$path/trial_1" 

done
