#!/bin/bash

# Basic script settings
export CUDA_VISIBLE_DEVICES=1 # Specify the GPU to use
# TASK=square_mod19
# TASK=relu
TASK=index
TARGET_LEN=13 # N: Specify the length of the permutation (e.g., 10)


DATASET_PATH_PREFIX="data/data/small/${TASK}/n=${TARGET_LEN}_m=8/data" # Please specify the dataset prefix
# OUTPUT_DIR="data/results/small/${TASK}/pe_n=${TARGET_LEN}/N=${TARGET_LEN}_v5_good_init_2_all_k"
OUTPUT_DIR="data/results/small/${TASK}/pe_n=${TARGET_LEN}_m=8/N=${TARGET_LEN}_v5_good"

# WandB settings (optional)
WANDB_PROJECT="permutation_exploration"
WANDB_RUN_NAME="v5_N${TARGET_LEN}_good_init"
# WANDB_RUN_NAME="v5_N${TARGET_LEN}_rand_init"

# Model parameters (see permutation_explorer_v3.py and main_permutation_loss_analysis.py)
GPT2_N_EMBD=512
GPT2_N_HEAD=1
GPT2_N_LAYER=6
MAX_SEQ_LENGTH=128 # Adjust to the maximum sequence length of the dataset
INPUT_PREFIX_LEN=${TARGET_LEN} # Length of the input before the target

# Training parameters
PER_DEVICE_TRAIN_BATCH_SIZE=128
PER_DEVICE_EVAL_BATCH_SIZE=128
NUM_TRAIN_EPOCHS=10
LEARNING_RATE=5e-5
WEIGHT_DECAY=0
LOGGING_STEPS=100
EVAL_STEPS=500 # Adjust so as not to evaluate too frequently

mkdir -p ${OUTPUT_DIR}

nohup python3 src/permutation_explorer_local.py \
    --output_dir ${OUTPUT_DIR} \
    --dataset_name ${DATASET_PATH_PREFIX} \
    --dataset_path_prefix ${DATASET_PATH_PREFIX} \
    --target_len ${TARGET_LEN} \
    --gpt2_n_embd ${GPT2_N_EMBD} \
    --gpt2_n_head ${GPT2_N_HEAD} \
    --gpt2_n_layer ${GPT2_N_LAYER} \
    --max_seq_length ${MAX_SEQ_LENGTH} \
    --input_prefix_len ${INPUT_PREFIX_LEN} \
    --do_train \
    --do_eval \
    --per_device_train_batch_size ${PER_DEVICE_TRAIN_BATCH_SIZE} \
    --per_device_eval_batch_size ${PER_DEVICE_EVAL_BATCH_SIZE} \
    --num_train_epochs ${NUM_TRAIN_EPOCHS} \
    --learning_rate ${LEARNING_RATE} \
    --weight_decay ${WEIGHT_DECAY} \
    --logging_steps ${LOGGING_STEPS} \
    --permutation_select_num 1 \
    --eval_steps ${EVAL_STEPS} \
    --save_strategy steps \
    --save_steps ${EVAL_STEPS} \
    --save_total_limit 1 \
    --fp16 \
    --report_to wandb \
    --wandb_project ${WANDB_PROJECT} \
    --wandb_run_name ${WANDB_RUN_NAME} > ${OUTPUT_DIR}/log.txt 2>&1 &
    # --dataloader_num_workers 4 # Set according to the environment
    # --remove_unused_columns False # PermutationExperimentDataCollator may require permutation_idx

echo "Script finished. Results are in ${OUTPUT_DIR}"