#!/bin/bash
trap 'exit' SIGINT
set -e

export CUDA_VISIBLE_DEVICES=0,1
export DS_SKIP_CUDA_CHECK=1

export NCCL_P2P_DISABLE=1
export NCCL_IB_DISABLE=1

accelerate_config=config/accelerate_config/acc_config.yaml

mixed_precision=fp16

log_dir=logs/
log_step=50
checkpoints_dir=checkpoints/
save_step=5
val_step=1
seed=42
use_resume=false

num_epochs=150
lr=0.0003
batch_size=64
gradient_accumulation_steps=1

subj_list=(1) 
nsddir=/opt/data/private/dataset/nsd
space=MNI_2mm
func=betas_fithrf_GLMdenoise_RR
clip_model=CLIP-ViT-H-14
norm_nii=true
text_scale=1
mixup_pct=0
patch_size=12
num_blocks=12
patch_drop=0
attn_drop=0
block_drop=0
nii_mask=brain
patch_type=conv
use_image_aug=false
mixin=false
buffer_size=4096
local_loss=false
local_loss2=false  # requires accelerate v0.16.0+
gather_with_grad=false

subj_list_param="${subj_list[@]}"
subj_list_string=$(IFS=','; echo "${subj_list[*]}")

name="ddp2-0.01-qwen25vl-s${subj_list_string}-${clip_model}-mask=${nii_mask}-blocks=${num_blocks}-patch=${patch_type}_${patch_size}-pt_dp=${patch_drop}-at_dp=${attn_drop}-bo_dp=${block_drop}-txt_scl=${text_scale}-m_pct=${mixup_pct}-bs=${batch_size}-lr=${lr}-bf=${buffer_size}"


if ${norm_nii}; then 
        name+="-norm_nii" 
fi
if ${mixin}; then
        name+="-mixin" 
fi
if ${use_image_aug}; then
        name+="-imgaug" 
fi
if ${local_loss}; then
        name+="-local_loss" 
fi
if ${local_loss2}; then
        name+="-local_loss2"
fi
if ${gather_with_grad}; then
        name+="-gather_with_grad"
fi


# run_cmd="python "
run_cmd="accelerate launch 
        --main_process_port=$((1024 + RANDOM % 64512)) 
        --config_file ${accelerate_config} 
        --mixed_precision=${mixed_precision} "
run_cmd+="src/main.py 
        > output/train/${name}.log 
        --name ${name} 
        --mixed_precision ${mixed_precision} 
        --log_dir ${log_dir} 
        --log_step ${log_step} 
        --checkpoints_dir ${checkpoints_dir} 
        --save_step ${save_step} 
        --val_step ${val_step} 
        --seed ${seed} 
        --num_epochs ${num_epochs} 
        --lr ${lr} 
        --batch_size ${batch_size} 
        --gradient_accumulation_steps ${gradient_accumulation_steps} 
        --subj_list ${subj_list_param} 
        --nsddir ${nsddir} 
        --space ${space} 
        --func ${func} 
        --clip_model ${clip_model} 
        --text_scale ${text_scale} 
        --mixup_pct ${mixup_pct} 
        --patch_size ${patch_size} 
        --num_blocks ${num_blocks} 
        --patch_type ${patch_type} 
        --patch_drop ${patch_drop} 
        --attn_drop ${attn_drop} 
        --block_drop ${block_drop} 
        --nii_mask ${nii_mask} 
        --buffer_size ${buffer_size}"
if ${norm_nii}; then 
        run_cmd+=" --norm_nii" 
fi
if ${mixin}; then
        run_cmd+=" --mixin" 
fi
if ${use_resume}; then 
        run_cmd+=" --use_resume" 
fi
if ${use_image_aug}; then
        run_cmd+=" --use_image_aug" 
fi
if ${local_loss}; then
        run_cmd+=" --local_loss" 
fi
if ${local_loss2}; then
        run_cmd+=" --local_loss2"
fi
if ${gather_with_grad}; then
        run_cmd+=" --gather_with_grad"
fi

echo ${run_cmd}
eval ${run_cmd} 

tag=last
run_cmd="python src/test.py 
        > output/test/${name}.log 
        --name ${name} 
        --checkpoints_dir ${checkpoints_dir}  
        --seed ${seed} 
        --batch_size ${batch_size} 
        --subj_list ${subj_list_param} 
        --nsddir ${nsddir} 
        --space ${space} 
        --func ${func} 
        --clip_model ${clip_model} 
        --patch_size ${patch_size} 
        --num_blocks ${num_blocks} 
        --patch_type ${patch_type} 
        --patch_drop ${patch_drop} 
        --attn_drop ${attn_drop} 
        --block_drop ${block_drop} 
        --nii_mask ${nii_mask} 
        --tag ${tag}"
if ${norm_nii}; then 
        run_cmd+=" --norm_nii" 
fi

echo ${run_cmd}
eval ${run_cmd} 