#!/bin/bash
#SBATCH --job-name=baseline_difflogic
#SBATCH --output=/longer/term/storage/%u/jobs/baseline_difflogic_%A_%a.log
#SBATCH --cpus-per-task=8
#SBATCH --mem=60G
#SBATCH --gres=gpu:1
#SBATCH --nodelist=NODES
#SBATCH --time=6:00:00
#SBATCH --array=2-2%4

# Source project variables to get storage paths
source helper_scripts/project_variables.sh

# Ensure container is synced to scratch storage
echo "Syncing container to scratch storage..."
bash helper_scripts/remote_sync_container.sh

# Set environment variables
export CUDA_VISIBLE_DEVICES=0
export MPLBACKEND=agg
export PYTHONWARNINGS="ignore::UserWarning"
export MKL_THREADING_LAYER=GNU
export PYTHONUNBUFFERED=1
export WANDB_CACHE_DIR=${SCRATCH_STORAGE_DIR}/wandb_cache
export WANDB_DATA_DIR=${SCRATCH_STORAGE_DIR}/wandb_data
export WANDB_CONFIG_DIR=${SCRATCH_STORAGE_DIR}/wandb_config
export WANDB_DIR=${SCRATCH_STORAGE_DIR}/wandb_logs

export TOKENIZER_CACHE=${SCRATCH_STORAGE_DIR}/tokenizer_cache
export CHECKPOINT_DIR=${SCRATCH_STORAGE_DIR}/checkpoints/test

# Define config files for each array task
declare -a CONFIGS=(
    "configs/wmt/transformer_new.json"
    "configs/wmt/unsynced_gru_new.json"
    "configs/wmt/unsynced_lstm_new.json"
    "configs/wmt/unsynced_recurrent_difflogic_new.json"
    "configs/wmt/unsynced_rnn_new.json"
)

# Get the config file for this array task
CONFIG_FILE=${CONFIGS[$SLURM_ARRAY_TASK_ID]}

echo "========================================="
echo "Array Task ID: $SLURM_ARRAY_TASK_ID"
echo "Config File: $CONFIG_FILE"
echo "Container: ${APPTAINER_CONTAINER_SCRATCH}"
echo "========================================="

# Run the training with the selected config
apptainer exec --nv \
  --bind ${SCRATCH_STORAGE_DIR}:/data \
  ${APPTAINER_CONTAINER_SCRATCH} \
  python3 src/main.py --config ${CONFIG_FILE}

echo "Training completed for $CONFIG_FILE"
