#!/bin/bash
# (sleep 3h; bash /path/to/run.sh)

# Argument settings (change as appropriate)
# PERMUTATION_RESULTS_PATH="/mnt/nfs/results/small/relu/relu_gpt2_single_layer_digits=50_bs=128_inverse/16select_permutation_sparsity_results.pt"
# PERMUTATION_RESULTS_PATH="/mnt/nfs/results/small/square_mod19/square_gpt2_mod19_L=6_digits=50_bs=128_inverse/16select_permutation_sparsity_results.pt"
PERMUTATION_RESULTS_PATH="data/results/small/index/index_gpt2_longlearning_L=6_digits=13_m=4_bs=128_inverse/16select_permutation_sparsity_results.pt"
# BASE_MODEL_PATH="/mnt/nfs/results/small/relu/relu_gpt2_single_layer_digits=50_bs=128_inverse/"
# BASE_MODEL_PATH="/mnt/nfs/results/small/square_mod19/square_gpt2_mod19_L=6_digits=50_bs=128_inverse/"
BASE_MODEL_PATH="data/results/small/index/index_gpt2_longlearning_L=6_digits=13_m=4_bs=128_inverse/"
# DATASET_PATH="/mnt/nfs/data/small/relu/n=50/data"                               # Path to the training dataset
# DATASET_PATH="/mnt/nfs/data/small/square_mod19/n=50/data"                               # Path to the training dataset
DATASET_PATH="data/data/small/index/n=13_m=4/data"                               # Path to the training dataset
# MODEL_SAVE_DIR="/mnt/nfs/results/small/relu/all_permutations_trainer_n=50_16select"
# MODEL_SAVE_DIR="/mnt/nfs/results/small/square_mod19/all_permutations_trainer_n=50_16select"
MODEL_SAVE_DIR="data/results/small/index/all_permutations_trainer_n=13_m=4_16select"
INPUT_LEN=13
GENERATION_MAX_LENGTH=27

# Training hyperparameters for Trainer (adjust according to main.py or best practice)
NUM_EPOCHS=40
BATCH_SIZE=128 # per_device_train_batch_size
LEARNING_RATE=5e-5
WARMUP_STEPS=0
WEIGHT_DECAY=0
LOGGING_STEPS=500
# SAVE_STEPS=500 # Not needed because of save_strategy="epoch"
SAVE_TOTAL_LIMIT=1
GRADIENT_ACCUMULATION_STEPS=1
FP16_TRAINING=false # true or false

# Execute Python script
CMD="CUDA_VISIBLE_DEVICES=5 nohup python3 src/train_with_permutations.py \
    --permutation_results_path \"${PERMUTATION_RESULTS_PATH}\" \
    --base_model_path \"${BASE_MODEL_PATH}\" \
    --dataset_path \"${DATASET_PATH}\" \
    --model_save_dir \"${MODEL_SAVE_DIR}\" \
    --input_len ${INPUT_LEN} \
    --num_epochs ${NUM_EPOCHS} \
    --batch_size ${BATCH_SIZE} \
    --learning_rate ${LEARNING_RATE} \
    --warmup_steps ${WARMUP_STEPS} \
    --weight_decay ${WEIGHT_DECAY} \
    --logging_steps ${LOGGING_STEPS} \
    --save_total_limit ${SAVE_TOTAL_LIMIT} \
    --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \
    --generation_max_length ${GENERATION_MAX_LENGTH} > nohup.out 2>&1 &"

if [ "${FP16_TRAINING}" = true ] ; then
    CMD="${CMD} --fp16"
fi

echo "Executing command:"
echo "${CMD}"
eval "${CMD}"

echo "Permutation training script finished."