#!/bin/bash

# run "accelerate config" first!
output_dir=./LaMed/output/LaMed-Phi3-4B-multimodal-combined-finetune-freeze-viz-again-new-dataset-0000
train_path=brats_gli_3d_vqa_subjTrue_train_updated_v3_seed0_multitask_fixed.json
val_path=brats_gli_3d_vqa_subjTrue_val_updated_v3_seed0_multitask_fixed.json
test_path=brats_gli_3d_vqa_subjTrue_test_updated_v3_seed0_multitask_fixed.json
accelerate launch --gpu_ids $1 LaMed/src/train/train.py \
    --version v0 \
    --model_name_or_path microsoft/Phi-3-mini-4k-instruct \
    --model_type phi3 \
    --multimodal True \
    --combined_projector True \
    --freeze_vision_tower True \
    --pretrain_mm_mlp_adapter ./LaMed/pretrained_model/M3D-LaMed-Phi-3-4B/mm_projector.bin \
    --vqa_data_train_path $train_path \
    --vqa_data_val_path $val_path \
    --vqa_data_test_path $test_path \
    --lora_enable True \
    --vision_tower vit3d \
    --pretrain_vision_model ./LaMed/pretrained_model/M3D-CLIP/pretrained_ViT.bin \
    --bf16 True \
    --output_dir $output_dir \
    --num_train_epochs 2 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --evaluation_strategy "steps" \
    --eval_accumulation_steps 1 \
    --eval_steps 0.5 \
    --save_strategy "steps" \
    --save_steps 1000 \
    --save_total_limit 1 \
    --learning_rate 5e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 0.001 \
    --gradient_checkpointing False \
    --dataloader_pin_memory True\
    --dataloader_num_workers 8 \
    --report_to tensorboard

PYTHONPATH=. CUDA_VISIBLE_DEVICES="" python LaMed/src/utils/merge_lora_weights_and_save_hf_model.py \
--version="" --model_type="phi3" \
--model_with_lora="$output_dir"/model_with_lora.bin \
--output_dir="$output_dir"/hf

PYTHONPATH=. CUDA_VISIBLE_DEVICES=$1 python Bench/eval/eval_vqa.py \
--output_dir $output_dir/eval_vqa \
--model_name_or_path "$output_dir"/hf \
--vqa_data_test_path $test_path \

PYTHONPATH=. CUDA_VISIBLE_DEVICES=$1 python Bench/eval/eval_vqa_utils.py \
--output_dir $output_dir/eval_vqa \
--gt_file $test_path \

# run "accelerate config" first!
output_dir=./LaMed/output/LaMed-Phi3-4B-multimodal-combined-finetune-freeze-viz-goat-new-dataset-0000
train_path=brats_goat_3d_vqa_subjTrue_train_updated_v3_seed0_multitask_fixed.json
val_path=brats_goat_3d_vqa_subjTrue_val_updated_v3_seed0_multitask_fixed.json
test_path=brats_goat_3d_vqa_subjTrue_test_updated_v3_seed0_multitask_fixed.json
accelerate launch --gpu_ids $1 LaMed/src/train/train.py \
    --version v0 \
    --model_name_or_path microsoft/Phi-3-mini-4k-instruct \
    --model_type phi3 \
    --multimodal True \
    --combined_projector True \
    --freeze_vision_tower True \
    --pretrain_mm_mlp_adapter ./LaMed/pretrained_model/M3D-LaMed-Phi-3-4B/mm_projector.bin \
    --vqa_data_train_path $train_path \
    --vqa_data_val_path $val_path \
    --vqa_data_test_path $test_path \
    --lora_enable True \
    --vision_tower vit3d \
    --pretrain_vision_model ./LaMed/pretrained_model/M3D-CLIP/pretrained_ViT.bin \
    --bf16 True \
    --output_dir $output_dir \
    --num_train_epochs 2 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --evaluation_strategy "steps" \
    --eval_accumulation_steps 1 \
    --eval_steps 0.5 \
    --save_strategy "steps" \
    --save_steps 1000 \
    --save_total_limit 1 \
    --learning_rate 5e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 0.001 \
    --gradient_checkpointing False \
    --dataloader_pin_memory True\
    --dataloader_num_workers 8 \
    --report_to tensorboard

PYTHONPATH=. CUDA_VISIBLE_DEVICES="" python LaMed/src/utils/merge_lora_weights_and_save_hf_model.py \
--version="" --model_type="phi3" \
--model_with_lora="$output_dir"/model_with_lora.bin \
--output_dir="$output_dir"/hf

PYTHONPATH=. CUDA_VISIBLE_DEVICES=$1 python Bench/eval/eval_vqa.py \
--output_dir $output_dir/eval_vqa \
--model_name_or_path "$output_dir"/hf \
--vqa_data_test_path $test_path \

PYTHONPATH=. CUDA_VISIBLE_DEVICES=$1 python Bench/eval/eval_vqa_utils.py \
--output_dir $output_dir/eval_vqa \
--gt_file $test_path \

# run "accelerate config" first!
output_dir=./LaMed/output/LaMed-Phi3-4B-multimodal-combined-finetune-freeze-viz-met-new-dataset-0000
train_path=brats_met_3d_vqa_subjTrue_train_updated_v3_seed0_multitask_fixed.json
val_path=brats_met_3d_vqa_subjTrue_val_updated_v3_seed0_multitask_fixed.json
test_path=brats_met_3d_vqa_subjTrue_test_updated_v3_seed0_multitask_fixed.json
accelerate launch --gpu_ids $1 LaMed/src/train/train.py \
    --version v0 \
    --model_name_or_path microsoft/Phi-3-mini-4k-instruct \
    --model_type phi3 \
    --multimodal True \
    --combined_projector True \
    --freeze_vision_tower True \
    --pretrain_mm_mlp_adapter ./LaMed/pretrained_model/M3D-LaMed-Phi-3-4B/mm_projector.bin \
    --vqa_data_train_path $train_path \
    --vqa_data_val_path $val_path \
    --vqa_data_test_path $test_path \
    --lora_enable True \
    --vision_tower vit3d \
    --pretrain_vision_model ./LaMed/pretrained_model/M3D-CLIP/pretrained_ViT.bin \
    --bf16 True \
    --output_dir $output_dir \
    --num_train_epochs 2 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --evaluation_strategy "steps" \
    --eval_accumulation_steps 1 \
    --eval_steps 0.5 \
    --save_strategy "steps" \
    --save_steps 1000 \
    --save_total_limit 1 \
    --learning_rate 5e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 0.001 \
    --gradient_checkpointing False \
    --dataloader_pin_memory True\
    --dataloader_num_workers 8 \
    --report_to tensorboard

PYTHONPATH=. CUDA_VISIBLE_DEVICES="" python LaMed/src/utils/merge_lora_weights_and_save_hf_model.py \
--version="" --model_type="phi3" \
--model_with_lora="$output_dir"/model_with_lora.bin \
--output_dir="$output_dir"/hf

PYTHONPATH=. CUDA_VISIBLE_DEVICES=$1 python Bench/eval/eval_vqa.py \
--output_dir $output_dir/eval_vqa \
--model_name_or_path "$output_dir"/hf \
--vqa_data_test_path $test_path \

PYTHONPATH=. CUDA_VISIBLE_DEVICES=$1 python Bench/eval/eval_vqa_utils.py \
--output_dir $output_dir/eval_vqa \
--gt_file $test_path \
