#!/bin/bash

# Set default values
task=${1:-caption}
dataset=${2:-coco2017}
device_num=${3:-1}
lever_lm_model=${4:-query_img_ice_text}

which python
check_command_status() {
    local pid=$1
    # 循环检查command1的状态，直到它完成
    while ps -p $pid > /dev/null; do
        echo "command $pid not done"
        sleep 60
    done
}

pids=()
for pid in "${pids[@]}"; do
    check_command_status $pid
done

# Define a function to run the train.py script with the given parameters
run_train() {
    local data_file=$1
    local val_step=$2
    local ex_name_suffix=$3

    local ex_name_prefix="ab_${task}"

    if [ "${task}" == "vqa" ]; then
        echo "==========Begin: ${ex_name_prefix}_${ex_name_suffix}-ICLM: ${lever_lm_model}==========" 
        python train.py train="${lever_lm_model}" \
            data_files="${data_file}" \
            epochs=20 \
            val_step=${val_step} \
            ex_name="${ex_name_prefix}_${ex_name_suffix}_${lever_lm_model}" \
            device_num=${device_num} \
            dataset=${dataset} \
            task=${task}

    elif [ "${task}" == "caption" ]; then
        echo "==========Begin: ${ex_name_prefix}_${ex_name_suffix}-ICLM: ${lever_lm_model}==========" 
        python train.py train="${lever_lm_model}" \
            data_files="${data_file}" \
            epochs=20 \
            val_step=${val_step} \
            ex_name="${ex_name_prefix}_${ex_name_suffix}_non_norm_freeze_adapter_${lever_lm_model}" \
            device_num=${device_num} \
            dataset=${dataset} \
            task=${task} \
            train.lever_lm_model.norm=false \
            train.lever_lm_model.freeze_prefix_list="[img_model,sen_model]" \
            train.lever_lm_model.adpter=true
    fi
}

if [ "${task}" == "vqa" ]; then
    # VQA mode
    run_train "vqa-vqav2-only_y_loss-OpenFlamingo-9B-vitl-mpt7b-random-beam_size:5-few_shot:2-candidate_set_num:64-sample_num:5000.json" 80 "baseline"
    run_train "vqa-vqav2-only_y_loss-OpenFlamingo-9B-vitl-mpt7b-random-beam_size:1-few_shot:2-candidate_set_num:64-sample_num:5000.json" 80 "1beam"
    run_train "vqa-vqav2-only_y_loss-OpenFlamingo-9B-vitl-mpt7b-random-beam_size:10-few_shot:2-candidate_set_num:64-sample_num:5000.json" 80 "10beam"
    run_train "vqa-vqav2-only_y_loss-OpenFlamingo-9B-vitl-mpt7b-random-beam_size:5-few_shot:2-candidate_set_num:128-sample_num:5000.json" 80 "128candidate"
    run_train "vqa-vqav2-only_y_loss-OpenFlamingo-9B-vitl-mpt7b-text-sim-beam_size:5-few_shot:2-candidate_set_num:64-sample_num:5000.json" 80 "text-sim"
    run_train "vqa-vqav2-only_y_loss-OpenFlamingo-9B-vitl-mpt7b-image-sim-beam_size:5-few_shot:2-candidate_set_num:64-sample_num:5000.json" 80 "img-sim"
    run_train "vqa-vqav2-only_y_loss-OpenFlamingo-9B-vitl-mpt7b-random-beam_size:5-few_shot:2-candidate_set_num:64-sample_num:10000.json" 160 "1wanchors"

elif [ "${task}" == "caption" ]; then
    # Caption mode
    run_train "caption-coco2017-only_y_loss-OpenFlamingo-9B-vitl-mpt7b-random-beam_size:5-few_shot:2-candidate_set_num:64-sample_num:5000.json" 80 "baseline"
    run_train "caption-coco2017-only_y_loss-OpenFlamingo-9B-vitl-mpt7b-random-beam_size:1-few_shot:2-candidate_set_num:64-sample_num:5000.json" 80 "1beam"
    run_train "caption-coco2017-only_y_loss-OpenFlamingo-9B-vitl-mpt7b-random-beam_size:10-few_shot:2-candidate_set_num:64-sample_num:5000.json" 80 "10beam"
    run_train "caption-coco2017-only_y_loss-OpenFlamingo-9B-vitl-mpt7b-random-beam_size:5-few_shot:2-candidate_set_num:128-sample_num:5000.json" 80 "128candidate"
    run_train "caption-coco2017-only_y_loss-OpenFlamingo-9B-vitl-mpt7b-text-sim-beam_size:5-few_shot:2-candidate_set_num:64-sample_num:5000.json" 80 "text-sim"
    run_train "caption-coco2017-only_y_loss-OpenFlamingo-9B-vitl-mpt7b-image-sim-beam_size:5-few_shot:2-candidate_set_num:64-sample_num:5000.json" 80 "img-sim"
    run_train "caption-coco2017-only_y_loss-OpenFlamingo-9B-vitl-mpt7b-random-beam_size:5-few_shot:2-candidate_set_num:64-sample_num:10000.json" 160 "1wanchors"
else
    echo "Invalid task. Please choose 'vqa' or 'caption'."
    exit 1
fi
