#!/bin/bash

# Exit immediately if a command exits with a non-zero status.
set -e

# Define experiment parameters
EXP_NAME="permutation_optimization"
BASE_MODEL="gpt2" # Base model to load weights from (e.g., gpt2, gpt2-medium)
# EXP_ID="index_n=13_m=4_8select_fixed"
# DATA_PATH="/mnt/nfs/data/small/index/n=13_m=4/data-inv" # CHANGE THIS to your data path prefix (e.g., ./data/wikitext-103/wiki)
# EXP_ID="relu_n=20_8select_fixed"
# DATA_PATH="/mnt/nfs/data/small/relu/n=20/data-inv"
EXP_ID="square_mod19_n=20_2select_fixed_2stage"
DATA_PATH="data/data/small/square_mod19/n=20/data-inv"
OUTPUT_DIR="data/results/small/${EXP_NAME}/square_mod19/${EXP_ID}"
MAX_SEQ_LEN=50
BATCH_SIZE=128      # Corresponds to --num_batch in MyTrainingArguments (optional)
TEST_BATCH_SIZE=100 # Corresponds to --test_batch_size in MyTrainingArguments (optional)

# Training duration parameters
TOTAL_EPOCHS=5      # Define total epochs for the entire training
# TOTAL_STEPS=1000  # Uncomment and set this to train for a specific number of steps instead

# Learning rates for the two alternating stages
TRANSFORMER_LR=5e-5 # Learning rate for transformer weights (applied on even steps)
PERMUTATION_LR=1e-2 # Learning rate for permutation matrix P (applied on odd steps)
count=1

# Construct training arguments based on whether TOTAL_EPOCHS or TOTAL_STEPS is set
TRAINING_ARGS=""
if [ -n "${TOTAL_STEPS}" ]; then
    echo "Training for a total of ${TOTAL_STEPS} steps."
    TRAINING_ARGS="--max_steps ${TOTAL_STEPS}"
elif [ -n "${TOTAL_EPOCHS}" ]; then
    echo "Training for a total of ${TOTAL_EPOCHS} epochs."
    TRAINING_ARGS="--num_train_epochs ${TOTAL_EPOCHS}"
else
    echo "Error: Either TOTAL_EPOCHS or TOTAL_STEPS must be set in the script."
    exit 1
fi

# Append learning rates and other necessary arguments from MyTrainingArguments
TRAINING_ARGS="${TRAINING_ARGS} --transformer_lr ${TRANSFORMER_LR} --permutation_lr ${PERMUTATION_LR}"

# Ensure the output directory exists before trying to write the log file
mkdir -p ${OUTPUT_DIR}

echo "Starting permutation training script..."

# Execute the python script with argument names matching MyTrainingArguments fields
# Use --output_dir, --dry_run
{
    CUDA_VISIBLE_DEVICES=$count python3 src/main_permutation.py \
        --model_name_or_path ${BASE_MODEL} \
        --data_path ${DATA_PATH} \
        --output_dir ${OUTPUT_DIR} \
        --exp_name ${EXP_NAME} \
        --exp_id ${EXP_ID} \
        --max_sequence_length ${MAX_SEQ_LEN} \
        --num_batch ${BATCH_SIZE} \
        --test_batch_size ${TEST_BATCH_SIZE} \
        ${TRAINING_ARGS}
} 

# Store the process ID
echo $! > ${OUTPUT_DIR}/pid.txt

echo "Script started in background. Process ID: $(cat ${OUTPUT_DIR}/pid.txt)" 