#!/bin/bash
# Script to run experiments with different recurrent ConvNext configurations in parallel
# - Tests 4 equations (heat, burgers, kdv, ks) with different depth, k_bar and hidden_channels
# - Runs experiments in batches, one per available GPU
# - Sorts experiments by computational complexity for optimal GPU utilization

# Configuration
NUM_GPUS=7  # Number of GPUs to use (0 to NUM_GPUS-1)

# Set environment variable to limit thread usage for better GPU utilization
export OMP_NUM_THREADS=1

# Data paths for different PDE equations
declare -a EQUATION_NAMES=("heat" "burgers" "kdv" "ks")
declare -a TRAIN_PATHS=("heat1d_8192_256_train_500.parquet" "burgers1d_8192_256_train_500.parquet" "kdv1d_1024_256_train_500.parquet" "ks1d_4096_256_train_500.parquet")
declare -a VALID_PATHS=("heat1d_8192_256_valid_100.parquet" "burgers1d_8192_256_valid_100.parquet" "kdv1d_1024_256_valid_100.parquet" "ks1d_4096_256_valid_100_long.parquet")

# Define depth-specific configurations
# Format: "recurrent_depth:k_bar:hidden_channels"
declare -a MODEL_CONFIGS=(
    "1:32:64"
    "1:32:96"
    "1:32:128"
    "4:8:48"
    "4:8:72"
    "4:8:96"
)

# Function to run a single experiment
run_experiment() {
    local device=$1
    local train_path=$2
    local valid_path=$3
    local config=$4
    
    # Parse the configuration
    IFS=':' read -r recurrent_depth k_bar hidden_channels <<< "$config"
    
    # Extract equation name from train path for logging
    local equation_name=${train_path%%_*}
    
    echo "Starting experiment on GPU $device: equation=$equation_name, depth=$recurrent_depth, k_bar=$k_bar, hidden_channels=$hidden_channels"
    
    # Run the experiment with the specified parameters
    python main.py --config configs/experiment_recurrent_convnext.yaml \
                   --set device=$device \
                   --set data.train_path=$train_path \
                   --set data.valid_path=$valid_path \
                   --set model.params.recurrent_depth=$recurrent_depth \
                   --set model.params.recurrent_k_distribution.k_bar=$k_bar \
                   --set model.params.hidden_channels=$hidden_channels
    
    return $?
}

# Main function to orchestrate all experiments
main() {
    # Create an array to hold all experiment configurations
    declare -a experiment_configs=()
    
    # Fill the array with all experiment configurations
    for ((eq_idx=0; eq_idx<${#EQUATION_NAMES[@]}; eq_idx++)); do
        for model_config in "${MODEL_CONFIGS[@]}"; do
            experiment_configs+=("$eq_idx:$model_config")
        done
    done
    
    # Count total experiments
    local total_experiments=${#experiment_configs[@]}
    echo "Preparing to run $total_experiments total experiments"
    
    # Create arrays to track active PIDs for each GPU
    declare -a gpu_pids=()
    declare -a gpu_busy=()
    for ((i=0; i<NUM_GPUS; i++)); do
        gpu_pids[$i]=""
        gpu_busy[$i]=0
    done
    
    local active_experiments=0
    local completed_experiments=0
    local experiment_index=0
    
    echo "Starting experiment batches..."
    
    # Keep going until all experiments are completed
    while [ $completed_experiments -lt $total_experiments ]; do
        # First, check for and handle any completed experiments
        for ((i=0; i<NUM_GPUS; i++)); do
            if [ "${gpu_busy[$i]}" -eq 1 ]; then
                # Check if this GPU's process has finished
                if ! kill -0 ${gpu_pids[$i]} 2>/dev/null; then
                    # Process is done
                    wait ${gpu_pids[$i]} || echo "WARNING: Process ${gpu_pids[$i]} exited with non-zero status"
                    ((completed_experiments++))
                    ((active_experiments--))
                    gpu_busy[$i]=0
                    echo "Completed experiment $completed_experiments/$total_experiments on GPU $i"
                fi
            fi
        done
        
        # Then start new experiments on free GPUs if we have more to run
        while [ $experiment_index -lt $total_experiments ] && [ $active_experiments -lt $NUM_GPUS ]; do
            # Find a free GPU
            for ((i=0; i<NUM_GPUS; i++)); do
                if [ "${gpu_busy[$i]}" -eq 0 ]; then
                    # This GPU is free, start an experiment
                    config=${experiment_configs[$experiment_index]}
                    IFS=':' read -r eq_idx model_config <<< "$config"
                    
                    # Start the experiment on this GPU
                    run_experiment $i \
                                   "${TRAIN_PATHS[$eq_idx]}" \
                                   "${VALID_PATHS[$eq_idx]}" \
                                   "$model_config" &
                    
                    # Store the process ID and mark GPU as busy
                    gpu_pids[$i]=$!
                    gpu_busy[$i]=1
                    ((active_experiments++))
                    ((experiment_index++))
                    
                    # Parse the model config for logging
                    IFS=':' read -r depth k_bar channels <<< "$model_config"
                    echo "Scheduled experiment $experiment_index/$total_experiments on GPU $i (equation: ${EQUATION_NAMES[$eq_idx]}, depth: $depth, k_bar: $k_bar, channels: $channels, active: $active_experiments)"
                    
                    # Break the GPU loop to recheck how many active experiments we have
                    break
                fi
            done
            
            # If all GPUs are busy, break out of the experiment starting loop
            if [ $active_experiments -eq $NUM_GPUS ]; then
                break
            fi
        done
        
        # Sleep briefly to avoid hammering the CPU with checks
        sleep 1
    done
    
    echo "All experiments completed. Ran $completed_experiments of $total_experiments experiments."
}

# Execute the main function
main