#!/bin/bash

# Script to run experiments with permutations ranked by loss.

# --- Configuration ---
# Path to the JSON file containing permutation losses.
# Example: {"0": 2.1, "1": 2.0}
gpu_id=3
RESULT_DIR="data/results/small/"

# DATA_PATH="/mnt/nfs/data/small/relu/n=5/data"
# DATA_PATH="/mnt/nfs/data/small/relu/n=50/data"
# DATA_PATH="/mnt/nfs/data/small/square_mod19/n=5/data"
DATA_PATH="data/data/small/index/n=13_m=2/data"

# PERMUTATION_N_EXP=0
PERMUTATION_TYPE="family"
PERMUTATION_NUM=8
TARGET_LEN=13
EPOCHS=40
LAYERS=6
HEADS=1

# LOSS_JSON_FILE="/mnt/nfs/results/small/square_mod19/square_mod19_n=5_perm0_target5_prefix5_1epochs_fixed_batch_bs120/all_permutations_eval_losses.json"
# LOSS_JSON_FILE="/mnt/nfs/results/small/relu/relu_n=5_perm${PERMUTATION_NUM}_target5_prefix5_${PERMUTATION_TYPE}/all_permutations_eval_losses.json"
# LOSS_JSON_FILE="${RESULT_DIR}relu/relu_n=50_perm${PERMUTATION_N_EXP_SQUARE}_target50_prefix50_1epochs_fixed_batch/all_permutations_eval_losses.json"
# LOSS_JSON_FILE="${RESULT_DIR}square_mod19/square_mod19_n=50_perm${PERMUTATION_N_EXP_SQUARE}_target50_prefix50_1epochs/all_permutations_eval_losses.json"
# LOSS_JSON_FILE="${RESULT_DIR}index/index_n=13_m=4_perm${PERMUTATION_N_EXP_SQUARE}_target13_prefix13_1epochs/all_permutations_eval_losses.json"
# LOSS_JSON_FILE="${RESULT_DIR}index/index_n=31_m=2_perm${PERMUTATION_N_EXP_SQUARE}_target31_prefix31_8epochs_bs128_family/all_permutations_eval_losses.json"
# LOSS_JSON_FILE="data/results/small/index/index_n=31_m=2_perm8_target31_prefix31_32epochs_bs128_family/all_permutations_eval_losses.json"
LOSS_JSON_FILE="data/results/small/index/index_n=13_m=2_perm8_target13_prefix13_2epochs_fixed_batch_bs128/all_permutations_eval_losses.json"
# Path to the main python script for training/evaluation
MAIN_SCRIPT_PATH="src/main_ranked_permutation_experiment.py"

# Python interpreter
PYTHON_CMD="python3"

# Directory to save experiment outputs for each permutation
# BASE_SAVE_PATH_PREFIX="${RESULT_DIR}square_mod19/loss_sorted_perm${PERMUTATION_NUM}_${PERMUTATION_TYPE}/ranked_exp"
# BASE_SAVE_PATH_PREFIX="${RESULT_DIR}square_mod19/loss_sorted_perm${PERMUTATION_N_EXP_SQUARE}_fixed_epochs/ranked_exp"
BASE_SAVE_PATH_PREFIX="${RESULT_DIR}index/loss_sorted_perm${PERMUTATION_NUM}_${PERMUTATION_TYPE}_n=${TARGET_LEN}_m=2/ranked_exp"

# Common arguments for main.py
# These should be adjusted based on your src/main.py script's requirements.
# Ensure to include --permutation_n_exp if your make_perm_family needs it.
# Example: COMMON_ARGS="--data_path ./data/my_data --exp_name my_experiment --epochs 10 --permutation_n_exp 4"

# COMMON_ARGS="--data_path $DATA_PATH --exp_name square_mod19_n=5_perm${PERMUTATION_NUM}_target5_prefix5_1epochs --epochs 10 --num_encoder_layers 1 --num_decoder_layers 6 --nhead 1 --target_len $TARGET_LEN"
# COMMON_ARGS="--data_path $DATA_PATH --exp_name relu_n=50_perm32_target50_prefix50_1epochs --epochs 2 --num_encoder_layers 1 --num_decoder_layers 1 --nhead 1"
# COMMON_ARGS="--data_path $DATA_PATH --exp_name square_mod19_n=50_perm32_target50_prefix50_2epochs --epochs 2 --num_encoder_layers 1 --num_decoder_layers 1 --nhead 1"
COMMON_ARGS="--data_path $DATA_PATH --exp_name index_n=31_m=2_perm8_target31_prefix31_8epochs_fixed_batch_bs128 --epochs $EPOCHS --num_encoder_layers $LAYERS --num_decoder_layers $LAYERS --nhead $HEADS"
# COMMON_ARGS="--data_path $DATA_PATH --exp_name index_n=13_m=2_perm8_target13_prefix13_2epochs_fixed_batch_bs128 --epochs 2 --num_encoder_layers 1 --nhead 1"

# n_exponent for make_perm_family (must match the one used to generate the initial loss file)

# --- Argument Parsing ---
usage() {
  echo "Usage: $0 -j <loss_json_file> -a \"<common_args_for_main_py>\" -n <permutation_num> [-s <main_script_path>] [-p <python_cmd>] [-o <base_output_dir_prefix>]"
  echo ""
  echo "  -j <loss_json_file> : Path to the JSON file with permutation losses (Required)."
  echo "  -a \"<common_args>\"   : Quoted string of common arguments for the main training script (Required)."
  echo "                          Example: \"--data_path ./data/data_sum --task sum --model gpt2 ...\""
  echo "  -n <perm_num>     : The number of permutations to generate (Required)."
  echo "  -s <main_script>    : Path to the main Python script (Default: $MAIN_SCRIPT_PATH)."
  echo "  -p <python_cmd>     : Python command to use (Default: $PYTHON_CMD)."
  echo "  -o <output_prefix>  : Prefix for saving experiment outputs (Default: $BASE_SAVE_PATH_PREFIX)."
  exit 1
}

while getopts "j:a:n:s:p:o:h" opt; do
  case ${opt} in
    j ) LOSS_JSON_FILE=$OPTARG ;;
    a ) COMMON_ARGS=$OPTARG ;;
    n ) PERMUTATION_NUM=$OPTARG ;;
    s ) MAIN_SCRIPT_PATH=$OPTARG ;;
    p ) PYTHON_CMD=$OPTARG ;;
    o ) BASE_SAVE_PATH_PREFIX=$OPTARG ;;
    h ) usage ;;
    * ) usage ;;
  esac
done

if [ -z "$LOSS_JSON_FILE" ] || [ -z "$COMMON_ARGS" ] || [ -z "$PERMUTATION_NUM" ]; then
  echo "Error: Missing required arguments."
  usage
fi

if [ ! -f "$LOSS_JSON_FILE" ]; then
  echo "Error: Loss JSON file not found at $LOSS_JSON_FILE"
  exit 1
fi

if [ ! -f "$MAIN_SCRIPT_PATH" ]; then
  echo "Error: Main script not found at $MAIN_SCRIPT_PATH"
  exit 1
fi

echo "--- Experiment Configuration ---"
echo "Loss JSON File: $LOSS_JSON_FILE"
echo "Main Script: $MAIN_SCRIPT_PATH"
echo "Python Command: $PYTHON_CMD"
echo "Base Save Path Prefix: $BASE_SAVE_PATH_PREFIX"
echo "Common Arguments for Main Script: $COMMON_ARGS"
echo "Permutation num: $PERMUTATION_NUM"
echo "------------------------------"
echo ""

# --- Main Logic ---

# 1. Get sorted permutation IDs
echo "Step 1: Sorting permutations by loss from $LOSS_JSON_FILE..."
# Assuming sort_permutations.py is in the same directory or in PATH
# If sort_permutations.py is in the root directory:
SORT_SCRIPT_PATH="src/sort_permutations.py" 
if [ ! -f "$SORT_SCRIPT_PATH" ]; then
    echo "Error: sort_permutations.py not found at $SORT_SCRIPT_PATH. Make sure it is in the root of the project."
    exit 1
fi

SORTED_PERM_IDS=$($PYTHON_CMD $SORT_SCRIPT_PATH "$LOSS_JSON_FILE")
if [ $? -ne 0 ]; then
    echo "Error: Failed to sort permutations. Check output from sort_permutations.py."
    exit 1
fi
if [ -z "$SORTED_PERM_IDS" ]; then
    echo "Error: No permutation IDs were output by sort_permutations.py. Is the JSON file empty or malformed?"
    exit 1
fi

echo "Sorted Permutation IDs (best loss first):"
echo "$SORTED_PERM_IDS"
echo ""

# 2. Iterate through sorted IDs, run training/evaluation, and collect results
echo "Step 2: Running experiments for each permutation..."
RESULTS_FILE="${BASE_SAVE_PATH_PREFIX}_ranked_accuracies.txt"
echo "Rank,PermutationID,Accuracy" > "$RESULTS_FILE"
echo "Results will be saved to $RESULTS_FILE"
echo ""

RANK=0
while IFS= read -r PERM_ID; do
  if [ -z "$PERM_ID" ]; then
    continue # Skip empty lines if any
  fi

  RANK=$((RANK + 1))
  echo "--- Running for Rank $RANK, Permutation ID: $PERM_ID ---"

  # Define a specific save path for this run
  CURRENT_SAVE_PATH="${BASE_SAVE_PATH_PREFIX}_rank${RANK}_perm${PERM_ID}"
  mkdir -p "$CURRENT_SAVE_PATH"

  # Construct the command
  # Pass --permutation_id and the specific --save_path for this run.
  # Also pass --permutation_num.
  CMD_ARGS="$PYTHON_CMD $MAIN_SCRIPT_PATH $COMMON_ARGS --target_len $TARGET_LEN --permutation_id $PERM_ID --permutation_num $PERMUTATION_NUM --permutation_type $PERMUTATION_TYPE --save_path $CURRENT_SAVE_PATH"
  
  echo "Executing: CUDA_VISIBLE_DEVICES=$gpu_id $CMD_ARGS"
  
  # Execute the command with the environment variable set correctly
  OUTPUT=$(CUDA_VISIBLE_DEVICES=$gpu_id $CMD_ARGS)
  EXIT_CODE=$?

  if [ $EXIT_CODE -ne 0 ]; then
    echo "Error: Experiment failed for Permutation ID $PERM_ID (Rank $RANK). Exit code: $EXIT_CODE"
    echo "Output:"
    echo "$OUTPUT"
    ACCURACY="ERROR"
  else
    # Extract accuracy. Assuming format "Final Test Accuracy: VALUE"
    ACCURACY_LINE=$(echo "$OUTPUT" | grep "Final Test Accuracy:")
    if [ -n "$ACCURACY_LINE" ]; then
        ACCURACY=$(echo "$ACCURACY_LINE" | awk -F ':' '{print $2}' | xargs) # xargs trims whitespace
        echo "Permutation ID $PERM_ID (Rank $RANK) completed. Accuracy: $ACCURACY"
    else
        echo "Warning: Could not parse accuracy for Permutation ID $PERM_ID (Rank $RANK)."
        echo "Full output:"
        echo "$OUTPUT"
        ACCURACY="PARSE_ERROR"
    fi
  fi
  
  echo "$RANK,$PERM_ID,$ACCURACY" >> "$RESULTS_FILE"
  echo "--------------------------------------------------"
  echo ""

done <<< "$SORTED_PERM_IDS"

echo "All experiments completed."
echo "Final ranked accuracies saved to: $RESULTS_FILE"

exit 0 


# debug command
# CUDA_VISIBLE_DEVICES=7 python3 src/main_ranked_permutation_experiment.py --data_path /mnt/nfs/data/small/relu/n=50/data --exp_name relu_n=50_perm8_target50_prefix50_2epochs --epochs 2 --permutation_id 4 --permutation_num 3