#!/bin/bash
# Define configurations
models=("mlp" "causal_transformer")
SMs=("random" "uniform" "skew")
split_sizes=(0.3)
WDs=(0.3 0.5 1)
Cs=(1 10)
LRs=(0.001 0.005 0.009 0.01)

GPs=("amodp" "modp" "dmodp" "sum_squares_modp")
batch_sizes=(64 128 512)
# Total number of devices
devices=(0 1 2 3)
# Clean the logs folder
rm -rf checkpoints/logs/*
mkdir -p checkpoints/logs

# Create a command file to hold all job commands
cmd_file="job_commands.txt"
rm -f $cmd_file

# Generate all commands
for gr in "${GPs[@]}"; do
  for model in "${models[@]}"; do
    for sam in "${SMs[@]}"; do
      for sp in "${split_sizes[@]}"; do
        for weightDecay in "${WDs[@]}"; do
          for lrnRate in "${LRs[@]}"; do
            for batch_size in "${batch_sizes[@]}"; do
              for c in "${Cs[@]}"; do
                # Append the command to the file
                echo "python main_train.py \
                  --model $model \
                  --sample $sam \
                  --group $gr \
                  --split  $sp \
                  --wd $weightDecay \
                  --wd_scale $c \
                  --lr $lrnRate \
                  --batch_size $batch_size \
                  --device \$((PARALLEL_SEQ % ${#devices[@]})) \
                  > checkpoints/logs/${model}_${sam}_${gr}_bs${batch_size}_wd${weightDecay}_lr${lrnRate}_c${c}.log 2>&1" >> $cmd_file
              done
            done
          done
        done
      done
    done
  done
done

# Run all commands in parallel using GNU Parallel
parallel --joblog parallel_log.txt --results parallel_results -j ${#devices[@]} < $cmd_file

echo "All jobs have been executed in parallel across ${#devices[@]} devices."
