#!/bin/bash

exp_name="your_experiment_name_here"
task_ids=("chartqa" "docvqa" "iconqa" "medicalqa") # You may change order
lrs=(4e-5 4e-5 4e-5 4e-5)


init_model_name_or_path="./checkpoints/llava-v1.5-7b"
output_name=""

for i in "${!task_ids[@]}"; do
  task_id="${task_ids[$i]}"
  lr="${lrs[$i]}"
  echo "Task ID: $task_id, Learning Rate: $lr"
done

for i in "${!task_ids[@]}"; do
    task_id="${task_ids[$i]}"
    lr_id="${lrs[$i]}"

    data_path="./json_files/"$task_id"_train.json"
    output_name="$output_name"-"$task_id"
    deepspeed llava/train/train_mem.py \
        --deepspeed ./scripts/zero3.json \
        --model_name_or_path $init_model_name_or_path \
        --version v1 \
        --data_path $data_path \
        --cl_data_clss $task_id \
        --image_folder ./playground/data \
        --vision_tower ./checkpoints/clip-vit-large-patch14-336 \
        --mm_projector_type mlp2x_gelu \
        --mm_vision_select_layer -2 \
        --mm_use_im_start_end False \
        --mm_use_im_patch_token False \
        --image_aspect_ratio pad \
        --group_by_modality_length True \
        --bf16 True \
        --output_dir ./output/"$exp_name"/"$exp_name""$output_name" \
        --num_train_epochs 2 \
        --per_device_train_batch_size 32 \
        --per_device_eval_batch_size 4 \
        --gradient_accumulation_steps 1 \
        --evaluation_strategy "no" \
        --save_strategy "steps" \
        --save_steps 50000 \
        --save_total_limit 1 \
        --learning_rate $lr_id \
        --weight_decay 0. \
        --warmup_ratio 0.03 \
        --lr_scheduler_type "cosine" \
        --logging_steps 1 \
        --tf32 True \
        --model_max_length 2048 \
        --gradient_checkpointing True \
        --dataloader_num_workers 4 \
        --lazy_preprocess True \
        --report_to wandb

    init_model_name_or_path=./output/"$exp_name"/"$exp_name""$output_name"
done