#!/bin/bash

# Script to run the Genetic Algorithm for permutation optimization

# --- Configuration ---
# Dataset and Model parameters (should match the experiment setup)
# DATASET_BASE_NAME="square_mod19" # Example: relu, square_mod19, index
gpu_id=2
# DATASET_BASE_NAME="square_mod19"
# DATASET_BASE_NAME="relu"
DATASET_BASE_NAME="prod"
# DATASET_BASE_NAME="index"
# DATASET_BASE_NAME="gcd"
TARGET_LEN=10
# DATASET_NAME="n=${TARGET_LEN}/data" # Example: n50, n31, m19
# DATASET_NAME="n=${TARGET_LEN}_m=8/data"
DATASET_NAME="n=5/data-inv-hard2"
DATASET_PATH_PREFIX="data/data/small/${DATASET_BASE_NAME}/${DATASET_NAME}" # Path to .train and .test files
# INPUT_PREFIX_LEN=$TARGET_LEN
INPUT_PREFIX_LEN=11
GPT2_N_HEAD=8
GPT2_N_LAYER=6
GPT2_N_EMBD=512 # Should be consistent with pre-trained or intended model
MAX_SEQ_LENGTH=128

# TrainingArguments (minimal set needed for GA evaluation context)
PER_DEVICE_EVAL_BATCH_SIZE=128 # Adjust based on GPU memory
N_EPOCHS=2
FP16=false # Set to true if using fp16, ensure GPU compatibility
DATALOADER_NUM_WORKERS=0 # Can increase if I/O is a bottleneck

# Permutation Exploration Parameters
M_PARAM=7

OUTPUT_DIR="data/results/small/${DATASET_BASE_NAME}/pe_${DATASET_NAME}_tl=${TARGET_LEN}_M=${M_PARAM}_n_epochs=${N_EPOCHS}_v1_fliped_hard2"
# OUTPUT_DIR="data/results/small/${DATASET_BASE_NAME}/pe_${DATASET_NAME}_tl=${TARGET_LEN}_M=${M_PARAM}_n_epochs=${N_EPOCHS}_v1_fliped_easy_init_intrablock_permutation"
# OUTPUT_DIR="data/results/small/${DATASET_BASE_NAME}/pe_${DATASET_NAME}_tl=${TARGET_LEN}_M=${M_PARAM}_n_epochs=${N_EPOCHS}_v1_fliped_easy_init_block_permutation"

# WandB Logging
WANDB_PROJECT="permutation_exploration"
# WANDB_ENTITY="your_wandb_entity" # Optional: specify your wandb entity
WANDB_RUN_NAME="pe_${DATASET_NAME}_tl${TARGET_LEN}_M=${M_PARAM}_n_epochs=${N_EPOCHS}"

mkdir -p ${OUTPUT_DIR}

# --- Activate Virtual Environment (if any) ---
# source /path/to/your/venv/bin/activate

# --- Run the GA optimization script ---
CUDA_VISIBLE_DEVICES=${gpu_id} nohup python3 src/permutation_explorer_global.py \
    --dataset_name "${DATASET_NAME}" \
    --dataset_path_prefix "${DATASET_PATH_PREFIX}" \
    --target_len ${TARGET_LEN} \
    --input_prefix_len ${INPUT_PREFIX_LEN} \
    --gpt2_n_head ${GPT2_N_HEAD} \
    --gpt2_n_layer ${GPT2_N_LAYER} \
    --gpt2_n_embd ${GPT2_N_EMBD} \
    --max_seq_length ${MAX_SEQ_LENGTH} \
    --output_dir "${OUTPUT_DIR}" \
    --per_device_eval_batch_size ${PER_DEVICE_EVAL_BATCH_SIZE} \
    --per_device_train_batch_size ${PER_DEVICE_EVAL_BATCH_SIZE} \
    --fp16 ${FP16} \
    --dataloader_num_workers ${DATALOADER_NUM_WORKERS} \
    --num_train_epochs ${N_EPOCHS} \
    --do_train false \
    --do_eval false \
    --permutation_select_num 1 \
    --remove_unused_columns false \
    --m_param ${M_PARAM} \
    --wandb_project "${WANDB_PROJECT}" \
    --wandb_run_name "${WANDB_RUN_NAME}" > ${OUTPUT_DIR}/log.txt 2>&1 &
    # Add --wandb_entity "${WANDB_ENTITY}" if you have set it