#!/bin/bash

mkdir -p slurmout

MODELS=("apple/OpenELM-3B" "Qwen/Qwen2-1.5B")
DATASETS=("imdb" "imdb-large")
weight_decays=(0.0 0.1 0.2 0.4 0.6 0.8 1.0 1.25 1.5 1.75 2.0)

# Calculate indices
model_index=$((SLURM_ARRAY_TASK_ID % ${#MODELS[@]}))
dataset_index=$(((SLURM_ARRAY_TASK_ID / ${#MODELS[@]}) % ${#DATASETS[@]}))
weight_decay_index=$((SLURM_ARRAY_TASK_ID / (${#MODELS[@]} * ${#DATASETS[@]})))

# Select model, dataset, and weight decay
MODEL=${MODELS[$model_index]}
DATASET=${DATASETS[$dataset_index]}
WEIGHT_DECAY=${weight_decays[$weight_decay_index]}

# Set seed (constant for now)
SEED=1

# Dynamically set the SLURM output and error filenames
export SLURM_OUTPUT="slurmout/llm_wd_fairness_omega_wd${WEIGHT_DECAY}_sd${SEED}_%A_%a.out"
export SLURM_ERROR="slurmout/llm_wd_fairness_omega_wd${WEIGHT_DECAY}_sd${SEED}_%A_%a.err"

echo "Running training with weight decay: $WEIGHT_DECAY and seed: $SEED for model: $MODEL and dataset: $DATASET"

# Change to the working directory
cd SOMEWHERE

# Initialize conda (important for SLURM scripts)
eval "$(conda shell.bash hook)"

# Activate the conda environment
conda activate llm-wd

# Use the dynamically set SLURM_ARRAY_RANGE in srun
srun --output=$SLURM_OUTPUT --error=$SLURM_ERROR \
    python train.py \
    --weight-decay $WEIGHT_DECAY \
    --seed $SEED \
    --backbone $MODEL \
    --dataset $DATASET \
    --per-device-batch-size 16 \
    --gradient-accumulation-steps 4 \
    --max-length 64 \
