#!/bin/bash

# This script runs a series of experiments for permutation loss analysis,
# logging results to Weights & Biases.

# --- Common training parameters ---
# TODO: Adjust these common parameters as needed for your specific setup.
device_id=0
NUM_EPOCHS=(1 2 4 8)
TASK_NAME="prod"
# TASK_NAME="gcd"
# TASK_NAME="relu"
# TASK_NAME="square_mod19"
# TASK_NAME="index"
# NUM_EPOCHS=2
BATCH_SIZE=128
LOGGING_STEPS=100
SAVE_STEPS=500 # Adjust if you want to save checkpoints
MAX_SEQ_LENGTH=128 # Should match tokenizer and model config
GPT2_N_EMBD=512
LEARNING_RATE=5e-5
# Set a common W&B project name.
# TODO: Replace 'your_wandb_project' with your actual W&B project name.
# TODO: If you use a W&B entity, uncomment and set WANDB_ENTITY.

# --- Permutation numbers to test ---
# PERMUTATION_TYPE="all"
# PERMUTATION_TYPE="random_one"
PERMUTATION_TYPE="family"
# PERMUTATION_SELECT_NUMS=(0)
PERMUTATION_SELECT_NUMS=(8)
# PERMUTATION_SELECT_NUMS=()

# WANDB_ENTITY="your_wandb_entity"

WANDB_PROJECT_NAME="permutation_loss_experiments"

# --- Dataset-specific configurations ---
# TODO: CRITICAL - Adjust target_len and input_prefix_len for each dataset.
# These values depend on your data and tokenizer.
# target_len: The length of the sequence part to be permuted (in tokens).
# input_prefix_len: The length of the static prefix before the permutable target (in tokens).

declare -A dataset_configs
# dataset_configs["relu_n=5"]="/mnt/nfs/data/small/relu/n=5/data 5 5" # path_prefix target_len input_prefix_len
# dataset_configs["relu_n=13"]="data/data/small/relu/n=13/data 13 13"
# dataset_configs["relu_n=50"]="/mnt/nfs/data/small/relu/n=50/data 50 50" # path_prefix target_len input_prefix_len
# dataset_configs["relu_n=100"]="/mnt/nfs/data/small/relu/n=100/data 100 100" # path_prefix target_len input_prefix_len
# dataset_configs["square_mod19_n=5"]="/mnt/nfs/data/small/square_mod19/n=5/data 5 5"
# dataset_configs["square_mod19_n=13"]="data/data/small/square_mod19/n=13/data 13 13"
# dataset_configs["square_mod19_n=20"]="data/data/small/square_mod19/n=20/data 20 20"
# dataset_configs["index_n=13_m=2"]="data/data/small/index/n=13_m=2/data 13 13"
# dataset_configs["index_n=31_m=2"]="/mnt/nfs/data/small/index/n=31_m=2/data 31 31"
dataset_configs["prod_n=10_inv_hard"]="data/data/small/prod/n=5/data-inv-hard 10 11"


# --- Main experiment loop ---
for dataset_key in "${!dataset_configs[@]}"; do
    config_values=(${dataset_configs[$dataset_key]})
    dataset_path_prefix=${config_values[0]}
    target_len=${config_values[1]}
    input_prefix_len=${config_values[2]}

    echo "----------------------------------------------------------------------"
    echo "Starting experiments for dataset: ${dataset_key}"
    echo "Path prefix: ${dataset_path_prefix}, Target len: ${target_len}, Input prefix len: ${input_prefix_len}"
    echo "----------------------------------------------------------------------"

    for perm_num in "${PERMUTATION_SELECT_NUMS[@]}"; do
        for num_epochs in "${NUM_EPOCHS[@]}"; do
            run_name="${dataset_key}_perm${perm_num}_target${target_len}_prefix${input_prefix_len}_${num_epochs}epochs_bs${BATCH_SIZE}_${PERMUTATION_TYPE}_v5"
            output_dir="data/results/small/${TASK_NAME}/${run_name}"
            mkdir -p ${output_dir}

            echo ""
            echo ">>> Running: ${run_name}"
            echo "Output directory: ${output_dir}"
            echo "Permutation select num: ${perm_num}"
            echo ""

            # Construct wandb entity argument if WANDB_ENTITY is set
            wandb_entity_arg=""
            if [ -n "${WANDB_ENTITY}" ]; then
                wandb_entity_arg="--wandb_entity ${WANDB_ENTITY}"
            fi

            CUDA_VISIBLE_DEVICES=${device_id} nohup python3 src/main_permutation_loss_analysis.py \
                --output_dir "${output_dir}" \
                --remove_unused_columns False \
                --dataset_name "${dataset_key}" \
                --dataset_path_prefix "${dataset_path_prefix}" \
                --target_len ${target_len} \
                --permutation_select_num ${perm_num} \
                --permutation_type ${PERMUTATION_TYPE} \
                --input_prefix_len ${input_prefix_len} \
                --save_total_limit 1 \
                --gpt2_n_head  8 \
                --gpt2_n_layer 6 \
                --gpt2_n_embd ${GPT2_N_EMBD} \
                --max_seq_length ${MAX_SEQ_LENGTH} \
                --per_device_train_batch_size ${BATCH_SIZE} \
                --num_train_epochs ${num_epochs} \
                --logging_steps ${LOGGING_STEPS} \
                --save_steps ${SAVE_STEPS} \
                --learning_rate ${LEARNING_RATE} \
                --report_to "wandb" \
                --wandb_project "${WANDB_PROJECT_NAME}" \
                --wandb_run_name "${run_name}" \
                ${wandb_entity_arg} \
                --do_train \
                --do_eval > ${output_dir}/nohup.out 2>&1 &

            if [ $? -ne 0 ]; then
                echo "Error during experiment: ${run_name}. Exiting."
                exit 1
            fi
            echo "<<< Finished: ${run_name}"
            device_id=$((device_id + 1))
            echo "----------------------------------------------------------------------"
            echo "Finished experiments for dataset: ${dataset_key} with ${num_epochs} epochs"
            echo "----------------------------------------------------------------------"
            count=$((count + 1))
        done
    done
done

echo ""
echo "----------------------------------------------------------------------"
echo "All experiments finished successfully."
echo "----------------------------------------------------------------------" 