#!/bin/bash
#SBATCH --account project_account
#SBATCH --job-name 0.3scaledl
#SBATCH --nodes=1
#SBATCH --gpus-per-node=2
#SBATCH --ntasks-per-node=10  # Allocate 30 CPU tasks
#SBATCH --cpus-per-task=5     # Each task gets 2 CPUs (60 CPUs total)
#SBATCH --time=25:30:00
#SBATCH --mail-type=BEGIN,END,FAIL



source ~/miniforge3/etc/profile.d/conda.sh
conda activate fxt
conda env list 
nvidia-smi

cat $0
echo "--------------------"

export PYTHONPATH=$(pwd)
export HF_HOME=/fs/scratch/project_account/your_username/.cache/
export WANDB_CACHE_DIR=/fs/scratch/project_account/your_username/.cache/

exp_dir=/fs/scratch/project_account/your_username/experiments/fxt/model_ckpts/
C=configs/train/fxt_baseline_1_bp_6_priors_0.3_en_hard_no_binomial_lambda_1_scaled_std.yaml
GPUS=2
accelerate_config_file=configs/accelerate/gpu_$GPUS.yaml
export TORCH_DISTRIBUTED_DEBUG=DETAIL

echo 'Run training...'

if [ -z $GPUS ]
then
    python src/train.py --config_file "$C" --work_dir $work_dir 
else
    echo 'Finding free port'    
    PORT=$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()')
    accelerate launch --main_process_port=$PORT --config_file=$accelerate_config_file --num_processes="$GPUS" src/train/train.py --config_file "$C" --exp_dir $exp_dir --with_tracking True  #--resume_from  /fs/ess/project_account/your_username/experiments/fxt-base/model_ckpts/fxt_baseline_1_bp_1_prior_en_no_residual_0.2/_2025-02-03_18-10-25/step_42000
fi

