#!/bin/bash
#SBATCH --account project_account
#SBATCH --job-name h-pawsx
#SBATCH --nodes=1
#SBATCH --gpus-per-node=1
#SBATCH --ntasks-per-node=2  # Allocate 30 CPU tasks
#SBATCH --cpus-per-task=2     # Each task gets 2 CPUs (60 CPUs total)
#SBATCH --time=04:40:00
#SBATCH --cluster=
#SBATCH --partition=
#SBATCH --mail-type=BEGIN,END,FAIL

conda init
conda activate fxt
conda env list 
nvidia-smi
export PYTHONPATH=.

cat $0
echo "--------------------"

export PYTHONPATH=$(pwd)
export HF_HOME="cache"
export WANDB_CACHE_DIR="cache"

C=configs/finetune/pawsx_routing_3_6_12_1bp.yml
GPUS=1
config_file=configs/accelerate/gpu_1.yaml

SEEDS=(42)
LRS=(5e-5)
BSZS=(32)
gradient_accumulation_steps=2
LANGS=(en)


for SEED in "${SEEDS[@]}";do
    echo "Starting with seed ${SEED}"

    for LR in "${LRS[@]}";do

        for BSZ in "${BSZS[@]}";do

            for language in "${LANGS[@]}"; do
                work_dir="model_ckpts/downstream/your_username_pawsx_joint_input_gridsearch_routing_3x_6x_12x/${language}_pawsx_fixed_routing_seed${SEED}_bsz${BSZ}_lr${LR}_clip1.0_cosine_schedule"
                echo $work_dir
                echo 'Finding free port'
                PORT=$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()')
                accelerate launch --main_process_port=$PORT --config_file=$config_file --num_processes="$GPUS" src/finetune/train_classification.py  \
                    --config_file "$C" \
                    --work_dir $work_dir \
                    --language $language \
                    --lr  $LR \
                    --batch_size $BSZ \
                    --seed $SEED \
                    --gradient_accumulation_steps $gradient_accumulation_steps \

            done
        done
    done
done