#!/bin/bash

# Basic script settings
export CUDA_VISIBLE_DEVICES=7 # Specify the GPU to use
# TASK=square_mod19
TASK=relu
TARGET_LEN=8 # N: Length of the permutation

# --- New Algorithm Parameters ---
T_CANDIDATES=24 # Number of initial candidates (set so that K! >= T)
TOP_K_LOCAL=2   # Number of top permutations to apply Local Stage

# Output directory
OUTPUT_DIR="data/results/small/${TASK}/pe_n=${TARGET_LEN}/N=${TARGET_LEN}_new_alg_T${T_CANDIDATES}"
DATASET_PATH_PREFIX="data/data/small/${TASK}/n=${TARGET_LEN}/data"

# WandB settings (optional)
WANDB_PROJECT="permutation_exploration_new"
WANDB_RUN_NAME="new_alg_${TASK}_N${TARGET_LEN}_T${T_CANDIDATES}"

# Model parameters
GPT2_N_EMBD=256
GPT2_N_HEAD=1
GPT2_N_LAYER=1
MAX_SEQ_LENGTH=128
INPUT_PREFIX_LEN=${TARGET_LEN}

# Training parameters
PER_DEVICE_TRAIN_BATCH_SIZE=128
PER_DEVICE_EVAL_BATCH_SIZE=128
NUM_TRAIN_EPOCHS=1
LEARNING_RATE=5e-5
WEIGHT_DECAY=0
LOGGING_STEPS=100
EVAL_STEPS=500

mkdir -p ${OUTPUT_DIR}

python3 src/permutation_explorer_new.py \
    --output_dir ${OUTPUT_DIR} \
    --dataset_path_prefix ${DATASET_PATH_PREFIX} \
    --dataset_name ${DATASET_PATH_PREFIX} \
    --target_len ${TARGET_LEN} \
    --initial_candidates_T ${T_CANDIDATES} \
    --top_k_local_search ${TOP_K_LOCAL} \
    --gpt2_n_embd ${GPT2_N_EMBD} \
    --gpt2_n_head ${GPT2_N_HEAD} \
    --gpt2_n_layer ${GPT2_N_LAYER} \
    --max_seq_length ${MAX_SEQ_LENGTH} \
    --input_prefix_len ${INPUT_PREFIX_LEN} \
    --do_train \
    --do_eval \
    --per_device_train_batch_size ${PER_DEVICE_TRAIN_BATCH_SIZE} \
    --per_device_eval_batch_size ${PER_DEVICE_EVAL_BATCH_SIZE} \
    --num_train_epochs ${NUM_TRAIN_EPOCHS} \
    --learning_rate ${LEARNING_RATE} \
    --weight_decay ${WEIGHT_DECAY} \
    --logging_steps ${LOGGING_STEPS} \
    --eval_steps ${EVAL_STEPS} \
    --save_strategy steps \
    --save_steps ${EVAL_STEPS} \
    --save_total_limit 1 \
    --fp16 \
    --report_to wandb \
    --permutation_select_num 1 \
    --wandb_project ${WANDB_PROJECT} \
    --wandb_run_name ${WANDB_RUN_NAME}

echo "Script finished. Results are in ${OUTPUT_DIR}"