#!/bin/bash
# Script to run multiple PDE experiments in parallel with different parameters
# - Tests 4 equations and multiple depth/k_bar combinations
# - Runs experiments in batches, one per available GPU
# - Sorts experiments by computational complexity (depth * k_bar) for optimal GPU utilization

# Configuration
NUM_GPUS=7  # Number of GPUs to use (0 to NUM_GPUS-1)

# Set environment variable to limit thread usage for better GPU utilization
export OMP_NUM_THREADS=1

# Data paths for different PDE equations
declare -a EQUATION_NAMES=("heat" "burgers" "kdv" "ks")
declare -a TRAIN_PATHS=("heat1d_8192_256_train_500.parquet" "burgers1d_8192_256_train_500.parquet" "kdv1d_1024_256_train_500.parquet" "ks1d_4096_256_train_500.parquet")
declare -a VALID_PATHS=("heat1d_8192_256_valid_100.parquet" "burgers1d_8192_256_valid_100.parquet" "kdv1d_1024_256_valid_100.parquet" "ks1d_4096_256_valid_100_long.parquet")

# Depth vs k_bar combinations
# Each recurrent_depth has specific valid k_bar values to test
# These combinations maintain constant theoretical layer budget
declare -A DEPTH_VS_KBAR=(
    ["1"]="1 2 4 8 16 32"
    ["2"]="1 2 4 8 16"
    ["4"]="1 2 4 8"
    ["8"]="1 2 4"
    ["16"]="1 2"
    ["32"]="1"
)

# Mapping of depth to kernel size
declare -A DEPTH_TO_KERNEL_SIZE=(
    ["1"]="65"
    ["2"]="33"
    ["4"]="17"
    ["8"]="9"
    ["16"]="5"
    ["32"]="3"
)

# Function to run a single experiment
run_experiment() {
    local device=$1
    local train_path=$2
    local valid_path=$3
    local recurrent_depth=$4
    local k_bar=$5
    
    # Get kernel size based on depth
    local kernel_size=${DEPTH_TO_KERNEL_SIZE[$recurrent_depth]}
    
    # Extract equation name from train path for logging
    local equation_name=${train_path%%_*}
    
    echo "Starting experiment on GPU $device: equation=$equation_name, depth=$recurrent_depth, k_bar=$k_bar, kernel_size=$kernel_size"
    
    # Run the experiment with the specified parameters
    python main.py --config configs/experiment_effective_layers_residual.yaml \
                   --set device=$device \
                   --set data.train_path=$train_path \
                   --set data.valid_path=$valid_path \
                   --set model.params.recurrent_depth=$recurrent_depth \
                   --set model.params.recurrent_kwargs.kernel_size=$kernel_size \
                   --set model.params.recurrent_k_distribution.k_bar=$k_bar \
                   --set model.params.recurrent_tbptt_steps=$k_bar
    
    return $?
}

# Main function to orchestrate all experiments
main() {
    # Create an array to hold all experiment configurations
    declare -a experiment_configs=()
    
    # Fill the array with all experiment configurations and their complexity (depth * k_bar)
    for ((eq_idx=0; eq_idx<${#EQUATION_NAMES[@]}; eq_idx++)); do
        for depth in "${!DEPTH_VS_KBAR[@]}"; do
            for k_bar in ${DEPTH_VS_KBAR[$depth]}; do
                # Calculate complexity metric (depth * k_bar)
                local complexity=$((depth * k_bar))
                
                # Store as "complexity:eq_idx:depth:k_bar" for easy sorting
                experiment_configs+=("$complexity:$eq_idx:$depth:$k_bar")
            done
        done
    done
    
    # Sort the array by complexity (first field) - highest complexity first
    IFS=$'\n' sorted_configs=($(sort -nr -t ':' -k1 <<<"${experiment_configs[*]}"))
    unset IFS
    
    # Count total experiments
    local total_experiments=${#sorted_configs[@]}
    echo "Preparing to run $total_experiments total experiments"
    
    # Create arrays to track active PIDs for each GPU
    declare -a gpu_pids=()
    declare -a gpu_busy=()
    for ((i=0; i<NUM_GPUS; i++)); do
        gpu_pids[$i]=""
        gpu_busy[$i]=0
    done
    
    local active_experiments=0
    local completed_experiments=0
    local experiment_index=0
    
    echo "Starting experiment batches..."
    
    # Keep going until all experiments are completed
    while [ $completed_experiments -lt $total_experiments ]; do
        # First, check for and handle any completed experiments
        for ((i=0; i<NUM_GPUS; i++)); do
            if [ "${gpu_busy[$i]}" -eq 1 ]; then
                # Check if this GPU's process has finished
                if ! kill -0 ${gpu_pids[$i]} 2>/dev/null; then
                    # Process is done
                    wait ${gpu_pids[$i]} || echo "WARNING: Process ${gpu_pids[$i]} exited with non-zero status"
                    ((completed_experiments++))
                    ((active_experiments--))
                    gpu_busy[$i]=0
                    echo "Completed experiment $completed_experiments/$total_experiments on GPU $i"
                fi
            fi
        done
        
        # Then start new experiments on free GPUs if we have more to run
        while [ $experiment_index -lt $total_experiments ] && [ $active_experiments -lt $NUM_GPUS ]; do
            # Find a free GPU
            for ((i=0; i<NUM_GPUS; i++)); do
                if [ "${gpu_busy[$i]}" -eq 0 ]; then
                    # This GPU is free, start an experiment
                    config=${sorted_configs[$experiment_index]}
                    IFS=':' read -r complexity eq_idx depth k_bar <<< "$config"
                    
                    # Start the experiment on this GPU
                    run_experiment $i \
                                   "${TRAIN_PATHS[$eq_idx]}" \
                                   "${VALID_PATHS[$eq_idx]}" \
                                   "$depth" \
                                   "$k_bar" &
                    
                    # Store the process ID and mark GPU as busy
                    gpu_pids[$i]=$!
                    gpu_busy[$i]=1
                    ((active_experiments++))
                    ((experiment_index++))
                    
                    echo "Scheduled experiment $experiment_index/$total_experiments on GPU $i (complexity: $complexity, active: $active_experiments)"
                    
                    # Break the GPU loop to recheck how many active experiments we have
                    break
                fi
            done
            
            # If all GPUs are busy, break out of the experiment starting loop
            if [ $active_experiments -eq $NUM_GPUS ]; then
                break
            fi
        done
        
        # Sleep briefly to avoid hammering the CPU with checks
        sleep 1
    done
    
    echo "All experiments completed. Ran $completed_experiments of $total_experiments experiments."
}

# Execute the main function
main