# train_and_evaluate.sh
# Description: Sequentially run training tasks, then evaluate models in parallel

# ----------------- Basic Configuration -----------------
# Model paths and names
MODEL_PATH="YOUR_MODEL_PATH"
MODEL_NAME="YOUR_MODEL_NAME"

# Base output directories
BASE_OUTPUT_DIR="./outputs"
EVAL_BASE_OUTPUT_DIR="./evaluation_results" # Evaluation results root directory

# GPU ID
GPU_ID="YOUR_GPU_ID"

# Training parameters
TRAIN_BATCH_SIZE=8 
EPOCHS=1
LEARNING_RATE=5e-5
ADV_STEPS=5
ADV_LR=0.05

# Dataset settings
TRAINING_DATASET="SNLI"  # Dataset for training (SNLI, MultiNLI, MTBench, SummEval)
EVALUATION_DATASETS=("SNLI") # Example: ("MTBench" "SNLI" "SummEval")

# Perturbation and Alpha parameters (consistent with training script)
EPSILON_VALUES=(0.0 0.05 0.1 0.15 0.2 0.25)
ALPHA_VALUES=(0.0 0.2 0.4 0.6 0.8 1.0)


# Python environment path
PYTHON_EXEC="YOUR_PYTHON_EXEC"
TRAIN_SCRIPT="train.py"
EVAL_SCRIPT="evaluate.py"

# ----------------- Training Phase ------------------------
echo "========================================================"
echo "Starting Training Phase"
echo "Training Dataset: $TRAINING_DATASET"
echo "GPU ID: $GPU_ID"
echo "Time: $(date)"
echo "========================================================"

# Extract short name from training dataset name
case "$TRAINING_DATASET" in
  "SNLI") DATASET_SHORT="snli" ;;
  "MultiNLI") DATASET_SHORT="multinli" ;;
  "MTBench") DATASET_SHORT="mtbench" ;;
  "SummEval") DATASET_SHORT="summeval" ;;
  *) DATASET_SHORT="$TRAINING_DATASET" ;;
esac

# Training loop
for epsilon in "${EPSILON_VALUES[@]}"; do
    epsilon_formatted=${epsilon//./_}
    for alpha in "${ALPHA_VALUES[@]}"; do
        alpha_formatted=${alpha//./_}
        OUTPUT_DIR="$BASE_OUTPUT_DIR/$MODEL_NAME/$DATASET_SHORT/eps_${epsilon_formatted}/alpha_${alpha_formatted}"
        mkdir -p "$OUTPUT_DIR"

        echo "----------------------------------------"
        echo "Starting training: epsilon = $epsilon, alpha = $alpha"
        echo "Output directory: $OUTPUT_DIR"
        echo "Time: $(date)"
        echo "----------------------------------------"

        CUDA_VISIBLE_DEVICES=$GPU_ID $PYTHON_EXEC $TRAIN_SCRIPT \
            --output_dir $OUTPUT_DIR \
            --batch_size $TRAIN_BATCH_SIZE \
            --epochs $EPOCHS \
            --learning_rate $LEARNING_RATE \
            --model_path $MODEL_PATH \
            --alpha $alpha \
            --epsilon $epsilon \
            --dataset $TRAINING_DATASET \
            --adv_steps $ADV_STEPS \
            --adv_lr $ADV_LR

        echo "----------------------------------------"
        echo "Training completed: epsilon = $epsilon, alpha = $alpha"
        echo "Time: $(date)"
        echo "----------------------------------------"
        echo ""
    done
done
echo "========================================================"
echo "All training tasks completed!"
echo "Time: $(date)"
echo "========================================================"
echo ""

echo "========================================================"
echo "Starting Evaluation Phase"
echo "Evaluation Datasets: ${EVALUATION_DATASETS[*]}" # Still display all configured evaluation datasets
echo "Training Dataset: $TRAINING_DATASET"
echo "Tasks will be split by Alpha values and run in parallel on GPU $GPU_ID"
echo "Time: $(date)"
echo "========================================================"

# Prepare evaluation parameters (epsilon list)
epsilon_values_str=$(IFS=" "; echo "${EPSILON_VALUES[*]}")
# Get number of Alpha values
num_alpha_values=${#ALPHA_VALUES[@]}

# Create base directory for evaluation results
mkdir -p "${EVAL_BASE_OUTPUT_DIR}/raw" # Subdirectories will be created by evaluate.py

# Define evaluation function - now accepts a list of alpha values as parameters
run_evaluation() {
    local alpha_values_to_eval=("$@")
    local alpha_values_str=$(IFS=" "; echo "${alpha_values_to_eval[*]}")
    local datasets_str=$(IFS=" "; echo "${EVALUATION_DATASETS[*]}") # Evaluate all specified datasets

    echo "Starting evaluation process (PID: $$) - Alpha values: ${alpha_values_str} - Datasets: ${datasets_str}"
    CUDA_VISIBLE_DEVICES=$GPU_ID $PYTHON_EXEC $EVAL_SCRIPT \
        --model_type qwen_lora \
        --model_path "${MODEL_PATH}" \
        --model_name "${MODEL_NAME}" \
        --datasets ${datasets_str} \
        --dataset "${TRAINING_DATASET}" \
        --epsilon_values ${epsilon_values_str} \
        --alpha_values ${alpha_values_str} \

    echo "Evaluation process (PID: $$) complete - Alpha values: ${alpha_values_str}"
}

# Decide how to run evaluation based on number of Alpha values
if [ $num_alpha_values -eq 0 ]; then
    echo "No Alpha values specified, skipping evaluation phase."
elif [ $num_alpha_values -eq 1 ]; then
    echo "Only one Alpha value (${ALPHA_VALUES[0]}), running evaluation sequentially..."
    run_evaluation "${ALPHA_VALUES[@]}"
else
    echo "Multiple Alpha values detected, attempting to run two evaluation processes in parallel on GPU $GPU_ID..."
    # Split the Alpha values list
    midpoint=$(( (num_alpha_values + 1) / 2 ))
    alphas1=("${ALPHA_VALUES[@]:0:$midpoint}")
    alphas2=("${ALPHA_VALUES[@]:$midpoint}")

    # Run two evaluation processes in parallel
    run_evaluation "${alphas1[@]}" &
    pid1=$!
    echo "Started background evaluation process 1 (PID: $pid1) - Alpha values: ${alphas1[*]}"

    run_evaluation "${alphas2[@]}" &
    pid2=$!
    echo "Started background evaluation process 2 (PID: $pid2) - Alpha values: ${alphas2[*]}"

    # Wait for both background processes to finish
    echo "Waiting for evaluation processes $pid1 and $pid2 to complete..."
    wait $pid1
    status1=$?
    wait $pid2
    status2=$?

    if [ $status1 -eq 0 ] && [ $status2 -eq 0 ]; then
        echo "All parallel evaluation processes completed successfully."
    else
        echo "Warning: At least one parallel evaluation process failed (Status codes: $status1, $status2)."
    fi
fi

echo "========================================================"
echo "All evaluation tasks completed!"
echo "Time: $(date)"
echo "========================================================"
