#!/usr/bin/env bash
set -e            # exit on the first error
set -u            # treat unset variables as errors
set -o pipefail   # catch errors in pipelines

BASE="python3 train_model.py --plot_classification --plot_pca"

for seed in {0..4}; do
 for num_layers in 3 2 1; do
    for p_train in 1 0.99 0.95 0.8 0.5; do
      for norm_flag in "none" "--use_rms"; do

        # Determine tag for the output_path and build the command
        if [[ "$norm_flag" == "none" ]]; then
          norm_tag="LN"      # default layer-norm
          cmd="$BASE --num_layers $num_layers --seed $seed --p_train $p_train \
               --output_path outputs-layers=${num_layers}-seed=${seed}-norm=${norm_tag}_p=${p_train}"
        else
          norm_tag="RMS"     # RMS-norm
          cmd="$BASE --num_layers $num_layers --seed $seed --p_train $p_train \
               $norm_flag \
               --output_path outputs-layers=${num_layers}-seed=${seed}-norm=${norm_tag}_p=${p_train}"
        fi

        echo "Running: $cmd"
        eval $cmd
      done
    done
  done
done